dataeval 0.73.0__py3-none-any.whl → 0.74.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/detectors/__init__.py +1 -1
  3. dataeval/detectors/drift/__init__.py +1 -1
  4. dataeval/detectors/drift/base.py +2 -2
  5. dataeval/detectors/drift/torch.py +1 -101
  6. dataeval/detectors/linters/clusterer.py +1 -1
  7. dataeval/detectors/ood/__init__.py +11 -4
  8. dataeval/detectors/ood/ae.py +2 -1
  9. dataeval/detectors/ood/ae_torch.py +70 -0
  10. dataeval/detectors/ood/aegmm.py +4 -3
  11. dataeval/detectors/ood/base.py +58 -108
  12. dataeval/detectors/ood/base_tf.py +109 -0
  13. dataeval/detectors/ood/base_torch.py +109 -0
  14. dataeval/detectors/ood/llr.py +2 -2
  15. dataeval/detectors/ood/metadata_ks_compare.py +53 -14
  16. dataeval/detectors/ood/vae.py +3 -2
  17. dataeval/detectors/ood/vaegmm.py +5 -4
  18. dataeval/metrics/bias/__init__.py +3 -0
  19. dataeval/metrics/bias/balance.py +77 -64
  20. dataeval/metrics/bias/coverage.py +12 -12
  21. dataeval/metrics/bias/diversity.py +74 -114
  22. dataeval/metrics/bias/metadata_preprocessing.py +285 -0
  23. dataeval/metrics/bias/metadata_utils.py +229 -0
  24. dataeval/metrics/bias/parity.py +54 -158
  25. dataeval/utils/__init__.py +2 -2
  26. dataeval/utils/gmm.py +26 -0
  27. dataeval/utils/metadata.py +29 -9
  28. dataeval/utils/shared.py +1 -1
  29. dataeval/utils/split_dataset.py +12 -6
  30. dataeval/utils/tensorflow/_internal/gmm.py +4 -24
  31. dataeval/utils/torch/datasets.py +2 -2
  32. dataeval/utils/torch/gmm.py +98 -0
  33. dataeval/utils/torch/models.py +192 -0
  34. dataeval/utils/torch/trainer.py +84 -5
  35. dataeval/utils/torch/utils.py +107 -1
  36. dataeval/workflows/__init__.py +1 -1
  37. {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/METADATA +1 -2
  38. {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/RECORD +40 -34
  39. dataeval/metrics/bias/metadata.py +0 -358
  40. {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/LICENSE.txt +0 -0
  41. {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,109 @@
1
+ """
2
+ Source code derived from Alibi-Detect 0.11.4
3
+ https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
4
+
5
+ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
+ Licensed under Apache Software License (Apache 2.0)
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Callable, cast
12
+
13
+ import torch
14
+ from numpy.typing import ArrayLike
15
+
16
+ from dataeval.detectors.drift.torch import get_device
17
+ from dataeval.detectors.ood.base import OODBaseMixin, OODFitMixin, OODGMMMixin
18
+ from dataeval.interop import to_numpy
19
+ from dataeval.utils.torch.gmm import gmm_params
20
+ from dataeval.utils.torch.trainer import trainer
21
+
22
+
23
+ class OODBase(OODBaseMixin[torch.nn.Module], OODFitMixin[Callable[..., torch.nn.Module], torch.optim.Optimizer]):
24
+ def __init__(self, model: torch.nn.Module, device: str | torch.device | None = None) -> None:
25
+ self.device: torch.device = get_device(device)
26
+ super().__init__(model)
27
+
28
+ def fit(
29
+ self,
30
+ x_ref: ArrayLike,
31
+ threshold_perc: float,
32
+ loss_fn: Callable[..., torch.nn.Module] | None,
33
+ optimizer: torch.optim.Optimizer | None,
34
+ epochs: int,
35
+ batch_size: int,
36
+ verbose: bool,
37
+ ) -> None:
38
+ """
39
+ Train the model and infer the threshold value.
40
+
41
+ Parameters
42
+ ----------
43
+ x_ref : ArrayLike
44
+ Training data.
45
+ threshold_perc : float, default 100.0
46
+ Percentage of reference data that is normal.
47
+ loss_fn : Callable | None, default None
48
+ Loss function used for training.
49
+ optimizer : Optimizer, default keras.optimizers.Adam
50
+ Optimizer used for training.
51
+ epochs : int, default 20
52
+ Number of training epochs.
53
+ batch_size : int, default 64
54
+ Batch size used for training.
55
+ verbose : bool, default True
56
+ Whether to print training progress.
57
+ """
58
+
59
+ # Train the model
60
+ trainer(
61
+ model=self.model,
62
+ x_train=to_numpy(x_ref),
63
+ y_train=None,
64
+ loss_fn=loss_fn,
65
+ optimizer=optimizer,
66
+ preprocess_fn=None,
67
+ epochs=epochs,
68
+ batch_size=batch_size,
69
+ device=self.device,
70
+ verbose=verbose,
71
+ )
72
+
73
+ # Infer the threshold values
74
+ self._ref_score = self.score(x_ref, batch_size)
75
+ self._threshold_perc = threshold_perc
76
+
77
+
78
+ class OODBaseGMM(OODBase, OODGMMMixin[torch.Tensor]):
79
+ def fit(
80
+ self,
81
+ x_ref: ArrayLike,
82
+ threshold_perc: float,
83
+ loss_fn: Callable[..., torch.nn.Module] | None,
84
+ optimizer: torch.optim.Optimizer | None,
85
+ epochs: int,
86
+ batch_size: int,
87
+ verbose: bool,
88
+ ) -> None:
89
+ # Train the model
90
+ trainer(
91
+ model=self.model,
92
+ x_train=to_numpy(x_ref),
93
+ y_train=None,
94
+ loss_fn=loss_fn,
95
+ optimizer=optimizer,
96
+ preprocess_fn=None,
97
+ epochs=epochs,
98
+ batch_size=batch_size,
99
+ device=self.device,
100
+ verbose=verbose,
101
+ )
102
+
103
+ # Calculate the GMM parameters
104
+ _, z, gamma = cast(tuple[torch.Tensor, torch.Tensor, torch.Tensor], self.model(x_ref))
105
+ self._gmm_params = gmm_params(z, gamma)
106
+
107
+ # Infer the threshold values
108
+ self._ref_score = self.score(x_ref, batch_size)
109
+ self._threshold_perc = threshold_perc
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Callable
16
16
  import numpy as np
17
17
  from numpy.typing import ArrayLike, NDArray
18
18
 
19
- from dataeval.detectors.ood.base import OODBase, OODScoreOutput
19
+ from dataeval.detectors.ood.base import OODBaseMixin, OODScoreOutput
20
20
  from dataeval.interop import to_numpy
21
21
  from dataeval.utils.lazy import lazyload
22
22
  from dataeval.utils.tensorflow._internal.trainer import trainer
@@ -96,7 +96,7 @@ def _mutate_categorical(
96
96
  return tf.cast(X, tf.float32) # type: ignore
97
97
 
98
98
 
99
- class OOD_LLR(OODBase):
99
+ class OOD_LLR(OODBaseMixin[tf_models.PixelCNN]):
100
100
  """
101
101
  Likelihood Ratios based outlier detector.
102
102
 
@@ -2,17 +2,45 @@ from __future__ import annotations
2
2
 
3
3
  import numbers
4
4
  import warnings
5
- from typing import Any, Mapping
5
+ from dataclasses import dataclass
6
+ from typing import Any, Mapping, NamedTuple
6
7
 
7
8
  import numpy as np
8
9
  from numpy.typing import NDArray
9
10
  from scipy.stats import iqr, ks_2samp
10
11
  from scipy.stats import wasserstein_distance as emd
11
12
 
13
+ from dataeval.output import OutputMetadata, set_metadata
12
14
 
15
+
16
+ class MetadataKSResult(NamedTuple):
17
+ statistic: float
18
+ statistic_location: float
19
+ shift_magnitude: float
20
+ pvalue: float
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class KSOutput(OutputMetadata):
25
+ """
26
+ Output class for results of ks_2samp featurewise comparisons of new metadata to reference metadata.
27
+
28
+ Attributes
29
+ ----------
30
+ mdc : dict[str, dict[str, float]]
31
+ dict keyed by metadata feature names. Each value contains four floats, which are the KS statistic itself, its
32
+ location within the range of the reference metadata, the shift of new metadata relative to reference, the
33
+ p-value from the KS two-sample test.
34
+
35
+ """
36
+
37
+ mdc: dict[str, MetadataKSResult]
38
+
39
+
40
+ @set_metadata()
13
41
  def meta_distribution_compare(
14
42
  md0: Mapping[str, list[Any] | NDArray[Any]], md1: Mapping[str, list[Any] | NDArray[Any]]
15
- ) -> dict[str, dict[str, float]]:
43
+ ) -> KSOutput:
16
44
  """Measures the featurewise distance between two metadata distributions, and computes a p-value to evaluate its
17
45
  significance.
18
46
 
@@ -43,27 +71,29 @@ def meta_distribution_compare(
43
71
  >>> import numpy
44
72
  >>> md0 = {"time": [1.2, 3.4, 5.6], "altitude": [235, 6789, 101112]}
45
73
  >>> md1 = {"time": [7.8, 9.10, 11.12], "altitude": [532, 9876, 211101]}
46
- >>> md_out = meta_distribution_compare(md0, md1)
74
+ >>> md_out = meta_distribution_compare(md0, md1).mdc
47
75
  >>> for k, v in md_out.items():
48
76
  >>> print(k)
49
77
  >>> for kv in v:
50
78
  >>> print("\t", f"{kv}: {v[kv]:.3f}")
51
79
  time
52
- statistic_location: 0.444
53
- shift_magnitude: 2.700
54
- pvalue: 0.000
80
+ statistic: 1.000
81
+ statistic_location: 0.444
82
+ shift_magnitude: 2.700
83
+ pvalue: 0.000
55
84
  altitude
56
- statistic_location: 0.478
57
- shift_magnitude: 0.749
58
- pvalue: 0.944
85
+ statistic: 0.333
86
+ statistic_location: 0.478
87
+ shift_magnitude: 0.749
88
+ pvalue: 0.944
59
89
  """
60
90
 
61
91
  if (metadata_keys := md0.keys()) != md1.keys():
62
92
  raise ValueError(f"Both sets of metadata keys must be identical: {list(md0)}, {list(md1)}")
63
93
 
64
- mdc_dict = {} # output dict
94
+ mdc = {} # output dict
65
95
  for k in metadata_keys:
66
- mdc_dict.update({k: {}})
96
+ mdc.update({k: {}})
67
97
 
68
98
  x0, x1 = list(md0[k]), list(md1[k])
69
99
 
@@ -81,7 +111,9 @@ def meta_distribution_compare(
81
111
 
82
112
  xmin, xmax = min(allx), max(allx)
83
113
  if xmin == xmax: # only one value in this feature, so fill in the obvious results for feature k
84
- mdc_dict[k].update({"statistic_location": 0.0, "shift_magnitude": 0.0, "pvalue": 1.0})
114
+ mdc[k] = MetadataKSResult(
115
+ **{"statistic": 0.0, "statistic_location": 0.0, "shift_magnitude": 0.0, "pvalue": 1.0}
116
+ )
85
117
  continue
86
118
 
87
119
  ks_result = ks_2samp(x0, x1, method="asymp")
@@ -94,6 +126,13 @@ def meta_distribution_compare(
94
126
 
95
127
  drift = emd(x0, x1) / dX
96
128
 
97
- mdc_dict[k].update({"statistic_location": loc, "shift_magnitude": drift, "pvalue": ks_result.pvalue}) # pyright: ignore
129
+ mdc[k] = MetadataKSResult(
130
+ **{
131
+ "statistic": ks_result.statistic, # pyright: ignore
132
+ "statistic_location": loc,
133
+ "shift_magnitude": drift,
134
+ "pvalue": ks_result.pvalue, # pyright: ignore
135
+ }
136
+ )
98
137
 
99
- return mdc_dict
138
+ return KSOutput(mdc)
@@ -15,7 +15,8 @@ from typing import TYPE_CHECKING, Callable
15
15
  import numpy as np
16
16
  from numpy.typing import ArrayLike
17
17
 
18
- from dataeval.detectors.ood.base import OODBase, OODScoreOutput
18
+ from dataeval.detectors.ood.base import OODScoreOutput
19
+ from dataeval.detectors.ood.base_tf import OODBase
19
20
  from dataeval.interop import to_numpy
20
21
  from dataeval.utils.lazy import lazyload
21
22
  from dataeval.utils.tensorflow._internal.loss import Elbo
@@ -67,7 +68,7 @@ class OOD_VAE(OODBase):
67
68
  self,
68
69
  x_ref: ArrayLike,
69
70
  threshold_perc: float = 100.0,
70
- loss_fn: Callable[..., tf.Tensor] = Elbo(0.05),
71
+ loss_fn: Callable[..., tf.Tensor] | None = Elbo(0.05),
71
72
  optimizer: keras.optimizers.Optimizer | None = None,
72
73
  epochs: int = 20,
73
74
  batch_size: int = 64,
@@ -15,7 +15,8 @@ from typing import TYPE_CHECKING, Callable
15
15
  import numpy as np
16
16
  from numpy.typing import ArrayLike
17
17
 
18
- from dataeval.detectors.ood.base import OODGMMBase, OODScoreOutput
18
+ from dataeval.detectors.ood.base import OODScoreOutput
19
+ from dataeval.detectors.ood.base_tf import OODBaseGMM
19
20
  from dataeval.interop import to_numpy
20
21
  from dataeval.utils.lazy import lazyload
21
22
  from dataeval.utils.tensorflow._internal.gmm import gmm_energy
@@ -33,7 +34,7 @@ else:
33
34
  tf_models = lazyload("dataeval.utils.tensorflow._internal.models")
34
35
 
35
36
 
36
- class OOD_VAEGMM(OODGMMBase):
37
+ class OOD_VAEGMM(OODBaseGMM):
37
38
  """
38
39
  VAE with Gaussian Mixture Model based outlier detector.
39
40
 
@@ -53,7 +54,7 @@ class OOD_VAEGMM(OODGMMBase):
53
54
  self,
54
55
  x_ref: ArrayLike,
55
56
  threshold_perc: float = 100.0,
56
- loss_fn: Callable[..., tf.Tensor] = LossGMM(elbo=Elbo(0.05)),
57
+ loss_fn: Callable[..., tf.Tensor] | None = LossGMM(elbo=Elbo(0.05)),
57
58
  optimizer: keras.optimizers.Optimizer | None = None,
58
59
  epochs: int = 20,
59
60
  batch_size: int = 64,
@@ -69,7 +70,7 @@ class OOD_VAEGMM(OODGMMBase):
69
70
  _, z, _ = predict_batch(X_samples, self.model, batch_size=batch_size)
70
71
 
71
72
  # compute average energy for samples
72
- energy, _ = gmm_energy(z, self.gmm_params, return_mean=False)
73
+ energy, _ = gmm_energy(z, self._gmm_params, return_mean=False)
73
74
  energy_samples = energy.numpy().reshape((-1, self.samples)) # type: ignore
74
75
  iscore = np.mean(energy_samples, axis=-1)
75
76
  return OODScoreOutput(iscore)
@@ -6,6 +6,7 @@ representation which may impact model performance.
6
6
  from dataeval.metrics.bias.balance import BalanceOutput, balance
7
7
  from dataeval.metrics.bias.coverage import CoverageOutput, coverage
8
8
  from dataeval.metrics.bias.diversity import DiversityOutput, diversity
9
+ from dataeval.metrics.bias.metadata_preprocessing import MetadataOutput, metadata_preprocessing
9
10
  from dataeval.metrics.bias.parity import ParityOutput, label_parity, parity
10
11
 
11
12
  __all__ = [
@@ -14,8 +15,10 @@ __all__ = [
14
15
  "diversity",
15
16
  "label_parity",
16
17
  "parity",
18
+ "metadata_preprocessing",
17
19
  "BalanceOutput",
18
20
  "CoverageOutput",
19
21
  "DiversityOutput",
20
22
  "ParityOutput",
23
+ "MetadataOutput",
21
24
  ]
@@ -5,13 +5,15 @@ __all__ = ["BalanceOutput", "balance"]
5
5
  import contextlib
6
6
  import warnings
7
7
  from dataclasses import dataclass
8
- from typing import Any, Mapping
8
+ from typing import Any
9
9
 
10
10
  import numpy as np
11
- from numpy.typing import ArrayLike, NDArray
11
+ import scipy as sp
12
+ from numpy.typing import NDArray
12
13
  from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
13
14
 
14
- from dataeval.metrics.bias.metadata import entropy, heatmap, preprocess_metadata
15
+ from dataeval.metrics.bias.metadata_preprocessing import MetadataOutput
16
+ from dataeval.metrics.bias.metadata_utils import get_counts, heatmap
15
17
  from dataeval.output import OutputMetadata, set_metadata
16
18
 
17
19
  with contextlib.suppress(ImportError):
@@ -31,17 +33,17 @@ class BalanceOutput(OutputMetadata):
31
33
  Estimate of inter/intra-factor mutual information
32
34
  classwise : NDArray[np.float64]
33
35
  Estimate of mutual information between metadata factors and individual class labels
34
- class_list: NDArray
35
- Array of the class labels present in the dataset
36
- metadata_names: list[str]
36
+ factor_names : list[str]
37
37
  Names of each metadata factor
38
+ class_list : NDArray
39
+ Array of the class labels present in the dataset
38
40
  """
39
41
 
40
42
  balance: NDArray[np.float64]
41
43
  factors: NDArray[np.float64]
42
44
  classwise: NDArray[np.float64]
45
+ factor_names: list[str]
43
46
  class_list: NDArray[Any]
44
- metadata_names: list[str]
45
47
 
46
48
  def plot(
47
49
  self,
@@ -54,9 +56,9 @@ class BalanceOutput(OutputMetadata):
54
56
 
55
57
  Parameters
56
58
  ----------
57
- row_labels : ArrayLike | None, default None
59
+ row_labels : ArrayLike or None, default None
58
60
  List/Array containing the labels for rows in the histogram
59
- col_labels : ArrayLike | None, default None
61
+ col_labels : ArrayLike or None, default None
60
62
  List/Array containing the labels for columns in the histogram
61
63
  plot_classwise : bool, default False
62
64
  Whether to plot per-class balance instead of global balance
@@ -65,7 +67,7 @@ class BalanceOutput(OutputMetadata):
65
67
  if row_labels is None:
66
68
  row_labels = self.class_list
67
69
  if col_labels is None:
68
- col_labels = np.concatenate((["class"], self.metadata_names))
70
+ col_labels = self.factor_names
69
71
 
70
72
  fig = heatmap(
71
73
  self.classwise,
@@ -83,7 +85,7 @@ class BalanceOutput(OutputMetadata):
83
85
  # Finalize the data for the plot, last row is last factor x last factor so it gets dropped
84
86
  heat_data = np.where(mask, np.nan, data)[:-1]
85
87
  # Creating label array for heat map axes
86
- heat_labels = np.concatenate((["class"], self.metadata_names))
88
+ heat_labels = self.factor_names
87
89
 
88
90
  if row_labels is None:
89
91
  row_labels = heat_labels[:-1]
@@ -95,7 +97,7 @@ class BalanceOutput(OutputMetadata):
95
97
  return fig
96
98
 
97
99
 
98
- def validate_num_neighbors(num_neighbors: int) -> int:
100
+ def _validate_num_neighbors(num_neighbors: int) -> int:
99
101
  if not isinstance(num_neighbors, (int, float)):
100
102
  raise TypeError(
101
103
  f"Variable {num_neighbors} is not real-valued numeric type."
@@ -116,19 +118,17 @@ def validate_num_neighbors(num_neighbors: int) -> int:
116
118
 
117
119
 
118
120
  @set_metadata("dataeval.metrics")
119
- def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neighbors: int = 5) -> BalanceOutput:
121
+ def balance(
122
+ metadata: MetadataOutput,
123
+ num_neighbors: int = 5,
124
+ ) -> BalanceOutput:
120
125
  """
121
126
  Mutual information (MI) between factors (class label, metadata, label/image properties)
122
127
 
123
128
  Parameters
124
129
  ----------
125
- class_labels: ArrayLike
126
- List of class labels for each image
127
- metadata: Mapping[str, ArrayLike]
128
- Dict of lists of metadata factors for each image
129
- num_neighbors: int, default 5
130
- Number of nearest neighbors to use for computing MI between discrete
131
- and continuous variables.
130
+ metadata : MetadataOutput
131
+ Output after running `metadata_preprocessing`
132
132
 
133
133
  Returns
134
134
  -------
@@ -140,30 +140,34 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
140
140
  ----
141
141
  We use `mutual_info_classif` from sklearn since class label is categorical.
142
142
  `mutual_info_classif` outputs are consistent up to O(1e-4) and depend on a random
143
- seed. MI is computed differently for categorical and continuous variables, and
144
- we attempt to infer whether a variable is categorical by the fraction of unique
145
- values in the dataset.
143
+ seed. MI is computed differently for categorical and continuous variables.
146
144
 
147
145
  Example
148
146
  -------
149
147
  Return balance (mutual information) of factors with class_labels
150
148
 
151
- >>> bal = balance(class_labels, metadata)
149
+ >>> bal = balance(metadata)
152
150
  >>> bal.balance
153
- array([0.99999822, 0.13363788, 0.04505382, 0.02994455])
151
+ array([0.9999982 , 0.2494567 , 0.02994455, 0.13363788, 0. ,
152
+ 0. ])
154
153
 
155
154
  Return intra/interfactor balance (mutual information)
156
155
 
157
156
  >>> bal.factors
158
- array([[0.99999843, 0.04133555, 0.09725766],
159
- [0.04133555, 0.08433558, 0.1301489 ],
160
- [0.09725766, 0.1301489 , 0.99999856]])
157
+ array([[0.99999935, 0.31360499, 0.26925848, 0.85201924, 0.36653548],
158
+ [0.31360499, 0.99999856, 0.09725766, 0.15836905, 1.98031993],
159
+ [0.26925848, 0.09725766, 0.99999846, 0.03713108, 0.01544656],
160
+ [0.85201924, 0.15836905, 0.03713108, 0.47450653, 0.25509664],
161
+ [0.36653548, 1.98031993, 0.01544656, 0.25509664, 1.06260686]])
161
162
 
162
163
  Return classwise balance (mutual information) of factors with individual class_labels
163
164
 
164
165
  >>> bal.classwise
165
- array([[0.99999822, 0.13363788, 0. , 0. ],
166
- [0.99999822, 0.13363788, 0. , 0. ]])
166
+ array([[0.9999982 , 0.2494567 , 0.02994455, 0.13363788, 0. ,
167
+ 0. ],
168
+ [0.9999982 , 0.2494567 , 0.02994455, 0.13363788, 0. ,
169
+ 0. ]])
170
+
167
171
 
168
172
  See Also
169
173
  --------
@@ -171,69 +175,78 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
171
175
  sklearn.feature_selection.mutual_info_regression
172
176
  sklearn.metrics.mutual_info_score
173
177
  """
174
- num_neighbors = validate_num_neighbors(num_neighbors)
175
- data, names, is_categorical, unique_labels = preprocess_metadata(class_labels, metadata)
176
- num_factors = len(names)
177
- mi = np.empty((num_factors, num_factors))
178
- mi[:] = np.nan
178
+ num_neighbors = _validate_num_neighbors(num_neighbors)
179
+
180
+ num_factors = metadata.total_num_factors
181
+ is_discrete = [True] * (len(metadata.discrete_factor_names) + 1) + [False] * len(metadata.continuous_factor_names)
182
+ mi = np.full((num_factors, num_factors), np.nan, dtype=np.float32)
183
+ data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
184
+ discretized_data = data
185
+ if metadata.continuous_data is not None:
186
+ data = np.hstack((data, metadata.continuous_data))
187
+ discrete_idx = [metadata.discrete_factor_names.index(name) for name in metadata.continuous_factor_names]
188
+ discretized_data = np.hstack((discretized_data, metadata.discrete_data[:, discrete_idx]))
179
189
 
180
190
  for idx in range(num_factors):
181
- tgt = data[:, idx].astype(int)
182
-
183
- if is_categorical[idx]:
184
- mi[idx, :] = mutual_info_classif(
191
+ if idx >= len(metadata.discrete_factor_names) + 1:
192
+ mi[idx, :] = mutual_info_regression(
185
193
  data,
186
- tgt,
187
- discrete_features=is_categorical, # type: ignore
194
+ data[:, idx],
195
+ discrete_features=is_discrete, # type: ignore
188
196
  n_neighbors=num_neighbors,
189
197
  random_state=0,
190
198
  )
191
199
  else:
192
- mi[idx, :] = mutual_info_regression(
200
+ mi[idx, :] = mutual_info_classif(
193
201
  data,
194
- tgt,
195
- discrete_features=is_categorical, # type: ignore
202
+ data[:, idx],
203
+ discrete_features=is_discrete, # type: ignore
196
204
  n_neighbors=num_neighbors,
197
205
  random_state=0,
198
206
  )
199
207
 
200
- ent_all = entropy(data, names, is_categorical, normalized=False)
201
- norm_factor = 0.5 * np.add.outer(ent_all, ent_all) + 1e-6
208
+ # Normalization via entropy
209
+ bin_cnts = get_counts(discretized_data)
210
+ ent_factor = sp.stats.entropy(bin_cnts, axis=0)
211
+ norm_factor = 0.5 * np.add.outer(ent_factor, ent_factor) + 1e-6
212
+
202
213
  # in principle MI should be symmetric, but it is not in practice.
203
214
  nmi = 0.5 * (mi + mi.T) / norm_factor
204
215
  balance = nmi[0]
205
216
  factors = nmi[1:, 1:]
206
217
 
207
- # unique class labels
208
- class_idx = names.index("class_label")
209
- u_cls = np.unique(data[:, class_idx])
210
- num_classes = len(u_cls)
211
-
212
218
  # assume class is a factor
213
- classwise_mi = np.empty((num_classes, num_factors))
214
- classwise_mi[:] = np.nan
215
-
216
- # categorical variables, excluding class label
217
- cat_mask = np.concatenate((is_categorical[:class_idx], is_categorical[(class_idx + 1) :]), axis=0).astype(int)
219
+ num_classes = metadata.class_names.size
220
+ classwise_mi = np.full((num_classes, num_factors), np.nan, dtype=np.float32)
218
221
 
219
- tgt_bin = np.stack([data[:, class_idx] == cls for cls in u_cls]).T.astype(int)
220
- ent_tgt_bin = entropy(
221
- tgt_bin, names=[str(idx) for idx in range(num_classes)], is_categorical=[True for idx in range(num_classes)]
222
- )
222
+ # classwise targets
223
+ classes = np.unique(metadata.class_labels)
224
+ tgt_bin = data[:, 0][:, None] == classes
223
225
 
224
226
  # classification MI for discrete/categorical features
225
227
  for idx in range(num_classes):
226
- # tgt = class_data == cls
227
228
  # units: nat
228
229
  classwise_mi[idx, :] = mutual_info_classif(
229
230
  data,
230
231
  tgt_bin[:, idx],
231
- discrete_features=cat_mask, # type: ignore
232
+ discrete_features=is_discrete, # type: ignore
232
233
  n_neighbors=num_neighbors,
233
234
  random_state=0,
234
235
  )
235
236
 
236
- norm_factor = 0.5 * np.add.outer(ent_tgt_bin, ent_all) + 1e-6
237
+ # Classwise normalization via entropy
238
+ classwise_bin_cnts = get_counts(tgt_bin)
239
+ ent_tgt_bin = sp.stats.entropy(classwise_bin_cnts, axis=0)
240
+ norm_factor = 0.5 * np.add.outer(ent_tgt_bin, ent_factor) + 1e-6
237
241
  classwise = classwise_mi / norm_factor
238
242
 
239
- return BalanceOutput(balance, factors, classwise, unique_labels, list(metadata.keys()))
243
+ # Grabbing factor names for plotting function
244
+ factor_names = ["class"]
245
+ for name in metadata.discrete_factor_names:
246
+ if name in metadata.continuous_factor_names:
247
+ name = name + "-discrete"
248
+ factor_names.append(name)
249
+ for name in metadata.continuous_factor_names:
250
+ factor_names.append(name + "-continuous")
251
+
252
+ return BalanceOutput(balance, factors, classwise, factor_names, metadata.class_names)
@@ -5,14 +5,14 @@ __all__ = ["CoverageOutput", "coverage"]
5
5
  import contextlib
6
6
  import math
7
7
  from dataclasses import dataclass
8
- from typing import Any, Literal
8
+ from typing import Literal
9
9
 
10
10
  import numpy as np
11
11
  from numpy.typing import ArrayLike, NDArray
12
12
  from scipy.spatial.distance import pdist, squareform
13
13
 
14
14
  from dataeval.interop import to_numpy
15
- from dataeval.metrics.bias.metadata import coverage_plot
15
+ from dataeval.metrics.bias.metadata_utils import coverage_plot
16
16
  from dataeval.output import OutputMetadata, set_metadata
17
17
  from dataeval.utils.shared import flatten
18
18
 
@@ -27,9 +27,9 @@ class CoverageOutput(OutputMetadata):
27
27
 
28
28
  Attributes
29
29
  ----------
30
- indices : NDArray
30
+ indices : NDArray[np.intp]
31
31
  Array of uncovered indices
32
- radii : NDArray
32
+ radii : NDArray[np.float64]
33
33
  Array of critical value radii
34
34
  critical_value : float
35
35
  Radius for :term:`coverage<Coverage>`
@@ -39,11 +39,7 @@ class CoverageOutput(OutputMetadata):
39
39
  radii: NDArray[np.float64]
40
40
  critical_value: float
41
41
 
42
- def plot(
43
- self,
44
- images: NDArray[Any],
45
- top_k: int = 6,
46
- ) -> Figure:
42
+ def plot(self, images: ArrayLike, top_k: int = 6) -> Figure:
47
43
  """
48
44
  Plot the top k images together for visualization
49
45
 
@@ -53,6 +49,10 @@ class CoverageOutput(OutputMetadata):
53
49
  Original images (not embeddings) in (N, C, H, W) or (N, H, W) format
54
50
  top_k : int, default 6
55
51
  Number of images to plot (plotting assumes groups of 3)
52
+
53
+ Returns
54
+ -------
55
+ matplotlib.figure.Figure
56
56
  """
57
57
  # Determine which images to plot
58
58
  highest_uncovered_indices = self.indices[:top_k]
@@ -82,12 +82,12 @@ def coverage(
82
82
  embeddings : ArrayLike, shape - (N, P)
83
83
  A dataset in an ArrayLike format.
84
84
  Function expects the data to have 2 dimensions, N number of observations in a P-dimesionial space.
85
- radius_type : Literal["adaptive", "naive"], default "adaptive"
85
+ radius_type : {"adaptive", "naive"}, default "adaptive"
86
86
  The function used to determine radius.
87
- k: int, default 20
87
+ k : int, default 20
88
88
  Number of observations required in order to be covered.
89
89
  [1] suggests that a minimum of 20-50 samples is necessary.
90
- percent: float, default 0.01
90
+ percent : float, default 0.01
91
91
  Percent of observations to be considered uncovered. Only applies to adaptive radius.
92
92
 
93
93
  Returns