dataeval 0.76.0__py3-none-any.whl → 0.81.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/{output.py → _output.py} +14 -0
  3. dataeval/config.py +77 -0
  4. dataeval/detectors/__init__.py +1 -1
  5. dataeval/detectors/drift/__init__.py +6 -6
  6. dataeval/detectors/drift/{base.py → _base.py} +41 -30
  7. dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
  8. dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
  9. dataeval/detectors/drift/{mmd.py → _mmd.py} +33 -19
  10. dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
  11. dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +23 -7
  12. dataeval/detectors/drift/updates.py +1 -1
  13. dataeval/detectors/linters/__init__.py +0 -3
  14. dataeval/detectors/linters/duplicates.py +17 -8
  15. dataeval/detectors/linters/outliers.py +52 -43
  16. dataeval/detectors/ood/ae.py +29 -8
  17. dataeval/detectors/ood/base.py +5 -4
  18. dataeval/detectors/ood/metadata_ks_compare.py +1 -1
  19. dataeval/detectors/ood/mixin.py +20 -5
  20. dataeval/detectors/ood/output.py +1 -1
  21. dataeval/detectors/ood/vae.py +73 -0
  22. dataeval/metadata/__init__.py +5 -0
  23. dataeval/metadata/_ood.py +238 -0
  24. dataeval/metrics/__init__.py +1 -1
  25. dataeval/metrics/bias/__init__.py +5 -4
  26. dataeval/metrics/bias/{balance.py → _balance.py} +67 -17
  27. dataeval/metrics/bias/{coverage.py → _coverage.py} +41 -35
  28. dataeval/metrics/bias/{diversity.py → _diversity.py} +17 -12
  29. dataeval/metrics/bias/{parity.py → _parity.py} +89 -63
  30. dataeval/metrics/estimators/__init__.py +14 -4
  31. dataeval/metrics/estimators/{ber.py → _ber.py} +42 -11
  32. dataeval/metrics/estimators/_clusterer.py +104 -0
  33. dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -13
  34. dataeval/metrics/estimators/{uap.py → _uap.py} +4 -4
  35. dataeval/metrics/stats/__init__.py +7 -7
  36. dataeval/metrics/stats/{base.py → _base.py} +52 -16
  37. dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +6 -9
  38. dataeval/metrics/stats/{datasetstats.py → _datasetstats.py} +10 -14
  39. dataeval/metrics/stats/{dimensionstats.py → _dimensionstats.py} +6 -5
  40. dataeval/metrics/stats/{hashstats.py → _hashstats.py} +6 -6
  41. dataeval/metrics/stats/{labelstats.py → _labelstats.py} +25 -25
  42. dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +5 -4
  43. dataeval/metrics/stats/{visualstats.py → _visualstats.py} +9 -8
  44. dataeval/typing.py +54 -0
  45. dataeval/utils/__init__.py +2 -2
  46. dataeval/utils/_array.py +169 -0
  47. dataeval/utils/_bin.py +199 -0
  48. dataeval/utils/_clusterer.py +144 -0
  49. dataeval/utils/_fast_mst.py +189 -0
  50. dataeval/utils/{image.py → _image.py} +6 -4
  51. dataeval/utils/_method.py +18 -0
  52. dataeval/utils/{shared.py → _mst.py} +3 -65
  53. dataeval/utils/{plot.py → _plot.py} +4 -4
  54. dataeval/utils/data/__init__.py +22 -0
  55. dataeval/utils/data/_embeddings.py +105 -0
  56. dataeval/utils/data/_images.py +65 -0
  57. dataeval/utils/data/_metadata.py +352 -0
  58. dataeval/utils/data/_selection.py +119 -0
  59. dataeval/utils/{dataset/split.py → data/_split.py} +13 -14
  60. dataeval/utils/data/_targets.py +73 -0
  61. dataeval/utils/data/_types.py +58 -0
  62. dataeval/utils/data/collate.py +103 -0
  63. dataeval/utils/data/datasets/__init__.py +17 -0
  64. dataeval/utils/data/datasets/_base.py +254 -0
  65. dataeval/utils/data/datasets/_cifar10.py +134 -0
  66. dataeval/utils/data/datasets/_fileio.py +168 -0
  67. dataeval/utils/data/datasets/_milco.py +153 -0
  68. dataeval/utils/data/datasets/_mixin.py +56 -0
  69. dataeval/utils/data/datasets/_mnist.py +183 -0
  70. dataeval/utils/data/datasets/_ships.py +123 -0
  71. dataeval/utils/data/datasets/_voc.py +352 -0
  72. dataeval/utils/data/selections/__init__.py +15 -0
  73. dataeval/utils/data/selections/_classfilter.py +60 -0
  74. dataeval/utils/data/selections/_indices.py +26 -0
  75. dataeval/utils/data/selections/_limit.py +26 -0
  76. dataeval/utils/data/selections/_reverse.py +18 -0
  77. dataeval/utils/data/selections/_shuffle.py +29 -0
  78. dataeval/utils/metadata.py +198 -376
  79. dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
  80. dataeval/utils/torch/{internal.py → _internal.py} +21 -51
  81. dataeval/utils/torch/models.py +43 -2
  82. dataeval/workflows/sufficiency.py +10 -9
  83. {dataeval-0.76.0.dist-info → dataeval-0.81.0.dist-info}/METADATA +44 -15
  84. dataeval-0.81.0.dist-info/RECORD +94 -0
  85. dataeval/detectors/linters/clusterer.py +0 -512
  86. dataeval/detectors/linters/merged_stats.py +0 -49
  87. dataeval/detectors/ood/metadata_least_likely.py +0 -119
  88. dataeval/interop.py +0 -69
  89. dataeval/utils/dataset/__init__.py +0 -7
  90. dataeval/utils/dataset/datasets.py +0 -412
  91. dataeval/utils/dataset/read.py +0 -63
  92. dataeval-0.76.0.dist-info/RECORD +0 -67
  93. /dataeval/{log.py → _log.py} +0 -0
  94. /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
  95. {dataeval-0.76.0.dist-info → dataeval-0.81.0.dist-info}/LICENSE.txt +0 -0
  96. {dataeval-0.76.0.dist-info → dataeval-0.81.0.dist-info}/WHEEL +0 -0
