dataeval 0.64.0__py3-none-any.whl → 0.66.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. dataeval/__init__.py +13 -9
  2. dataeval/_internal/detectors/clusterer.py +63 -49
  3. dataeval/_internal/detectors/drift/base.py +248 -51
  4. dataeval/_internal/detectors/drift/cvm.py +28 -26
  5. dataeval/_internal/detectors/drift/ks.py +31 -28
  6. dataeval/_internal/detectors/drift/mmd.py +62 -42
  7. dataeval/_internal/detectors/drift/torch.py +69 -60
  8. dataeval/_internal/detectors/drift/uncertainty.py +32 -32
  9. dataeval/_internal/detectors/duplicates.py +67 -31
  10. dataeval/_internal/detectors/ood/ae.py +15 -29
  11. dataeval/_internal/detectors/ood/aegmm.py +33 -27
  12. dataeval/_internal/detectors/ood/base.py +86 -47
  13. dataeval/_internal/detectors/ood/llr.py +34 -31
  14. dataeval/_internal/detectors/ood/vae.py +32 -31
  15. dataeval/_internal/detectors/ood/vaegmm.py +34 -28
  16. dataeval/_internal/detectors/{linter.py → outliers.py} +60 -38
  17. dataeval/_internal/flags.py +44 -21
  18. dataeval/_internal/interop.py +5 -3
  19. dataeval/_internal/metrics/balance.py +42 -5
  20. dataeval/_internal/metrics/ber.py +11 -8
  21. dataeval/_internal/metrics/coverage.py +15 -8
  22. dataeval/_internal/metrics/divergence.py +41 -7
  23. dataeval/_internal/metrics/diversity.py +57 -19
  24. dataeval/_internal/metrics/parity.py +141 -66
  25. dataeval/_internal/metrics/stats.py +330 -313
  26. dataeval/_internal/metrics/uap.py +33 -4
  27. dataeval/_internal/metrics/utils.py +79 -40
  28. dataeval/_internal/models/pytorch/autoencoder.py +127 -22
  29. dataeval/_internal/models/tensorflow/autoencoder.py +33 -30
  30. dataeval/_internal/models/tensorflow/gmm.py +4 -2
  31. dataeval/_internal/models/tensorflow/losses.py +17 -13
  32. dataeval/_internal/models/tensorflow/pixelcnn.py +19 -18
  33. dataeval/_internal/models/tensorflow/trainer.py +10 -7
  34. dataeval/_internal/models/tensorflow/utils.py +23 -20
  35. dataeval/_internal/output.py +85 -0
  36. dataeval/_internal/utils.py +5 -3
  37. dataeval/_internal/workflows/sufficiency.py +122 -121
  38. dataeval/detectors/__init__.py +6 -25
  39. dataeval/detectors/drift/__init__.py +16 -0
  40. dataeval/detectors/drift/kernels/__init__.py +6 -0
  41. dataeval/detectors/drift/updates/__init__.py +3 -0
  42. dataeval/detectors/linters/__init__.py +5 -0
  43. dataeval/detectors/ood/__init__.py +11 -0
  44. dataeval/flags/__init__.py +2 -2
  45. dataeval/metrics/__init__.py +2 -26
  46. dataeval/metrics/bias/__init__.py +14 -0
  47. dataeval/metrics/estimators/__init__.py +9 -0
  48. dataeval/metrics/stats/__init__.py +6 -0
  49. dataeval/tensorflow/__init__.py +3 -0
  50. dataeval/tensorflow/loss/__init__.py +3 -0
  51. dataeval/tensorflow/models/__init__.py +5 -0
  52. dataeval/tensorflow/recon/__init__.py +3 -0
  53. dataeval/torch/__init__.py +3 -0
  54. dataeval/{models/torch → torch/models}/__init__.py +1 -2
  55. dataeval/torch/trainer/__init__.py +3 -0
  56. dataeval/utils/__init__.py +3 -6
  57. dataeval/workflows/__init__.py +2 -4
  58. {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/METADATA +1 -1
  59. dataeval-0.66.0.dist-info/RECORD +72 -0
  60. dataeval/_internal/metrics/base.py +0 -10
  61. dataeval/models/__init__.py +0 -15
  62. dataeval/models/tensorflow/__init__.py +0 -6
  63. dataeval-0.64.0.dist-info/RECORD +0 -60
  64. {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/LICENSE.txt +0 -0
  65. {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/WHEEL +0 -0
@@ -1,14 +1,19 @@
1
+ from __future__ import annotations
2
+
1
3
  import warnings
2
- from typing import Dict, List, NamedTuple, Sequence
4
+ from dataclasses import dataclass
5
+ from typing import Sequence
3
6
 
4
7
  import numpy as np
5
8
  from numpy.typing import NDArray
6
9
  from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
7
10
 
8
11
  from dataeval._internal.metrics.utils import entropy, preprocess_metadata
12
+ from dataeval._internal.output import OutputMetadata, set_metadata
9
13
 
10
14
 
11
- class BalanceOutput(NamedTuple):
15
+ @dataclass(frozen=True)
16
+ class BalanceOutput(OutputMetadata):
12
17
  """
13
18
  Attributes
14
19
  ----------
@@ -39,7 +44,8 @@ def validate_num_neighbors(num_neighbors: int) -> int:
39
44
  return num_neighbors
40
45
 
41
46
 
42
- def balance(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: int = 5) -> BalanceOutput:
47
+ @set_metadata("dataeval.metrics")
48
+ def balance(class_labels: Sequence[int], metadata: list[dict], num_neighbors: int = 5) -> BalanceOutput:
43
49
  """
44
50
  Mutual information (MI) between factors (class label, metadata, label/image properties)
45
51
 
@@ -67,6 +73,22 @@ def balance(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: in
67
73
  we attempt to infer whether a variable is categorical by the fraction of unique
68
74
  values in the dataset.
69
75
 
76
+ Example
77
+ -------
78
+ Return balance (mutual information) of factors with class_labels
79
+
80
+ >>> balance(class_labels, metadata).mutual_information[0]
81
+ array([0.99999822, 0.13363788, 0. , 0.02994455])
82
+
83
+ Return balance (mutual information) of metadata factors with class_labels
84
+ and each other
85
+
86
+ >>> balance(class_labels, metadata).mutual_information
87
+ array([[0.99999822, 0.13363788, 0. , 0.02994455],
88
+ [0.13363788, 0.99999843, 0.01389763, 0.09725766],
89
+ [0. , 0.01389763, 0.48549233, 0.15314612],
90
+ [0.02994455, 0.09725766, 0.15314612, 0.99999856]])
91
+
70
92
  See Also
71
93
  --------
72
94
  sklearn.feature_selection.mutual_info_classif
@@ -83,20 +105,24 @@ def balance(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: in
83
105
  tgt = data[:, idx]
84
106
 
85
107
  if is_categorical[idx]:
108
+ if tgt.dtype == float:
109
+ # map to unique integers if categorical
110
+ _, tgt = np.unique(tgt, return_inverse=True)
86
111
  # categorical target
87
112
  mi[idx, :] = mutual_info_classif(
88
113
  data,
89
114
  tgt,
90
115
  discrete_features=is_categorical, # type: ignore
91
116
  n_neighbors=num_neighbors,
117
+ random_state=0,
92
118
  )
93
119
  else:
94
- # continuous variables
95
120
  mi[idx, :] = mutual_info_regression(
96
121
  data,
97
122
  tgt,
98
123
  discrete_features=is_categorical, # type: ignore
99
124
  n_neighbors=num_neighbors,
125
+ random_state=0,
100
126
  )
101
127
 
102
128
  ent_all = entropy(data, names, is_categorical, normalized=False)
@@ -107,7 +133,8 @@ def balance(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: in
107
133
  return BalanceOutput(nmi)
108
134
 
109
135
 
110
- def balance_classwise(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: int = 5) -> BalanceOutput:
136
+ @set_metadata("dataeval.metrics")
137
+ def balance_classwise(class_labels: Sequence[int], metadata: list[dict], num_neighbors: int = 5) -> BalanceOutput:
111
138
  """
112
139
  Compute mutual information (analogous to correlation) between metadata factors
113
140
  (class label, metadata, label/image properties) with individual class labels.
@@ -135,6 +162,15 @@ def balance_classwise(class_labels: Sequence[int], metadata: List[Dict], num_nei
135
162
  (num_classes x num_factors) estimate of mutual information between
136
163
  num_factors metadata factors and individual class labels.
137
164
 
165
+ Example
166
+ -------
167
+ Return classwise balance (mutual information) of factors with individual class_labels
168
+
169
+ >>> balance_classwise(class_labels, metadata).mutual_information
170
+ array([[0.13363788, 0.54085156, 0. ],
171
+ [0.13363788, 0.54085156, 0. ]])
172
+
173
+
138
174
  See Also
139
175
  --------
140
176
  sklearn.feature_selection.mutual_info_classif
@@ -169,6 +205,7 @@ def balance_classwise(class_labels: Sequence[int], metadata: List[Dict], num_nei
169
205
  tgt,
170
206
  discrete_features=cat_mask, # type: ignore
171
207
  n_neighbors=num_neighbors,
208
+ random_state=0,
172
209
  )
173
210
 
174
211
  # let this recompute for all features including class label
@@ -7,7 +7,10 @@ Learning to Bound the Multi-class Bayes Error (Th. 3 and Th. 4)
7
7
  https://arxiv.org/abs/1811.06419
8
8
  """
9
9
 
10
- from typing import Literal, NamedTuple, Tuple
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass
13
+ from typing import Literal
11
14
 
12
15
  import numpy as np
13
16
  from numpy.typing import ArrayLike, NDArray
@@ -16,9 +19,11 @@ from scipy.stats import mode
16
19
 
17
20
  from dataeval._internal.interop import to_numpy
18
21
  from dataeval._internal.metrics.utils import compute_neighbors, get_classes_counts, get_method, minimum_spanning_tree
22
+ from dataeval._internal.output import OutputMetadata, set_metadata
19
23
 
20
24
 
21
- class BEROutput(NamedTuple):
25
+ @dataclass(frozen=True)
26
+ class BEROutput(OutputMetadata):
22
27
  """
23
28
  Attributes
24
29
  ----------
@@ -32,7 +37,7 @@ class BEROutput(NamedTuple):
32
37
  ber_lower: float
33
38
 
34
39
 
35
- def ber_mst(X: NDArray, y: NDArray) -> Tuple[float, float]:
40
+ def ber_mst(X: NDArray, y: NDArray) -> tuple[float, float]:
36
41
  """Calculates the Bayes Error Rate using a minimum spanning tree
37
42
 
38
43
  Parameters
@@ -57,7 +62,7 @@ def ber_mst(X: NDArray, y: NDArray) -> Tuple[float, float]:
57
62
  return upper, lower
58
63
 
59
64
 
60
- def ber_knn(X: NDArray, y: NDArray, k: int) -> Tuple[float, float]:
65
+ def ber_knn(X: NDArray, y: NDArray, k: int) -> tuple[float, float]:
61
66
  """Calculates the Bayes Error Rate using K-nearest neighbors
62
67
 
63
68
  Parameters
@@ -73,9 +78,6 @@ def ber_knn(X: NDArray, y: NDArray, k: int) -> Tuple[float, float]:
73
78
  The upper and lower bounds of the bayes error rate
74
79
  """
75
80
  M, N = get_classes_counts(y)
76
-
77
- # All features belong on second dimension
78
- X = X.reshape((X.shape[0], -1))
79
81
  nn_indices = compute_neighbors(X, X, k=k)
80
82
  nn_indices = np.expand_dims(nn_indices, axis=1) if nn_indices.ndim == 1 else nn_indices
81
83
  modal_class = mode(y[nn_indices], axis=1, keepdims=True).mode.squeeze()
@@ -107,6 +109,7 @@ def knn_lowerbound(value: float, classes: int, k: int) -> float:
107
109
  BER_FN_MAP = {"KNN": ber_knn, "MST": ber_mst}
108
110
 
109
111
 
112
+ @set_metadata("dataeval.metrics")
110
113
  def ber(images: ArrayLike, labels: ArrayLike, k: int = 1, method: Literal["KNN", "MST"] = "KNN") -> BEROutput:
111
114
  """
112
115
  An estimator for Multi-class Bayes Error Rate using FR or KNN test statistic basis
@@ -134,7 +137,7 @@ def ber(images: ArrayLike, labels: ArrayLike, k: int = 1, method: Literal["KNN",
134
137
  Examples
135
138
  --------
136
139
  >>> import sklearn.datasets as dsets
137
- >>> from dataeval.metrics import ber
140
+ >>> from dataeval.metrics.estimators import ber
138
141
 
139
142
  >>> images, labels = dsets.make_blobs(n_samples=50, centers=2, n_features=2, random_state=0)
140
143
 
@@ -1,20 +1,24 @@
1
1
  import math
2
- from typing import Literal, NamedTuple
2
+ from dataclasses import dataclass
3
+ from typing import Literal
3
4
 
4
5
  import numpy as np
5
6
  from numpy.typing import ArrayLike, NDArray
6
7
  from scipy.spatial.distance import pdist, squareform
7
8
 
8
9
  from dataeval._internal.interop import to_numpy
10
+ from dataeval._internal.metrics.utils import flatten
11
+ from dataeval._internal.output import OutputMetadata, set_metadata
9
12
 
10
13
 
11
- class CoverageOutput(NamedTuple):
14
+ @dataclass(frozen=True)
15
+ class CoverageOutput(OutputMetadata):
12
16
  """
13
17
  Attributes
14
18
  ----------
15
- indices : np.ndarray
19
+ indices : NDArray
16
20
  Array of uncovered indices
17
- radii : np.ndarray
21
+ radii : NDArray
18
22
  Array of critical value radii
19
23
  critical_value : float
20
24
  Radius for coverage
@@ -25,6 +29,7 @@ class CoverageOutput(NamedTuple):
25
29
  critical_value: float
26
30
 
27
31
 
32
+ @set_metadata("dataeval.metrics")
28
33
  def coverage(
29
34
  embeddings: ArrayLike,
30
35
  radius_type: Literal["adaptive", "naive"] = "adaptive",
@@ -87,12 +92,14 @@ def coverage(
87
92
  embeddings = to_numpy(embeddings)
88
93
  n = len(embeddings)
89
94
  if n <= k:
90
- raise ValueError("Number of observations less than or equal to the specified number of neighbors.")
91
- mat = squareform(pdist(embeddings)).astype(np.float64)
95
+ raise ValueError(
96
+ f"Number of observations n={n} is less than or equal to the specified number of neighbors k={k}."
97
+ )
98
+ mat = squareform(pdist(flatten(embeddings))).astype(np.float64)
92
99
  sorted_dists = np.sort(mat, axis=1)
93
100
  crit = sorted_dists[:, k + 1]
94
101
 
95
- d = np.shape(embeddings)[1]
102
+ d = embeddings.shape[1]
96
103
  if radius_type == "naive":
97
104
  rho = (1 / math.sqrt(math.pi)) * ((2 * k * math.gamma(d / 2 + 1)) / (n)) ** (1 / d)
98
105
  pvals = np.where(crit > rho)[0]
@@ -101,5 +108,5 @@ def coverage(
101
108
  rho = int(n * percent)
102
109
  pvals = np.argsort(crit)[::-1][:rho]
103
110
  else:
104
- raise ValueError("Invalid radius type.")
111
+ raise ValueError(f"{radius_type} is an invalid radius type. Expected 'adaptive' or 'naive'")
105
112
  return CoverageOutput(pvals, crit, rho)
@@ -3,16 +3,19 @@ This module contains the implementation of HP Divergence
3
3
  using the Fast Nearest Neighbor and Minimum Spanning Tree algorithms
4
4
  """
5
5
 
6
- from typing import Literal, NamedTuple
6
+ from dataclasses import dataclass
7
+ from typing import Literal
7
8
 
8
9
  import numpy as np
9
- from numpy.typing import ArrayLike
10
+ from numpy.typing import ArrayLike, NDArray
10
11
 
11
12
  from dataeval._internal.interop import to_numpy
12
13
  from dataeval._internal.metrics.utils import compute_neighbors, get_method, minimum_spanning_tree
14
+ from dataeval._internal.output import OutputMetadata, set_metadata
13
15
 
14
16
 
15
- class DivergenceOutput(NamedTuple):
17
+ @dataclass(frozen=True)
18
+ class DivergenceOutput(OutputMetadata):
16
19
  """
17
20
  Attributes
18
21
  ----------
@@ -26,14 +29,44 @@ class DivergenceOutput(NamedTuple):
26
29
  errors: int
27
30
 
28
31
 
29
- def divergence_mst(data: np.ndarray, labels: np.ndarray) -> int:
32
+ def divergence_mst(data: NDArray, labels: NDArray) -> int:
33
+ """
34
+ Calculates the estimated label errors based on the minimum spanning tree
35
+
36
+ Parameters
37
+ ----------
38
+ data : NDArray, shape - (N, ... )
39
+ Input images to be grouped
40
+ labels : NDArray
41
+ Corresponding labels for each data point
42
+
43
+ Returns
44
+ -------
45
+ int
46
+ Number of label errors when creating the minimum spanning tree
47
+ """
30
48
  mst = minimum_spanning_tree(data).toarray()
31
49
  edgelist = np.transpose(np.nonzero(mst))
32
50
  errors = np.sum(labels[edgelist[:, 0]] != labels[edgelist[:, 1]])
33
51
  return errors
34
52
 
35
53
 
36
- def divergence_fnn(data: np.ndarray, labels: np.ndarray) -> int:
54
+ def divergence_fnn(data: NDArray, labels: NDArray) -> int:
55
+ """
56
+ Calculates the estimated label errors based on their nearest neighbors
57
+
58
+ Parameters
59
+ ----------
60
+ data : NDArray, shape - (N, ... )
61
+ Input images to be grouped
62
+ labels : NDArray
63
+ Corresponding labels for each data point
64
+
65
+ Returns
66
+ -------
67
+ int
68
+ Number of label errors when finding nearest neighbors
69
+ """
37
70
  nn_indices = compute_neighbors(data, data)
38
71
  errors = np.sum(np.abs(labels[nn_indices] - labels))
39
72
  return errors
@@ -42,6 +75,7 @@ def divergence_fnn(data: np.ndarray, labels: np.ndarray) -> int:
42
75
  DIVERGENCE_FN_MAP = {"FNN": divergence_fnn, "MST": divergence_mst}
43
76
 
44
77
 
78
+ @set_metadata("dataeval.metrics")
45
79
  def divergence(data_a: ArrayLike, data_b: ArrayLike, method: Literal["FNN", "MST"] = "FNN") -> DivergenceOutput:
46
80
  """
47
81
  Calculates the divergence and any errors between the datasets
@@ -50,10 +84,10 @@ def divergence(data_a: ArrayLike, data_b: ArrayLike, method: Literal["FNN", "MST
50
84
  ----------
51
85
  data_a : ArrayLike, shape - (N, P)
52
86
  A dataset in an ArrayLike format to compare.
53
- Function expects the data to have 2 dimensions, N number of observations in a P-dimesionial space.
87
+ Function expects the data to have 2 dimensions, N number of observations in a P-dimensionial space.
54
88
  data_b : ArrayLike, shape - (N, P)
55
89
  A dataset in an ArrayLike format to compare.
56
- Function expects the data to have 2 dimensions, N number of observations in a P-dimesionial space.
90
+ Function expects the data to have 2 dimensions, N number of observations in a P-dimensionial space.
57
91
  method : Literal["MST, "FNN"], default "FNN"
58
92
  Method used to estimate dataset divergence
59
93
 
@@ -1,12 +1,17 @@
1
- from typing import Dict, List, Literal, NamedTuple, Optional, Sequence
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Literal, Sequence
2
5
 
3
6
  import numpy as np
4
7
  from numpy.typing import NDArray
5
8
 
6
9
  from dataeval._internal.metrics.utils import entropy, get_counts, get_method, get_num_bins, preprocess_metadata
10
+ from dataeval._internal.output import OutputMetadata, set_metadata
7
11
 
8
12
 
9
- class DiversityOutput(NamedTuple):
13
+ @dataclass(frozen=True)
14
+ class DiversityOutput(OutputMetadata):
10
15
  """
11
16
  Attributes
12
17
  ----------
@@ -18,11 +23,11 @@ class DiversityOutput(NamedTuple):
18
23
 
19
24
 
20
25
  def diversity_shannon(
21
- data: np.ndarray,
22
- names: List[str],
23
- is_categorical: List[bool],
24
- subset_mask: Optional[np.ndarray] = None,
25
- ) -> np.ndarray:
26
+ data: NDArray,
27
+ names: list[str],
28
+ is_categorical: list[bool],
29
+ subset_mask: NDArray[np.bool_] | None = None,
30
+ ) -> NDArray:
26
31
  """
27
32
  Compute diversity for discrete/categorical variables and, through standard
28
33
  histogram binning, for continuous variables.
@@ -34,7 +39,7 @@ def diversity_shannon(
34
39
 
35
40
  Parameters
36
41
  ----------
37
- subset_mask: Optional[np.ndarray[bool]]
42
+ subset_mask: NDArray[np.bool_] | None
38
43
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
39
44
 
40
45
  Notes
@@ -43,7 +48,7 @@ def diversity_shannon(
43
48
 
44
49
  Returns
45
50
  -------
46
- diversity_index: np.ndarray
51
+ diversity_index: NDArray
47
52
  Diversity index per column of X
48
53
 
49
54
  See Also
@@ -55,15 +60,18 @@ def diversity_shannon(
55
60
  ent_unnormalized = entropy(data, names, is_categorical, normalized=False, subset_mask=subset_mask)
56
61
  # normalize by global counts rather than classwise counts
57
62
  num_bins = get_num_bins(data, names, is_categorical=is_categorical, subset_mask=subset_mask)
58
- return ent_unnormalized / np.log(num_bins)
63
+ ent_norm = np.empty(ent_unnormalized.shape)
64
+ ent_norm[num_bins != 1] = ent_unnormalized[num_bins != 1] / np.log(num_bins[num_bins != 1])
65
+ ent_norm[num_bins == 1] = 0
66
+ return ent_norm
59
67
 
60
68
 
61
69
  def diversity_simpson(
62
- data: np.ndarray,
63
- names: List[str],
64
- is_categorical: List[bool],
65
- subset_mask: Optional[np.ndarray] = None,
66
- ) -> np.ndarray:
70
+ data: NDArray,
71
+ names: list[str],
72
+ is_categorical: list[bool],
73
+ subset_mask: NDArray[np.bool_] | None = None,
74
+ ) -> NDArray:
67
75
  """
68
76
  Compute diversity for discrete/categorical variables and, through standard
69
77
  histogram binning, for continuous variables.
@@ -76,7 +84,7 @@ def diversity_simpson(
76
84
 
77
85
  Parameters
78
86
  ----------
79
- subset_mask: Optional[np.ndarray[bool]]
87
+ subset_mask: NDArray[np.bool_] | None
80
88
  Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
81
89
 
82
90
  Notes
@@ -90,7 +98,7 @@ def diversity_simpson(
90
98
 
91
99
  Returns
92
100
  -------
93
- np.ndarray
101
+ NDArray
94
102
  Diversity index per column of X
95
103
 
96
104
  See Also
@@ -116,8 +124,9 @@ def diversity_simpson(
116
124
  DIVERSITY_FN_MAP = {"simpson": diversity_simpson, "shannon": diversity_shannon}
117
125
 
118
126
 
127
+ @set_metadata("dataeval.metrics")
119
128
  def diversity(
120
- class_labels: Sequence[int], metadata: List[Dict], method: Literal["shannon", "simpson"] = "simpson"
129
+ class_labels: Sequence[int], metadata: list[dict], method: Literal["shannon", "simpson"] = "simpson"
121
130
  ) -> DiversityOutput:
122
131
  """
123
132
  Compute diversity for discrete/categorical variables and, through standard
@@ -145,6 +154,19 @@ def diversity(
145
154
  DiversityOutput
146
155
  Diversity index per column of self.data or each factor in self.names
147
156
 
157
+ Example
158
+ -------
159
+ Compute Simpson diversity index of metadata and class labels
160
+
161
+ >>> diversity(class_labels, metadata, method="simpson").diversity_index
162
+ array([0.34482759, 0.34482759, 0.90909091])
163
+
164
+ Compute Shannon diversity index of metadata and class labels
165
+
166
+ >>> diversity(class_labels, metadata, method="shannon").diversity_index
167
+ array([0.37955133, 0.37955133, 0.96748876])
168
+
169
+
148
170
  See Also
149
171
  --------
150
172
  numpy.histogram
@@ -155,8 +177,9 @@ def diversity(
155
177
  return DiversityOutput(diversity_index)
156
178
 
157
179
 
180
+ @set_metadata("dataeval.metrics")
158
181
  def diversity_classwise(
159
- class_labels: Sequence[int], metadata: List[Dict], method: Literal["shannon", "simpson"] = "simpson"
182
+ class_labels: Sequence[int], metadata: list[dict], method: Literal["shannon", "simpson"] = "simpson"
160
183
  ) -> DiversityOutput:
161
184
  """
162
185
  Compute diversity for discrete/categorical variables and, through standard
@@ -186,6 +209,21 @@ def diversity_classwise(
186
209
  DiversityOutput
187
210
  Diversity index [n_class x n_factor]
188
211
 
212
+ Example
213
+ -------
214
+ Compute classwise Simpson diversity index of metadata and class labels
215
+
216
+ >>> diversity_classwise(class_labels, metadata, method="simpson").diversity_index
217
+ array([[0.33793103, 0.51578947],
218
+ [0.36 , 0.36 ]])
219
+
220
+ Compute classwise Shannon diversity index of metadata and class labels
221
+
222
+ >>> diversity_classwise(class_labels, metadata, method="shannon").diversity_index
223
+ array([[0.43156028, 0.83224889],
224
+ [0.57938016, 0.57938016]])
225
+
226
+
189
227
  See Also
190
228
  --------
191
229
  numpy.histogram