dataeval 0.73.1__py3-none-any.whl → 0.74.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. dataeval/__init__.py +3 -9
  2. dataeval/detectors/__init__.py +2 -10
  3. dataeval/detectors/drift/base.py +3 -3
  4. dataeval/detectors/drift/mmd.py +1 -1
  5. dataeval/detectors/drift/torch.py +1 -101
  6. dataeval/detectors/linters/clusterer.py +3 -3
  7. dataeval/detectors/linters/duplicates.py +4 -4
  8. dataeval/detectors/linters/outliers.py +4 -4
  9. dataeval/detectors/ood/__init__.py +9 -9
  10. dataeval/detectors/ood/{ae.py → ae_torch.py} +22 -27
  11. dataeval/detectors/ood/base.py +63 -113
  12. dataeval/detectors/ood/base_torch.py +109 -0
  13. dataeval/detectors/ood/metadata_ks_compare.py +52 -14
  14. dataeval/interop.py +1 -1
  15. dataeval/metrics/bias/__init__.py +3 -0
  16. dataeval/metrics/bias/balance.py +73 -70
  17. dataeval/metrics/bias/coverage.py +4 -4
  18. dataeval/metrics/bias/diversity.py +67 -136
  19. dataeval/metrics/bias/metadata_preprocessing.py +285 -0
  20. dataeval/metrics/bias/metadata_utils.py +229 -0
  21. dataeval/metrics/bias/parity.py +51 -161
  22. dataeval/metrics/estimators/ber.py +3 -3
  23. dataeval/metrics/estimators/divergence.py +3 -3
  24. dataeval/metrics/estimators/uap.py +3 -3
  25. dataeval/metrics/stats/base.py +2 -2
  26. dataeval/metrics/stats/boxratiostats.py +1 -1
  27. dataeval/metrics/stats/datasetstats.py +6 -6
  28. dataeval/metrics/stats/dimensionstats.py +1 -1
  29. dataeval/metrics/stats/hashstats.py +1 -1
  30. dataeval/metrics/stats/labelstats.py +3 -3
  31. dataeval/metrics/stats/pixelstats.py +1 -1
  32. dataeval/metrics/stats/visualstats.py +1 -1
  33. dataeval/output.py +77 -53
  34. dataeval/utils/__init__.py +1 -7
  35. dataeval/utils/gmm.py +26 -0
  36. dataeval/utils/metadata.py +29 -9
  37. dataeval/utils/torch/gmm.py +98 -0
  38. dataeval/utils/torch/models.py +192 -0
  39. dataeval/utils/torch/trainer.py +84 -5
  40. dataeval/utils/torch/utils.py +107 -1
  41. dataeval/workflows/sufficiency.py +4 -4
  42. {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/METADATA +3 -9
  43. dataeval-0.74.1.dist-info/RECORD +65 -0
  44. dataeval/detectors/ood/aegmm.py +0 -66
  45. dataeval/detectors/ood/llr.py +0 -302
  46. dataeval/detectors/ood/vae.py +0 -97
  47. dataeval/detectors/ood/vaegmm.py +0 -75
  48. dataeval/metrics/bias/metadata.py +0 -440
  49. dataeval/utils/lazy.py +0 -26
  50. dataeval/utils/tensorflow/__init__.py +0 -19
  51. dataeval/utils/tensorflow/_internal/gmm.py +0 -123
  52. dataeval/utils/tensorflow/_internal/loss.py +0 -121
  53. dataeval/utils/tensorflow/_internal/models.py +0 -1394
  54. dataeval/utils/tensorflow/_internal/trainer.py +0 -114
  55. dataeval/utils/tensorflow/_internal/utils.py +0 -256
  56. dataeval/utils/tensorflow/loss/__init__.py +0 -11
  57. dataeval-0.73.1.dist-info/RECORD +0 -73
  58. {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/LICENSE.txt +0 -0
  59. {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,109 @@
1
+ """
2
+ Source code derived from Alibi-Detect 0.11.4
3
+ https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
4
+
5
+ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
+ Licensed under Apache Software License (Apache 2.0)
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Callable, cast
12
+
13
+ import torch
14
+ from numpy.typing import ArrayLike
15
+
16
+ from dataeval.detectors.drift.torch import get_device
17
+ from dataeval.detectors.ood.base import OODBaseMixin, OODFitMixin, OODGMMMixin
18
+ from dataeval.interop import to_numpy
19
+ from dataeval.utils.torch.gmm import gmm_params
20
+ from dataeval.utils.torch.trainer import trainer
21
+
22
+
23
+ class OODBase(OODBaseMixin[torch.nn.Module], OODFitMixin[Callable[..., torch.nn.Module], torch.optim.Optimizer]):
24
+ def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
25
+ self.device: torch.device = get_device(device)
26
+ super().__init__(model)
27
+
28
+ def fit(
29
+ self,
30
+ x_ref: ArrayLike,
31
+ threshold_perc: float,
32
+ loss_fn: Callable[..., torch.nn.Module] | None,
33
+ optimizer: torch.optim.Optimizer | None,
34
+ epochs: int,
35
+ batch_size: int,
36
+ verbose: bool,
37
+ ) -> None:
38
+ """
39
+ Train the model and infer the threshold value.
40
+
41
+ Parameters
42
+ ----------
43
+ x_ref : ArrayLike
44
+ Training data.
45
+ threshold_perc : float, default 100.0
46
+ Percentage of reference data that is normal.
47
+ loss_fn : Callable | None, default None
48
+ Loss function used for training.
49
+ optimizer : Optimizer, default keras.optimizers.Adam
50
+ Optimizer used for training.
51
+ epochs : int, default 20
52
+ Number of training epochs.
53
+ batch_size : int, default 64
54
+ Batch size used for training.
55
+ verbose : bool, default True
56
+ Whether to print training progress.
57
+ """
58
+
59
+ # Train the model
60
+ trainer(
61
+ model=self.model,
62
+ x_train=to_numpy(x_ref),
63
+ y_train=None,
64
+ loss_fn=loss_fn,
65
+ optimizer=optimizer,
66
+ preprocess_fn=None,
67
+ epochs=epochs,
68
+ batch_size=batch_size,
69
+ device=self.device,
70
+ verbose=verbose,
71
+ )
72
+
73
+ # Infer the threshold values
74
+ self._ref_score = self.score(x_ref, batch_size)
75
+ self._threshold_perc = threshold_perc
76
+
77
+
78
+ class OODBaseGMM(OODBase, OODGMMMixin[torch.Tensor]):
79
+ def fit(
80
+ self,
81
+ x_ref: ArrayLike,
82
+ threshold_perc: float,
83
+ loss_fn: Callable[..., torch.nn.Module] | None,
84
+ optimizer: torch.optim.Optimizer | None,
85
+ epochs: int,
86
+ batch_size: int,
87
+ verbose: bool,
88
+ ) -> None:
89
+ # Train the model
90
+ trainer(
91
+ model=self.model,
92
+ x_train=to_numpy(x_ref),
93
+ y_train=None,
94
+ loss_fn=loss_fn,
95
+ optimizer=optimizer,
96
+ preprocess_fn=None,
97
+ epochs=epochs,
98
+ batch_size=batch_size,
99
+ device=self.device,
100
+ verbose=verbose,
101
+ )
102
+
103
+ # Calculate the GMM parameters
104
+ _, z, gamma = cast(tuple[torch.Tensor, torch.Tensor, torch.Tensor], self.model(x_ref))
105
+ self._gmm_params = gmm_params(z, gamma)
106
+
107
+ # Infer the threshold values
108
+ self._ref_score = self.score(x_ref, batch_size)
109
+ self._threshold_perc = threshold_perc
@@ -2,17 +2,44 @@ from __future__ import annotations
2
2
 
3
3
  import numbers
4
4
  import warnings
5
- from typing import Any, Mapping
5
+ from typing import Any, Mapping, NamedTuple
6
6
 
7
7
  import numpy as np
8
8
  from numpy.typing import NDArray
9
9
  from scipy.stats import iqr, ks_2samp
10
10
  from scipy.stats import wasserstein_distance as emd
11
11
 
12
+ from dataeval.output import MappingOutput, set_metadata
12
13
 
14
+
15
+ class MetadataKSResult(NamedTuple):
16
+ statistic: float
17
+ statistic_location: float
18
+ shift_magnitude: float
19
+ pvalue: float
20
+
21
+
22
+ class KSOutput(MappingOutput[str, MetadataKSResult]):
23
+ """
24
+ Output dictionary class for results of ks_2samp featurewise comparisons of new metadata to reference metadata.
25
+
26
+ Attributes
27
+ ----------
28
+ key: str
29
+ Metadata feature names
30
+ value: NamedTuple[float, float, float, float]
31
+ Each value contains four floats, which are:
32
+ - statistic: the KS statistic itself
33
+ - statistic_location: its location within the range of the reference metadata
34
+ - shift_magnitude: the shift of new metadata relative to reference
35
+ - pvalue: the p-value from the KS two-sample test
36
+ """
37
+
38
+
39
+ @set_metadata
13
40
  def meta_distribution_compare(
14
41
  md0: Mapping[str, list[Any] | NDArray[Any]], md1: Mapping[str, list[Any] | NDArray[Any]]
15
- ) -> dict[str, dict[str, float]]:
42
+ ) -> KSOutput:
16
43
  """Measures the featurewise distance between two metadata distributions, and computes a p-value to evaluate its
17
44
  significance.
18
45
 
@@ -43,27 +70,29 @@ def meta_distribution_compare(
43
70
  >>> import numpy
44
71
  >>> md0 = {"time": [1.2, 3.4, 5.6], "altitude": [235, 6789, 101112]}
45
72
  >>> md1 = {"time": [7.8, 9.10, 11.12], "altitude": [532, 9876, 211101]}
46
- >>> md_out = meta_distribution_compare(md0, md1)
73
+ >>> md_out = meta_distribution_compare(md0, md1).mdc
47
74
  >>> for k, v in md_out.items():
48
75
  >>> print(k)
49
76
  >>> for kv in v:
50
77
  >>> print("\t", f"{kv}: {v[kv]:.3f}")
51
78
  time
52
- statistic_location: 0.444
53
- shift_magnitude: 2.700
54
- pvalue: 0.000
79
+ statistic: 1.000
80
+ statistic_location: 0.444
81
+ shift_magnitude: 2.700
82
+ pvalue: 0.000
55
83
  altitude
56
- statistic_location: 0.478
57
- shift_magnitude: 0.749
58
- pvalue: 0.944
84
+ statistic: 0.333
85
+ statistic_location: 0.478
86
+ shift_magnitude: 0.749
87
+ pvalue: 0.944
59
88
  """
60
89
 
61
90
  if (metadata_keys := md0.keys()) != md1.keys():
62
91
  raise ValueError(f"Both sets of metadata keys must be identical: {list(md0)}, {list(md1)}")
63
92
 
64
- mdc_dict = {} # output dict
93
+ mdc = {} # output dict
65
94
  for k in metadata_keys:
66
- mdc_dict.update({k: {}})
95
+ mdc.update({k: {}})
67
96
 
68
97
  x0, x1 = list(md0[k]), list(md1[k])
69
98
 
@@ -81,7 +110,9 @@ def meta_distribution_compare(
81
110
 
82
111
  xmin, xmax = min(allx), max(allx)
83
112
  if xmin == xmax: # only one value in this feature, so fill in the obvious results for feature k
84
- mdc_dict[k].update({"statistic_location": 0.0, "shift_magnitude": 0.0, "pvalue": 1.0})
113
+ mdc[k] = MetadataKSResult(
114
+ **{"statistic": 0.0, "statistic_location": 0.0, "shift_magnitude": 0.0, "pvalue": 1.0}
115
+ )
85
116
  continue
86
117
 
87
118
  ks_result = ks_2samp(x0, x1, method="asymp")
@@ -94,6 +125,13 @@ def meta_distribution_compare(
94
125
 
95
126
  drift = emd(x0, x1) / dX
96
127
 
97
- mdc_dict[k].update({"statistic_location": loc, "shift_magnitude": drift, "pvalue": ks_result.pvalue}) # pyright: ignore
128
+ mdc[k] = MetadataKSResult(
129
+ **{
130
+ "statistic": ks_result.statistic, # pyright: ignore
131
+ "statistic_location": loc,
132
+ "shift_magnitude": drift,
133
+ "pvalue": ks_result.pvalue, # pyright: ignore
134
+ }
135
+ )
98
136
 
99
- return mdc_dict
137
+ return KSOutput(mdc)
dataeval/interop.py CHANGED
@@ -47,7 +47,7 @@ def to_numpy(array: ArrayLike | None, copy: bool = True) -> NDArray[Any]:
47
47
  if torch and isinstance(array, torch.Tensor):
48
48
  return array.detach().cpu().numpy().copy() if copy else array.detach().cpu().numpy() # type: ignore
49
49
 
50
- return np.array(array, copy=copy)
50
+ return np.array(array) if copy else np.asarray(array)
51
51
 
52
52
 
53
53
  def to_numpy_iter(iterable: Iterable[ArrayLike]) -> Iterator[NDArray[Any]]:
@@ -6,6 +6,7 @@ representation which may impact model performance.
6
6
  from dataeval.metrics.bias.balance import BalanceOutput, balance
7
7
  from dataeval.metrics.bias.coverage import CoverageOutput, coverage
8
8
  from dataeval.metrics.bias.diversity import DiversityOutput, diversity
9
+ from dataeval.metrics.bias.metadata_preprocessing import MetadataOutput, metadata_preprocessing
9
10
  from dataeval.metrics.bias.parity import ParityOutput, label_parity, parity
10
11
 
11
12
  __all__ = [
@@ -14,8 +15,10 @@ __all__ = [
14
15
  "diversity",
15
16
  "label_parity",
16
17
  "parity",
18
+ "metadata_preprocessing",
17
19
  "BalanceOutput",
18
20
  "CoverageOutput",
19
21
  "DiversityOutput",
20
22
  "ParityOutput",
23
+ "MetadataOutput",
21
24
  ]
@@ -5,21 +5,23 @@ __all__ = ["BalanceOutput", "balance"]
5
5
  import contextlib
6
6
  import warnings
7
7
  from dataclasses import dataclass
8
- from typing import Any, Mapping
8
+ from typing import Any
9
9
 
10
10
  import numpy as np
11
- from numpy.typing import ArrayLike, NDArray
11
+ import scipy as sp
12
+ from numpy.typing import NDArray
12
13
  from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
13
14
 
14
- from dataeval.metrics.bias.metadata import CLASS_LABEL, entropy, heatmap, preprocess_metadata
15
- from dataeval.output import OutputMetadata, set_metadata
15
+ from dataeval.metrics.bias.metadata_preprocessing import MetadataOutput
16
+ from dataeval.metrics.bias.metadata_utils import get_counts, heatmap
17
+ from dataeval.output import Output, set_metadata
16
18
 
17
19
  with contextlib.suppress(ImportError):
18
20
  from matplotlib.figure import Figure
19
21
 
20
22
 
21
23
  @dataclass(frozen=True)
22
- class BalanceOutput(OutputMetadata):
24
+ class BalanceOutput(Output):
23
25
  """
24
26
  Output class for :func:`balance` bias metric
25
27
 
@@ -31,17 +33,17 @@ class BalanceOutput(OutputMetadata):
31
33
  Estimate of inter/intra-factor mutual information
32
34
  classwise : NDArray[np.float64]
33
35
  Estimate of mutual information between metadata factors and individual class labels
36
+ factor_names : list[str]
37
+ Names of each metadata factor
34
38
  class_list : NDArray
35
39
  Array of the class labels present in the dataset
36
- metadata_names : list[str]
37
- Names of each metadata factor
38
40
  """
39
41
 
40
42
  balance: NDArray[np.float64]
41
43
  factors: NDArray[np.float64]
42
44
  classwise: NDArray[np.float64]
45
+ factor_names: list[str]
43
46
  class_list: NDArray[Any]
44
- metadata_names: list[str]
45
47
 
46
48
  def plot(
47
49
  self,
@@ -65,7 +67,7 @@ class BalanceOutput(OutputMetadata):
65
67
  if row_labels is None:
66
68
  row_labels = self.class_list
67
69
  if col_labels is None:
68
- col_labels = np.concatenate((["class"], self.metadata_names))
70
+ col_labels = self.factor_names
69
71
 
70
72
  fig = heatmap(
71
73
  self.classwise,
@@ -83,7 +85,7 @@ class BalanceOutput(OutputMetadata):
83
85
  # Finalize the data for the plot, last row is last factor x last factor so it gets dropped
84
86
  heat_data = np.where(mask, np.nan, data)[:-1]
85
87
  # Creating label array for heat map axes
86
- heat_labels = np.concatenate((["class"], self.metadata_names))
88
+ heat_labels = self.factor_names
87
89
 
88
90
  if row_labels is None:
89
91
  row_labels = heat_labels[:-1]
@@ -95,7 +97,7 @@ class BalanceOutput(OutputMetadata):
95
97
  return fig
96
98
 
97
99
 
98
- def validate_num_neighbors(num_neighbors: int) -> int:
100
+ def _validate_num_neighbors(num_neighbors: int) -> int:
99
101
  if not isinstance(num_neighbors, (int, float)):
100
102
  raise TypeError(
101
103
  f"Variable {num_neighbors} is not real-valued numeric type."
@@ -115,30 +117,18 @@ def validate_num_neighbors(num_neighbors: int) -> int:
115
117
  return num_neighbors
116
118
 
117
119
 
118
- @set_metadata("dataeval.metrics")
120
+ @set_metadata
119
121
  def balance(
120
- class_labels: ArrayLike,
121
- metadata: Mapping[str, ArrayLike],
122
+ metadata: MetadataOutput,
122
123
  num_neighbors: int = 5,
123
- continuous_factor_bincounts: Mapping[str, int] | None = None,
124
124
  ) -> BalanceOutput:
125
125
  """
126
126
  Mutual information (MI) between factors (class label, metadata, label/image properties)
127
127
 
128
128
  Parameters
129
129
  ----------
130
- class_labels : ArrayLike
131
- List of class labels for each image
132
- metadata : Mapping[str, ArrayLike]
133
- Dict of lists of metadata factors for each image
134
- num_neighbors : int, default 5
135
- Number of nearest neighbors to use for computing MI between discrete
136
- and continuous variables.
137
- continuous_factor_bincounts : Mapping[str, int] or None, default None
138
- The factors in metadata that have continuous values and the array of bin counts to
139
- discretize values into. All factors are treated as having discrete values unless they
140
- are specified as keys in this dictionary. Each element of this array must occur as a key
141
- in metadata.
130
+ metadata : MetadataOutput
131
+ Output after running `metadata_preprocessing`
142
132
 
143
133
  Returns
144
134
  -------
@@ -150,30 +140,33 @@ def balance(
150
140
  ----
151
141
  We use `mutual_info_classif` from sklearn since class label is categorical.
152
142
  `mutual_info_classif` outputs are consistent up to O(1e-4) and depend on a random
153
- seed. MI is computed differently for categorical and continuous variables, and
154
- we attempt to infer whether a variable is categorical by the fraction of unique
155
- values in the dataset.
143
+ seed. MI is computed differently for categorical and continuous variables.
156
144
 
157
145
  Example
158
146
  -------
159
147
  Return balance (mutual information) of factors with class_labels
160
148
 
161
- >>> bal = balance(class_labels, metadata, continuous_factor_bincounts=continuous_factor_bincounts)
149
+ >>> bal = balance(metadata)
162
150
  >>> bal.balance
163
- array([0.99999822, 0.13363788, 0.04505382, 0.02994455])
151
+ array([0.9999982 , 0.2494567 , 0.02994455, 0.13363788, 0. ,
152
+ 0. ])
164
153
 
165
154
  Return intra/interfactor balance (mutual information)
166
155
 
167
156
  >>> bal.factors
168
- array([[0.99999843, 0.04133555, 0.09725766],
169
- [0.04133555, 0.08433558, 0.1301489 ],
170
- [0.09725766, 0.1301489 , 0.99999856]])
157
+ array([[0.99999935, 0.31360499, 0.26925848, 0.85201924, 0.36653548],
158
+ [0.31360499, 0.99999856, 0.09725766, 0.15836905, 1.98031993],
159
+ [0.26925848, 0.09725766, 0.99999846, 0.03713108, 0.01544656],
160
+ [0.85201924, 0.15836905, 0.03713108, 0.47450653, 0.25509664],
161
+ [0.36653548, 1.98031993, 0.01544656, 0.25509664, 1.06260686]])
171
162
 
172
163
  Return classwise balance (mutual information) of factors with individual class_labels
173
164
 
174
165
  >>> bal.classwise
175
- array([[0.99999822, 0.13363788, 0. , 0. ],
176
- [0.99999822, 0.13363788, 0. , 0. ]])
166
+ array([[0.9999982 , 0.2494567 , 0.02994455, 0.13363788, 0. ,
167
+ 0. ],
168
+ [0.9999982 , 0.2494567 , 0.02994455, 0.13363788, 0. ,
169
+ 0. ]])
177
170
 
178
171
 
179
172
  See Also
@@ -182,68 +175,78 @@ def balance(
182
175
  sklearn.feature_selection.mutual_info_regression
183
176
  sklearn.metrics.mutual_info_score
184
177
  """
185
- num_neighbors = validate_num_neighbors(num_neighbors)
186
- data, names, is_categorical, unique_labels = preprocess_metadata(class_labels, metadata)
187
- num_factors = len(names)
188
- mi = np.empty((num_factors, num_factors))
189
- mi[:] = np.nan
178
+ num_neighbors = _validate_num_neighbors(num_neighbors)
179
+
180
+ num_factors = metadata.total_num_factors
181
+ is_discrete = [True] * (len(metadata.discrete_factor_names) + 1) + [False] * len(metadata.continuous_factor_names)
182
+ mi = np.full((num_factors, num_factors), np.nan, dtype=np.float32)
183
+ data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
184
+ discretized_data = data
185
+ if metadata.continuous_data is not None:
186
+ data = np.hstack((data, metadata.continuous_data))
187
+ discrete_idx = [metadata.discrete_factor_names.index(name) for name in metadata.continuous_factor_names]
188
+ discretized_data = np.hstack((discretized_data, metadata.discrete_data[:, discrete_idx]))
190
189
 
191
190
  for idx in range(num_factors):
192
- tgt = data[:, idx].astype(np.intp)
193
-
194
- if continuous_factor_bincounts and names[idx] not in continuous_factor_bincounts:
195
- mi[idx, :] = mutual_info_classif(
191
+ if idx >= len(metadata.discrete_factor_names) + 1:
192
+ mi[idx, :] = mutual_info_regression(
196
193
  data,
197
- tgt,
198
- discrete_features=is_categorical, # type: ignore
194
+ data[:, idx],
195
+ discrete_features=is_discrete, # type: ignore
199
196
  n_neighbors=num_neighbors,
200
197
  random_state=0,
201
198
  )
202
199
  else:
203
- mi[idx, :] = mutual_info_regression(
200
+ mi[idx, :] = mutual_info_classif(
204
201
  data,
205
- tgt,
206
- discrete_features=is_categorical, # type: ignore
202
+ data[:, idx],
203
+ discrete_features=is_discrete, # type: ignore
207
204
  n_neighbors=num_neighbors,
208
205
  random_state=0,
209
206
  )
210
207
 
211
- ent_all = entropy(data, names, continuous_factor_bincounts, normalized=False)
212
- norm_factor = 0.5 * np.add.outer(ent_all, ent_all) + 1e-6
208
+ # Normalization via entropy
209
+ bin_cnts = get_counts(discretized_data)
210
+ ent_factor = sp.stats.entropy(bin_cnts, axis=0)
211
+ norm_factor = 0.5 * np.add.outer(ent_factor, ent_factor) + 1e-6
212
+
213
213
  # in principle MI should be symmetric, but it is not in practice.
214
214
  nmi = 0.5 * (mi + mi.T) / norm_factor
215
215
  balance = nmi[0]
216
216
  factors = nmi[1:, 1:]
217
217
 
218
- # unique class labels
219
- class_idx = names.index(CLASS_LABEL)
220
- u_cls = np.unique(data[:, class_idx])
221
- num_classes = len(u_cls)
222
-
223
218
  # assume class is a factor
224
- classwise_mi = np.empty((num_classes, num_factors))
225
- classwise_mi[:] = np.nan
219
+ num_classes = metadata.class_names.size
220
+ classwise_mi = np.full((num_classes, num_factors), np.nan, dtype=np.float32)
226
221
 
227
- # categorical variables, excluding class label
228
- cat_mask = np.concatenate((is_categorical[:class_idx], is_categorical[(class_idx + 1) :]), axis=0).astype(np.intp)
229
-
230
- tgt_bin = np.stack([data[:, class_idx] == cls for cls in u_cls]).T.astype(np.intp)
231
- names = [str(idx) for idx in range(num_classes)]
232
- ent_tgt_bin = entropy(tgt_bin, names, continuous_factor_bincounts)
222
+ # classwise targets
223
+ classes = np.unique(metadata.class_labels)
224
+ tgt_bin = data[:, 0][:, None] == classes
233
225
 
234
226
  # classification MI for discrete/categorical features
235
227
  for idx in range(num_classes):
236
- # tgt = class_data == cls
237
228
  # units: nat
238
229
  classwise_mi[idx, :] = mutual_info_classif(
239
230
  data,
240
231
  tgt_bin[:, idx],
241
- discrete_features=cat_mask, # type: ignore
232
+ discrete_features=is_discrete, # type: ignore
242
233
  n_neighbors=num_neighbors,
243
234
  random_state=0,
244
235
  )
245
236
 
246
- norm_factor = 0.5 * np.add.outer(ent_tgt_bin, ent_all) + 1e-6
237
+ # Classwise normalization via entropy
238
+ classwise_bin_cnts = get_counts(tgt_bin)
239
+ ent_tgt_bin = sp.stats.entropy(classwise_bin_cnts, axis=0)
240
+ norm_factor = 0.5 * np.add.outer(ent_tgt_bin, ent_factor) + 1e-6
247
241
  classwise = classwise_mi / norm_factor
248
242
 
249
- return BalanceOutput(balance, factors, classwise, unique_labels, list(metadata.keys()))
243
+ # Grabbing factor names for plotting function
244
+ factor_names = ["class"]
245
+ for name in metadata.discrete_factor_names:
246
+ if name in metadata.continuous_factor_names:
247
+ name = name + "-discrete"
248
+ factor_names.append(name)
249
+ for name in metadata.continuous_factor_names:
250
+ factor_names.append(name + "-continuous")
251
+
252
+ return BalanceOutput(balance, factors, classwise, factor_names, metadata.class_names)
@@ -12,8 +12,8 @@ from numpy.typing import ArrayLike, NDArray
12
12
  from scipy.spatial.distance import pdist, squareform
13
13
 
14
14
  from dataeval.interop import to_numpy
15
- from dataeval.metrics.bias.metadata import coverage_plot
16
- from dataeval.output import OutputMetadata, set_metadata
15
+ from dataeval.metrics.bias.metadata_utils import coverage_plot
16
+ from dataeval.output import Output, set_metadata
17
17
  from dataeval.utils.shared import flatten
18
18
 
19
19
  with contextlib.suppress(ImportError):
@@ -21,7 +21,7 @@ with contextlib.suppress(ImportError):
21
21
 
22
22
 
23
23
  @dataclass(frozen=True)
24
- class CoverageOutput(OutputMetadata):
24
+ class CoverageOutput(Output):
25
25
  """
26
26
  Output class for :func:`coverage` :term:`bias<Bias>` metric
27
27
 
@@ -67,7 +67,7 @@ class CoverageOutput(OutputMetadata):
67
67
  return fig
68
68
 
69
69
 
70
- @set_metadata()
70
+ @set_metadata
71
71
  def coverage(
72
72
  embeddings: ArrayLike,
73
73
  radius_type: Literal["adaptive", "naive"] = "adaptive",