dataeval 0.85.0__py3-none-any.whl → 0.86.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/_log.py +1 -1
- dataeval/config.py +21 -4
- dataeval/data/_embeddings.py +2 -2
- dataeval/data/_images.py +2 -3
- dataeval/data/_metadata.py +65 -42
- dataeval/data/_selection.py +2 -3
- dataeval/data/_split.py +2 -3
- dataeval/data/_targets.py +17 -13
- dataeval/data/selections/_classfilter.py +6 -8
- dataeval/data/selections/_prioritize.py +6 -9
- dataeval/data/selections/_shuffle.py +3 -1
- dataeval/detectors/drift/__init__.py +4 -1
- dataeval/detectors/drift/_base.py +4 -5
- dataeval/detectors/drift/_mmd.py +3 -6
- dataeval/detectors/drift/_mvdc.py +92 -0
- dataeval/detectors/drift/_nml/__init__.py +6 -0
- dataeval/detectors/drift/_nml/_base.py +70 -0
- dataeval/detectors/drift/_nml/_chunk.py +396 -0
- dataeval/detectors/drift/_nml/_domainclassifier.py +181 -0
- dataeval/detectors/drift/_nml/_result.py +97 -0
- dataeval/detectors/drift/_nml/_thresholds.py +269 -0
- dataeval/detectors/linters/outliers.py +7 -7
- dataeval/metrics/bias/_parity.py +10 -13
- dataeval/metrics/estimators/_divergence.py +2 -4
- dataeval/metrics/stats/_base.py +103 -42
- dataeval/metrics/stats/_boxratiostats.py +21 -19
- dataeval/metrics/stats/_dimensionstats.py +14 -10
- dataeval/metrics/stats/_hashstats.py +1 -1
- dataeval/metrics/stats/_pixelstats.py +6 -6
- dataeval/metrics/stats/_visualstats.py +3 -3
- dataeval/outputs/__init__.py +2 -1
- dataeval/outputs/_base.py +22 -7
- dataeval/outputs/_bias.py +27 -31
- dataeval/outputs/_drift.py +60 -0
- dataeval/outputs/_linters.py +12 -17
- dataeval/outputs/_stats.py +83 -29
- dataeval/outputs/_workflows.py +2 -2
- dataeval/utils/_array.py +6 -9
- dataeval/utils/_bin.py +1 -2
- dataeval/utils/_clusterer.py +7 -4
- dataeval/utils/_fast_mst.py +27 -13
- dataeval/utils/_image.py +65 -11
- dataeval/utils/_mst.py +1 -3
- dataeval/utils/_plot.py +15 -10
- dataeval/utils/data/_dataset.py +32 -20
- dataeval/utils/data/metadata.py +104 -82
- dataeval/utils/datasets/__init__.py +2 -0
- dataeval/utils/datasets/_antiuav.py +189 -0
- dataeval/utils/datasets/_base.py +11 -8
- dataeval/utils/datasets/_cifar10.py +104 -45
- dataeval/utils/datasets/_fileio.py +21 -47
- dataeval/utils/datasets/_milco.py +19 -11
- dataeval/utils/datasets/_mixin.py +2 -4
- dataeval/utils/datasets/_mnist.py +3 -4
- dataeval/utils/datasets/_ships.py +14 -7
- dataeval/utils/datasets/_voc.py +229 -42
- dataeval/utils/torch/models.py +5 -10
- dataeval/utils/torch/trainer.py +3 -3
- dataeval/workflows/sufficiency.py +2 -2
- {dataeval-0.85.0.dist-info → dataeval-0.86.1.dist-info}/METADATA +3 -2
- dataeval-0.86.1.dist-info/RECORD +114 -0
- dataeval/detectors/ood/vae.py +0 -74
- dataeval-0.85.0.dist-info/RECORD +0 -107
- {dataeval-0.85.0.dist-info → dataeval-0.86.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.85.0.dist-info → dataeval-0.86.1.dist-info}/WHEEL +0 -0
dataeval/outputs/_drift.py
CHANGED
@@ -2,11 +2,17 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
import contextlib
|
5
6
|
from dataclasses import dataclass
|
6
7
|
|
7
8
|
import numpy as np
|
9
|
+
import pandas as pd
|
8
10
|
from numpy.typing import NDArray
|
9
11
|
|
12
|
+
with contextlib.suppress(ImportError):
|
13
|
+
from matplotlib.figure import Figure
|
14
|
+
|
15
|
+
from dataeval.detectors.drift._nml._result import Metric, PerMetricResult
|
10
16
|
from dataeval.outputs._base import Output
|
11
17
|
|
12
18
|
|
@@ -81,3 +87,57 @@ class DriftOutput(DriftBaseOutput):
|
|
81
87
|
feature_threshold: float
|
82
88
|
p_vals: NDArray[np.float32]
|
83
89
|
distances: NDArray[np.float32]
|
90
|
+
|
91
|
+
|
92
|
+
class DriftMVDCOutput(PerMetricResult):
|
93
|
+
"""Class wrapping the results of the classifier for drift detection and providing plotting functionality."""
|
94
|
+
|
95
|
+
def __init__(self, results_data: pd.DataFrame) -> None:
|
96
|
+
"""Initialize a DomainClassifierCalculator results object.
|
97
|
+
|
98
|
+
Parameters
|
99
|
+
----------
|
100
|
+
results_data : pd.DataFrame
|
101
|
+
Results data returned by a DomainClassifierCalculator.
|
102
|
+
"""
|
103
|
+
metric = Metric(display_name="Domain Classifier", column_name="domain_classifier_auroc")
|
104
|
+
super().__init__(results_data, [metric])
|
105
|
+
|
106
|
+
def plot(self) -> Figure:
|
107
|
+
"""
|
108
|
+
Render the roc_auc metric over the train/test data in relation to the threshold.
|
109
|
+
|
110
|
+
Returns
|
111
|
+
-------
|
112
|
+
matplotlib.figure.Figure
|
113
|
+
"""
|
114
|
+
import matplotlib.pyplot as plt
|
115
|
+
|
116
|
+
fig, ax = plt.subplots(dpi=300)
|
117
|
+
resdf = self.to_df()
|
118
|
+
xticks = np.arange(resdf.shape[0])
|
119
|
+
trndf = resdf[resdf["chunk"]["period"] == "reference"]
|
120
|
+
tstdf = resdf[resdf["chunk"]["period"] == "analysis"]
|
121
|
+
# Get local indices for drift markers
|
122
|
+
driftx = np.where(resdf["domain_classifier_auroc"]["alert"].values) # type: ignore | dataframe
|
123
|
+
if np.size(driftx) > 2:
|
124
|
+
ax.plot(resdf.index, resdf["domain_classifier_auroc"]["upper_threshold"], "r--", label="thr_up")
|
125
|
+
ax.plot(resdf.index, resdf["domain_classifier_auroc"]["lower_threshold"], "r--", label="thr_low")
|
126
|
+
ax.plot(trndf.index, trndf["domain_classifier_auroc"]["value"], "b", label="train")
|
127
|
+
ax.plot(tstdf.index, tstdf["domain_classifier_auroc"]["value"], "g", label="test")
|
128
|
+
ax.plot(
|
129
|
+
resdf.index.values[driftx], # type: ignore | dataframe
|
130
|
+
resdf["domain_classifier_auroc"]["value"].values[driftx], # type: ignore | dataframe
|
131
|
+
"dm",
|
132
|
+
markersize=3,
|
133
|
+
label="drift",
|
134
|
+
)
|
135
|
+
ax.set_xticks(xticks)
|
136
|
+
ax.tick_params(axis="x", labelsize=6)
|
137
|
+
ax.tick_params(axis="y", labelsize=6)
|
138
|
+
ax.legend(loc="lower left", fontsize=6)
|
139
|
+
ax.set_title("Domain Classifier, Drift Detection", fontsize=8)
|
140
|
+
ax.set_ylabel("ROC AUC", fontsize=7)
|
141
|
+
ax.set_xlabel("Chunk Index", fontsize=7)
|
142
|
+
ax.set_ylim((0.0, 1.1))
|
143
|
+
return fig
|
dataeval/outputs/_linters.py
CHANGED
@@ -2,15 +2,12 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
import contextlib
|
6
5
|
from dataclasses import dataclass
|
7
6
|
from typing import Generic, TypeVar, Union
|
8
7
|
|
8
|
+
import pandas as pd
|
9
9
|
from typing_extensions import TypeAlias
|
10
10
|
|
11
|
-
with contextlib.suppress(ImportError):
|
12
|
-
import pandas as pd
|
13
|
-
|
14
11
|
from dataeval.outputs._base import Output
|
15
12
|
from dataeval.outputs._stats import DimensionStatsOutput, LabelStatsOutput, PixelStatsOutput, VisualStatsOutput
|
16
13
|
|
@@ -46,10 +43,12 @@ class DuplicatesOutput(Output, Generic[TIndexCollection]):
|
|
46
43
|
near: list[TIndexCollection]
|
47
44
|
|
48
45
|
|
49
|
-
def _reorganize_by_class_and_metric(
|
46
|
+
def _reorganize_by_class_and_metric(
|
47
|
+
result: IndexIssueMap, lstats: LabelStatsOutput
|
48
|
+
) -> tuple[dict[str, list[int]], dict[str, dict[str, int]]]:
|
50
49
|
"""Flip result from grouping by image to grouping by class and metric"""
|
51
|
-
metrics = {}
|
52
|
-
class_wise = {label: {} for label in lstats.class_names}
|
50
|
+
metrics: dict[str, list[int]] = {}
|
51
|
+
class_wise: dict[str, dict[str, int]] = {label: {} for label in lstats.class_names}
|
53
52
|
|
54
53
|
# Group metrics and calculate class-wise counts
|
55
54
|
for img, group in result.items():
|
@@ -62,7 +61,7 @@ def _reorganize_by_class_and_metric(result: IndexIssueMap, lstats: LabelStatsOut
|
|
62
61
|
return metrics, class_wise
|
63
62
|
|
64
63
|
|
65
|
-
def _create_table(metrics, class_wise):
|
64
|
+
def _create_table(metrics: dict[str, list[int]], class_wise: dict[str, dict[str, int]]) -> list[str]:
|
66
65
|
"""Create table for displaying the results"""
|
67
66
|
max_class_length = max(len(str(label)) for label in class_wise) + 2
|
68
67
|
max_total = max(len(metrics[group]) for group in metrics) + 2
|
@@ -72,7 +71,7 @@ def _create_table(metrics, class_wise):
|
|
72
71
|
+ [f"{group:^{max(5, len(str(group))) + 2}}" for group in sorted(metrics.keys())]
|
73
72
|
+ [f"{'Total':<{max_total}}"]
|
74
73
|
)
|
75
|
-
table_rows = []
|
74
|
+
table_rows: list[str] = []
|
76
75
|
|
77
76
|
for class_cat, results in class_wise.items():
|
78
77
|
table_value = [f"{class_cat:>{max_class_length}}"]
|
@@ -84,15 +83,14 @@ def _create_table(metrics, class_wise):
|
|
84
83
|
table_value.append(f"{total:^{max_total}}")
|
85
84
|
table_rows.append(" | ".join(table_value))
|
86
85
|
|
87
|
-
|
88
|
-
return table
|
86
|
+
return [table_header] + table_rows
|
89
87
|
|
90
88
|
|
91
|
-
def _create_pandas_dataframe(class_wise):
|
89
|
+
def _create_pandas_dataframe(class_wise: dict[str, dict[str, int]]) -> list[dict[str, str | int]]:
|
92
90
|
"""Create data for pandas dataframe"""
|
93
91
|
data = []
|
94
92
|
for label, metrics_dict in class_wise.items():
|
95
|
-
row = {"Class": label}
|
93
|
+
row: dict[str, str | int] = {"Class": label}
|
96
94
|
total = sum(metrics_dict.values())
|
97
95
|
row.update(metrics_dict) # Add metric counts
|
98
96
|
row["Total"] = total
|
@@ -121,8 +119,7 @@ class OutliersOutput(Output, Generic[TIndexIssueMap]):
|
|
121
119
|
def __len__(self) -> int:
|
122
120
|
if isinstance(self.issues, dict):
|
123
121
|
return len(self.issues)
|
124
|
-
|
125
|
-
return sum(len(d) for d in self.issues)
|
122
|
+
return sum(len(d) for d in self.issues)
|
126
123
|
|
127
124
|
def to_table(self, labelstats: LabelStatsOutput) -> str:
|
128
125
|
"""
|
@@ -168,8 +165,6 @@ class OutliersOutput(Output, Generic[TIndexIssueMap]):
|
|
168
165
|
-----
|
169
166
|
This method requires `pandas <https://pandas.pydata.org/>`_ to be installed.
|
170
167
|
"""
|
171
|
-
import pandas as pd
|
172
|
-
|
173
168
|
if isinstance(self.issues, dict):
|
174
169
|
_, classwise = _reorganize_by_class_and_metric(self.issues, labelstats)
|
175
170
|
data = _create_pandas_dataframe(classwise)
|
dataeval/outputs/_stats.py
CHANGED
@@ -2,24 +2,27 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
import contextlib
|
6
5
|
from dataclasses import dataclass
|
7
|
-
from typing import Any, Iterable, NamedTuple, Optional, Union
|
6
|
+
from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Optional, Sequence, Union
|
8
7
|
|
9
8
|
import numpy as np
|
9
|
+
import pandas as pd
|
10
10
|
from numpy.typing import NDArray
|
11
11
|
from typing_extensions import TypeAlias
|
12
12
|
|
13
|
-
with contextlib.suppress(ImportError):
|
14
|
-
import pandas as pd
|
15
|
-
|
16
13
|
from dataeval.outputs._base import Output
|
17
14
|
from dataeval.utils._plot import channel_histogram_plot, histogram_plot
|
18
15
|
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from matplotlib.figure import Figure
|
18
|
+
|
19
19
|
OptionalRange: TypeAlias = Optional[Union[int, Iterable[int]]]
|
20
20
|
|
21
21
|
SOURCE_INDEX = "source_index"
|
22
|
-
|
22
|
+
OBJECT_COUNT = "object_count"
|
23
|
+
IMAGE_COUNT = "image_count"
|
24
|
+
|
25
|
+
BASE_ATTRS = (SOURCE_INDEX, OBJECT_COUNT, IMAGE_COUNT)
|
23
26
|
|
24
27
|
|
25
28
|
class SourceIndex(NamedTuple):
|
@@ -54,17 +57,24 @@ class BaseStatsOutput(Output):
|
|
54
57
|
----------
|
55
58
|
source_index : List[SourceIndex]
|
56
59
|
Mapping from statistic to source image, box and channel index
|
57
|
-
|
60
|
+
object_count : NDArray[np.uint16]
|
61
|
+
The number of detected objects in each image
|
58
62
|
"""
|
59
63
|
|
60
64
|
source_index: list[SourceIndex]
|
61
|
-
|
65
|
+
object_count: NDArray[np.uint16]
|
66
|
+
image_count: int
|
62
67
|
|
63
68
|
def __post_init__(self) -> None:
|
64
|
-
|
65
|
-
|
66
|
-
if
|
67
|
-
raise ValueError(f"All values must have the same length as source_index. Bad values: {str(
|
69
|
+
si_length = len(self.source_index)
|
70
|
+
mismatch = {k: len(v) for k, v in self.data().items() if k not in BASE_ATTRS and len(v) != si_length}
|
71
|
+
if mismatch:
|
72
|
+
raise ValueError(f"All values must have the same length as source_index. Bad values: {str(mismatch)}.")
|
73
|
+
oc_length = len(self.object_count)
|
74
|
+
if oc_length != self.image_count:
|
75
|
+
raise ValueError(
|
76
|
+
f"Total object counts per image does not match image count. {oc_length} != {self.image_count}."
|
77
|
+
)
|
68
78
|
|
69
79
|
def get_channel_mask(
|
70
80
|
self,
|
@@ -126,21 +136,64 @@ class BaseStatsOutput(Output):
|
|
126
136
|
|
127
137
|
return max_channels, ch_mask
|
128
138
|
|
129
|
-
def factors(
|
139
|
+
def factors(
|
140
|
+
self,
|
141
|
+
filter: str | Sequence[str] | None = None, # noqa: A002
|
142
|
+
exclude_constant: bool = False,
|
143
|
+
) -> dict[str, NDArray[Any]]:
|
144
|
+
"""
|
145
|
+
Returns all 1-dimensional data as a dictionary of numpy arrays.
|
146
|
+
|
147
|
+
Parameters
|
148
|
+
----------
|
149
|
+
filter : str, Sequence[str] or None, default None:
|
150
|
+
If provided, only returns keys that match the filter.
|
151
|
+
exclude_constant : bool, default False
|
152
|
+
If True, exclude arrays that contain only a single unique value.
|
153
|
+
|
154
|
+
Returns
|
155
|
+
-------
|
156
|
+
dict[str, NDArray[Any]]
|
157
|
+
"""
|
158
|
+
filter_ = [filter] if isinstance(filter, str) else filter
|
130
159
|
return {
|
131
160
|
k: v
|
132
161
|
for k, v in self.data().items()
|
133
|
-
if k not in
|
162
|
+
if k not in BASE_ATTRS
|
163
|
+
and (filter_ is None or k in filter_)
|
164
|
+
and isinstance(v, np.ndarray)
|
165
|
+
and v.ndim == 1
|
166
|
+
and (not exclude_constant or len(np.unique(v)) > 1)
|
134
167
|
}
|
135
168
|
|
136
169
|
def plot(
|
137
170
|
self, log: bool, channel_limit: int | None = None, channel_index: int | Iterable[int] | None = None
|
138
|
-
) ->
|
171
|
+
) -> Figure:
|
172
|
+
"""
|
173
|
+
Plots the statistics as a set of histograms.
|
174
|
+
|
175
|
+
Parameters
|
176
|
+
----------
|
177
|
+
log : bool
|
178
|
+
If True, plots the histograms on a logarithmic scale.
|
179
|
+
channel_limit : int or None
|
180
|
+
The maximum number of channels to plot. If None, all channels are plotted.
|
181
|
+
channel_index : int, Iterable[int] or None
|
182
|
+
The index or indices of the channels to plot. If None, all channels are plotted.
|
183
|
+
|
184
|
+
Returns
|
185
|
+
-------
|
186
|
+
matplotlib.Figure
|
187
|
+
"""
|
188
|
+
from matplotlib.figure import Figure
|
189
|
+
|
139
190
|
max_channels, ch_mask = self._get_channels(channel_limit, channel_index)
|
191
|
+
factors = self.factors(exclude_constant=True)
|
192
|
+
if not factors:
|
193
|
+
return Figure()
|
140
194
|
if max_channels == 1:
|
141
|
-
histogram_plot(
|
142
|
-
|
143
|
-
channel_histogram_plot(self.factors(), log, max_channels, ch_mask)
|
195
|
+
return histogram_plot(factors, log)
|
196
|
+
return channel_histogram_plot(factors, log, max_channels, ch_mask)
|
144
197
|
|
145
198
|
|
146
199
|
@dataclass(frozen=True)
|
@@ -150,9 +203,9 @@ class DimensionStatsOutput(BaseStatsOutput):
|
|
150
203
|
|
151
204
|
Attributes
|
152
205
|
----------
|
153
|
-
|
206
|
+
offset_x : NDArray[np.int32]
|
154
207
|
Offsets from the left edge of images in pixels
|
155
|
-
|
208
|
+
offset_y : NDArray[np.int32]
|
156
209
|
Offsets from the top edge of images in pixels
|
157
210
|
width : NDArray[np.uint32]
|
158
211
|
Width of the images in pixels
|
@@ -163,25 +216,28 @@ class DimensionStatsOutput(BaseStatsOutput):
|
|
163
216
|
size : NDArray[np.uint32]
|
164
217
|
Size of the images in pixels
|
165
218
|
aspect_ratio : NDArray[np.float16]
|
166
|
-
:term:`
|
219
|
+
:term:`Aspect Ratio<Aspect Ratio>` of the images (width/height)
|
167
220
|
depth : NDArray[np.uint8]
|
168
221
|
Color depth of the images in bits
|
169
|
-
center : NDArray[np.
|
222
|
+
center : NDArray[np.uint32]
|
170
223
|
Offset from center in [x,y] coordinates of the images in pixels
|
171
|
-
|
224
|
+
distance_center : NDArray[np.float32]
|
172
225
|
Distance in pixels from center
|
226
|
+
distance_edge : NDArray[np.uint32]
|
227
|
+
Distance in pixels from nearest edge
|
173
228
|
"""
|
174
229
|
|
175
|
-
|
176
|
-
|
230
|
+
offset_x: NDArray[np.int32]
|
231
|
+
offset_y: NDArray[np.int32]
|
177
232
|
width: NDArray[np.uint32]
|
178
233
|
height: NDArray[np.uint32]
|
179
234
|
channels: NDArray[np.uint8]
|
180
235
|
size: NDArray[np.uint32]
|
181
236
|
aspect_ratio: NDArray[np.float16]
|
182
237
|
depth: NDArray[np.uint8]
|
183
|
-
center: NDArray[np.
|
184
|
-
|
238
|
+
center: NDArray[np.int32]
|
239
|
+
distance_center: NDArray[np.float32]
|
240
|
+
distance_edge: NDArray[np.uint32]
|
185
241
|
|
186
242
|
|
187
243
|
@dataclass(frozen=True)
|
@@ -281,8 +337,6 @@ class LabelStatsOutput(Output):
|
|
281
337
|
-------
|
282
338
|
pd.DataFrame
|
283
339
|
"""
|
284
|
-
import pandas as pd
|
285
|
-
|
286
340
|
total_count = []
|
287
341
|
image_count = []
|
288
342
|
for cls in range(len(self.class_names)):
|
dataeval/outputs/_workflows.py
CHANGED
@@ -154,10 +154,10 @@ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[Any
|
|
154
154
|
Array of parameters to recreate line of best fit
|
155
155
|
"""
|
156
156
|
|
157
|
-
def is_valid(f_new, x_new, f_old, x_old):
|
157
|
+
def is_valid(f_new, x_new, f_old, x_old) -> bool: # noqa: ANN001
|
158
158
|
return f_new != np.nan
|
159
159
|
|
160
|
-
def f(x):
|
160
|
+
def f(x) -> float: # noqa: ANN001
|
161
161
|
try:
|
162
162
|
return np.sum(np.square(p_i - f_out(n_i, x)))
|
163
163
|
except RuntimeWarning:
|
dataeval/utils/_array.py
CHANGED
@@ -23,7 +23,7 @@ T = TypeVar("T", ArrayLike, np.ndarray, torch.Tensor)
|
|
23
23
|
_np_dtype = TypeVar("_np_dtype", bound=np.generic)
|
24
24
|
|
25
25
|
|
26
|
-
def _try_import(module_name) -> ModuleType | None:
|
26
|
+
def _try_import(module_name: str) -> ModuleType | None:
|
27
27
|
if module_name in _MODULE_CACHE:
|
28
28
|
return _MODULE_CACHE[module_name]
|
29
29
|
|
@@ -148,8 +148,7 @@ def ensure_embeddings(
|
|
148
148
|
|
149
149
|
if dtype is None:
|
150
150
|
return embeddings
|
151
|
-
|
152
|
-
return arr
|
151
|
+
return arr
|
153
152
|
|
154
153
|
|
155
154
|
@overload
|
@@ -174,10 +173,9 @@ def flatten(array: ArrayLike) -> NDArray[Any] | torch.Tensor:
|
|
174
173
|
if isinstance(array, np.ndarray):
|
175
174
|
nparr = as_numpy(array)
|
176
175
|
return nparr.reshape((nparr.shape[0], -1))
|
177
|
-
|
176
|
+
if isinstance(array, torch.Tensor):
|
178
177
|
return torch.flatten(array, start_dim=1)
|
179
|
-
|
180
|
-
raise TypeError(f"Unsupported array type {type(array)}.")
|
178
|
+
raise TypeError(f"Unsupported array type {type(array)}.")
|
181
179
|
|
182
180
|
|
183
181
|
_TArray = TypeVar("_TArray", bound=Array)
|
@@ -199,7 +197,6 @@ def channels_first_to_last(array: _TArray) -> _TArray:
|
|
199
197
|
"""
|
200
198
|
if isinstance(array, np.ndarray):
|
201
199
|
return np.transpose(array, (1, 2, 0))
|
202
|
-
|
200
|
+
if isinstance(array, torch.Tensor):
|
203
201
|
return torch.permute(array, (1, 2, 0))
|
204
|
-
|
205
|
-
raise TypeError(f"Unsupported array type {type(array)}.")
|
202
|
+
raise TypeError(f"Unsupported array type {type(array)}.")
|
dataeval/utils/_bin.py
CHANGED
dataeval/utils/_clusterer.py
CHANGED
@@ -4,6 +4,7 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import warnings
|
6
6
|
from dataclasses import dataclass
|
7
|
+
from typing import Any
|
7
8
|
|
8
9
|
import numba
|
9
10
|
import numpy as np
|
@@ -30,7 +31,9 @@ from dataeval.utils._fast_mst import calculate_neighbor_distances, minimum_spann
|
|
30
31
|
|
31
32
|
|
32
33
|
@numba.njit(parallel=True, locals={"i": numba.types.int32})
|
33
|
-
def compare_links_to_cluster_std(
|
34
|
+
def compare_links_to_cluster_std(
|
35
|
+
mst: NDArray[np.float32], clusters: NDArray[np.intp]
|
36
|
+
) -> tuple[NDArray[np.int32], NDArray[np.int32]]:
|
34
37
|
cluster_ids = np.unique(clusters)
|
35
38
|
cluster_grouping = np.full(mst.shape[0], -1, dtype=np.int16)
|
36
39
|
|
@@ -79,7 +82,7 @@ def cluster(data: ArrayLike) -> ClusterData:
|
|
79
82
|
cluster_selection_epsilon = 0.0
|
80
83
|
# cluster_selection_method = "eom"
|
81
84
|
|
82
|
-
x = flatten(to_numpy(data))
|
85
|
+
x: NDArray[Any] = flatten(to_numpy(data))
|
83
86
|
samples, features = x.shape # Due to flatten(), we know shape has a length of 2
|
84
87
|
if samples < 2:
|
85
88
|
raise ValueError(f"Data should have at least 2 samples; got {samples}")
|
@@ -125,9 +128,9 @@ def cluster(data: ArrayLike) -> ClusterData:
|
|
125
128
|
return ClusterData(clusters, mst, linkage_tree, condensed_tree, membership_strengths, kneighbors, kdistances)
|
126
129
|
|
127
130
|
|
128
|
-
def sorted_union_find(index_groups):
|
131
|
+
def sorted_union_find(index_groups: NDArray[np.int32]) -> list[list[np.int32]]:
|
129
132
|
"""Merges and sorts groups of indices that share any common index"""
|
130
|
-
groups = [[np.int32(x) for x in range(0)] for y in range(0)]
|
133
|
+
groups: list[list[np.int32]] = [[np.int32(x) for x in range(0)] for y in range(0)]
|
131
134
|
uniques, inverse = np.unique(index_groups, return_inverse=True)
|
132
135
|
inverse = inverse.flatten()
|
133
136
|
disjoint_set = ds_rank_create(uniques.size)
|
dataeval/utils/_fast_mst.py
CHANGED
@@ -6,9 +6,11 @@
|
|
6
6
|
__all__ = []
|
7
7
|
|
8
8
|
import warnings
|
9
|
+
from typing import Any
|
9
10
|
|
10
11
|
import numba
|
11
12
|
import numpy as np
|
13
|
+
from numpy.typing import NDArray
|
12
14
|
from sklearn.neighbors import NearestNeighbors
|
13
15
|
|
14
16
|
with warnings.catch_warnings():
|
@@ -17,24 +19,26 @@ with warnings.catch_warnings():
|
|
17
19
|
|
18
20
|
|
19
21
|
@numba.njit()
|
20
|
-
def _ds_union_by_rank(disjoint_set, point, nbr):
|
22
|
+
def _ds_union_by_rank(disjoint_set: tuple[NDArray[np.int32], NDArray[np.int32]], point: int, nbr: int) -> int:
|
21
23
|
y = ds_find(disjoint_set, point)
|
22
24
|
x = ds_find(disjoint_set, nbr)
|
23
25
|
|
24
26
|
if x == y:
|
25
27
|
return 0
|
26
28
|
|
27
|
-
if disjoint_set
|
29
|
+
if disjoint_set[1][x] < disjoint_set[1][y]:
|
28
30
|
x, y = y, x
|
29
31
|
|
30
|
-
disjoint_set
|
31
|
-
if disjoint_set
|
32
|
-
disjoint_set
|
32
|
+
disjoint_set[0][y] = x
|
33
|
+
if disjoint_set[1][x] == disjoint_set[1][y]:
|
34
|
+
disjoint_set[1][x] += 1
|
33
35
|
return 1
|
34
36
|
|
35
37
|
|
36
38
|
@numba.njit(locals={"i": numba.types.uint32, "nbr": numba.types.uint32, "dist": numba.types.float32})
|
37
|
-
def _init_tree(
|
39
|
+
def _init_tree(
|
40
|
+
n_neighbors: NDArray[np.intp], n_distance: NDArray[np.float32]
|
41
|
+
) -> tuple[NDArray[np.float32], int, tuple[NDArray[np.int32], NDArray[np.int32]], NDArray[np.uint32]]:
|
38
42
|
# Initial graph to hold tree connections
|
39
43
|
tree = np.zeros((n_neighbors.size - 1, 3), dtype=np.float32)
|
40
44
|
disjoint_set = ds_rank_create(n_neighbors.size)
|
@@ -56,7 +60,13 @@ def _init_tree(n_neighbors, n_distance):
|
|
56
60
|
|
57
61
|
|
58
62
|
@numba.njit(locals={"i": numba.types.uint32, "nbr": numba.types.uint32})
|
59
|
-
def _update_tree_by_distance(
|
63
|
+
def _update_tree_by_distance(
|
64
|
+
tree: NDArray[np.float32],
|
65
|
+
int_tree: int,
|
66
|
+
disjoint_set: tuple[NDArray[np.int32], NDArray[np.int32]],
|
67
|
+
n_neighbors: NDArray[np.uint32],
|
68
|
+
n_distance: NDArray[np.float32],
|
69
|
+
) -> tuple[NDArray[np.float32], int, tuple[NDArray[np.int32], NDArray[np.int32]], NDArray[np.uint32]]:
|
60
70
|
cluster_points = np.empty(n_neighbors.size, dtype=np.uint32)
|
61
71
|
sort_dist = np.argsort(n_distance)
|
62
72
|
dist_sorted = n_distance[sort_dist]
|
@@ -80,9 +90,9 @@ def _update_tree_by_distance(tree, int_tree, disjoint_set, n_neighbors, n_distan
|
|
80
90
|
|
81
91
|
|
82
92
|
@numba.njit(locals={"i": numba.types.uint32})
|
83
|
-
def _cluster_edges(tracker, last_idx, cluster_distances):
|
93
|
+
def _cluster_edges(tracker: NDArray[Any], last_idx: int, cluster_distances: NDArray[Any]) -> list[NDArray[np.intp]]:
|
84
94
|
cluster_ids = np.unique(tracker)
|
85
|
-
edge_points = []
|
95
|
+
edge_points: list[NDArray[np.intp]] = []
|
86
96
|
for idx in range(cluster_ids.size):
|
87
97
|
cluster_points = np.nonzero(tracker == cluster_ids[idx])[0]
|
88
98
|
cluster_size = cluster_points.size
|
@@ -102,14 +112,16 @@ def _cluster_edges(tracker, last_idx, cluster_distances):
|
|
102
112
|
return edge_points
|
103
113
|
|
104
114
|
|
105
|
-
def _compute_nn(dataA, dataB, k):
|
115
|
+
def _compute_nn(dataA: NDArray[Any], dataB: NDArray[Any], k: int) -> tuple[NDArray[np.int32], NDArray[np.float32]]:
|
106
116
|
distances, neighbors = NearestNeighbors(n_neighbors=k + 1, algorithm="brute").fit(dataA).kneighbors(dataB)
|
107
117
|
neighbors = np.array(neighbors[:, 1 : k + 1], dtype=np.int32)
|
108
118
|
distances = np.array(distances[:, 1 : k + 1], dtype=np.float32)
|
109
119
|
return neighbors, distances
|
110
120
|
|
111
121
|
|
112
|
-
def _calculate_cluster_neighbors(
|
122
|
+
def _calculate_cluster_neighbors(
|
123
|
+
data: NDArray[Any], groups: list[NDArray[np.intp]], point_array: NDArray[Any]
|
124
|
+
) -> tuple[NDArray[np.uint32], NDArray[np.float32]]:
|
113
125
|
"""Rerun nearest neighbor based on clusters"""
|
114
126
|
cluster_neighbors = np.zeros(point_array.size, dtype=np.uint32)
|
115
127
|
cluster_nbr_distances = np.full(point_array.size, np.inf, dtype=np.float32)
|
@@ -126,7 +138,9 @@ def _calculate_cluster_neighbors(data, groups, point_array):
|
|
126
138
|
return cluster_neighbors, cluster_nbr_distances
|
127
139
|
|
128
140
|
|
129
|
-
def minimum_spanning_tree(
|
141
|
+
def minimum_spanning_tree(
|
142
|
+
data: NDArray[Any], neighbors: NDArray[np.int32], distances: NDArray[np.float32]
|
143
|
+
) -> NDArray[np.float32]:
|
130
144
|
# Transpose arrays to get number of samples along a row
|
131
145
|
k_neighbors = neighbors.T.astype(np.uint32).copy()
|
132
146
|
k_distances = distances.T.astype(np.float32).copy()
|
@@ -168,7 +182,7 @@ def minimum_spanning_tree(data, neighbors, distances):
|
|
168
182
|
return tree
|
169
183
|
|
170
184
|
|
171
|
-
def calculate_neighbor_distances(data: np.ndarray, k: int = 10):
|
185
|
+
def calculate_neighbor_distances(data: np.ndarray, k: int = 10) -> tuple[NDArray[np.int32], NDArray[np.float32]]:
|
172
186
|
# Have the potential to add in other distance calculations - supported calculations:
|
173
187
|
# https://github.com/lmcinnes/pynndescent/blob/master/pynndescent/pynndescent_.py#L524
|
174
188
|
try:
|
dataeval/utils/_image.py
CHANGED
@@ -12,6 +12,9 @@ from scipy.signal import convolve2d
|
|
12
12
|
EDGE_KERNEL = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=np.int8)
|
13
13
|
BIT_DEPTH = (1, 8, 12, 16, 32)
|
14
14
|
|
15
|
+
Box = tuple[int, int, int, int]
|
16
|
+
"""Bounding box as tuple of integers in x0, y0, x1, y1 format."""
|
17
|
+
|
15
18
|
|
16
19
|
@dataclass
|
17
20
|
class BitDepth:
|
@@ -25,12 +28,11 @@ def get_bitdepth(image: NDArray[Any]) -> BitDepth:
|
|
25
28
|
Approximates the bit depth of the image using the
|
26
29
|
min and max pixel values.
|
27
30
|
"""
|
28
|
-
pmin, pmax = np.
|
31
|
+
pmin, pmax = np.nanmin(image), np.nanmax(image)
|
29
32
|
if pmin < 0:
|
30
33
|
return BitDepth(0, pmin, pmax)
|
31
|
-
|
32
|
-
|
33
|
-
return BitDepth(depth, 0, 2**depth - 1)
|
34
|
+
depth = ([x for x in BIT_DEPTH if 2**x > pmax] or [max(BIT_DEPTH)])[0]
|
35
|
+
return BitDepth(depth, 0, 2**depth - 1)
|
34
36
|
|
35
37
|
|
36
38
|
def rescale(image: NDArray[Any], depth: int = 1) -> NDArray[Any]:
|
@@ -40,9 +42,8 @@ def rescale(image: NDArray[Any], depth: int = 1) -> NDArray[Any]:
|
|
40
42
|
bitdepth = get_bitdepth(image)
|
41
43
|
if bitdepth.depth == depth:
|
42
44
|
return image
|
43
|
-
|
44
|
-
|
45
|
-
return normalized * (2**depth - 1)
|
45
|
+
normalized = (image + bitdepth.pmin) / (bitdepth.pmax - bitdepth.pmin)
|
46
|
+
return normalized * (2**depth - 1)
|
46
47
|
|
47
48
|
|
48
49
|
def normalize_image_shape(image: NDArray[Any]) -> NDArray[Any]:
|
@@ -52,13 +53,12 @@ def normalize_image_shape(image: NDArray[Any]) -> NDArray[Any]:
|
|
52
53
|
ndim = image.ndim
|
53
54
|
if ndim == 2:
|
54
55
|
return np.expand_dims(image, axis=0)
|
55
|
-
|
56
|
+
if ndim == 3:
|
56
57
|
return image
|
57
|
-
|
58
|
+
if ndim > 3:
|
58
59
|
# Slice all but the last 3 dimensions
|
59
60
|
return image[(0,) * (ndim - 3)]
|
60
|
-
|
61
|
-
raise ValueError("Images must have 2 or more dimensions.")
|
61
|
+
raise ValueError("Images must have 2 or more dimensions.")
|
62
62
|
|
63
63
|
|
64
64
|
def edge_filter(image: NDArray[Any], offset: float = 0.5) -> NDArray[np.uint8]:
|
@@ -71,3 +71,57 @@ def edge_filter(image: NDArray[Any], offset: float = 0.5) -> NDArray[np.uint8]:
|
|
71
71
|
edges = convolve2d(image, EDGE_KERNEL, mode="same", boundary="symm") + offset
|
72
72
|
np.clip(edges, 0, 255, edges)
|
73
73
|
return edges
|
74
|
+
|
75
|
+
|
76
|
+
def clip_box(image: NDArray[Any], box: Box) -> Box:
|
77
|
+
"""
|
78
|
+
Clip the box to inside the provided image dimensions.
|
79
|
+
"""
|
80
|
+
x0, y0, x1, y1 = box
|
81
|
+
h, w = image.shape[-2:]
|
82
|
+
|
83
|
+
return max(0, x0), max(0, y0), min(w, x1), min(h, y1)
|
84
|
+
|
85
|
+
|
86
|
+
def is_valid_box(box: Box) -> bool:
|
87
|
+
"""
|
88
|
+
Check if the box dimensions provided are a valid image.
|
89
|
+
"""
|
90
|
+
return box[2] > box[0] and box[3] > box[1]
|
91
|
+
|
92
|
+
|
93
|
+
def clip_and_pad(image: NDArray[Any], box: Box) -> NDArray[Any]:
|
94
|
+
"""
|
95
|
+
Extract a region from an image based on a bounding box, clipping to image boundaries
|
96
|
+
and padding out-of-bounds areas with np.nan.
|
97
|
+
|
98
|
+
Parameters:
|
99
|
+
-----------
|
100
|
+
image : NDArray[Any]
|
101
|
+
Input image array in format C, H, W (channels first)
|
102
|
+
box : Box
|
103
|
+
Bounding box coordinates as (x0, y0, x1, y1) where (x0, y0) is top-left and (x1, y1) is bottom-right
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
--------
|
107
|
+
NDArray[Any]
|
108
|
+
The extracted region with out-of-bounds areas padded with np.nan
|
109
|
+
"""
|
110
|
+
|
111
|
+
# Create output array filled with NaN with a minimum size of 1x1
|
112
|
+
bw, bh = max(1, box[2] - box[0]), max(1, box[3] - box[1])
|
113
|
+
|
114
|
+
output = np.full((image.shape[-3] if image.ndim > 2 else 1, bh, bw), np.nan)
|
115
|
+
|
116
|
+
# Calculate source box
|
117
|
+
sbox = clip_box(image, box)
|
118
|
+
|
119
|
+
# Calculate destination box
|
120
|
+
x0, y0 = sbox[0] - box[0], sbox[1] - box[1]
|
121
|
+
x1, y1 = x0 + (sbox[2] - sbox[0]), y0 + (sbox[3] - sbox[1])
|
122
|
+
|
123
|
+
# Copy the source if valid from the image to the output
|
124
|
+
if is_valid_box(sbox):
|
125
|
+
output[:, y0:y1, x0:x1] = image[:, sbox[1] : sbox[3], sbox[0] : sbox[2]]
|
126
|
+
|
127
|
+
return output
|