dataeval 0.65.0__py3-none-any.whl → 0.66.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. dataeval/__init__.py +13 -9
  2. dataeval/_internal/detectors/clusterer.py +24 -22
  3. dataeval/_internal/detectors/drift/base.py +206 -26
  4. dataeval/_internal/detectors/drift/cvm.py +25 -23
  5. dataeval/_internal/detectors/drift/ks.py +28 -25
  6. dataeval/_internal/detectors/drift/mmd.py +30 -29
  7. dataeval/_internal/detectors/drift/torch.py +66 -58
  8. dataeval/_internal/detectors/drift/uncertainty.py +28 -28
  9. dataeval/_internal/detectors/duplicates.py +28 -18
  10. dataeval/_internal/detectors/ood/ae.py +15 -29
  11. dataeval/_internal/detectors/ood/aegmm.py +33 -27
  12. dataeval/_internal/detectors/ood/base.py +61 -43
  13. dataeval/_internal/detectors/ood/llr.py +27 -24
  14. dataeval/_internal/detectors/ood/vae.py +32 -31
  15. dataeval/_internal/detectors/ood/vaegmm.py +34 -28
  16. dataeval/_internal/detectors/{linter.py → outliers.py} +33 -27
  17. dataeval/_internal/flags.py +5 -3
  18. dataeval/_internal/interop.py +4 -2
  19. dataeval/_internal/metrics/balance.py +33 -4
  20. dataeval/_internal/metrics/ber.py +6 -4
  21. dataeval/_internal/metrics/diversity.py +45 -12
  22. dataeval/_internal/metrics/parity.py +114 -26
  23. dataeval/_internal/metrics/stats.py +154 -16
  24. dataeval/_internal/metrics/uap.py +28 -2
  25. dataeval/_internal/metrics/utils.py +20 -18
  26. dataeval/_internal/models/pytorch/autoencoder.py +127 -22
  27. dataeval/_internal/models/tensorflow/autoencoder.py +33 -30
  28. dataeval/_internal/models/tensorflow/gmm.py +4 -2
  29. dataeval/_internal/models/tensorflow/losses.py +15 -11
  30. dataeval/_internal/models/tensorflow/pixelcnn.py +19 -18
  31. dataeval/_internal/models/tensorflow/trainer.py +8 -6
  32. dataeval/_internal/models/tensorflow/utils.py +21 -19
  33. dataeval/_internal/output.py +13 -10
  34. dataeval/_internal/utils.py +5 -3
  35. dataeval/_internal/workflows/sufficiency.py +42 -30
  36. dataeval/detectors/__init__.py +6 -25
  37. dataeval/detectors/drift/__init__.py +16 -0
  38. dataeval/detectors/drift/kernels/__init__.py +6 -0
  39. dataeval/detectors/drift/updates/__init__.py +3 -0
  40. dataeval/detectors/linters/__init__.py +5 -0
  41. dataeval/detectors/ood/__init__.py +11 -0
  42. dataeval/metrics/__init__.py +2 -26
  43. dataeval/metrics/bias/__init__.py +14 -0
  44. dataeval/metrics/estimators/__init__.py +9 -0
  45. dataeval/metrics/stats/__init__.py +6 -0
  46. dataeval/tensorflow/__init__.py +3 -0
  47. dataeval/tensorflow/loss/__init__.py +3 -0
  48. dataeval/tensorflow/models/__init__.py +5 -0
  49. dataeval/tensorflow/recon/__init__.py +3 -0
  50. dataeval/torch/__init__.py +3 -0
  51. dataeval/{models/torch → torch/models}/__init__.py +1 -2
  52. dataeval/torch/trainer/__init__.py +3 -0
  53. dataeval/utils/__init__.py +3 -6
  54. dataeval/workflows/__init__.py +2 -4
  55. {dataeval-0.65.0.dist-info → dataeval-0.66.0.dist-info}/METADATA +1 -1
  56. dataeval-0.66.0.dist-info/RECORD +72 -0
  57. dataeval/models/__init__.py +0 -15
  58. dataeval/models/tensorflow/__init__.py +0 -6
  59. dataeval-0.65.0.dist-info/RECORD +0 -60
  60. {dataeval-0.65.0.dist-info → dataeval-0.66.0.dist-info}/LICENSE.txt +0 -0
  61. {dataeval-0.65.0.dist-info → dataeval-0.66.0.dist-info}/WHEEL +0 -0
