dataeval 0.65.0__py3-none-any.whl → 0.66.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +13 -9
- dataeval/_internal/detectors/clusterer.py +24 -22
- dataeval/_internal/detectors/drift/base.py +206 -26
- dataeval/_internal/detectors/drift/cvm.py +25 -23
- dataeval/_internal/detectors/drift/ks.py +28 -25
- dataeval/_internal/detectors/drift/mmd.py +30 -29
- dataeval/_internal/detectors/drift/torch.py +66 -58
- dataeval/_internal/detectors/drift/uncertainty.py +28 -28
- dataeval/_internal/detectors/duplicates.py +28 -18
- dataeval/_internal/detectors/ood/ae.py +15 -29
- dataeval/_internal/detectors/ood/aegmm.py +33 -27
- dataeval/_internal/detectors/ood/base.py +61 -43
- dataeval/_internal/detectors/ood/llr.py +27 -24
- dataeval/_internal/detectors/ood/vae.py +32 -31
- dataeval/_internal/detectors/ood/vaegmm.py +34 -28
- dataeval/_internal/detectors/{linter.py → outliers.py} +33 -27
- dataeval/_internal/flags.py +5 -3
- dataeval/_internal/interop.py +4 -2
- dataeval/_internal/metrics/balance.py +33 -4
- dataeval/_internal/metrics/ber.py +6 -4
- dataeval/_internal/metrics/diversity.py +45 -12
- dataeval/_internal/metrics/parity.py +114 -26
- dataeval/_internal/metrics/stats.py +154 -16
- dataeval/_internal/metrics/uap.py +28 -2
- dataeval/_internal/metrics/utils.py +20 -18
- dataeval/_internal/models/pytorch/autoencoder.py +127 -22
- dataeval/_internal/models/tensorflow/autoencoder.py +33 -30
- dataeval/_internal/models/tensorflow/gmm.py +4 -2
- dataeval/_internal/models/tensorflow/losses.py +15 -11
- dataeval/_internal/models/tensorflow/pixelcnn.py +19 -18
- dataeval/_internal/models/tensorflow/trainer.py +8 -6
- dataeval/_internal/models/tensorflow/utils.py +21 -19
- dataeval/_internal/output.py +13 -10
- dataeval/_internal/utils.py +5 -3
- dataeval/_internal/workflows/sufficiency.py +42 -30
- dataeval/detectors/__init__.py +6 -25
- dataeval/detectors/drift/__init__.py +16 -0
- dataeval/detectors/drift/kernels/__init__.py +6 -0
- dataeval/detectors/drift/updates/__init__.py +3 -0
- dataeval/detectors/linters/__init__.py +5 -0
- dataeval/detectors/ood/__init__.py +11 -0
- dataeval/metrics/__init__.py +2 -26
- dataeval/metrics/bias/__init__.py +14 -0
- dataeval/metrics/estimators/__init__.py +9 -0
- dataeval/metrics/stats/__init__.py +6 -0
- dataeval/tensorflow/__init__.py +3 -0
- dataeval/tensorflow/loss/__init__.py +3 -0
- dataeval/tensorflow/models/__init__.py +5 -0
- dataeval/tensorflow/recon/__init__.py +3 -0
- dataeval/torch/__init__.py +3 -0
- dataeval/{models/torch → torch/models}/__init__.py +1 -2
- dataeval/torch/trainer/__init__.py +3 -0
- dataeval/utils/__init__.py +3 -6
- dataeval/workflows/__init__.py +2 -4
- {dataeval-0.65.0.dist-info → dataeval-0.66.0.dist-info}/METADATA +1 -1
- dataeval-0.66.0.dist-info/RECORD +72 -0
- dataeval/models/__init__.py +0 -15
- dataeval/models/tensorflow/__init__.py +0 -6
- dataeval-0.65.0.dist-info/RECORD +0 -60
- {dataeval-0.65.0.dist-info → dataeval-0.66.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.65.0.dist-info → dataeval-0.66.0.dist-info}/WHEEL +0 -0
@@ -6,10 +6,13 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
|
|
6
6
|
Licensed under Apache Software License (Apache 2.0)
|
7
7
|
"""
|
8
8
|
|
9
|
+
from __future__ import annotations
|
10
|
+
|
9
11
|
from typing import Callable
|
10
12
|
|
11
13
|
import keras
|
12
14
|
import numpy as np
|
15
|
+
import tensorflow as tf
|
13
16
|
from numpy.typing import ArrayLike
|
14
17
|
|
15
18
|
from dataeval._internal.detectors.ood.base import OODGMMBase, OODScore
|
@@ -21,17 +24,18 @@ from dataeval._internal.models.tensorflow.utils import predict_batch
|
|
21
24
|
|
22
25
|
|
23
26
|
class OOD_VAEGMM(OODGMMBase):
|
24
|
-
|
25
|
-
|
26
|
-
VAE with Gaussian Mixture Model based outlier detector.
|
27
|
+
"""
|
28
|
+
VAE with Gaussian Mixture Model based outlier detector.
|
27
29
|
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
30
|
+
Parameters
|
31
|
+
----------
|
32
|
+
model : VAEGMM
|
33
|
+
A VAEGMM model.
|
34
|
+
samples
|
35
|
+
Number of samples sampled to evaluate each instance.
|
36
|
+
"""
|
37
|
+
|
38
|
+
def __init__(self, model: VAEGMM, samples: int = 10) -> None:
|
35
39
|
super().__init__(model)
|
36
40
|
self.samples = samples
|
37
41
|
|
@@ -39,35 +43,37 @@ class OOD_VAEGMM(OODGMMBase):
|
|
39
43
|
self,
|
40
44
|
x_ref: ArrayLike,
|
41
45
|
threshold_perc: float = 100.0,
|
42
|
-
loss_fn: Callable =
|
46
|
+
loss_fn: Callable[..., tf.Tensor] | None = None,
|
43
47
|
optimizer: keras.optimizers.Optimizer = keras.optimizers.Adam,
|
44
48
|
epochs: int = 20,
|
45
49
|
batch_size: int = 64,
|
46
50
|
verbose: bool = True,
|
47
51
|
) -> None:
|
52
|
+
if loss_fn is None:
|
53
|
+
loss_fn = LossGMM(elbo=Elbo(0.05))
|
54
|
+
super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
55
|
+
|
56
|
+
def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
|
48
57
|
"""
|
49
|
-
|
58
|
+
Compute the out-of-distribution (OOD) score for a given dataset.
|
50
59
|
|
51
60
|
Parameters
|
52
61
|
----------
|
53
62
|
X : ArrayLike
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
Loss function used for training.
|
59
|
-
optimizer : keras.optimizers.Optimizer, default keras.optimizers.Adam
|
60
|
-
Optimizer used for training.
|
61
|
-
epochs : int, default 20
|
62
|
-
Number of training epochs.
|
63
|
-
batch_size : int, default 64
|
64
|
-
Batch size used for training.
|
65
|
-
verbose : bool, default True
|
66
|
-
Whether to print training progress.
|
67
|
-
"""
|
68
|
-
super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
63
|
+
Input data to score.
|
64
|
+
batch_size : int, default 1e10
|
65
|
+
Number of instances to process in each batch.
|
66
|
+
Use a smaller batch size if your dataset is large or if you encounter memory issues.
|
69
67
|
|
70
|
-
|
68
|
+
Returns
|
69
|
+
-------
|
70
|
+
OODScore
|
71
|
+
An object containing the instance-level OOD score.
|
72
|
+
|
73
|
+
Note
|
74
|
+
----
|
75
|
+
This model does not produce a feature level score like the OOD_AE or OOD_VAE models.
|
76
|
+
"""
|
71
77
|
self._validate(X := to_numpy(X))
|
72
78
|
|
73
79
|
# draw samples from latent space
|
@@ -1,17 +1,18 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from dataclasses import dataclass
|
2
|
-
from typing import
|
4
|
+
from typing import Iterable, Literal
|
3
5
|
|
4
6
|
import numpy as np
|
5
7
|
from numpy.typing import ArrayLike, NDArray
|
6
8
|
|
7
|
-
from dataeval._internal.flags import verify_supported
|
9
|
+
from dataeval._internal.flags import ImageStat, to_distinct, verify_supported
|
10
|
+
from dataeval._internal.metrics.stats import StatsOutput, imagestats
|
8
11
|
from dataeval._internal.output import OutputMetadata, set_metadata
|
9
|
-
from dataeval.flags import ImageStat
|
10
|
-
from dataeval.metrics import imagestats
|
11
12
|
|
12
13
|
|
13
14
|
@dataclass(frozen=True)
|
14
|
-
class
|
15
|
+
class OutliersOutput(OutputMetadata):
|
15
16
|
"""
|
16
17
|
Attributes
|
17
18
|
----------
|
@@ -20,11 +21,11 @@ class LinterOutput(OutputMetadata):
|
|
20
21
|
the issues and calculated values for the given index.
|
21
22
|
"""
|
22
23
|
|
23
|
-
issues:
|
24
|
+
issues: dict[int, dict[str, float]]
|
24
25
|
|
25
26
|
|
26
27
|
def _get_outlier_mask(
|
27
|
-
values: NDArray, method: Literal["zscore", "modzscore", "iqr"], threshold:
|
28
|
+
values: NDArray, method: Literal["zscore", "modzscore", "iqr"], threshold: float | None
|
28
29
|
) -> NDArray:
|
29
30
|
if method == "zscore":
|
30
31
|
threshold = threshold if threshold else 3.0
|
@@ -46,7 +47,7 @@ def _get_outlier_mask(
|
|
46
47
|
raise ValueError("Outlier method must be 'zscore' 'modzscore' or 'iqr'.")
|
47
48
|
|
48
49
|
|
49
|
-
class
|
50
|
+
class Outliers:
|
50
51
|
r"""
|
51
52
|
Calculates statistical outliers of a dataset using various statistical tests applied to each image
|
52
53
|
|
@@ -92,28 +93,28 @@ class Linter:
|
|
92
93
|
|
93
94
|
Examples
|
94
95
|
--------
|
95
|
-
Initialize the
|
96
|
+
Initialize the Outliers class:
|
96
97
|
|
97
|
-
>>>
|
98
|
+
>>> outliers = Outliers()
|
98
99
|
|
99
100
|
Specifying specific metrics to analyze:
|
100
101
|
|
101
|
-
>>>
|
102
|
+
>>> outliers = Outliers(flags=ImageStat.SIZE | ImageStat.ALL_VISUALS)
|
102
103
|
|
103
104
|
Specifying an outlier method:
|
104
105
|
|
105
|
-
>>>
|
106
|
+
>>> outliers = Outliers(outlier_method="iqr")
|
106
107
|
|
107
108
|
Specifying an outlier method and threshold:
|
108
109
|
|
109
|
-
>>>
|
110
|
+
>>> outliers = Outliers(outlier_method="zscore", outlier_threshold=2.5)
|
110
111
|
"""
|
111
112
|
|
112
113
|
def __init__(
|
113
114
|
self,
|
114
115
|
flags: ImageStat = ImageStat.ALL_PROPERTIES | ImageStat.ALL_VISUALS,
|
115
116
|
outlier_method: Literal["zscore", "modzscore", "iqr"] = "modzscore",
|
116
|
-
outlier_threshold:
|
117
|
+
outlier_threshold: float | None = None,
|
117
118
|
):
|
118
119
|
verify_supported(flags, ImageStat.ALL_STATS)
|
119
120
|
self.flags = flags
|
@@ -123,11 +124,9 @@ class Linter:
|
|
123
124
|
def _get_outliers(self) -> dict:
|
124
125
|
flagged_images = {}
|
125
126
|
stats_dict = self.stats.dict()
|
127
|
+
supported = to_distinct(ImageStat.ALL_STATS)
|
126
128
|
for stat, values in stats_dict.items():
|
127
|
-
if
|
128
|
-
continue
|
129
|
-
|
130
|
-
if values.ndim == 1 and np.std(values) != 0:
|
129
|
+
if stat in supported.values() and values.ndim == 1 and np.std(values) != 0:
|
131
130
|
mask = _get_outlier_mask(values, self.outlier_method, self.outlier_threshold)
|
132
131
|
indices = np.flatnonzero(mask)
|
133
132
|
for i, value in zip(indices, values[mask]):
|
@@ -136,19 +135,18 @@ class Linter:
|
|
136
135
|
return dict(sorted(flagged_images.items()))
|
137
136
|
|
138
137
|
@set_metadata("dataeval.detectors", ["flags", "outlier_method", "outlier_threshold"])
|
139
|
-
def evaluate(self,
|
138
|
+
def evaluate(self, data: Iterable[ArrayLike] | StatsOutput) -> OutliersOutput:
|
140
139
|
"""
|
141
140
|
Returns indices of outliers with the issues identified for each
|
142
141
|
|
143
142
|
Parameters
|
144
143
|
----------
|
145
|
-
|
146
|
-
A dataset in an ArrayLike format
|
147
|
-
Function expects the data to have 3 dimensions, CxHxW.
|
144
|
+
data : Iterable[ArrayLike], shape - (C, H, W) | StatsOutput
|
145
|
+
A dataset of images in an ArrayLike format or the output from an imagestats metric analysis
|
148
146
|
|
149
147
|
Returns
|
150
148
|
-------
|
151
|
-
|
149
|
+
OutliersOutput
|
152
150
|
Output class containing the indices of outliers and a dictionary showing
|
153
151
|
the issues and calculated values for the given index.
|
154
152
|
|
@@ -156,8 +154,16 @@ class Linter:
|
|
156
154
|
-------
|
157
155
|
Evaluate the dataset:
|
158
156
|
|
159
|
-
>>>
|
160
|
-
|
157
|
+
>>> outliers.evaluate(images)
|
158
|
+
OutliersOutput(issues={18: {'brightness': 0.78}, 25: {'brightness': 0.98}})
|
161
159
|
"""
|
162
|
-
|
163
|
-
|
160
|
+
if isinstance(data, StatsOutput):
|
161
|
+
flags = set(to_distinct(self.flags).values())
|
162
|
+
stats = set(data.dict())
|
163
|
+
missing = flags - stats
|
164
|
+
if missing:
|
165
|
+
raise ValueError(f"StatsOutput is missing {missing} from the required stats: {flags}.")
|
166
|
+
self.stats = data
|
167
|
+
else:
|
168
|
+
self.stats = imagestats(data, self.flags)
|
169
|
+
return OutliersOutput(self._get_outliers())
|
dataeval/_internal/flags.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from enum import IntFlag, auto
|
2
4
|
from functools import reduce
|
3
|
-
from typing import
|
5
|
+
from typing import Iterable, TypeVar, cast
|
4
6
|
|
5
7
|
TFlag = TypeVar("TFlag", bound=IntFlag)
|
6
8
|
|
@@ -47,7 +49,7 @@ def is_distinct(flag: IntFlag) -> bool:
|
|
47
49
|
return (flag & (flag - 1) == 0) and flag != 0
|
48
50
|
|
49
51
|
|
50
|
-
def to_distinct(flag: TFlag) ->
|
52
|
+
def to_distinct(flag: TFlag) -> dict[TFlag, str]:
|
51
53
|
"""
|
52
54
|
Returns a distinct set of all flags set on the input flag and their names
|
53
55
|
|
@@ -61,7 +63,7 @@ def to_distinct(flag: TFlag) -> Dict[TFlag, str]:
|
|
61
63
|
return {f: f.name.lower() for f in list(flag.__class__) if f & flag and is_distinct(f) and f.name}
|
62
64
|
|
63
65
|
|
64
|
-
def verify_supported(flag: TFlag, flags:
|
66
|
+
def verify_supported(flag: TFlag, flags: TFlag | Iterable[TFlag]):
|
65
67
|
supported = flags if isinstance(flags, flag.__class__) else cast(TFlag, reduce(lambda a, b: a | b, flags)) # type: ignore
|
66
68
|
unsupported = flag & ~supported
|
67
69
|
if unsupported:
|
dataeval/_internal/interop.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from importlib import import_module
|
2
|
-
from typing import Iterable
|
4
|
+
from typing import Iterable
|
3
5
|
|
4
6
|
import numpy as np
|
5
7
|
from numpy.typing import ArrayLike, NDArray
|
@@ -20,7 +22,7 @@ def try_import(module_name):
|
|
20
22
|
return module
|
21
23
|
|
22
24
|
|
23
|
-
def to_numpy(array:
|
25
|
+
def to_numpy(array: ArrayLike | None) -> NDArray:
|
24
26
|
if array is None:
|
25
27
|
return np.ndarray([])
|
26
28
|
|
@@ -1,6 +1,8 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import warnings
|
2
4
|
from dataclasses import dataclass
|
3
|
-
from typing import
|
5
|
+
from typing import Sequence
|
4
6
|
|
5
7
|
import numpy as np
|
6
8
|
from numpy.typing import NDArray
|
@@ -43,7 +45,7 @@ def validate_num_neighbors(num_neighbors: int) -> int:
|
|
43
45
|
|
44
46
|
|
45
47
|
@set_metadata("dataeval.metrics")
|
46
|
-
def balance(class_labels: Sequence[int], metadata:
|
48
|
+
def balance(class_labels: Sequence[int], metadata: list[dict], num_neighbors: int = 5) -> BalanceOutput:
|
47
49
|
"""
|
48
50
|
Mutual information (MI) between factors (class label, metadata, label/image properties)
|
49
51
|
|
@@ -71,6 +73,22 @@ def balance(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: in
|
|
71
73
|
we attempt to infer whether a variable is categorical by the fraction of unique
|
72
74
|
values in the dataset.
|
73
75
|
|
76
|
+
Example
|
77
|
+
-------
|
78
|
+
Return balance (mutual information) of factors with class_labels
|
79
|
+
|
80
|
+
>>> balance(class_labels, metadata).mutual_information[0]
|
81
|
+
array([0.99999822, 0.13363788, 0. , 0.02994455])
|
82
|
+
|
83
|
+
Return balance (mutual information) of metadata factors with class_labels
|
84
|
+
and each other
|
85
|
+
|
86
|
+
>>> balance(class_labels, metadata).mutual_information
|
87
|
+
array([[0.99999822, 0.13363788, 0. , 0.02994455],
|
88
|
+
[0.13363788, 0.99999843, 0.01389763, 0.09725766],
|
89
|
+
[0. , 0.01389763, 0.48549233, 0.15314612],
|
90
|
+
[0.02994455, 0.09725766, 0.15314612, 0.99999856]])
|
91
|
+
|
74
92
|
See Also
|
75
93
|
--------
|
76
94
|
sklearn.feature_selection.mutual_info_classif
|
@@ -96,14 +114,15 @@ def balance(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: in
|
|
96
114
|
tgt,
|
97
115
|
discrete_features=is_categorical, # type: ignore
|
98
116
|
n_neighbors=num_neighbors,
|
117
|
+
random_state=0,
|
99
118
|
)
|
100
119
|
else:
|
101
|
-
# continuous variables
|
102
120
|
mi[idx, :] = mutual_info_regression(
|
103
121
|
data,
|
104
122
|
tgt,
|
105
123
|
discrete_features=is_categorical, # type: ignore
|
106
124
|
n_neighbors=num_neighbors,
|
125
|
+
random_state=0,
|
107
126
|
)
|
108
127
|
|
109
128
|
ent_all = entropy(data, names, is_categorical, normalized=False)
|
@@ -115,7 +134,7 @@ def balance(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: in
|
|
115
134
|
|
116
135
|
|
117
136
|
@set_metadata("dataeval.metrics")
|
118
|
-
def balance_classwise(class_labels: Sequence[int], metadata:
|
137
|
+
def balance_classwise(class_labels: Sequence[int], metadata: list[dict], num_neighbors: int = 5) -> BalanceOutput:
|
119
138
|
"""
|
120
139
|
Compute mutual information (analogous to correlation) between metadata factors
|
121
140
|
(class label, metadata, label/image properties) with individual class labels.
|
@@ -143,6 +162,15 @@ def balance_classwise(class_labels: Sequence[int], metadata: List[Dict], num_nei
|
|
143
162
|
(num_classes x num_factors) estimate of mutual information between
|
144
163
|
num_factors metadata factors and individual class labels.
|
145
164
|
|
165
|
+
Example
|
166
|
+
-------
|
167
|
+
Return classwise balance (mutual information) of factors with individual class_labels
|
168
|
+
|
169
|
+
>>> balance_classwise(class_labels, metadata).mutual_information
|
170
|
+
array([[0.13363788, 0.54085156, 0. ],
|
171
|
+
[0.13363788, 0.54085156, 0. ]])
|
172
|
+
|
173
|
+
|
146
174
|
See Also
|
147
175
|
--------
|
148
176
|
sklearn.feature_selection.mutual_info_classif
|
@@ -177,6 +205,7 @@ def balance_classwise(class_labels: Sequence[int], metadata: List[Dict], num_nei
|
|
177
205
|
tgt,
|
178
206
|
discrete_features=cat_mask, # type: ignore
|
179
207
|
n_neighbors=num_neighbors,
|
208
|
+
random_state=0,
|
180
209
|
)
|
181
210
|
|
182
211
|
# let this recompute for all features including class label
|
@@ -7,8 +7,10 @@ Learning to Bound the Multi-class Bayes Error (Th. 3 and Th. 4)
|
|
7
7
|
https://arxiv.org/abs/1811.06419
|
8
8
|
"""
|
9
9
|
|
10
|
+
from __future__ import annotations
|
11
|
+
|
10
12
|
from dataclasses import dataclass
|
11
|
-
from typing import Literal
|
13
|
+
from typing import Literal
|
12
14
|
|
13
15
|
import numpy as np
|
14
16
|
from numpy.typing import ArrayLike, NDArray
|
@@ -35,7 +37,7 @@ class BEROutput(OutputMetadata):
|
|
35
37
|
ber_lower: float
|
36
38
|
|
37
39
|
|
38
|
-
def ber_mst(X: NDArray, y: NDArray) ->
|
40
|
+
def ber_mst(X: NDArray, y: NDArray) -> tuple[float, float]:
|
39
41
|
"""Calculates the Bayes Error Rate using a minimum spanning tree
|
40
42
|
|
41
43
|
Parameters
|
@@ -60,7 +62,7 @@ def ber_mst(X: NDArray, y: NDArray) -> Tuple[float, float]:
|
|
60
62
|
return upper, lower
|
61
63
|
|
62
64
|
|
63
|
-
def ber_knn(X: NDArray, y: NDArray, k: int) ->
|
65
|
+
def ber_knn(X: NDArray, y: NDArray, k: int) -> tuple[float, float]:
|
64
66
|
"""Calculates the Bayes Error Rate using K-nearest neighbors
|
65
67
|
|
66
68
|
Parameters
|
@@ -135,7 +137,7 @@ def ber(images: ArrayLike, labels: ArrayLike, k: int = 1, method: Literal["KNN",
|
|
135
137
|
Examples
|
136
138
|
--------
|
137
139
|
>>> import sklearn.datasets as dsets
|
138
|
-
>>> from dataeval.metrics import ber
|
140
|
+
>>> from dataeval.metrics.estimators import ber
|
139
141
|
|
140
142
|
>>> images, labels = dsets.make_blobs(n_samples=50, centers=2, n_features=2, random_state=0)
|
141
143
|
|
@@ -1,5 +1,7 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from dataclasses import dataclass
|
2
|
-
from typing import
|
4
|
+
from typing import Literal, Sequence
|
3
5
|
|
4
6
|
import numpy as np
|
5
7
|
from numpy.typing import NDArray
|
@@ -22,9 +24,9 @@ class DiversityOutput(OutputMetadata):
|
|
22
24
|
|
23
25
|
def diversity_shannon(
|
24
26
|
data: NDArray,
|
25
|
-
names:
|
26
|
-
is_categorical:
|
27
|
-
subset_mask:
|
27
|
+
names: list[str],
|
28
|
+
is_categorical: list[bool],
|
29
|
+
subset_mask: NDArray[np.bool_] | None = None,
|
28
30
|
) -> NDArray:
|
29
31
|
"""
|
30
32
|
Compute diversity for discrete/categorical variables and, through standard
|
@@ -37,7 +39,7 @@ def diversity_shannon(
|
|
37
39
|
|
38
40
|
Parameters
|
39
41
|
----------
|
40
|
-
subset_mask:
|
42
|
+
subset_mask: NDArray[np.bool_] | None
|
41
43
|
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
42
44
|
|
43
45
|
Notes
|
@@ -58,14 +60,17 @@ def diversity_shannon(
|
|
58
60
|
ent_unnormalized = entropy(data, names, is_categorical, normalized=False, subset_mask=subset_mask)
|
59
61
|
# normalize by global counts rather than classwise counts
|
60
62
|
num_bins = get_num_bins(data, names, is_categorical=is_categorical, subset_mask=subset_mask)
|
61
|
-
|
63
|
+
ent_norm = np.empty(ent_unnormalized.shape)
|
64
|
+
ent_norm[num_bins != 1] = ent_unnormalized[num_bins != 1] / np.log(num_bins[num_bins != 1])
|
65
|
+
ent_norm[num_bins == 1] = 0
|
66
|
+
return ent_norm
|
62
67
|
|
63
68
|
|
64
69
|
def diversity_simpson(
|
65
70
|
data: NDArray,
|
66
|
-
names:
|
67
|
-
is_categorical:
|
68
|
-
subset_mask:
|
71
|
+
names: list[str],
|
72
|
+
is_categorical: list[bool],
|
73
|
+
subset_mask: NDArray[np.bool_] | None = None,
|
69
74
|
) -> NDArray:
|
70
75
|
"""
|
71
76
|
Compute diversity for discrete/categorical variables and, through standard
|
@@ -79,7 +84,7 @@ def diversity_simpson(
|
|
79
84
|
|
80
85
|
Parameters
|
81
86
|
----------
|
82
|
-
subset_mask:
|
87
|
+
subset_mask: NDArray[np.bool_] | None
|
83
88
|
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
84
89
|
|
85
90
|
Notes
|
@@ -121,7 +126,7 @@ DIVERSITY_FN_MAP = {"simpson": diversity_simpson, "shannon": diversity_shannon}
|
|
121
126
|
|
122
127
|
@set_metadata("dataeval.metrics")
|
123
128
|
def diversity(
|
124
|
-
class_labels: Sequence[int], metadata:
|
129
|
+
class_labels: Sequence[int], metadata: list[dict], method: Literal["shannon", "simpson"] = "simpson"
|
125
130
|
) -> DiversityOutput:
|
126
131
|
"""
|
127
132
|
Compute diversity for discrete/categorical variables and, through standard
|
@@ -149,6 +154,19 @@ def diversity(
|
|
149
154
|
DiversityOutput
|
150
155
|
Diversity index per column of self.data or each factor in self.names
|
151
156
|
|
157
|
+
Example
|
158
|
+
-------
|
159
|
+
Compute Simpson diversity index of metadata and class labels
|
160
|
+
|
161
|
+
>>> diversity(class_labels, metadata, method="simpson").diversity_index
|
162
|
+
array([0.34482759, 0.34482759, 0.90909091])
|
163
|
+
|
164
|
+
Compute Shannon diversity index of metadata and class labels
|
165
|
+
|
166
|
+
>>> diversity(class_labels, metadata, method="shannon").diversity_index
|
167
|
+
array([0.37955133, 0.37955133, 0.96748876])
|
168
|
+
|
169
|
+
|
152
170
|
See Also
|
153
171
|
--------
|
154
172
|
numpy.histogram
|
@@ -161,7 +179,7 @@ def diversity(
|
|
161
179
|
|
162
180
|
@set_metadata("dataeval.metrics")
|
163
181
|
def diversity_classwise(
|
164
|
-
class_labels: Sequence[int], metadata:
|
182
|
+
class_labels: Sequence[int], metadata: list[dict], method: Literal["shannon", "simpson"] = "simpson"
|
165
183
|
) -> DiversityOutput:
|
166
184
|
"""
|
167
185
|
Compute diversity for discrete/categorical variables and, through standard
|
@@ -191,6 +209,21 @@ def diversity_classwise(
|
|
191
209
|
DiversityOutput
|
192
210
|
Diversity index [n_class x n_factor]
|
193
211
|
|
212
|
+
Example
|
213
|
+
-------
|
214
|
+
Compute classwise Simpson diversity index of metadata and class labels
|
215
|
+
|
216
|
+
>>> diversity_classwise(class_labels, metadata, method="simpson").diversity_index
|
217
|
+
array([[0.33793103, 0.51578947],
|
218
|
+
[0.36 , 0.36 ]])
|
219
|
+
|
220
|
+
Compute classwise Shannon diversity index of metadata and class labels
|
221
|
+
|
222
|
+
>>> diversity_classwise(class_labels, metadata, method="shannon").diversity_index
|
223
|
+
array([[0.43156028, 0.83224889],
|
224
|
+
[0.57938016, 0.57938016]])
|
225
|
+
|
226
|
+
|
194
227
|
See Also
|
195
228
|
--------
|
196
229
|
numpy.histogram
|