dataeval 0.81.0__py3-none-any.whl → 0.82.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/config.py +68 -11
- dataeval/detectors/drift/__init__.py +2 -2
- dataeval/detectors/drift/_base.py +8 -64
- dataeval/detectors/drift/_mmd.py +12 -38
- dataeval/detectors/drift/_torch.py +7 -7
- dataeval/detectors/drift/_uncertainty.py +6 -5
- dataeval/detectors/drift/updates.py +20 -3
- dataeval/detectors/linters/__init__.py +3 -2
- dataeval/detectors/linters/duplicates.py +14 -46
- dataeval/detectors/linters/outliers.py +25 -159
- dataeval/detectors/ood/__init__.py +1 -1
- dataeval/detectors/ood/ae.py +6 -5
- dataeval/detectors/ood/base.py +2 -2
- dataeval/detectors/ood/metadata_ood_mi.py +4 -6
- dataeval/detectors/ood/mixin.py +3 -4
- dataeval/detectors/ood/vae.py +3 -2
- dataeval/metadata/__init__.py +2 -1
- dataeval/metadata/_distance.py +134 -0
- dataeval/metadata/_ood.py +30 -49
- dataeval/metadata/_utils.py +44 -0
- dataeval/metrics/bias/__init__.py +5 -4
- dataeval/metrics/bias/_balance.py +17 -149
- dataeval/metrics/bias/_coverage.py +4 -106
- dataeval/metrics/bias/_diversity.py +12 -107
- dataeval/metrics/bias/_parity.py +7 -71
- dataeval/metrics/estimators/__init__.py +5 -4
- dataeval/metrics/estimators/_ber.py +2 -20
- dataeval/metrics/estimators/_clusterer.py +1 -61
- dataeval/metrics/estimators/_divergence.py +2 -19
- dataeval/metrics/estimators/_uap.py +2 -16
- dataeval/metrics/stats/__init__.py +15 -12
- dataeval/metrics/stats/_base.py +41 -128
- dataeval/metrics/stats/_boxratiostats.py +13 -13
- dataeval/metrics/stats/_dimensionstats.py +17 -58
- dataeval/metrics/stats/_hashstats.py +19 -35
- dataeval/metrics/stats/_imagestats.py +94 -0
- dataeval/metrics/stats/_labelstats.py +42 -121
- dataeval/metrics/stats/_pixelstats.py +19 -51
- dataeval/metrics/stats/_visualstats.py +19 -51
- dataeval/outputs/__init__.py +57 -0
- dataeval/outputs/_base.py +182 -0
- dataeval/outputs/_bias.py +381 -0
- dataeval/outputs/_drift.py +83 -0
- dataeval/outputs/_estimators.py +114 -0
- dataeval/outputs/_linters.py +186 -0
- dataeval/outputs/_metadata.py +54 -0
- dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
- dataeval/outputs/_stats.py +393 -0
- dataeval/outputs/_utils.py +44 -0
- dataeval/outputs/_workflows.py +364 -0
- dataeval/typing.py +187 -7
- dataeval/utils/_method.py +1 -5
- dataeval/utils/_plot.py +2 -2
- dataeval/utils/data/__init__.py +5 -1
- dataeval/utils/data/_dataset.py +217 -0
- dataeval/utils/data/_embeddings.py +12 -14
- dataeval/utils/data/_images.py +30 -27
- dataeval/utils/data/_metadata.py +28 -11
- dataeval/utils/data/_selection.py +25 -22
- dataeval/utils/data/_split.py +5 -29
- dataeval/utils/data/_targets.py +14 -2
- dataeval/utils/data/datasets/_base.py +5 -5
- dataeval/utils/data/datasets/_cifar10.py +1 -1
- dataeval/utils/data/datasets/_milco.py +1 -1
- dataeval/utils/data/datasets/_mnist.py +1 -1
- dataeval/utils/data/datasets/_ships.py +1 -1
- dataeval/utils/data/{_types.py → datasets/_types.py} +10 -16
- dataeval/utils/data/datasets/_voc.py +1 -1
- dataeval/utils/data/selections/_classfilter.py +4 -5
- dataeval/utils/data/selections/_indices.py +2 -2
- dataeval/utils/data/selections/_limit.py +2 -2
- dataeval/utils/data/selections/_reverse.py +2 -2
- dataeval/utils/data/selections/_shuffle.py +2 -2
- dataeval/utils/torch/_internal.py +5 -5
- dataeval/utils/torch/trainer.py +8 -8
- dataeval/workflows/__init__.py +2 -1
- dataeval/workflows/sufficiency.py +6 -342
- {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/METADATA +2 -2
- dataeval-0.82.1.dist-info/RECORD +105 -0
- dataeval/_output.py +0 -137
- dataeval/detectors/ood/metadata_ks_compare.py +0 -129
- dataeval/metrics/stats/_datasetstats.py +0 -198
- dataeval-0.81.0.dist-info/RECORD +0 -94
- {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/WHEEL +0 -0
dataeval/metadata/_ood.py
CHANGED
@@ -7,51 +7,12 @@ import warnings
|
|
7
7
|
import numpy as np
|
8
8
|
from numpy.typing import NDArray
|
9
9
|
|
10
|
-
from dataeval.
|
10
|
+
from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
|
11
|
+
from dataeval.outputs import MostDeviatedFactorsOutput, OODOutput
|
12
|
+
from dataeval.outputs._base import set_metadata
|
11
13
|
from dataeval.utils.data import Metadata
|
12
14
|
|
13
15
|
|
14
|
-
def _validate_keys(keys1: list[str], keys2: list[str]) -> None:
|
15
|
-
"""
|
16
|
-
Raises error when two lists are not equivalent including ordering
|
17
|
-
|
18
|
-
Parameters
|
19
|
-
----------
|
20
|
-
keys1 : list of strings
|
21
|
-
List of strings to compare
|
22
|
-
keys2 : list of strings
|
23
|
-
List of strings to compare
|
24
|
-
|
25
|
-
Raises
|
26
|
-
------
|
27
|
-
ValueError
|
28
|
-
If lists do not have the same values, value counts, or ordering
|
29
|
-
"""
|
30
|
-
|
31
|
-
if keys1 != keys2:
|
32
|
-
raise ValueError(f"Metadata keys must be identical, got {keys1} and {keys2}")
|
33
|
-
|
34
|
-
|
35
|
-
def _validate_factors_and_data(factors: list[str], data: NDArray) -> None:
|
36
|
-
"""
|
37
|
-
Raises error when the number of factors and number of rows do not match
|
38
|
-
|
39
|
-
Parameters
|
40
|
-
----------
|
41
|
-
factors : list of strings
|
42
|
-
List of factor names of size N
|
43
|
-
data : NDArray
|
44
|
-
Array of values with shape (M, N)
|
45
|
-
|
46
|
-
Raises
|
47
|
-
------
|
48
|
-
ValueError
|
49
|
-
If the length of factors does not equal the length of the transposed data
|
50
|
-
"""
|
51
|
-
if len(factors) != len(data.T):
|
52
|
-
raise ValueError(f"Factors and data have mismatched lengths. Got {len(factors)} and {len(data.T)}")
|
53
|
-
|
54
|
-
|
55
16
|
def _combine_metadata(metadata_1: Metadata, metadata_2: Metadata) -> tuple[list[str], list[NDArray], list[NDArray]]:
|
56
17
|
"""
|
57
18
|
Combines the factor names and data arrays of metadata_1 and metadata_2 when the names
|
@@ -93,7 +54,7 @@ def _combine_metadata(metadata_1: Metadata, metadata_2: Metadata) -> tuple[list[
|
|
93
54
|
|
94
55
|
# Validate and attach discrete data
|
95
56
|
if metadata_1.discrete_factor_names:
|
96
|
-
|
57
|
+
_compare_keys(metadata_1.discrete_factor_names, metadata_2.discrete_factor_names)
|
97
58
|
_validate_factors_and_data(metadata_1.discrete_factor_names, metadata_1.discrete_data)
|
98
59
|
|
99
60
|
factor_names.extend(metadata_1.discrete_factor_names)
|
@@ -102,7 +63,7 @@ def _combine_metadata(metadata_1: Metadata, metadata_2: Metadata) -> tuple[list[
|
|
102
63
|
|
103
64
|
# Validate and attach continuous data
|
104
65
|
if metadata_1.continuous_factor_names:
|
105
|
-
|
66
|
+
_compare_keys(metadata_1.continuous_factor_names, metadata_2.continuous_factor_names)
|
106
67
|
_validate_factors_and_data(metadata_1.continuous_factor_names, metadata_1.continuous_data)
|
107
68
|
|
108
69
|
factor_names.extend(metadata_1.continuous_factor_names)
|
@@ -159,11 +120,12 @@ def _calc_median_deviations(reference: NDArray, test: NDArray) -> NDArray:
|
|
159
120
|
return np.abs(np.where(test_dev >= 0, test_dev / pscale, test_dev / nscale)) # (S_t, F)
|
160
121
|
|
161
122
|
|
123
|
+
@set_metadata
|
162
124
|
def most_deviated_factors(
|
163
125
|
metadata_1: Metadata,
|
164
126
|
metadata_2: Metadata,
|
165
127
|
ood: OODOutput,
|
166
|
-
) ->
|
128
|
+
) -> MostDeviatedFactorsOutput:
|
167
129
|
"""
|
168
130
|
Determines greatest deviation in metadata features per out of distribution sample in metadata_2.
|
169
131
|
|
@@ -189,13 +151,30 @@ def most_deviated_factors(
|
|
189
151
|
and have equivalent factor names and lengths
|
190
152
|
2. The flag at index `i` in :attr:`.OODOutput.is_ood` must correspond
|
191
153
|
directly to sample `i` of `metadata_2` being out-of-distribution from `metadata_1`
|
154
|
+
|
155
|
+
Examples
|
156
|
+
--------
|
157
|
+
|
158
|
+
>>> from dataeval.detectors.ood import OODOutput
|
159
|
+
|
160
|
+
All samples are out-of-distribution
|
161
|
+
|
162
|
+
>>> is_ood = OODOutput(np.array([True, True, True]), np.array([]), np.array([]))
|
163
|
+
>>> most_deviated_factors(metadata1, metadata2, is_ood)
|
164
|
+
MostDeviatedFactorsOutput([('time', 2.0), ('time', 2.592), ('time', 3.51)])
|
165
|
+
|
166
|
+
If there are no out-of-distribution samples, a list is returned
|
167
|
+
|
168
|
+
>>> is_ood = OODOutput(np.array([False, False, False]), np.array([]), np.array([]))
|
169
|
+
>>> most_deviated_factors(metadata1, metadata2, is_ood)
|
170
|
+
MostDeviatedFactorsOutput([])
|
192
171
|
"""
|
193
172
|
|
194
173
|
ood_mask: NDArray[np.bool] = ood.is_ood
|
195
174
|
|
196
175
|
# No metadata correlated with out of distribution data
|
197
176
|
if not any(ood_mask):
|
198
|
-
return []
|
177
|
+
return MostDeviatedFactorsOutput([])
|
199
178
|
|
200
179
|
# Combines reference and test factor names and data if exists and match exactly
|
201
180
|
# shape -> (samples, factors)
|
@@ -204,6 +183,7 @@ def most_deviated_factors(
|
|
204
183
|
metadata_2=metadata_2,
|
205
184
|
)
|
206
185
|
|
186
|
+
# Stack discrete and continuous factors as separate factors. Must have equal sample counts
|
207
187
|
metadata_ref = np.hstack(md_1) if md_1 else np.array([])
|
208
188
|
metadata_tst = np.hstack(md_2) if md_2 else np.array([])
|
209
189
|
|
@@ -212,7 +192,7 @@ def most_deviated_factors(
|
|
212
192
|
f"At least 3 reference metadata samples are needed, got {len(metadata_ref)}",
|
213
193
|
UserWarning,
|
214
194
|
)
|
215
|
-
return []
|
195
|
+
return MostDeviatedFactorsOutput([])
|
216
196
|
|
217
197
|
if len(metadata_tst) != len(ood_mask):
|
218
198
|
raise ValueError(
|
@@ -226,7 +206,7 @@ def most_deviated_factors(
|
|
226
206
|
deviations = _calc_median_deviations(metadata_ref, metadata_tst)
|
227
207
|
|
228
208
|
# Get most impactful factor deviation of each sample for ood samples only
|
229
|
-
deviation = np.max(deviations, axis=1)[ood_mask]
|
209
|
+
deviation = np.max(deviations, axis=1)[ood_mask].astype(np.float16)
|
230
210
|
|
231
211
|
# Get indices of most impactful factors for ood samples only
|
232
212
|
max_factors = np.argmax(deviations, axis=1)[ood_mask]
|
@@ -235,4 +215,5 @@ def most_deviated_factors(
|
|
235
215
|
most_ood_factors = np.array(factor_names)[max_factors].tolist()
|
236
216
|
|
237
217
|
# List of tuples matching the factor name with its deviation
|
238
|
-
|
218
|
+
|
219
|
+
return MostDeviatedFactorsOutput([(factor, dev) for factor, dev in zip(most_ood_factors, deviation)])
|
@@ -0,0 +1,44 @@
|
|
1
|
+
__all__ = []
|
2
|
+
|
3
|
+
from numpy.typing import NDArray
|
4
|
+
|
5
|
+
|
6
|
+
def _compare_keys(keys1: list[str], keys2: list[str]) -> None:
|
7
|
+
"""
|
8
|
+
Raises error when two lists are not equivalent including ordering
|
9
|
+
|
10
|
+
Parameters
|
11
|
+
----------
|
12
|
+
keys1 : list of strings
|
13
|
+
List of strings to compare
|
14
|
+
keys2 : list of strings
|
15
|
+
List of strings to compare
|
16
|
+
|
17
|
+
Raises
|
18
|
+
------
|
19
|
+
ValueError
|
20
|
+
If lists do not have the same values, value counts, or ordering
|
21
|
+
"""
|
22
|
+
|
23
|
+
if keys1 != keys2:
|
24
|
+
raise ValueError(f"Metadata keys must be identical, got {keys1} and {keys2}")
|
25
|
+
|
26
|
+
|
27
|
+
def _validate_factors_and_data(factors: list[str], data: NDArray) -> None:
|
28
|
+
"""
|
29
|
+
Raises error when the number of factors and number of rows do not match
|
30
|
+
|
31
|
+
Parameters
|
32
|
+
----------
|
33
|
+
factors : list of strings
|
34
|
+
List of factor names of size N
|
35
|
+
data : NDArray
|
36
|
+
Array of values with shape (M, N)
|
37
|
+
|
38
|
+
Raises
|
39
|
+
------
|
40
|
+
ValueError
|
41
|
+
If the length of factors does not equal the length of the transposed data
|
42
|
+
"""
|
43
|
+
if len(factors) != len(data.T):
|
44
|
+
raise ValueError(f"Factors and data have mismatched lengths. Got {len(factors)} and {len(data.T)}")
|
@@ -16,7 +16,8 @@ __all__ = [
|
|
16
16
|
"parity",
|
17
17
|
]
|
18
18
|
|
19
|
-
from dataeval.metrics.bias._balance import
|
20
|
-
from dataeval.metrics.bias._coverage import
|
21
|
-
from dataeval.metrics.bias._diversity import
|
22
|
-
from dataeval.metrics.bias._parity import
|
19
|
+
from dataeval.metrics.bias._balance import balance
|
20
|
+
from dataeval.metrics.bias._coverage import coverage
|
21
|
+
from dataeval.metrics.bias._diversity import diversity
|
22
|
+
from dataeval.metrics.bias._parity import label_parity, parity
|
23
|
+
from dataeval.outputs._bias import BalanceOutput, CoverageOutput, DiversityOutput, LabelParityOutput, ParityOutput
|
@@ -2,150 +2,18 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
import contextlib
|
6
5
|
import warnings
|
7
|
-
from dataclasses import dataclass
|
8
|
-
from typing import Any, Literal, overload
|
9
6
|
|
10
7
|
import numpy as np
|
11
8
|
import scipy as sp
|
12
|
-
from numpy.typing import NDArray
|
13
9
|
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
|
14
10
|
|
15
|
-
from dataeval.
|
11
|
+
from dataeval.config import get_seed
|
12
|
+
from dataeval.outputs import BalanceOutput
|
13
|
+
from dataeval.outputs._base import set_metadata
|
16
14
|
from dataeval.utils._bin import get_counts
|
17
|
-
from dataeval.utils._plot import heatmap
|
18
15
|
from dataeval.utils.data import Metadata
|
19
16
|
|
20
|
-
with contextlib.suppress(ImportError):
|
21
|
-
from matplotlib.figure import Figure
|
22
|
-
|
23
|
-
|
24
|
-
@dataclass(frozen=True)
|
25
|
-
class BalanceOutput(Output):
|
26
|
-
"""
|
27
|
-
Output class for :func:`.balance` :term:`bias<Bias>` metric.
|
28
|
-
|
29
|
-
Attributes
|
30
|
-
----------
|
31
|
-
balance : NDArray[np.float64]
|
32
|
-
Estimate of mutual information between metadata factors and class label
|
33
|
-
factors : NDArray[np.float64]
|
34
|
-
Estimate of inter/intra-factor mutual information
|
35
|
-
classwise : NDArray[np.float64]
|
36
|
-
Estimate of mutual information between metadata factors and individual class labels
|
37
|
-
factor_names : list[str]
|
38
|
-
Names of each metadata factor
|
39
|
-
class_names : list[str]
|
40
|
-
List of the class labels present in the dataset
|
41
|
-
"""
|
42
|
-
|
43
|
-
balance: NDArray[np.float64]
|
44
|
-
factors: NDArray[np.float64]
|
45
|
-
classwise: NDArray[np.float64]
|
46
|
-
factor_names: list[str]
|
47
|
-
class_names: list[str]
|
48
|
-
|
49
|
-
@overload
|
50
|
-
def _by_factor_type(
|
51
|
-
self,
|
52
|
-
attr: Literal["factor_names"],
|
53
|
-
factor_type: Literal["discrete", "continuous", "both"],
|
54
|
-
) -> list[str]: ...
|
55
|
-
|
56
|
-
@overload
|
57
|
-
def _by_factor_type(
|
58
|
-
self,
|
59
|
-
attr: Literal["balance", "factors", "classwise"],
|
60
|
-
factor_type: Literal["discrete", "continuous", "both"],
|
61
|
-
) -> NDArray[np.float64]: ...
|
62
|
-
|
63
|
-
def _by_factor_type(
|
64
|
-
self,
|
65
|
-
attr: Literal["balance", "factors", "classwise", "factor_names"],
|
66
|
-
factor_type: Literal["discrete", "continuous", "both"],
|
67
|
-
) -> NDArray[np.float64] | list[str]:
|
68
|
-
# if not filtering by factor_type then just return the requested attribute without mask
|
69
|
-
if factor_type == "both":
|
70
|
-
return getattr(self, attr)
|
71
|
-
|
72
|
-
# create the mask for the selected factor_type
|
73
|
-
mask_lambda = (
|
74
|
-
(lambda x: "-continuous" not in x) if factor_type == "discrete" else (lambda x: "-discrete" not in x)
|
75
|
-
)
|
76
|
-
|
77
|
-
# return the masked attribute
|
78
|
-
if attr == "factor_names":
|
79
|
-
return [x.replace(f"-{factor_type}", "") for x in self.factor_names if mask_lambda(x)]
|
80
|
-
else:
|
81
|
-
factor_type_mask = [mask_lambda(x) for x in self.factor_names]
|
82
|
-
if attr == "factors":
|
83
|
-
return self.factors[factor_type_mask[1:]][:, factor_type_mask[1:]]
|
84
|
-
elif attr == "balance":
|
85
|
-
return self.balance[factor_type_mask]
|
86
|
-
elif attr == "classwise":
|
87
|
-
return self.classwise[:, factor_type_mask]
|
88
|
-
|
89
|
-
def plot(
|
90
|
-
self,
|
91
|
-
row_labels: list[Any] | NDArray[Any] | None = None,
|
92
|
-
col_labels: list[Any] | NDArray[Any] | None = None,
|
93
|
-
plot_classwise: bool = False,
|
94
|
-
factor_type: Literal["discrete", "continuous", "both"] = "discrete",
|
95
|
-
) -> Figure:
|
96
|
-
"""
|
97
|
-
Plot a heatmap of balance information
|
98
|
-
|
99
|
-
Parameters
|
100
|
-
----------
|
101
|
-
row_labels : ArrayLike or None, default None
|
102
|
-
List/Array containing the labels for rows in the histogram
|
103
|
-
col_labels : ArrayLike or None, default None
|
104
|
-
List/Array containing the labels for columns in the histogram
|
105
|
-
plot_classwise : bool, default False
|
106
|
-
Whether to plot per-class balance instead of global balance
|
107
|
-
factor_type : "discrete", "continuous", or "both", default "discrete"
|
108
|
-
Whether to plot discretized values, continuous values, or to include both
|
109
|
-
"""
|
110
|
-
if plot_classwise:
|
111
|
-
if row_labels is None:
|
112
|
-
row_labels = self.class_names
|
113
|
-
if col_labels is None:
|
114
|
-
col_labels = self._by_factor_type("factor_names", factor_type)
|
115
|
-
|
116
|
-
fig = heatmap(
|
117
|
-
self._by_factor_type("classwise", factor_type),
|
118
|
-
row_labels,
|
119
|
-
col_labels,
|
120
|
-
xlabel="Factors",
|
121
|
-
ylabel="Class",
|
122
|
-
cbarlabel="Normalized Mutual Information",
|
123
|
-
)
|
124
|
-
else:
|
125
|
-
# Combine balance and factors results
|
126
|
-
data = np.concatenate(
|
127
|
-
[
|
128
|
-
self._by_factor_type("balance", factor_type)[np.newaxis, 1:],
|
129
|
-
self._by_factor_type("factors", factor_type),
|
130
|
-
],
|
131
|
-
axis=0,
|
132
|
-
)
|
133
|
-
# Create a mask for the upper triangle of the symmetrical array, ignoring the diagonal
|
134
|
-
mask = np.triu(data + 1, k=0) < 1
|
135
|
-
# Finalize the data for the plot, last row is last factor x last factor so it gets dropped
|
136
|
-
heat_data = np.where(mask, np.nan, data)[:-1]
|
137
|
-
# Creating label array for heat map axes
|
138
|
-
heat_labels = self._by_factor_type("factor_names", factor_type)
|
139
|
-
|
140
|
-
if row_labels is None:
|
141
|
-
row_labels = heat_labels[:-1]
|
142
|
-
if col_labels is None:
|
143
|
-
col_labels = heat_labels[1:]
|
144
|
-
|
145
|
-
fig = heatmap(heat_data, row_labels, col_labels, cbarlabel="Normalized Mutual Information")
|
146
|
-
|
147
|
-
return fig
|
148
|
-
|
149
17
|
|
150
18
|
def _validate_num_neighbors(num_neighbors: int) -> int:
|
151
19
|
if not isinstance(num_neighbors, (int, float)):
|
@@ -200,25 +68,22 @@ def balance(
|
|
200
68
|
|
201
69
|
>>> bal = balance(metadata)
|
202
70
|
>>> bal.balance
|
203
|
-
array([0.
|
204
|
-
0. ])
|
71
|
+
array([1. , 0.249, 0.03 , 0.134, 0. , 0. ])
|
205
72
|
|
206
73
|
Return intra/interfactor balance (mutual information)
|
207
74
|
|
208
75
|
>>> bal.factors
|
209
|
-
array([[
|
210
|
-
[0.
|
211
|
-
[0.
|
212
|
-
[0.
|
213
|
-
[0.
|
76
|
+
array([[1. , 0.314, 0.269, 0.852, 0.367],
|
77
|
+
[0.314, 1. , 0.097, 0.158, 1.98 ],
|
78
|
+
[0.269, 0.097, 1. , 0.037, 0.015],
|
79
|
+
[0.852, 0.158, 0.037, 0.475, 0.255],
|
80
|
+
[0.367, 1.98 , 0.015, 0.255, 1.063]])
|
214
81
|
|
215
82
|
Return classwise balance (mutual information) of factors with individual class_labels
|
216
83
|
|
217
84
|
>>> bal.classwise
|
218
|
-
array([[0.
|
219
|
-
|
220
|
-
[0.9999982 , 0.2494567 , 0.02994455, 0.13363788, 0. ,
|
221
|
-
0. ]])
|
85
|
+
array([[1. , 0.249, 0.03 , 0.134, 0. , 0. ],
|
86
|
+
[1. , 0.249, 0.03 , 0.134, 0. , 0. ]])
|
222
87
|
|
223
88
|
|
224
89
|
See Also
|
@@ -227,6 +92,9 @@ def balance(
|
|
227
92
|
sklearn.feature_selection.mutual_info_regression
|
228
93
|
sklearn.metrics.mutual_info_score
|
229
94
|
"""
|
95
|
+
if not metadata.discrete_factor_names and not metadata.continuous_factor_names:
|
96
|
+
raise ValueError("No factors found in provided metadata.")
|
97
|
+
|
230
98
|
num_neighbors = _validate_num_neighbors(num_neighbors)
|
231
99
|
|
232
100
|
num_factors = metadata.total_num_factors
|
@@ -246,7 +114,7 @@ def balance(
|
|
246
114
|
data[:, idx],
|
247
115
|
discrete_features=is_discrete, # type: ignore
|
248
116
|
n_neighbors=num_neighbors,
|
249
|
-
random_state=
|
117
|
+
random_state=get_seed(),
|
250
118
|
)
|
251
119
|
else:
|
252
120
|
mi[idx, :] = mutual_info_classif(
|
@@ -254,7 +122,7 @@ def balance(
|
|
254
122
|
data[:, idx],
|
255
123
|
discrete_features=is_discrete, # type: ignore
|
256
124
|
n_neighbors=num_neighbors,
|
257
|
-
random_state=
|
125
|
+
random_state=get_seed(),
|
258
126
|
)
|
259
127
|
|
260
128
|
# Normalization via entropy
|
@@ -283,7 +151,7 @@ def balance(
|
|
283
151
|
tgt_bin[:, idx],
|
284
152
|
discrete_features=is_discrete, # type: ignore
|
285
153
|
n_neighbors=num_neighbors,
|
286
|
-
random_state=
|
154
|
+
random_state=get_seed(),
|
287
155
|
)
|
288
156
|
|
289
157
|
# Classwise normalization via entropy
|
@@ -2,118 +2,16 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
import contextlib
|
6
5
|
import math
|
7
|
-
from
|
8
|
-
from typing import Any, Literal
|
6
|
+
from typing import Literal
|
9
7
|
|
10
8
|
import numpy as np
|
11
|
-
from numpy.typing import NDArray
|
12
9
|
from scipy.spatial.distance import pdist, squareform
|
13
10
|
|
14
|
-
from dataeval.
|
11
|
+
from dataeval.outputs import CoverageOutput
|
12
|
+
from dataeval.outputs._base import set_metadata
|
15
13
|
from dataeval.typing import ArrayLike
|
16
|
-
from dataeval.utils._array import ensure_embeddings, flatten
|
17
|
-
|
18
|
-
with contextlib.suppress(ImportError):
|
19
|
-
from matplotlib.figure import Figure
|
20
|
-
|
21
|
-
|
22
|
-
def _plot(images: NDArray[Any], num_images: int) -> Figure:
|
23
|
-
"""
|
24
|
-
Creates a single plot of all of the provided images
|
25
|
-
|
26
|
-
Parameters
|
27
|
-
----------
|
28
|
-
images : NDArray
|
29
|
-
Array containing only the desired images to plot
|
30
|
-
|
31
|
-
Returns
|
32
|
-
-------
|
33
|
-
matplotlib.figure.Figure
|
34
|
-
Plot of all provided images
|
35
|
-
"""
|
36
|
-
import matplotlib.pyplot as plt
|
37
|
-
|
38
|
-
num_images = min(num_images, len(images))
|
39
|
-
|
40
|
-
if images.ndim == 4:
|
41
|
-
images = np.moveaxis(images, 1, -1)
|
42
|
-
elif images.ndim == 3:
|
43
|
-
images = np.repeat(images[:, :, :, np.newaxis], 3, axis=-1)
|
44
|
-
else:
|
45
|
-
raise ValueError(
|
46
|
-
f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
|
47
|
-
)
|
48
|
-
|
49
|
-
rows = int(np.ceil(num_images / 3))
|
50
|
-
fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
|
51
|
-
|
52
|
-
if rows == 1:
|
53
|
-
for j in range(3):
|
54
|
-
if j >= len(images):
|
55
|
-
continue
|
56
|
-
axs[j].imshow(images[j])
|
57
|
-
axs[j].axis("off")
|
58
|
-
else:
|
59
|
-
for i in range(rows):
|
60
|
-
for j in range(3):
|
61
|
-
i_j = i * 3 + j
|
62
|
-
if i_j >= len(images):
|
63
|
-
continue
|
64
|
-
axs[i, j].imshow(images[i_j])
|
65
|
-
axs[i, j].axis("off")
|
66
|
-
|
67
|
-
fig.tight_layout()
|
68
|
-
return fig
|
69
|
-
|
70
|
-
|
71
|
-
@dataclass(frozen=True)
|
72
|
-
class CoverageOutput(Output):
|
73
|
-
"""
|
74
|
-
Output class for :func:`.coverage` :term:`bias<Bias>` metric.
|
75
|
-
|
76
|
-
Attributes
|
77
|
-
----------
|
78
|
-
uncovered_indices : NDArray[np.intp]
|
79
|
-
Array of uncovered indices
|
80
|
-
critical_value_radii : NDArray[np.float64]
|
81
|
-
Array of critical value radii
|
82
|
-
coverage_radius : float
|
83
|
-
Radius for :term:`coverage<Coverage>`
|
84
|
-
"""
|
85
|
-
|
86
|
-
uncovered_indices: NDArray[np.intp]
|
87
|
-
critical_value_radii: NDArray[np.float64]
|
88
|
-
coverage_radius: float
|
89
|
-
|
90
|
-
def plot(self, images: ArrayLike, top_k: int = 6) -> Figure:
|
91
|
-
"""
|
92
|
-
Plot the top k images together for visualization
|
93
|
-
|
94
|
-
Parameters
|
95
|
-
----------
|
96
|
-
images : ArrayLike
|
97
|
-
Original images (not embeddings) in (N, C, H, W) or (N, H, W) format
|
98
|
-
top_k : int, default 6
|
99
|
-
Number of images to plot (plotting assumes groups of 3)
|
100
|
-
|
101
|
-
Returns
|
102
|
-
-------
|
103
|
-
matplotlib.figure.Figure
|
104
|
-
"""
|
105
|
-
|
106
|
-
# Determine which images to plot
|
107
|
-
highest_uncovered_indices = self.uncovered_indices[:top_k]
|
108
|
-
|
109
|
-
# Grab the images
|
110
|
-
images = to_numpy(images)
|
111
|
-
selected_images = images[highest_uncovered_indices]
|
112
|
-
|
113
|
-
# Plot the images
|
114
|
-
fig = _plot(selected_images, top_k)
|
115
|
-
|
116
|
-
return fig
|
14
|
+
from dataeval.utils._array import ensure_embeddings, flatten
|
117
15
|
|
118
16
|
|
119
17
|
@set_metadata
|