@@ -6,10 +6,13 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
+ from __future__ import annotations
10
+
9
11
  from typing import Callable
10
12
 
11
13
  import keras
12
14
  import numpy as np
15
+ import tensorflow as tf
13
16
  from numpy.typing import ArrayLike
14
17
 
15
18
  from dataeval._internal.detectors.ood.base import OODGMMBase, OODScore
@@ -21,17 +24,18 @@ from dataeval._internal.models.tensorflow.utils import predict_batch
21
24
 
22
25
 
23
26
  class OOD_VAEGMM(OODGMMBase):
24
- def __init__(self, model: VAEGMM, samples: int = 10) -> None:
25
- """
26
- VAE with Gaussian Mixture Model based outlier detector.
27
+ """
28
+ VAE with Gaussian Mixture Model based outlier detector.
27
29
 
28
- Parameters
29
- ----------
30
- model : VAEGMM
31
- A VAEGMM model.
32
- samples
33
- Number of samples sampled to evaluate each instance.
34
- """
30
+ Parameters
31
+ ----------
32
+ model : VAEGMM
33
+ A VAEGMM model.
34
+ samples
35
+ Number of samples sampled to evaluate each instance.
36
+ """
37
+
38
+ def __init__(self, model: VAEGMM, samples: int = 10) -> None:
35
39
  super().__init__(model)
36
40
  self.samples = samples
37
41
 
@@ -39,35 +43,37 @@ class OOD_VAEGMM(OODGMMBase):
39
43
  self,
40
44
  x_ref: ArrayLike,
41
45
  threshold_perc: float = 100.0,
42
- loss_fn: Callable = LossGMM(elbo=Elbo(0.05)),
46
+ loss_fn: Callable[..., tf.Tensor] | None = None,
43
47
  optimizer: keras.optimizers.Optimizer = keras.optimizers.Adam,
44
48
  epochs: int = 20,
45
49
  batch_size: int = 64,
46
50
  verbose: bool = True,
47
51
  ) -> None:
52
+ if loss_fn is None:
53
+ loss_fn = LossGMM(elbo=Elbo(0.05))
54
+ super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
55
+
56
+ def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
48
57
  """
49
- Train the AE model with recommended loss function and optimizer.
58
+ Compute the out-of-distribution (OOD) score for a given dataset.
50
59
 
51
60
  Parameters
52
61
  ----------
53
62
  X : ArrayLike
54
- Training batch.
55
- threshold_perc : float, default 100.0
56
- Percentage of reference data that is normal.
57
- loss_fn : Callable, default LossGMM(elbo=Elbo(0.05))
58
- Loss function used for training.
59
- optimizer : keras.optimizers.Optimizer, default keras.optimizers.Adam
60
- Optimizer used for training.
61
- epochs : int, default 20
62
- Number of training epochs.
63
- batch_size : int, default 64
64
- Batch size used for training.
65
- verbose : bool, default True
66
- Whether to print training progress.
67
- """
68
- super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
63
+ Input data to score.
64
+ batch_size : int, default 1e10
65
+ Number of instances to process in each batch.
66
+ Use a smaller batch size if your dataset is large or if you encounter memory issues.
69
67
 
