dataeval 0.72.1__py3-none-any.whl → 0.73.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. dataeval/__init__.py +4 -4
  2. dataeval/detectors/__init__.py +4 -3
  3. dataeval/detectors/drift/__init__.py +9 -10
  4. dataeval/{_internal/detectors → detectors}/drift/base.py +39 -91
  5. dataeval/{_internal/detectors → detectors}/drift/cvm.py +4 -3
  6. dataeval/{_internal/detectors → detectors}/drift/ks.py +4 -3
  7. dataeval/{_internal/detectors → detectors}/drift/mmd.py +23 -25
  8. dataeval/{_internal/detectors → detectors}/drift/torch.py +13 -11
  9. dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +7 -5
  10. dataeval/detectors/drift/updates.py +61 -0
  11. dataeval/detectors/linters/__init__.py +3 -3
  12. dataeval/{_internal/detectors → detectors/linters}/clusterer.py +41 -39
  13. dataeval/{_internal/detectors → detectors/linters}/duplicates.py +19 -9
  14. dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
  15. dataeval/{_internal/detectors → detectors/linters}/outliers.py +14 -21
  16. dataeval/detectors/ood/__init__.py +6 -6
  17. dataeval/{_internal/detectors → detectors}/ood/ae.py +20 -12
  18. dataeval/detectors/ood/aegmm.py +66 -0
  19. dataeval/{_internal/detectors → detectors}/ood/base.py +33 -21
  20. dataeval/{_internal/detectors → detectors}/ood/llr.py +43 -33
  21. dataeval/detectors/ood/metadata_ks_compare.py +99 -0
  22. dataeval/detectors/ood/metadata_least_likely.py +119 -0
  23. dataeval/detectors/ood/metadata_ood_mi.py +92 -0
  24. dataeval/{_internal/detectors → detectors}/ood/vae.py +23 -17
  25. dataeval/detectors/ood/vaegmm.py +75 -0
  26. dataeval/interop.py +56 -0
  27. dataeval/metrics/__init__.py +1 -1
  28. dataeval/metrics/bias/__init__.py +4 -4
  29. dataeval/{_internal/metrics → metrics/bias}/balance.py +75 -13
  30. dataeval/{_internal/metrics → metrics/bias}/coverage.py +41 -7
  31. dataeval/{_internal/metrics → metrics/bias}/diversity.py +75 -18
  32. dataeval/metrics/bias/metadata.py +358 -0
  33. dataeval/{_internal/metrics → metrics/bias}/parity.py +54 -44
  34. dataeval/metrics/estimators/__init__.py +3 -3
  35. dataeval/{_internal/metrics → metrics/estimators}/ber.py +25 -22
  36. dataeval/{_internal/metrics → metrics/estimators}/divergence.py +11 -12
  37. dataeval/{_internal/metrics → metrics/estimators}/uap.py +5 -3
  38. dataeval/metrics/stats/__init__.py +7 -7
  39. dataeval/{_internal/metrics → metrics}/stats/base.py +59 -35
  40. dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +18 -14
  41. dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +18 -16
  42. dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +9 -7
  43. dataeval/metrics/stats/hashstats.py +156 -0
  44. dataeval/{_internal/metrics → metrics}/stats/labelstats.py +5 -3
  45. dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +9 -8
  46. dataeval/{_internal/metrics → metrics}/stats/visualstats.py +10 -9
  47. dataeval/{_internal/output.py → output.py} +26 -6
  48. dataeval/utils/__init__.py +8 -3
  49. dataeval/utils/image.py +71 -0
  50. dataeval/utils/lazy.py +26 -0
  51. dataeval/utils/metadata.py +258 -0
  52. dataeval/utils/shared.py +151 -0
  53. dataeval/{_internal → utils}/split_dataset.py +98 -33
  54. dataeval/utils/tensorflow/__init__.py +7 -6
  55. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +8 -2
  56. dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +28 -18
  57. dataeval/{_internal/models/tensorflow/pixelcnn.py → utils/tensorflow/_internal/models.py} +387 -97
  58. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +15 -6
  59. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +84 -85
  60. dataeval/utils/tensorflow/loss/__init__.py +6 -2
  61. dataeval/utils/torch/__init__.py +7 -3
  62. dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
  63. dataeval/{_internal → utils/torch}/datasets.py +48 -42
  64. dataeval/utils/torch/models.py +138 -0
  65. dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +7 -136
  66. dataeval/{_internal → utils/torch}/utils.py +3 -1
  67. dataeval/workflows/__init__.py +1 -1
  68. dataeval/{_internal/workflows → workflows}/sufficiency.py +39 -34
  69. {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/METADATA +4 -3
  70. dataeval-0.73.0.dist-info/RECORD +73 -0
  71. dataeval/_internal/detectors/__init__.py +0 -0
  72. dataeval/_internal/detectors/drift/__init__.py +0 -0
  73. dataeval/_internal/detectors/ood/__init__.py +0 -0
  74. dataeval/_internal/detectors/ood/aegmm.py +0 -78
  75. dataeval/_internal/detectors/ood/vaegmm.py +0 -89
  76. dataeval/_internal/interop.py +0 -49
  77. dataeval/_internal/metrics/__init__.py +0 -0
  78. dataeval/_internal/metrics/stats/hashstats.py +0 -75
  79. dataeval/_internal/metrics/utils.py +0 -447
  80. dataeval/_internal/models/__init__.py +0 -0
  81. dataeval/_internal/models/pytorch/__init__.py +0 -0
  82. dataeval/_internal/models/pytorch/utils.py +0 -67
  83. dataeval/_internal/models/tensorflow/__init__.py +0 -0
  84. dataeval/_internal/models/tensorflow/autoencoder.py +0 -320
  85. dataeval/_internal/workflows/__init__.py +0 -0
  86. dataeval/detectors/drift/kernels/__init__.py +0 -10
  87. dataeval/detectors/drift/updates/__init__.py +0 -8
  88. dataeval/utils/tensorflow/models/__init__.py +0 -9
  89. dataeval/utils/tensorflow/recon/__init__.py +0 -3
  90. dataeval/utils/torch/datasets/__init__.py +0 -12
  91. dataeval/utils/torch/models/__init__.py +0 -11
  92. dataeval/utils/torch/trainer/__init__.py +0 -7
  93. dataeval-0.72.1.dist-info/RECORD +0 -81
  94. {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/LICENSE.txt +0 -0
  95. {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/WHEEL +0 -0
@@ -3,6 +3,6 @@ Metrics are a way to measure the performance of your models or datasets that
3
3
  can then be analyzed in the context of a given problem.
4
4
  """
5
5
 
6
- from . import bias, estimators, stats
6
+ from dataeval.metrics import bias, estimators, stats
7
7
 
8
8
  __all__ = ["bias", "estimators", "stats"]
@@ -3,10 +3,10 @@ Bias metrics check for skewed or imbalanced datasets and incomplete feature
3
3
  representation which may impact model performance.
4
4
  """
5
5
 
6
- from dataeval._internal.metrics.balance import BalanceOutput, balance
7
- from dataeval._internal.metrics.coverage import CoverageOutput, coverage
8
- from dataeval._internal.metrics.diversity import DiversityOutput, diversity
9
- from dataeval._internal.metrics.parity import ParityOutput, label_parity, parity
6
+ from dataeval.metrics.bias.balance import BalanceOutput, balance
7
+ from dataeval.metrics.bias.coverage import CoverageOutput, coverage
8
+ from dataeval.metrics.bias.diversity import DiversityOutput, diversity
9
+ from dataeval.metrics.bias.parity import ParityOutput, label_parity, parity
10
10
 
11
11
  __all__ = [
12
12
  "balance",
@@ -1,35 +1,98 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["BalanceOutput", "balance"]
4
+
5
+ import contextlib
3
6
  import warnings
4
7
  from dataclasses import dataclass
5
- from typing import Mapping
8
+ from typing import Any, Mapping
6
9
 
7
10
  import numpy as np
8
11
  from numpy.typing import ArrayLike, NDArray
9
12
  from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
10
13
 
11
- from dataeval._internal.metrics.utils import entropy, preprocess_metadata
12
- from dataeval._internal.output import OutputMetadata, set_metadata
14
+ from dataeval.metrics.bias.metadata import entropy, heatmap, preprocess_metadata
15
+ from dataeval.output import OutputMetadata, set_metadata
16
+
17
+ with contextlib.suppress(ImportError):
18
+ from matplotlib.figure import Figure
13
19
 
14
20
 
15
21
  @dataclass(frozen=True)
16
22
  class BalanceOutput(OutputMetadata):
17
23
  """
18
- Output class for :func:`balance` :term:`Bias` metric
24
+ Output class for :func:`balance` bias metric
19
25
 
20
26
  Attributes
21
27
  ----------
22
28
  balance : NDArray[np.float64]
23
- Estimate of :term:`mutual information<Mutual Information (MI)>` between metadata factors and class label
29
+ Estimate of mutual information between metadata factors and class label
24
30
  factors : NDArray[np.float64]
25
31
  Estimate of inter/intra-factor mutual information
26
32
  classwise : NDArray[np.float64]
27
33
  Estimate of mutual information between metadata factors and individual class labels
34
+ class_list: NDArray
35
+ Array of the class labels present in the dataset
36
+ metadata_names: list[str]
37
+ Names of each metadata factor
28
38
  """
29
39
 
30
40
  balance: NDArray[np.float64]
31
41
  factors: NDArray[np.float64]
32
42
  classwise: NDArray[np.float64]
43
+ class_list: NDArray[Any]
44
+ metadata_names: list[str]
45
+
46
+ def plot(
47
+ self,
48
+ row_labels: list[Any] | NDArray[Any] | None = None,
49
+ col_labels: list[Any] | NDArray[Any] | None = None,
50
+ plot_classwise: bool = False,
51
+ ) -> Figure:
52
+ """
53
+ Plot a heatmap of balance information
54
+
55
+ Parameters
56
+ ----------
57
+ row_labels : ArrayLike | None, default None
58
+ List/Array containing the labels for rows in the histogram
59
+ col_labels : ArrayLike | None, default None
60
+ List/Array containing the labels for columns in the histogram
61
+ plot_classwise : bool, default False
62
+ Whether to plot per-class balance instead of global balance
63
+ """
64
+ if plot_classwise:
65
+ if row_labels is None:
66
+ row_labels = self.class_list
67
+ if col_labels is None:
68
+ col_labels = np.concatenate((["class"], self.metadata_names))
69
+
70
+ fig = heatmap(
71
+ self.classwise,
72
+ row_labels,
73
+ col_labels,
74
+ xlabel="Factors",
75
+ ylabel="Class",
76
+ cbarlabel="Normalized Mutual Information",
77
+ )
78
+ else:
79
+ # Combine balance and factors results
80
+ data = np.concatenate([self.balance[np.newaxis, 1:], self.factors], axis=0)
81
+ # Create a mask for the upper triangle of the symmetrical array, ignoring the diagonal
82
+ mask = np.triu(data + 1, k=0) < 1
83
+ # Finalize the data for the plot, last row is last factor x last factor so it gets dropped
84
+ heat_data = np.where(mask, np.nan, data)[:-1]
85
+ # Creating label array for heat map axes
86
+ heat_labels = np.concatenate((["class"], self.metadata_names))
87
+
88
+ if row_labels is None:
89
+ row_labels = heat_labels[:-1]
90
+ if col_labels is None:
91
+ col_labels = heat_labels[1:]
92
+
93
+ fig = heatmap(heat_data, row_labels, col_labels, cbarlabel="Normalized Mutual Information")
94
+
95
+ return fig
33
96
 
34
97
 
35
98
  def validate_num_neighbors(num_neighbors: int) -> int:
@@ -55,7 +118,7 @@ def validate_num_neighbors(num_neighbors: int) -> int:
55
118
  @set_metadata("dataeval.metrics")
56
119
  def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neighbors: int = 5) -> BalanceOutput:
57
120
  """
58
- :term:`Mutual information (MI)` between factors (class label, metadata, label/image properties)
121
+ Mutual information (MI) between factors (class label, metadata, label/image properties)
59
122
 
60
123
  Parameters
61
124
  ----------
@@ -70,7 +133,7 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
70
133
  Returns
71
134
  -------
72
135
  BalanceOutput
73
- (num_factors+1) x (num_factors+1) estimate of :term:`mutual information<Mutual Information (MI)>`
136
+ (num_factors+1) x (num_factors+1) estimate of mutual information
74
137
  between num_factors metadata factors and class label. Symmetry is enforced.
75
138
 
76
139
  Note
@@ -83,7 +146,7 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
83
146
 
84
147
  Example
85
148
  -------
86
- Return :term:`balance<Balance>` (:term:`mutual information<Mutual Information (MI)>`) of factors with class_labels
149
+ Return balance (mutual information) of factors with class_labels
87
150
 
88
151
  >>> bal = balance(class_labels, metadata)
89
152
  >>> bal.balance
@@ -109,7 +172,7 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
109
172
  sklearn.metrics.mutual_info_score
110
173
  """
111
174
  num_neighbors = validate_num_neighbors(num_neighbors)
112
- data, names, is_categorical = preprocess_metadata(class_labels, metadata)
175
+ data, names, is_categorical, unique_labels = preprocess_metadata(class_labels, metadata)
113
176
  num_factors = len(names)
114
177
  mi = np.empty((num_factors, num_factors))
115
178
  mi[:] = np.nan
@@ -143,8 +206,7 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
143
206
 
144
207
  # unique class labels
145
208
  class_idx = names.index("class_label")
146
- class_data = data[:, class_idx].astype(int)
147
- u_cls = np.unique(class_data)
209
+ u_cls = np.unique(data[:, class_idx])
148
210
  num_classes = len(u_cls)
149
211
 
150
212
  # assume class is a factor
@@ -154,7 +216,7 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
154
216
  # categorical variables, excluding class label
155
217
  cat_mask = np.concatenate((is_categorical[:class_idx], is_categorical[(class_idx + 1) :]), axis=0).astype(int)
156
218
 
157
- tgt_bin = np.stack([class_data == cls for cls in u_cls]).T.astype(int)
219
+ tgt_bin = np.stack([data[:, class_idx] == cls for cls in u_cls]).T.astype(int)
158
220
  ent_tgt_bin = entropy(
159
221
  tgt_bin, names=[str(idx) for idx in range(num_classes)], is_categorical=[True for idx in range(num_classes)]
160
222
  )
@@ -174,4 +236,4 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
174
236
  norm_factor = 0.5 * np.add.outer(ent_tgt_bin, ent_all) + 1e-6
175
237
  classwise = classwise_mi / norm_factor
176
238
 
177
- return BalanceOutput(balance, factors, classwise)
239
+ return BalanceOutput(balance, factors, classwise, unique_labels, list(metadata.keys()))
@@ -1,16 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["CoverageOutput", "coverage"]
4
+
5
+ import contextlib
3
6
  import math
4
7
  from dataclasses import dataclass
5
- from typing import Literal
8
+ from typing import Any, Literal
6
9
 
7
10
  import numpy as np
8
11
  from numpy.typing import ArrayLike, NDArray
9
12
  from scipy.spatial.distance import pdist, squareform
10
13
 
11
- from dataeval._internal.interop import to_numpy
12
- from dataeval._internal.metrics.utils import flatten
13
- from dataeval._internal.output import OutputMetadata, set_metadata
14
+ from dataeval.interop import to_numpy
15
+ from dataeval.metrics.bias.metadata import coverage_plot
16
+ from dataeval.output import OutputMetadata, set_metadata
17
+ from dataeval.utils.shared import flatten
18
+
19
+ with contextlib.suppress(ImportError):
20
+ from matplotlib.figure import Figure
14
21
 
15
22
 
16
23
  @dataclass(frozen=True)
@@ -32,13 +39,40 @@ class CoverageOutput(OutputMetadata):
32
39
  radii: NDArray[np.float64]
33
40
  critical_value: float
34
41
 
42
+ def plot(
43
+ self,
44
+ images: NDArray[Any],
45
+ top_k: int = 6,
46
+ ) -> Figure:
47
+ """
48
+ Plot the top k images together for visualization
49
+
50
+ Parameters
51
+ ----------
52
+ images : ArrayLike
53
+ Original images (not embeddings) in (N, C, H, W) or (N, H, W) format
54
+ top_k : int, default 6
55
+ Number of images to plot (plotting assumes groups of 3)
56
+ """
57
+ # Determine which images to plot
58
+ highest_uncovered_indices = self.indices[:top_k]
59
+
60
+ # Grab the images
61
+ images = to_numpy(images)
62
+ selected_images = images[highest_uncovered_indices]
63
+
64
+ # Plot the images
65
+ fig = coverage_plot(selected_images, top_k)
66
+
67
+ return fig
68
+
35
69
 
36
- @set_metadata("dataeval.metrics")
70
+ @set_metadata()
37
71
  def coverage(
38
72
  embeddings: ArrayLike,
39
73
  radius_type: Literal["adaptive", "naive"] = "adaptive",
40
74
  k: int = 20,
41
- percent: np.float64 = np.float64(0.01),
75
+ percent: float = 0.01,
42
76
  ) -> CoverageOutput:
43
77
  """
44
78
  Class for evaluating :term:`coverage<Coverage>` and identifying images/samples that are in undercovered regions.
@@ -53,7 +87,7 @@ def coverage(
53
87
  k: int, default 20
54
88
  Number of observations required in order to be covered.
55
89
  [1] suggests that a minimum of 20-50 samples is necessary.
56
- percent: np.float64, default np.float(0.01)
90
+ percent: float, default 0.01
57
91
  Percent of observations to be considered uncovered. Only applies to adaptive radius.
58
92
 
59
93
  Returns
@@ -1,13 +1,27 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["DiversityOutput", "diversity"]
4
+
5
+ import contextlib
3
6
  from dataclasses import dataclass
4
- from typing import Literal, Mapping
7
+ from typing import Any, Literal, Mapping
5
8
 
6
9
  import numpy as np
7
10
  from numpy.typing import ArrayLike, NDArray
8
11
 
9
- from dataeval._internal.metrics.utils import entropy, get_counts, get_method, get_num_bins, preprocess_metadata
10
- from dataeval._internal.output import OutputMetadata, set_metadata
12
+ from dataeval.metrics.bias.metadata import (
13
+ diversity_bar_plot,
14
+ entropy,
15
+ get_counts,
16
+ get_num_bins,
17
+ heatmap,
18
+ preprocess_metadata,
19
+ )
20
+ from dataeval.output import OutputMetadata, set_metadata
21
+ from dataeval.utils.shared import get_method
22
+
23
+ with contextlib.suppress(ImportError):
24
+ from matplotlib.figure import Figure
11
25
 
12
26
 
13
27
  @dataclass(frozen=True)
@@ -21,18 +35,66 @@ class DiversityOutput(OutputMetadata):
21
35
  :term:`Diversity` index for classes and factors
22
36
  classwise : NDArray[np.float64]
23
37
  Classwise diversity index [n_class x n_factor]
38
+ class_list: NDArray[np.int64]
39
+ Class labels for each value in the dataset
40
+ metadata_names: list[str]
41
+ Names of each metadata factor
24
42
  """
25
43
 
26
44
  diversity_index: NDArray[np.float64]
27
45
  classwise: NDArray[np.float64]
46
+ class_list: NDArray[Any]
47
+ metadata_names: list[str]
48
+ method: Literal["shannon", "simpson"]
49
+
50
+ def plot(
51
+ self,
52
+ row_labels: list[Any] | NDArray[Any] | None = None,
53
+ col_labels: list[Any] | NDArray[Any] | None = None,
54
+ plot_classwise: bool = False,
55
+ ) -> Figure:
56
+ """
57
+ Plot a heatmap of diversity information
58
+
59
+ Parameters
60
+ ----------
61
+ row_labels : ArrayLike | None, default None
62
+ List/Array containing the labels for rows in the histogram
63
+ col_labels : ArrayLike | None, default None
64
+ List/Array containing the labels for columns in the histogram
65
+ plot_classwise : bool, default False
66
+ Whether to plot per-class balance instead of global balance
67
+ """
68
+ if plot_classwise:
69
+ if row_labels is None:
70
+ row_labels = self.class_list
71
+ if col_labels is None:
72
+ col_labels = self.metadata_names
73
+
74
+ fig = heatmap(
75
+ self.classwise,
76
+ row_labels,
77
+ col_labels,
78
+ xlabel="Factors",
79
+ ylabel="Class",
80
+ cbarlabel=f"Normalized {self.method.title()} Index",
81
+ )
82
+
83
+ else:
84
+ # Creating label array for heat map axes
85
+ heat_labels = np.concatenate((["class"], self.metadata_names))
86
+
87
+ fig = diversity_bar_plot(heat_labels, self.diversity_index)
88
+
89
+ return fig
28
90
 
29
91
 
30
92
  def diversity_shannon(
31
- data: NDArray,
93
+ data: NDArray[Any],
32
94
  names: list[str],
33
95
  is_categorical: list[bool],
34
96
  subset_mask: NDArray[np.bool_] | None = None,
35
- ) -> NDArray:
97
+ ) -> NDArray[np.float64]:
36
98
  """
37
99
  Compute :term:`diversity<Diversity>` for discrete/categorical variables and, through standard
38
100
  histogram binning, for continuous variables.
@@ -79,11 +141,11 @@ def diversity_shannon(
79
141
 
80
142
 
81
143
  def diversity_simpson(
82
- data: NDArray,
144
+ data: NDArray[Any],
83
145
  names: list[str],
84
146
  is_categorical: list[bool],
85
147
  subset_mask: NDArray[np.bool_] | None = None,
86
- ) -> NDArray:
148
+ ) -> NDArray[np.float64]:
87
149
  """
88
150
  Compute :term:`diversity<Diversity>` for discrete/categorical variables and, through standard
89
151
  histogram binning, for continuous variables.
@@ -139,10 +201,7 @@ def diversity_simpson(
139
201
  return ev_index
140
202
 
141
203
 
142
- DIVERSITY_FN_MAP = {"simpson": diversity_simpson, "shannon": diversity_shannon}
143
-
144
-
145
- @set_metadata("dataeval.metrics")
204
+ @set_metadata()
146
205
  def diversity(
147
206
  class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], method: Literal["shannon", "simpson"] = "simpson"
148
207
  ) -> DiversityOutput:
@@ -202,20 +261,18 @@ def diversity(
202
261
  --------
203
262
  numpy.histogram
204
263
  """
205
- diversity_fn = get_method(DIVERSITY_FN_MAP, method)
206
- data, names, is_categorical = preprocess_metadata(class_labels, metadata)
264
+ diversity_fn = get_method({"simpson": diversity_simpson, "shannon": diversity_shannon}, method)
265
+ data, names, is_categorical, unique_labels = preprocess_metadata(class_labels, metadata)
207
266
  diversity_index = diversity_fn(data, names, is_categorical, None).astype(np.float64)
208
267
 
209
268
  class_idx = names.index("class_label")
210
- class_lbl = data[:, class_idx]
211
-
212
- u_classes = np.unique(class_lbl)
269
+ u_classes = np.unique(data[:, class_idx])
213
270
  num_factors = len(names)
214
271
  diversity = np.empty((len(u_classes), num_factors))
215
272
  diversity[:] = np.nan
216
273
  for idx, cls in enumerate(u_classes):
217
- subset_mask = class_lbl == cls
274
+ subset_mask = data[:, class_idx] == cls
218
275
  diversity[idx, :] = diversity_fn(data, names, is_categorical, subset_mask)
219
276
  div_no_class = np.concatenate((diversity[:, :class_idx], diversity[:, (class_idx + 1) :]), axis=1)
220
277
 
221
- return DiversityOutput(diversity_index, div_no_class)
278
+ return DiversityOutput(diversity_index, div_no_class, unique_labels, list(metadata.keys()), method)