@@ -5,4 +5,4 @@ can then be analyzed in the context of a given problem.
5
5
 
6
6
  __all__ = ["bias", "estimators", "stats"]
7
7
 
8
- from dataeval.metrics import bias, estimators, stats
8
+ from . import bias, estimators, stats
@@ -7,6 +7,7 @@ __all__ = [
7
7
  "BalanceOutput",
8
8
  "CoverageOutput",
9
9
  "DiversityOutput",
10
+ "LabelParityOutput",
10
11
  "ParityOutput",
11
12
  "balance",
12
13
  "coverage",
@@ -15,7 +16,7 @@ __all__ = [
15
16
  "parity",
16
17
  ]
17
18
 
18
- from dataeval.metrics.bias.balance import BalanceOutput, balance
19
- from dataeval.metrics.bias.coverage import CoverageOutput, coverage
20
- from dataeval.metrics.bias.diversity import DiversityOutput, diversity
21
- from dataeval.metrics.bias.parity import ParityOutput, label_parity, parity
19
+ from dataeval.metrics.bias._balance import BalanceOutput, balance
20
+ from dataeval.metrics.bias._coverage import CoverageOutput, coverage
21
+ from dataeval.metrics.bias._diversity import DiversityOutput, diversity
22
+ from dataeval.metrics.bias._parity import LabelParityOutput, ParityOutput, label_parity, parity
@@ -5,16 +5,17 @@ __all__ = []
5
5
  import contextlib
6
6
  import warnings
7
7
  from dataclasses import dataclass
8
- from typing import Any
8
+ from typing import Any, Literal, overload
9
9
 
10
10
  import numpy as np
11
11
  import scipy as sp
12
12
  from numpy.typing import NDArray
13
13
  from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
14
14
 
15
- from dataeval.output import Output, set_metadata
16
- from dataeval.utils.metadata import Metadata, get_counts
17
- from dataeval.utils.plot import heatmap
15
+ from dataeval._output import Output, set_metadata
16
+ from dataeval.utils._bin import get_counts
17
+ from dataeval.utils._plot import heatmap
18
+ from dataeval.utils.data import Metadata
18
19
 
19
20
  with contextlib.suppress(ImportError):
20
21
  from matplotlib.figure import Figure
@@ -23,8 +24,8 @@ with contextlib.suppress(ImportError):
23
24
  @dataclass(frozen=True)
24
25
  class BalanceOutput(Output):
25
26
  """
26
- Output class for :func:`balance` :term:`bias<Bias>` metric.
27
-
27
+ Output class for :func:`.balance` :term:`bias<Bias>` metric.
28
+
28
29
  Attributes
29
30
  ----------
30
31
  balance : NDArray[np.float64]
@@ -35,21 +36,62 @@ class BalanceOutput(Output):
35
36
  Estimate of mutual information between metadata factors and individual class labels
36
37
  factor_names : list[str]
37
38
  Names of each metadata factor
38
- class_list : NDArray
39
- Array of the class labels present in the dataset
39
+ class_names : list[str]
40
+ List of the class labels present in the dataset
40
41
  """
41
42
 
42
43
  balance: NDArray[np.float64]
43
44
  factors: NDArray[np.float64]
44
45
  classwise: NDArray[np.float64]
45
46
  factor_names: list[str]
46
- class_list: NDArray[Any]
47
+ class_names: list[str]
48
+
49
+ @overload
50
+ def _by_factor_type(
51
+ self,
52
+ attr: Literal["factor_names"],
53
+ factor_type: Literal["discrete", "continuous", "both"],
54
+ ) -> list[str]: ...
55
+
56
+ @overload
57
+ def _by_factor_type(
58
+ self,
59
+ attr: Literal["balance", "factors", "classwise"],
60
+ factor_type: Literal["discrete", "continuous", "both"],
61
+ ) -> NDArray[np.float64]: ...
62
+
63
+ def _by_factor_type(
64
+ self,
65
+ attr: Literal["balance", "factors", "classwise", "factor_names"],
66
+ factor_type: Literal["discrete", "continuous", "both"],
67
+ ) -> NDArray[np.float64] | list[str]:
68
+ # if not filtering by factor_type then just return the requested attribute without mask
69
+ if factor_type == "both":
70
+ return getattr(self, attr)
71
+
72
+ # create the mask for the selected factor_type
73
+ mask_lambda = (
74
+ (lambda x: "-continuous" not in x) if factor_type == "discrete" else (lambda x: "-discrete" not in x)
75
+ )
76
+
77
+ # return the masked attribute
78
+ if attr == "factor_names":
79
+ return [x.replace(f"-{factor_type}", "") for x in self.factor_names if mask_lambda(x)]
80
+ else:
81
+ factor_type_mask = [mask_lambda(x) for x in self.factor_names]
82
+ if attr == "factors":
83
+ return self.factors[factor_type_mask[1:]][:, factor_type_mask[1:]]
84
+ elif attr == "balance":
85
+ return self.balance[factor_type_mask]
86
+ elif attr == "classwise":
87
+ return self.classwise[:, factor_type_mask]
47
88
 
48
89
  def plot(
49
90
  self,
50
91
  row_labels: list[Any] | NDArray[Any] | None = None,
51
92
  col_labels: list[Any] | NDArray[Any] | None = None,
52
93
  plot_classwise: bool = False,
94
+ factor_type: Literal["discrete", "continuous", "both"] = "discrete",
53
95
  ) -> Figure:
54
96
  """
55
97
  Plot a heatmap of balance information
@@ -62,15 +104,17 @@ class BalanceOutput(Output):
62
104
  List/Array containing the labels for columns in the histogram
63
105
  plot_classwise : bool, default False
64
106
  Whether to plot per-class balance instead of global balance
107
+ factor_type : "discrete", "continuous", or "both", default "discrete"
108
+ Whether to plot discretized values, continuous values, or to include both
65
109
  """
66
110
  if plot_classwise:
67
111
  if row_labels is None:
68
- row_labels = self.class_list
112
+ row_labels = self.class_names
69
113
  if col_labels is None:
70
- col_labels = self.factor_names
114
+ col_labels = self._by_factor_type("factor_names", factor_type)
71
115
 
72
116
  fig = heatmap(
73
- self.classwise,
117
+ self._by_factor_type("classwise", factor_type),
74
118
  row_labels,
75
119
  col_labels,
76
120
  xlabel="Factors",
@@ -79,13 +123,19 @@ class BalanceOutput(Output):
79
123
  )
80
124
  else:
81
125
  # Combine balance and factors results
82
- data = np.concatenate([self.balance[np.newaxis, 1:], self.factors], axis=0)
126
+ data = np.concatenate(
127
+ [
128
+ self._by_factor_type("balance", factor_type)[np.newaxis, 1:],
129
+ self._by_factor_type("factors", factor_type),
130
+ ],
131
+ axis=0,
132
+ )
83
133
  # Create a mask for the upper triangle of the symmetrical array, ignoring the diagonal
84
134
  mask = np.triu(data + 1, k=0) < 1
85
135
  # Finalize the data for the plot, last row is last factor x last factor so it gets dropped
86
136
  heat_data = np.where(mask, np.nan, data)[:-1]
87
137
  # Creating label array for heat map axes
88
- heat_labels = self.factor_names
138
+ heat_labels = self._by_factor_type("factor_names", factor_type)
89
139
 
90
140
  if row_labels is None:
91
141
  row_labels = heat_labels[:-1]
@@ -128,7 +178,7 @@ def balance(
128
178
  Parameters
129
179
  ----------
130
180
  metadata : Metadata
131
- Preprocessed metadata from :func:`dataeval.utils.metadata.preprocess`
181
+ Preprocessed metadata
132
182
  num_neighbors : int, default 5
133
183
  Number of points to consider as neighbors
134
184
 
@@ -184,7 +234,7 @@ def balance(
184
234
  mi = np.full((num_factors, num_factors), np.nan, dtype=np.float32)
185
235
  data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
186
236
  discretized_data = data
187
- if metadata.continuous_data is not None:
237
+ if len(metadata.continuous_data):
188
238
  data = np.hstack((data, metadata.continuous_data))
189
239
  discrete_idx = [metadata.discrete_factor_names.index(name) for name in metadata.continuous_factor_names]
190
240
  discretized_data = np.hstack((discretized_data, metadata.discrete_data[:, discrete_idx]))
@@ -218,7 +268,7 @@ def balance(
218
268
  factors = nmi[1:, 1:]
219
269
 
220
270
  # assume class is a factor
221
- num_classes = metadata.class_names.size
271
+ num_classes = len(metadata.class_names)
222
272
  classwise_mi = np.full((num_classes, num_factors), np.nan, dtype=np.float32)
223
273
 
224
274
  # classwise targets
@@ -8,12 +8,12 @@ from dataclasses import dataclass
8
8
  from typing import Any, Literal
9
9
 
10
10
  import numpy as np
11
- from numpy.typing import ArrayLike, NDArray
11
+ from numpy.typing import NDArray
12
12
  from scipy.spatial.distance import pdist, squareform
13
13
 
14
- from dataeval.interop import to_numpy
15
- from dataeval.output import Output, set_metadata
16
- from dataeval.utils.shared import flatten
14
+ from dataeval._output import Output, set_metadata
15
+ from dataeval.typing import ArrayLike
16
+ from dataeval.utils._array import ensure_embeddings, flatten, to_numpy
17
17
 
18
18
  with contextlib.suppress(ImportError):
19
19
  from matplotlib.figure import Figure
@@ -71,21 +71,21 @@ def _plot(images: NDArray[Any], num_images: int) -> Figure:
71
71
  @dataclass(frozen=True)
72
72
  class CoverageOutput(Output):
73
73
  """
74
- Output class for :func:`coverage` :term:`bias<Bias>` metric.
74
+ Output class for :func:`.coverage` :term:`bias<Bias>` metric.
75
75
 
76
76
  Attributes
77
77
  ----------
78
- indices : NDArray[np.intp]
78
+ uncovered_indices : NDArray[np.intp]
79
79
  Array of uncovered indices
80
- radii : NDArray[np.float64]
80
+ critical_value_radii : NDArray[np.float64]
81
81
  Array of critical value radii
82
- critical_value : float
82
+ coverage_radius : float
83
83
  Radius for :term:`coverage<Coverage>`
84
84
  """
85
85
 
86
- indices: NDArray[np.intp]
87
- radii: NDArray[np.float64]
88
- critical_value: float
86
+ uncovered_indices: NDArray[np.intp]
87
+ critical_value_radii: NDArray[np.float64]
88
+ coverage_radius: float
89
89
 
90
90
  def plot(self, images: ArrayLike, top_k: int = 6) -> Figure:
91
91
  """
@@ -102,8 +102,9 @@ class CoverageOutput(Output):
102
102
  -------
103
103
  matplotlib.figure.Figure
104
104
  """
105
+
105
106
  # Determine which images to plot
106
- highest_uncovered_indices = self.indices[:top_k]
107
+ highest_uncovered_indices = self.uncovered_indices[:top_k]
107
108
 
108
109
  # Grab the images
109
110
  images = to_numpy(images)
@@ -119,7 +120,7 @@ class CoverageOutput(Output):
119
120
  def coverage(
120
121
  embeddings: ArrayLike,
121
122
  radius_type: Literal["adaptive", "naive"] = "adaptive",
122
- k: int = 20,
123
+ num_observations: int = 20,
123
124
  percent: float = 0.01,
124
125
  ) -> CoverageOutput:
125
126
  """
@@ -128,11 +129,11 @@ def coverage(
128
129
  Parameters
129
130
  ----------
130
131
  embeddings : ArrayLike, shape - (N, P)
131
- A dataset in an ArrayLike format.
132
- Function expects the data to have 2 dimensions, N number of observations in a P-dimesionial space.
132
+ Dataset embeddings as unit interval [0, 1].
133
+ Function expects the data to have 2 dimensions, N number of observations in a P-dimensional space.
133
134
  radius_type : {"adaptive", "naive"}, default "adaptive"
134
135
  The function used to determine radius.
135
- k : int, default 20
136
+ num_observations : int, default 20
136
137
  Number of observations required in order to be covered.
137
138
  [1] suggests that a minimum of 20-50 samples is necessary.
138
139
  percent : float, default 0.01
@@ -146,7 +147,9 @@ def coverage(
146
147
  Raises
147
148
  ------
148
149
  ValueError
149
- If length of :term:`embeddings<Embeddings>` is less than or equal to k
150
+ If embeddings are not unit interval [0-1]
151
+ ValueError
152
+ If length of :term:`embeddings<Embeddings>` is less than or equal to num_observations
150
153
  ValueError
151
154
  If radius_type is unknown
152
155
 
@@ -157,10 +160,10 @@ def coverage(
157
160
  Example
158
161
  -------
159
162
  >>> results = coverage(embeddings)
160
- >>> results.indices
163
+ >>> results.uncovered_indices
161
164
  array([447, 412, 8, 32, 63])
162
- >>> results.critical_value
163
- 0.8459038956941765
165
+ >>> results.coverage_radius
166
+ 0.17592147193757596
164
167
 
165
168
  Reference
166
169
  ---------
@@ -169,26 +172,29 @@ def coverage(
169
172
  [1] Seymour Sudman. 1976. Applied sampling. Academic Press New York (1976).
170
173
  """
171
174
 
172
- # Calculate distance matrix, look at the (k+1)th farthest neighbor for each image.
173
- embeddings = to_numpy(embeddings)
174
- n = len(embeddings)
175
- if n <= k:
175
+ # Calculate distance matrix, look at the (num_observations + 1)th farthest neighbor for each image.
176
+ embeddings = ensure_embeddings(embeddings, dtype=np.float64, unit_interval=True)
177
+ len_embeddings = len(embeddings)
178
+ if len_embeddings <= num_observations:
176
179
  raise ValueError(
177
- f"Number of observations n={n} is less than or equal to the specified number of neighbors k={k}."
180
+ f"Length of embeddings ({len_embeddings}) is less than or equal to the specified number of \
181
+ observations ({num_observations})."
178
182
  )
179
- mat = squareform(pdist(flatten(embeddings))).astype(np.float64)
180
- sorted_dists = np.sort(mat, axis=1)
181
- crit = sorted_dists[:, k + 1]
183
+ embeddings_matrix = squareform(pdist(flatten(embeddings))).astype(np.float64)
184
+ sorted_dists = np.sort(embeddings_matrix, axis=1)
185
+ critical_value_radii = sorted_dists[:, num_observations + 1]
182
186
 
183
187
  d = embeddings.shape[1]
184
188
  if radius_type == "naive":
185
- rho = (1 / math.sqrt(math.pi)) * ((2 * k * math.gamma(d / 2 + 1)) / (n)) ** (1 / d)
186
- pvals = np.where(crit > rho)[0]
189
+ coverage_radius = (1 / math.sqrt(math.pi)) * (
190
+ (2 * num_observations * math.gamma(d / 2 + 1)) / (len_embeddings)
191
+ ) ** (1 / d)
192
+ uncovered_indices = np.where(critical_value_radii > coverage_radius)[0]
187
193
  elif radius_type == "adaptive":
188
- # Use data adaptive cutoff as rho
189
- selection = int(max(n * percent, 1))
190
- pvals = np.argsort(crit)[::-1][:selection]
191
- rho = float(np.mean(np.sort(crit)[::-1][selection - 1 : selection + 1]))
194
+ # Use data adaptive cutoff as coverage_radius
195
+ selection = int(max(len_embeddings * percent, 1))
196
+ uncovered_indices = np.argsort(critical_value_radii)[::-1][:selection]
197
+ coverage_radius = float(np.mean(np.sort(critical_value_radii)[::-1][selection - 1 : selection + 1]))
192
198
  else:
193
199
  raise ValueError(f"{radius_type} is an invalid radius type. Expected 'adaptive' or 'naive'")
194
- return CoverageOutput(pvals, crit, rho)
200
+ return CoverageOutput(uncovered_indices, critical_value_radii, coverage_radius)
@@ -8,12 +8,14 @@ from typing import Any, Literal
8
8
 
9
9
  import numpy as np
10
10
  import scipy as sp
11
- from numpy.typing import ArrayLike, NDArray
11
+ from numpy.typing import NDArray
12
12
 
13
- from dataeval.output import Output, set_metadata
14
- from dataeval.utils.metadata import Metadata, get_counts
15
- from dataeval.utils.plot import heatmap
16
- from dataeval.utils.shared import get_method
13
+ from dataeval._output import Output, set_metadata
14
+ from dataeval.typing import ArrayLike
15
+ from dataeval.utils._bin import get_counts
16
+ from dataeval.utils._method import get_method
17
+ from dataeval.utils._plot import heatmap
18
+ from dataeval.utils.data import Metadata
17
19
 
18
20
  with contextlib.suppress(ImportError):
19
21
  from matplotlib.figure import Figure
@@ -37,7 +39,7 @@ def _plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
37
39
  """
38
40
  import matplotlib.pyplot as plt
39
41
 
40
- fig, ax = plt.subplots(figsize=(10, 10))
42
+ fig, ax = plt.subplots(figsize=(8, 8))
41
43
 
42
44
  ax.bar(labels, bar_heights)
43
45
  ax.set_xlabel("Factors")
@@ -51,7 +53,7 @@ def _plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
51
53
  @dataclass(frozen=True)
52
54
  class DiversityOutput(Output):
53
55
  """
54
- Output class for :func:`diversity` :term:`bias<Bias>` metric.
56
+ Output class for :func:`.diversity` :term:`bias<Bias>` metric.
55
57
 
56
58
  Attributes
57
59
  ----------
@@ -61,14 +63,14 @@ class DiversityOutput(Output):
61
63
  Classwise diversity index [n_class x n_factor]
62
64
  factor_names : list[str]
63
65
  Names of each metadata factor
64
- class_list : NDArray[Any]
66
+ class_names : list[str]
65
67
  Class labels for each value in the dataset
66
68
  """
67
69
 
68
70
  diversity_index: NDArray[np.double]
69
71
  classwise: NDArray[np.double]
70
72
  factor_names: list[str]
71
- class_list: NDArray[Any]
73
+ class_names: list[str]
72
74
 
73
75
  def plot(
74
76
  self,
@@ -90,7 +92,7 @@ class DiversityOutput(Output):
90
92
  """
91
93
  if plot_classwise:
92
94
  if row_labels is None:
93
- row_labels = self.class_list
95
+ row_labels = self.class_names
94
96
  if col_labels is None:
95
97
  col_labels = self.factor_names
96
98
 
@@ -191,6 +193,9 @@ def diversity_simpson(
191
193
  return ev_index
192
194
 
193
195
 
196
+ _DIVERSITY_FN_MAP = {"simpson": diversity_simpson, "shannon": diversity_shannon}
197
+
198
+
194
199
  @set_metadata
195
200
  def diversity(
196
201
  metadata: Metadata,
@@ -210,7 +215,7 @@ def diversity(
210
215
  Parameters
211
216
  ----------
212
217
  metadata : Metadata
213
- Preprocessed metadata from :func:`dataeval.utils.metadata.preprocess`
218
+ Preprocessed metadata
214
219
  method : "simpson" or "shannon", default "simpson"
215
220
  The methodology used for defining diversity
216
221
 
@@ -251,7 +256,7 @@ def diversity(
251
256
  --------
252
257
  scipy.stats.entropy
253
258
  """
254
- diversity_fn = get_method({"simpson": diversity_simpson, "shannon": diversity_shannon}, method)
259
+ diversity_fn = get_method(_DIVERSITY_FN_MAP, method)
255
260
  discretized_data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
256
261
  cnts = get_counts(discretized_data)
257
262
  num_bins = np.bincount(np.nonzero(cnts)[1])