70
- def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScore:
68
+ Returns
69
+ -------
70
+ OODScore
71
+ An object containing the instance-level OOD score.
72
+
73
+ Note
74
+ ----
75
+ This model does not produce a feature level score like the OOD_AE or OOD_VAE models.
76
+ """
71
77
  self._validate(X := to_numpy(X))
72
78
 
73
79
  # draw samples from latent space
@@ -1,17 +1,18 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import dataclass
2
- from typing import Dict, Iterable, Literal, Optional
4
+ from typing import Iterable, Literal
3
5
 
4
6
  import numpy as np
5
7
  from numpy.typing import ArrayLike, NDArray
6
8
 
7
- from dataeval._internal.flags import verify_supported
9
+ from dataeval._internal.flags import ImageStat, to_distinct, verify_supported
10
+ from dataeval._internal.metrics.stats import StatsOutput, imagestats
8
11
  from dataeval._internal.output import OutputMetadata, set_metadata
9
- from dataeval.flags import ImageStat
10
- from dataeval.metrics import imagestats
11
12
 
12
13
 
13
14
  @dataclass(frozen=True)
14
- class LinterOutput(OutputMetadata):
15
+ class OutliersOutput(OutputMetadata):
15
16
  """
16
17
  Attributes
17
18
  ----------
@@ -20,11 +21,11 @@ class LinterOutput(OutputMetadata):
20
21
  the issues and calculated values for the given index.
21
22
  """
22
23
 
23
- issues: Dict[int, Dict[str, float]]
24
+ issues: dict[int, dict[str, float]]
24
25
 
25
26
 
26
27
  def _get_outlier_mask(
27
- values: NDArray, method: Literal["zscore", "modzscore", "iqr"], threshold: Optional[float]
28
+ values: NDArray, method: Literal["zscore", "modzscore", "iqr"], threshold: float | None
28
29
  ) -> NDArray:
29
30
  if method == "zscore":
30
31
  threshold = threshold if threshold else 3.0
@@ -46,7 +47,7 @@ def _get_outlier_mask(
46
47
  raise ValueError("Outlier method must be 'zscore' 'modzscore' or 'iqr'.")
47
48
 
48
49
 
49
- class Linter:
50
+ class Outliers:
50
51
  r"""
51
52
  Calculates statistical outliers of a dataset using various statistical tests applied to each image
52
53
 
@@ -92,28 +93,28 @@ class Linter:
92
93
 
93
94
  Examples
94
95
  --------
95
- Initialize the Linter class:
96
+ Initialize the Outliers class:
96
97
 
97
- >>> lint = Linter()
98
+ >>> outliers = Outliers()
98
99
 
99
100
  Specifying specific metrics to analyze:
100
101
 
101
- >>> lint = Linter(flags=ImageStat.SIZE | ImageStat.ALL_VISUALS)
102
+ >>> outliers = Outliers(flags=ImageStat.SIZE | ImageStat.ALL_VISUALS)
102
103
 
103
104
  Specifying an outlier method:
104
105
 
105
- >>> lint = Linter(outlier_method="iqr")
106
+ >>> outliers = Outliers(outlier_method="iqr")
106
107
 
107
108
  Specifying an outlier method and threshold:
108
109
 
109
- >>> lint = Linter(outlier_method="zscore", outlier_threshold=2.5)
110
+ >>> outliers = Outliers(outlier_method="zscore", outlier_threshold=2.5)
110
111
  """
111
112
 
112
113
  def __init__(
113
114
  self,
114
115
  flags: ImageStat = ImageStat.ALL_PROPERTIES | ImageStat.ALL_VISUALS,
115
116
  outlier_method: Literal["zscore", "modzscore", "iqr"] = "modzscore",
116
- outlier_threshold: Optional[float] = None,
117
+ outlier_threshold: float | None = None,
117
118
  ):
118
119
  verify_supported(flags, ImageStat.ALL_STATS)
119
120
  self.flags = flags
@@ -123,11 +124,9 @@ class Linter:
123
124
  def _get_outliers(self) -> dict:
124
125
  flagged_images = {}
125
126
  stats_dict = self.stats.dict()
127
+ supported = to_distinct(ImageStat.ALL_STATS)
126
128
  for stat, values in stats_dict.items():
127
- if not isinstance(values, np.ndarray):
128
- continue
129
-
130
- if values.ndim == 1 and np.std(values) != 0:
129
+ if stat in supported.values() and values.ndim == 1 and np.std(values) != 0:
131
130
  mask = _get_outlier_mask(values, self.outlier_method, self.outlier_threshold)
132
131
  indices = np.flatnonzero(mask)
133
132
  for i, value in zip(indices, values[mask]):
@@ -136,19 +135,18 @@ class Linter:
136
135
  return dict(sorted(flagged_images.items()))
137
136
 
138
137
  @set_metadata("dataeval.detectors", ["flags", "outlier_method", "outlier_threshold"])
139
- def evaluate(self, images: Iterable[ArrayLike]) -> LinterOutput:
138
+ def evaluate(self, data: Iterable[ArrayLike] | StatsOutput) -> OutliersOutput:
140
139
  """
