dataeval 0.72.1__py3-none-any.whl → 0.73.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +4 -4
- dataeval/detectors/__init__.py +4 -3
- dataeval/detectors/drift/__init__.py +9 -10
- dataeval/{_internal/detectors → detectors}/drift/base.py +39 -91
- dataeval/{_internal/detectors → detectors}/drift/cvm.py +4 -3
- dataeval/{_internal/detectors → detectors}/drift/ks.py +4 -3
- dataeval/{_internal/detectors → detectors}/drift/mmd.py +23 -25
- dataeval/{_internal/detectors → detectors}/drift/torch.py +13 -11
- dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +7 -5
- dataeval/detectors/drift/updates.py +61 -0
- dataeval/detectors/linters/__init__.py +3 -3
- dataeval/{_internal/detectors → detectors/linters}/clusterer.py +41 -39
- dataeval/{_internal/detectors → detectors/linters}/duplicates.py +19 -9
- dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
- dataeval/{_internal/detectors → detectors/linters}/outliers.py +14 -21
- dataeval/detectors/ood/__init__.py +6 -6
- dataeval/{_internal/detectors → detectors}/ood/ae.py +20 -12
- dataeval/detectors/ood/aegmm.py +66 -0
- dataeval/{_internal/detectors → detectors}/ood/base.py +33 -21
- dataeval/{_internal/detectors → detectors}/ood/llr.py +43 -33
- dataeval/detectors/ood/metadata_ks_compare.py +99 -0
- dataeval/detectors/ood/metadata_least_likely.py +119 -0
- dataeval/detectors/ood/metadata_ood_mi.py +92 -0
- dataeval/{_internal/detectors → detectors}/ood/vae.py +23 -17
- dataeval/detectors/ood/vaegmm.py +75 -0
- dataeval/interop.py +56 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +4 -4
- dataeval/{_internal/metrics → metrics/bias}/balance.py +75 -13
- dataeval/{_internal/metrics → metrics/bias}/coverage.py +41 -7
- dataeval/{_internal/metrics → metrics/bias}/diversity.py +75 -18
- dataeval/metrics/bias/metadata.py +358 -0
- dataeval/{_internal/metrics → metrics/bias}/parity.py +54 -44
- dataeval/metrics/estimators/__init__.py +3 -3
- dataeval/{_internal/metrics → metrics/estimators}/ber.py +25 -22
- dataeval/{_internal/metrics → metrics/estimators}/divergence.py +11 -12
- dataeval/{_internal/metrics → metrics/estimators}/uap.py +5 -3
- dataeval/metrics/stats/__init__.py +7 -7
- dataeval/{_internal/metrics → metrics}/stats/base.py +59 -35
- dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +18 -14
- dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +18 -16
- dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +9 -7
- dataeval/metrics/stats/hashstats.py +156 -0
- dataeval/{_internal/metrics → metrics}/stats/labelstats.py +5 -3
- dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +9 -8
- dataeval/{_internal/metrics → metrics}/stats/visualstats.py +10 -9
- dataeval/{_internal/output.py → output.py} +26 -6
- dataeval/utils/__init__.py +8 -3
- dataeval/utils/image.py +71 -0
- dataeval/utils/lazy.py +26 -0
- dataeval/utils/metadata.py +258 -0
- dataeval/utils/shared.py +151 -0
- dataeval/{_internal → utils}/split_dataset.py +98 -33
- dataeval/utils/tensorflow/__init__.py +7 -6
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +8 -2
- dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +28 -18
- dataeval/{_internal/models/tensorflow/pixelcnn.py → utils/tensorflow/_internal/models.py} +387 -97
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +15 -6
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +84 -85
- dataeval/utils/tensorflow/loss/__init__.py +6 -2
- dataeval/utils/torch/__init__.py +7 -3
- dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
- dataeval/{_internal → utils/torch}/datasets.py +48 -42
- dataeval/utils/torch/models.py +138 -0
- dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +7 -136
- dataeval/{_internal → utils/torch}/utils.py +3 -1
- dataeval/workflows/__init__.py +1 -1
- dataeval/{_internal/workflows → workflows}/sufficiency.py +39 -34
- {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/METADATA +4 -3
- dataeval-0.73.0.dist-info/RECORD +73 -0
- dataeval/_internal/detectors/__init__.py +0 -0
- dataeval/_internal/detectors/drift/__init__.py +0 -0
- dataeval/_internal/detectors/ood/__init__.py +0 -0
- dataeval/_internal/detectors/ood/aegmm.py +0 -78
- dataeval/_internal/detectors/ood/vaegmm.py +0 -89
- dataeval/_internal/interop.py +0 -49
- dataeval/_internal/metrics/__init__.py +0 -0
- dataeval/_internal/metrics/stats/hashstats.py +0 -75
- dataeval/_internal/metrics/utils.py +0 -447
- dataeval/_internal/models/__init__.py +0 -0
- dataeval/_internal/models/pytorch/__init__.py +0 -0
- dataeval/_internal/models/pytorch/utils.py +0 -67
- dataeval/_internal/models/tensorflow/__init__.py +0 -0
- dataeval/_internal/models/tensorflow/autoencoder.py +0 -320
- dataeval/_internal/workflows/__init__.py +0 -0
- dataeval/detectors/drift/kernels/__init__.py +0 -10
- dataeval/detectors/drift/updates/__init__.py +0 -8
- dataeval/utils/tensorflow/models/__init__.py +0 -9
- dataeval/utils/tensorflow/recon/__init__.py +0 -3
- dataeval/utils/torch/datasets/__init__.py +0 -12
- dataeval/utils/torch/models/__init__.py +0 -11
- dataeval/utils/torch/trainer/__init__.py +0 -7
- dataeval-0.72.1.dist-info/RECORD +0 -81
- {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/WHEEL +0 -0
dataeval/metrics/__init__.py
CHANGED
@@ -3,6 +3,6 @@ Metrics are a way to measure the performance of your models or datasets that
|
|
3
3
|
can then be analyzed in the context of a given problem.
|
4
4
|
"""
|
5
5
|
|
6
|
-
from . import bias, estimators, stats
|
6
|
+
from dataeval.metrics import bias, estimators, stats
|
7
7
|
|
8
8
|
__all__ = ["bias", "estimators", "stats"]
|
@@ -3,10 +3,10 @@ Bias metrics check for skewed or imbalanced datasets and incomplete feature
|
|
3
3
|
representation which may impact model performance.
|
4
4
|
"""
|
5
5
|
|
6
|
-
from dataeval.
|
7
|
-
from dataeval.
|
8
|
-
from dataeval.
|
9
|
-
from dataeval.
|
6
|
+
from dataeval.metrics.bias.balance import BalanceOutput, balance
|
7
|
+
from dataeval.metrics.bias.coverage import CoverageOutput, coverage
|
8
|
+
from dataeval.metrics.bias.diversity import DiversityOutput, diversity
|
9
|
+
from dataeval.metrics.bias.parity import ParityOutput, label_parity, parity
|
10
10
|
|
11
11
|
__all__ = [
|
12
12
|
"balance",
|
@@ -1,35 +1,98 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
__all__ = ["BalanceOutput", "balance"]
|
4
|
+
|
5
|
+
import contextlib
|
3
6
|
import warnings
|
4
7
|
from dataclasses import dataclass
|
5
|
-
from typing import Mapping
|
8
|
+
from typing import Any, Mapping
|
6
9
|
|
7
10
|
import numpy as np
|
8
11
|
from numpy.typing import ArrayLike, NDArray
|
9
12
|
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
|
10
13
|
|
11
|
-
from dataeval.
|
12
|
-
from dataeval.
|
14
|
+
from dataeval.metrics.bias.metadata import entropy, heatmap, preprocess_metadata
|
15
|
+
from dataeval.output import OutputMetadata, set_metadata
|
16
|
+
|
17
|
+
with contextlib.suppress(ImportError):
|
18
|
+
from matplotlib.figure import Figure
|
13
19
|
|
14
20
|
|
15
21
|
@dataclass(frozen=True)
|
16
22
|
class BalanceOutput(OutputMetadata):
|
17
23
|
"""
|
18
|
-
Output class for :func:`balance`
|
24
|
+
Output class for :func:`balance` bias metric
|
19
25
|
|
20
26
|
Attributes
|
21
27
|
----------
|
22
28
|
balance : NDArray[np.float64]
|
23
|
-
Estimate of
|
29
|
+
Estimate of mutual information between metadata factors and class label
|
24
30
|
factors : NDArray[np.float64]
|
25
31
|
Estimate of inter/intra-factor mutual information
|
26
32
|
classwise : NDArray[np.float64]
|
27
33
|
Estimate of mutual information between metadata factors and individual class labels
|
34
|
+
class_list: NDArray
|
35
|
+
Array of the class labels present in the dataset
|
36
|
+
metadata_names: list[str]
|
37
|
+
Names of each metadata factor
|
28
38
|
"""
|
29
39
|
|
30
40
|
balance: NDArray[np.float64]
|
31
41
|
factors: NDArray[np.float64]
|
32
42
|
classwise: NDArray[np.float64]
|
43
|
+
class_list: NDArray[Any]
|
44
|
+
metadata_names: list[str]
|
45
|
+
|
46
|
+
def plot(
|
47
|
+
self,
|
48
|
+
row_labels: list[Any] | NDArray[Any] | None = None,
|
49
|
+
col_labels: list[Any] | NDArray[Any] | None = None,
|
50
|
+
plot_classwise: bool = False,
|
51
|
+
) -> Figure:
|
52
|
+
"""
|
53
|
+
Plot a heatmap of balance information
|
54
|
+
|
55
|
+
Parameters
|
56
|
+
----------
|
57
|
+
row_labels : ArrayLike | None, default None
|
58
|
+
List/Array containing the labels for rows in the histogram
|
59
|
+
col_labels : ArrayLike | None, default None
|
60
|
+
List/Array containing the labels for columns in the histogram
|
61
|
+
plot_classwise : bool, default False
|
62
|
+
Whether to plot per-class balance instead of global balance
|
63
|
+
"""
|
64
|
+
if plot_classwise:
|
65
|
+
if row_labels is None:
|
66
|
+
row_labels = self.class_list
|
67
|
+
if col_labels is None:
|
68
|
+
col_labels = np.concatenate((["class"], self.metadata_names))
|
69
|
+
|
70
|
+
fig = heatmap(
|
71
|
+
self.classwise,
|
72
|
+
row_labels,
|
73
|
+
col_labels,
|
74
|
+
xlabel="Factors",
|
75
|
+
ylabel="Class",
|
76
|
+
cbarlabel="Normalized Mutual Information",
|
77
|
+
)
|
78
|
+
else:
|
79
|
+
# Combine balance and factors results
|
80
|
+
data = np.concatenate([self.balance[np.newaxis, 1:], self.factors], axis=0)
|
81
|
+
# Create a mask for the upper triangle of the symmetrical array, ignoring the diagonal
|
82
|
+
mask = np.triu(data + 1, k=0) < 1
|
83
|
+
# Finalize the data for the plot, last row is last factor x last factor so it gets dropped
|
84
|
+
heat_data = np.where(mask, np.nan, data)[:-1]
|
85
|
+
# Creating label array for heat map axes
|
86
|
+
heat_labels = np.concatenate((["class"], self.metadata_names))
|
87
|
+
|
88
|
+
if row_labels is None:
|
89
|
+
row_labels = heat_labels[:-1]
|
90
|
+
if col_labels is None:
|
91
|
+
col_labels = heat_labels[1:]
|
92
|
+
|
93
|
+
fig = heatmap(heat_data, row_labels, col_labels, cbarlabel="Normalized Mutual Information")
|
94
|
+
|
95
|
+
return fig
|
33
96
|
|
34
97
|
|
35
98
|
def validate_num_neighbors(num_neighbors: int) -> int:
|
@@ -55,7 +118,7 @@ def validate_num_neighbors(num_neighbors: int) -> int:
|
|
55
118
|
@set_metadata("dataeval.metrics")
|
56
119
|
def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neighbors: int = 5) -> BalanceOutput:
|
57
120
|
"""
|
58
|
-
|
121
|
+
Mutual information (MI) between factors (class label, metadata, label/image properties)
|
59
122
|
|
60
123
|
Parameters
|
61
124
|
----------
|
@@ -70,7 +133,7 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
|
|
70
133
|
Returns
|
71
134
|
-------
|
72
135
|
BalanceOutput
|
73
|
-
(num_factors+1) x (num_factors+1) estimate of
|
136
|
+
(num_factors+1) x (num_factors+1) estimate of mutual information
|
74
137
|
between num_factors metadata factors and class label. Symmetry is enforced.
|
75
138
|
|
76
139
|
Note
|
@@ -83,7 +146,7 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
|
|
83
146
|
|
84
147
|
Example
|
85
148
|
-------
|
86
|
-
Return
|
149
|
+
Return balance (mutual information) of factors with class_labels
|
87
150
|
|
88
151
|
>>> bal = balance(class_labels, metadata)
|
89
152
|
>>> bal.balance
|
@@ -109,7 +172,7 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
|
|
109
172
|
sklearn.metrics.mutual_info_score
|
110
173
|
"""
|
111
174
|
num_neighbors = validate_num_neighbors(num_neighbors)
|
112
|
-
data, names, is_categorical = preprocess_metadata(class_labels, metadata)
|
175
|
+
data, names, is_categorical, unique_labels = preprocess_metadata(class_labels, metadata)
|
113
176
|
num_factors = len(names)
|
114
177
|
mi = np.empty((num_factors, num_factors))
|
115
178
|
mi[:] = np.nan
|
@@ -143,8 +206,7 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
|
|
143
206
|
|
144
207
|
# unique class labels
|
145
208
|
class_idx = names.index("class_label")
|
146
|
-
|
147
|
-
u_cls = np.unique(class_data)
|
209
|
+
u_cls = np.unique(data[:, class_idx])
|
148
210
|
num_classes = len(u_cls)
|
149
211
|
|
150
212
|
# assume class is a factor
|
@@ -154,7 +216,7 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
|
|
154
216
|
# categorical variables, excluding class label
|
155
217
|
cat_mask = np.concatenate((is_categorical[:class_idx], is_categorical[(class_idx + 1) :]), axis=0).astype(int)
|
156
218
|
|
157
|
-
tgt_bin = np.stack([
|
219
|
+
tgt_bin = np.stack([data[:, class_idx] == cls for cls in u_cls]).T.astype(int)
|
158
220
|
ent_tgt_bin = entropy(
|
159
221
|
tgt_bin, names=[str(idx) for idx in range(num_classes)], is_categorical=[True for idx in range(num_classes)]
|
160
222
|
)
|
@@ -174,4 +236,4 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
|
|
174
236
|
norm_factor = 0.5 * np.add.outer(ent_tgt_bin, ent_all) + 1e-6
|
175
237
|
classwise = classwise_mi / norm_factor
|
176
238
|
|
177
|
-
return BalanceOutput(balance, factors, classwise)
|
239
|
+
return BalanceOutput(balance, factors, classwise, unique_labels, list(metadata.keys()))
|
@@ -1,16 +1,23 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
__all__ = ["CoverageOutput", "coverage"]
|
4
|
+
|
5
|
+
import contextlib
|
3
6
|
import math
|
4
7
|
from dataclasses import dataclass
|
5
|
-
from typing import Literal
|
8
|
+
from typing import Any, Literal
|
6
9
|
|
7
10
|
import numpy as np
|
8
11
|
from numpy.typing import ArrayLike, NDArray
|
9
12
|
from scipy.spatial.distance import pdist, squareform
|
10
13
|
|
11
|
-
from dataeval.
|
12
|
-
from dataeval.
|
13
|
-
from dataeval.
|
14
|
+
from dataeval.interop import to_numpy
|
15
|
+
from dataeval.metrics.bias.metadata import coverage_plot
|
16
|
+
from dataeval.output import OutputMetadata, set_metadata
|
17
|
+
from dataeval.utils.shared import flatten
|
18
|
+
|
19
|
+
with contextlib.suppress(ImportError):
|
20
|
+
from matplotlib.figure import Figure
|
14
21
|
|
15
22
|
|
16
23
|
@dataclass(frozen=True)
|
@@ -32,13 +39,40 @@ class CoverageOutput(OutputMetadata):
|
|
32
39
|
radii: NDArray[np.float64]
|
33
40
|
critical_value: float
|
34
41
|
|
42
|
+
def plot(
|
43
|
+
self,
|
44
|
+
images: NDArray[Any],
|
45
|
+
top_k: int = 6,
|
46
|
+
) -> Figure:
|
47
|
+
"""
|
48
|
+
Plot the top k images together for visualization
|
49
|
+
|
50
|
+
Parameters
|
51
|
+
----------
|
52
|
+
images : ArrayLike
|
53
|
+
Original images (not embeddings) in (N, C, H, W) or (N, H, W) format
|
54
|
+
top_k : int, default 6
|
55
|
+
Number of images to plot (plotting assumes groups of 3)
|
56
|
+
"""
|
57
|
+
# Determine which images to plot
|
58
|
+
highest_uncovered_indices = self.indices[:top_k]
|
59
|
+
|
60
|
+
# Grab the images
|
61
|
+
images = to_numpy(images)
|
62
|
+
selected_images = images[highest_uncovered_indices]
|
63
|
+
|
64
|
+
# Plot the images
|
65
|
+
fig = coverage_plot(selected_images, top_k)
|
66
|
+
|
67
|
+
return fig
|
68
|
+
|
35
69
|
|
36
|
-
@set_metadata(
|
70
|
+
@set_metadata()
|
37
71
|
def coverage(
|
38
72
|
embeddings: ArrayLike,
|
39
73
|
radius_type: Literal["adaptive", "naive"] = "adaptive",
|
40
74
|
k: int = 20,
|
41
|
-
percent:
|
75
|
+
percent: float = 0.01,
|
42
76
|
) -> CoverageOutput:
|
43
77
|
"""
|
44
78
|
Class for evaluating :term:`coverage<Coverage>` and identifying images/samples that are in undercovered regions.
|
@@ -53,7 +87,7 @@ def coverage(
|
|
53
87
|
k: int, default 20
|
54
88
|
Number of observations required in order to be covered.
|
55
89
|
[1] suggests that a minimum of 20-50 samples is necessary.
|
56
|
-
percent:
|
90
|
+
percent: float, default 0.01
|
57
91
|
Percent of observations to be considered uncovered. Only applies to adaptive radius.
|
58
92
|
|
59
93
|
Returns
|
@@ -1,13 +1,27 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
__all__ = ["DiversityOutput", "diversity"]
|
4
|
+
|
5
|
+
import contextlib
|
3
6
|
from dataclasses import dataclass
|
4
|
-
from typing import Literal, Mapping
|
7
|
+
from typing import Any, Literal, Mapping
|
5
8
|
|
6
9
|
import numpy as np
|
7
10
|
from numpy.typing import ArrayLike, NDArray
|
8
11
|
|
9
|
-
from dataeval.
|
10
|
-
|
12
|
+
from dataeval.metrics.bias.metadata import (
|
13
|
+
diversity_bar_plot,
|
14
|
+
entropy,
|
15
|
+
get_counts,
|
16
|
+
get_num_bins,
|
17
|
+
heatmap,
|
18
|
+
preprocess_metadata,
|
19
|
+
)
|
20
|
+
from dataeval.output import OutputMetadata, set_metadata
|
21
|
+
from dataeval.utils.shared import get_method
|
22
|
+
|
23
|
+
with contextlib.suppress(ImportError):
|
24
|
+
from matplotlib.figure import Figure
|
11
25
|
|
12
26
|
|
13
27
|
@dataclass(frozen=True)
|
@@ -21,18 +35,66 @@ class DiversityOutput(OutputMetadata):
|
|
21
35
|
:term:`Diversity` index for classes and factors
|
22
36
|
classwise : NDArray[np.float64]
|
23
37
|
Classwise diversity index [n_class x n_factor]
|
38
|
+
class_list: NDArray[np.int64]
|
39
|
+
Class labels for each value in the dataset
|
40
|
+
metadata_names: list[str]
|
41
|
+
Names of each metadata factor
|
24
42
|
"""
|
25
43
|
|
26
44
|
diversity_index: NDArray[np.float64]
|
27
45
|
classwise: NDArray[np.float64]
|
46
|
+
class_list: NDArray[Any]
|
47
|
+
metadata_names: list[str]
|
48
|
+
method: Literal["shannon", "simpson"]
|
49
|
+
|
50
|
+
def plot(
|
51
|
+
self,
|
52
|
+
row_labels: list[Any] | NDArray[Any] | None = None,
|
53
|
+
col_labels: list[Any] | NDArray[Any] | None = None,
|
54
|
+
plot_classwise: bool = False,
|
55
|
+
) -> Figure:
|
56
|
+
"""
|
57
|
+
Plot a heatmap of diversity information
|
58
|
+
|
59
|
+
Parameters
|
60
|
+
----------
|
61
|
+
row_labels : ArrayLike | None, default None
|
62
|
+
List/Array containing the labels for rows in the histogram
|
63
|
+
col_labels : ArrayLike | None, default None
|
64
|
+
List/Array containing the labels for columns in the histogram
|
65
|
+
plot_classwise : bool, default False
|
66
|
+
Whether to plot per-class balance instead of global balance
|
67
|
+
"""
|
68
|
+
if plot_classwise:
|
69
|
+
if row_labels is None:
|
70
|
+
row_labels = self.class_list
|
71
|
+
if col_labels is None:
|
72
|
+
col_labels = self.metadata_names
|
73
|
+
|
74
|
+
fig = heatmap(
|
75
|
+
self.classwise,
|
76
|
+
row_labels,
|
77
|
+
col_labels,
|
78
|
+
xlabel="Factors",
|
79
|
+
ylabel="Class",
|
80
|
+
cbarlabel=f"Normalized {self.method.title()} Index",
|
81
|
+
)
|
82
|
+
|
83
|
+
else:
|
84
|
+
# Creating label array for heat map axes
|
85
|
+
heat_labels = np.concatenate((["class"], self.metadata_names))
|
86
|
+
|
87
|
+
fig = diversity_bar_plot(heat_labels, self.diversity_index)
|
88
|
+
|
89
|
+
return fig
|
28
90
|
|
29
91
|
|
30
92
|
def diversity_shannon(
|
31
|
-
data: NDArray,
|
93
|
+
data: NDArray[Any],
|
32
94
|
names: list[str],
|
33
95
|
is_categorical: list[bool],
|
34
96
|
subset_mask: NDArray[np.bool_] | None = None,
|
35
|
-
) -> NDArray:
|
97
|
+
) -> NDArray[np.float64]:
|
36
98
|
"""
|
37
99
|
Compute :term:`diversity<Diversity>` for discrete/categorical variables and, through standard
|
38
100
|
histogram binning, for continuous variables.
|
@@ -79,11 +141,11 @@ def diversity_shannon(
|
|
79
141
|
|
80
142
|
|
81
143
|
def diversity_simpson(
|
82
|
-
data: NDArray,
|
144
|
+
data: NDArray[Any],
|
83
145
|
names: list[str],
|
84
146
|
is_categorical: list[bool],
|
85
147
|
subset_mask: NDArray[np.bool_] | None = None,
|
86
|
-
) -> NDArray:
|
148
|
+
) -> NDArray[np.float64]:
|
87
149
|
"""
|
88
150
|
Compute :term:`diversity<Diversity>` for discrete/categorical variables and, through standard
|
89
151
|
histogram binning, for continuous variables.
|
@@ -139,10 +201,7 @@ def diversity_simpson(
|
|
139
201
|
return ev_index
|
140
202
|
|
141
203
|
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
@set_metadata("dataeval.metrics")
|
204
|
+
@set_metadata()
|
146
205
|
def diversity(
|
147
206
|
class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], method: Literal["shannon", "simpson"] = "simpson"
|
148
207
|
) -> DiversityOutput:
|
@@ -202,20 +261,18 @@ def diversity(
|
|
202
261
|
--------
|
203
262
|
numpy.histogram
|
204
263
|
"""
|
205
|
-
diversity_fn = get_method(
|
206
|
-
data, names, is_categorical = preprocess_metadata(class_labels, metadata)
|
264
|
+
diversity_fn = get_method({"simpson": diversity_simpson, "shannon": diversity_shannon}, method)
|
265
|
+
data, names, is_categorical, unique_labels = preprocess_metadata(class_labels, metadata)
|
207
266
|
diversity_index = diversity_fn(data, names, is_categorical, None).astype(np.float64)
|
208
267
|
|
209
268
|
class_idx = names.index("class_label")
|
210
|
-
|
211
|
-
|
212
|
-
u_classes = np.unique(class_lbl)
|
269
|
+
u_classes = np.unique(data[:, class_idx])
|
213
270
|
num_factors = len(names)
|
214
271
|
diversity = np.empty((len(u_classes), num_factors))
|
215
272
|
diversity[:] = np.nan
|
216
273
|
for idx, cls in enumerate(u_classes):
|
217
|
-
subset_mask =
|
274
|
+
subset_mask = data[:, class_idx] == cls
|
218
275
|
diversity[idx, :] = diversity_fn(data, names, is_categorical, subset_mask)
|
219
276
|
div_no_class = np.concatenate((diversity[:, :class_idx], diversity[:, (class_idx + 1) :]), axis=1)
|
220
277
|
|
221
|
-
return DiversityOutput(diversity_index, div_no_class)
|
278
|
+
return DiversityOutput(diversity_index, div_no_class, unique_labels, list(metadata.keys()), method)
|