141
140
  Returns indices of outliers with the issues identified for each
142
141
 
143
142
  Parameters
144
143
  ----------
145
- images : Iterable[ArrayLike], shape - (N, C, H, W)
146
- A dataset in an ArrayLike format.
147
- Function expects the data to have 3 dimensions, CxHxW.
144
+ data : Iterable[ArrayLike], shape - (C, H, W) | StatsOutput
145
+ A dataset of images in an ArrayLike format or the output from an imagestats metric analysis
148
146
 
149
147
  Returns
150
148
  -------
151
- LinterOutput
149
+ OutliersOutput
152
150
  Output class containing the indices of outliers and a dictionary showing
153
151
  the issues and calculated values for the given index.
154
152
 
@@ -156,8 +154,16 @@ class Linter:
156
154
  -------
157
155
  Evaluate the dataset:
158
156
 
159
- >>> lint.evaluate(images)
160
- LinterOutput(issues={18: {'brightness': 0.78}, 25: {'brightness': 0.98}})
157
+ >>> outliers.evaluate(images)
158
+ OutliersOutput(issues={18: {'brightness': 0.78}, 25: {'brightness': 0.98}})
161
159
  """
162
- self.stats = imagestats(images, self.flags)
163
- return LinterOutput(self._get_outliers())
160
+ if isinstance(data, StatsOutput):
161
+ flags = set(to_distinct(self.flags).values())
162
+ stats = set(data.dict())
163
+ missing = flags - stats
164
+ if missing:
165
+ raise ValueError(f"StatsOutput is missing {missing} from the required stats: {flags}.")
166
+ self.stats = data
167
+ else:
168
+ self.stats = imagestats(data, self.flags)
169
+ return OutliersOutput(self._get_outliers())
@@ -1,6 +1,8 @@
1
+ from __future__ import annotations
2
+
1
3
  from enum import IntFlag, auto
2
4
  from functools import reduce
3
- from typing import Dict, Iterable, TypeVar, Union, cast
5
+ from typing import Iterable, TypeVar, cast
4
6
 
5
7
  TFlag = TypeVar("TFlag", bound=IntFlag)
6
8
 
@@ -47,7 +49,7 @@ def is_distinct(flag: IntFlag) -> bool:
47
49
  return (flag & (flag - 1) == 0) and flag != 0
48
50
 
49
51
 
50
- def to_distinct(flag: TFlag) -> Dict[TFlag, str]:
52
+ def to_distinct(flag: TFlag) -> dict[TFlag, str]:
51
53
  """
52
54
  Returns a distinct set of all flags set on the input flag and their names
53
55
 
@@ -61,7 +63,7 @@ def to_distinct(flag: TFlag) -> Dict[TFlag, str]:
61
63
  return {f: f.name.lower() for f in list(flag.__class__) if f & flag and is_distinct(f) and f.name}
62
64
 
63
65
 
64
- def verify_supported(flag: TFlag, flags: Union[TFlag, Iterable[TFlag]]):
66
+ def verify_supported(flag: TFlag, flags: TFlag | Iterable[TFlag]):
65
67
  supported = flags if isinstance(flags, flag.__class__) else cast(TFlag, reduce(lambda a, b: a | b, flags)) # type: ignore
66
68
  unsupported = flag & ~supported
67
69
  if unsupported:
@@ -1,5 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  from importlib import import_module
2
- from typing import Iterable, Optional
4
+ from typing import Iterable
3
5
 
4
6
  import numpy as np
5
7
  from numpy.typing import ArrayLike, NDArray
@@ -20,7 +22,7 @@ def try_import(module_name):
20
22
  return module
21
23
 
22
24
 
23
- def to_numpy(array: Optional[ArrayLike]) -> NDArray:
25
+ def to_numpy(array: ArrayLike | None) -> NDArray:
24
26
  if array is None:
25
27
  return np.ndarray([])
26
28
 
@@ -1,6 +1,8 @@
1
+ from __future__ import annotations
2
+
1
3
  import warnings
2
4
  from dataclasses import dataclass
3
- from typing import Dict, List, Sequence
5
+ from typing import Sequence
4
6
 
5
7
  import numpy as np
6
8
  from numpy.typing import NDArray
@@ -43,7 +45,7 @@ def validate_num_neighbors(num_neighbors: int) -> int:
43
45
 
44
46
 
45
47
  @set_metadata("dataeval.metrics")
46
- def balance(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: int = 5) -> BalanceOutput:
48
+ def balance(class_labels: Sequence[int], metadata: list[dict], num_neighbors: int = 5) -> BalanceOutput:
47
49
  """
48
50
  Mutual information (MI) between factors (class label, metadata, label/image properties)
49
51
 
@@ -71,6 +73,22 @@ def balance(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: in
71
73
  we attempt to infer whether a variable is categorical by the fraction of unique
72
74
  values in the dataset.
73
75
 
76
+ Example
77
+ -------
78
+ Return balance (mutual information) of factors with class_labels
79
+
80
+ >>> balance(class_labels, metadata).mutual_information[0]
81
+ array([0.99999822, 0.13363788, 0. , 0.02994455])
82
+
83
+ Return balance (mutual information) of metadata factors with class_labels
84
+ and each other
85
+
86
+ >>> balance(class_labels, metadata).mutual_information
87
+ array([[0.99999822, 0.13363788, 0. , 0.02994455],
88
+ [0.13363788, 0.99999843, 0.01389763, 0.09725766],
89
+ [0. , 0.01389763, 0.48549233, 0.15314612],
90
+ [0.02994455, 0.09725766, 0.15314612, 0.99999856]])
91
+
74
92
  See Also
75
93
  --------
76
94
  sklearn.feature_selection.mutual_info_classif
@@ -96,14 +114,15 @@ def balance(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: in
96
114
  tgt,
97
115
  discrete_features=is_categorical, # type: ignore
98
116
  n_neighbors=num_neighbors,
117
+ random_state=0,
99
118
  )
100
119
  else:
101
- # continuous variables
102
120
  mi[idx, :] = mutual_info_regression(
103
121
  data,
104
122
  tgt,
105
123
  discrete_features=is_categorical, # type: ignore
106
124
  n_neighbors=num_neighbors,
125
+ random_state=0,
107
126
  )
108
127
 
109
128
  ent_all = entropy(data, names, is_categorical, normalized=False)
@@ -115,7 +134,7 @@ def balance(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: in
115
134
 
116
135
 
117
136
  @set_metadata("dataeval.metrics")
118
- def balance_classwise(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: int = 5) -> BalanceOutput:
137
+ def balance_classwise(class_labels: Sequence[int], metadata: list[dict], num_neighbors: int = 5) -> BalanceOutput:
119
138
  """
120
139
  Compute mutual information (analogous to correlation) between metadata factors
121
140
  (class label, metadata, label/image properties) with individual class labels.
@@ -143,6 +162,15 @@ def balance_classwise(class_labels: Sequence[int], metadata: List[Dict], num_nei
143
162
  (num_classes x num_factors) estimate of mutual information between
144
163
  num_factors metadata factors and individual class labels.
145
164
 
165
+ Example
166
+ -------
167
+ Return classwise balance (mutual information) of factors with individual class_labels
168
+
169
+ >>> balance_classwise(class_labels, metadata).mutual_information
170
+ array([[0.13363788, 0.54085156, 0. ],
171
+ [0.13363788, 0.54085156, 0. ]])
172
+
173
+
146
174
  See Also
147
175
  --------
148
176
  sklearn.feature_selection.mutual_info_classif
@@ -177,6 +205,7 @@ def balance_classwise(class_labels: Sequence[int], metadata: List[Dict], num_nei
177
205
  tgt,
178
206
  discrete_features=cat_mask, # type: ignore
179
207
  n_neighbors=num_neighbors,
208
+ random_state=0,
180
209
  )
181
210
 
182
211
  # let this recompute for all features including class label
@@ -7,8 +7,10 @@ Learning to Bound the Multi-class Bayes Error (Th. 3 and Th. 4)
7
7
  https://arxiv.org/abs/1811.06419
8
8
  """
9
9
 
10
+ from __future__ import annotations
11
+
10
12
  from dataclasses import dataclass
11
- from typing import Literal, Tuple
13
+ from typing import Literal
12
14
 
13
15
  import numpy as np
14
16
  from numpy.typing import ArrayLike, NDArray
@@ -35,7 +37,7 @@ class BEROutput(OutputMetadata):
35
37
  ber_lower: float
36
38
 
37
39
 
38
- def ber_mst(X: NDArray, y: NDArray) -> Tuple[float, float]:
40
+ def ber_mst(X: NDArray, y: NDArray) -> tuple[float, float]:
39
41
  """Calculates the Bayes Error Rate using a minimum spanning tree
40
42
 
41
43
  Parameters
@@ -60,7 +62,7 @@ def ber_mst(X: NDArray, y: NDArray) -> Tuple[float, float]:
60
62
  return upper, lower
61
63
 
62
64
 
63
- def ber_knn(X: NDArray, y: NDArray, k: int) -> Tuple[float, float]:
65
+ def ber_knn(X: NDArray, y: NDArray, k: int) -> tuple[float, float]:
64
66
  """Calculates the Bayes Error Rate using K-nearest neighbors
65
67
 
66
68
  Parameters
@@ -135,7 +137,7 @@ def ber(images: ArrayLike, labels: ArrayLike, k: int = 1, method: Literal["KNN",
135
137
  Examples
136
138
  --------
137
139
  >>> import sklearn.datasets as dsets
138
- >>> from dataeval.metrics import ber
140
+ >>> from dataeval.metrics.estimators import ber
139
141
 
140
142
  >>> images, labels = dsets.make_blobs(n_samples=50, centers=2, n_features=2, random_state=0)
141
143
 
@@ -1,5 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import dataclass
2
- from typing import Dict, List, Literal, Optional, Sequence
4
+ from typing import Literal, Sequence
3
5
 
4
6
  import numpy as np
5
7
  from numpy.typing import NDArray
@@ -22,9 +24,9 @@ class DiversityOutput(OutputMetadata):
22
24
 
23
25
  def diversity_shannon(
24
26
  data: NDArray,
25
- names: List[str],
26
- is_categorical: List[bool],
27
- subset_mask: Optional[NDArray[np.bool_]] = None,
27
+ names: list[str],
28
+ is_categorical: list[bool],
29
+ subset_mask: NDArray[np.bool_] | None = None,
28
30
  ) -> NDArray:
29
31
  """
30
32
  Compute diversity for discrete/categorical variables and, through standard
@@ -37,7 +39,7 @@ def diversity_shannon(
37
39
 
38
40
  Parameters
39
41
  ----------
40
- subset_mask: Optional[NDArray[np.bool_]]
42
+ subset_mask: NDArray[np.bool_] | None
41
43
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
42
44
 
43
45
  Notes
@@ -58,14 +60,17 @@ def diversity_shannon(
58
60
  ent_unnormalized = entropy(data, names, is_categorical, normalized=False, subset_mask=subset_mask)
59
61
  # normalize by global counts rather than classwise counts
60
62
  num_bins = get_num_bins(data, names, is_categorical=is_categorical, subset_mask=subset_mask)
61
- return ent_unnormalized / np.log(num_bins)
63
+ ent_norm = np.empty(ent_unnormalized.shape)
64
+ ent_norm[num_bins != 1] = ent_unnormalized[num_bins != 1] / np.log(num_bins[num_bins != 1])
65
+ ent_norm[num_bins == 1] = 0
66
+ return ent_norm
62
67
 
63
68
 
64
69
  def diversity_simpson(
65
70
  data: NDArray,
66
- names: List[str],
67
- is_categorical: List[bool],
68
- subset_mask: Optional[NDArray[np.bool_]] = None,
71
+ names: list[str],
72
+ is_categorical: list[bool],
73
+ subset_mask: NDArray[np.bool_] | None = None,
69
74
  ) -> NDArray:
70
75
  """
71
76
  Compute diversity for discrete/categorical variables and, through standard
@@ -79,7 +84,7 @@ def diversity_simpson(
79
84
 
80
85
  Parameters
81
86
  ----------
82
- subset_mask: Optional[NDArray[np.bool_]]
87
+ subset_mask: NDArray[np.bool_] | None
83
88
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
84
89
 
85
90
  Notes
@@ -121,7 +126,7 @@ DIVERSITY_FN_MAP = {"simpson": diversity_simpson, "shannon": diversity_shannon}
121
126
 
122
127
  @set_metadata("dataeval.metrics")
123
128
  def diversity(
124
- class_labels: Sequence[int], metadata: List[Dict], method: Literal["shannon", "simpson"] = "simpson"
129
+ class_labels: Sequence[int], metadata: list[dict], method: Literal["shannon", "simpson"] = "simpson"
125
130
  ) -> DiversityOutput:
126
131
  """
127
132
  Compute diversity for discrete/categorical variables and, through standard
@@ -149,6 +154,19 @@ def diversity(
149
154
  DiversityOutput
150
155
  Diversity index per column of self.data or each factor in self.names
151
156
 
157
+ Example
158
+ -------
159
+ Compute Simpson diversity index of metadata and class labels
160
+
161
+ >>> diversity(class_labels, metadata, method="simpson").diversity_index
162
+ array([0.34482759, 0.34482759, 0.90909091])
163
+
164
+ Compute Shannon diversity index of metadata and class labels
165
+
166
+ >>> diversity(class_labels, metadata, method="shannon").diversity_index
167
+ array([0.37955133, 0.37955133, 0.96748876])
168
+
169
+
152
170
  See Also
153
171
  --------
154
172
  numpy.histogram
@@ -161,7 +179,7 @@ def diversity(
161
179
 
162
180
  @set_metadata("dataeval.metrics")
163
181
  def diversity_classwise(
164
- class_labels: Sequence[int], metadata: List[Dict], method: Literal["shannon", "simpson"] = "simpson"
182
+ class_labels: Sequence[int], metadata: list[dict], method: Literal["shannon", "simpson"] = "simpson"
165
183
  ) -> DiversityOutput:
166
184
  """
167
185
  Compute diversity for discrete/categorical variables and, through standard
@@ -191,6 +209,21 @@ def diversity_classwise(
191
209
  DiversityOutput
192
210
  Diversity index [n_class x n_factor]
193
211
 
212
+ Example
213
+ -------
214
+ Compute classwise Simpson diversity index of metadata and class labels
215
+
216
+ >>> diversity_classwise(class_labels, metadata, method="simpson").diversity_index
217
+ array([[0.33793103, 0.51578947],
218
+ [0.36 , 0.36 ]])
219
+
220
+ Compute classwise Shannon diversity index of metadata and class labels
221
+
222
+ >>> diversity_classwise(class_labels, metadata, method="shannon").diversity_index
223
+ array([[0.43156028, 0.83224889],
224
+ [0.57938016, 0.57938016]])
225
+
226
+
194
227
  See Also
195
228
  --------
196
229
  numpy.histogram