dataeval 0.81.0__py3-none-any.whl → 0.82.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/config.py +68 -11
  3. dataeval/detectors/drift/__init__.py +2 -2
  4. dataeval/detectors/drift/_base.py +8 -64
  5. dataeval/detectors/drift/_mmd.py +12 -38
  6. dataeval/detectors/drift/_torch.py +7 -7
  7. dataeval/detectors/drift/_uncertainty.py +6 -5
  8. dataeval/detectors/drift/updates.py +20 -3
  9. dataeval/detectors/linters/__init__.py +3 -2
  10. dataeval/detectors/linters/duplicates.py +14 -46
  11. dataeval/detectors/linters/outliers.py +25 -159
  12. dataeval/detectors/ood/__init__.py +1 -1
  13. dataeval/detectors/ood/ae.py +6 -5
  14. dataeval/detectors/ood/base.py +2 -2
  15. dataeval/detectors/ood/metadata_ood_mi.py +4 -6
  16. dataeval/detectors/ood/mixin.py +3 -4
  17. dataeval/detectors/ood/vae.py +3 -2
  18. dataeval/metadata/__init__.py +2 -1
  19. dataeval/metadata/_distance.py +134 -0
  20. dataeval/metadata/_ood.py +30 -49
  21. dataeval/metadata/_utils.py +44 -0
  22. dataeval/metrics/bias/__init__.py +5 -4
  23. dataeval/metrics/bias/_balance.py +17 -149
  24. dataeval/metrics/bias/_coverage.py +4 -106
  25. dataeval/metrics/bias/_diversity.py +12 -107
  26. dataeval/metrics/bias/_parity.py +7 -71
  27. dataeval/metrics/estimators/__init__.py +5 -4
  28. dataeval/metrics/estimators/_ber.py +2 -20
  29. dataeval/metrics/estimators/_clusterer.py +1 -61
  30. dataeval/metrics/estimators/_divergence.py +2 -19
  31. dataeval/metrics/estimators/_uap.py +2 -16
  32. dataeval/metrics/stats/__init__.py +15 -12
  33. dataeval/metrics/stats/_base.py +41 -128
  34. dataeval/metrics/stats/_boxratiostats.py +13 -13
  35. dataeval/metrics/stats/_dimensionstats.py +17 -58
  36. dataeval/metrics/stats/_hashstats.py +19 -35
  37. dataeval/metrics/stats/_imagestats.py +94 -0
  38. dataeval/metrics/stats/_labelstats.py +42 -121
  39. dataeval/metrics/stats/_pixelstats.py +19 -51
  40. dataeval/metrics/stats/_visualstats.py +19 -51
  41. dataeval/outputs/__init__.py +57 -0
  42. dataeval/outputs/_base.py +182 -0
  43. dataeval/outputs/_bias.py +381 -0
  44. dataeval/outputs/_drift.py +83 -0
  45. dataeval/outputs/_estimators.py +114 -0
  46. dataeval/outputs/_linters.py +186 -0
  47. dataeval/outputs/_metadata.py +54 -0
  48. dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
  49. dataeval/outputs/_stats.py +393 -0
  50. dataeval/outputs/_utils.py +44 -0
  51. dataeval/outputs/_workflows.py +364 -0
  52. dataeval/typing.py +187 -7
  53. dataeval/utils/_method.py +1 -5
  54. dataeval/utils/_plot.py +2 -2
  55. dataeval/utils/data/__init__.py +5 -1
  56. dataeval/utils/data/_dataset.py +217 -0
  57. dataeval/utils/data/_embeddings.py +12 -14
  58. dataeval/utils/data/_images.py +30 -27
  59. dataeval/utils/data/_metadata.py +28 -11
  60. dataeval/utils/data/_selection.py +25 -22
  61. dataeval/utils/data/_split.py +5 -29
  62. dataeval/utils/data/_targets.py +14 -2
  63. dataeval/utils/data/datasets/_base.py +5 -5
  64. dataeval/utils/data/datasets/_cifar10.py +1 -1
  65. dataeval/utils/data/datasets/_milco.py +1 -1
  66. dataeval/utils/data/datasets/_mnist.py +1 -1
  67. dataeval/utils/data/datasets/_ships.py +1 -1
  68. dataeval/utils/data/{_types.py → datasets/_types.py} +10 -16
  69. dataeval/utils/data/datasets/_voc.py +1 -1
  70. dataeval/utils/data/selections/_classfilter.py +4 -5
  71. dataeval/utils/data/selections/_indices.py +2 -2
  72. dataeval/utils/data/selections/_limit.py +2 -2
  73. dataeval/utils/data/selections/_reverse.py +2 -2
  74. dataeval/utils/data/selections/_shuffle.py +2 -2
  75. dataeval/utils/torch/_internal.py +5 -5
  76. dataeval/utils/torch/trainer.py +8 -8
  77. dataeval/workflows/__init__.py +2 -1
  78. dataeval/workflows/sufficiency.py +6 -342
  79. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/METADATA +2 -2
  80. dataeval-0.82.1.dist-info/RECORD +105 -0
  81. dataeval/_output.py +0 -137
  82. dataeval/detectors/ood/metadata_ks_compare.py +0 -129
  83. dataeval/metrics/stats/_datasetstats.py +0 -198
  84. dataeval-0.81.0.dist-info/RECORD +0 -94
  85. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/LICENSE.txt +0 -0
  86. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/WHEEL +0 -0
dataeval/metadata/_ood.py CHANGED
@@ -7,51 +7,12 @@ import warnings
7
7
  import numpy as np
8
8
  from numpy.typing import NDArray
9
9
 
10
- from dataeval.detectors.ood import OODOutput
10
+ from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
11
+ from dataeval.outputs import MostDeviatedFactorsOutput, OODOutput
12
+ from dataeval.outputs._base import set_metadata
11
13
  from dataeval.utils.data import Metadata
12
14
 
13
15
 
14
- def _validate_keys(keys1: list[str], keys2: list[str]) -> None:
15
- """
16
- Raises error when two lists are not equivalent including ordering
17
-
18
- Parameters
19
- ----------
20
- keys1 : list of strings
21
- List of strings to compare
22
- keys2 : list of strings
23
- List of strings to compare
24
-
25
- Raises
26
- ------
27
- ValueError
28
- If lists do not have the same values, value counts, or ordering
29
- """
30
-
31
- if keys1 != keys2:
32
- raise ValueError(f"Metadata keys must be identical, got {keys1} and {keys2}")
33
-
34
-
35
- def _validate_factors_and_data(factors: list[str], data: NDArray) -> None:
36
- """
37
- Raises error when the number of factors and number of rows do not match
38
-
39
- Parameters
40
- ----------
41
- factors : list of strings
42
- List of factor names of size N
43
- data : NDArray
44
- Array of values with shape (M, N)
45
-
46
- Raises
47
- ------
48
- ValueError
49
- If the length of factors does not equal the length of the transposed data
50
- """
51
- if len(factors) != len(data.T):
52
- raise ValueError(f"Factors and data have mismatched lengths. Got {len(factors)} and {len(data.T)}")
53
-
54
-
55
16
  def _combine_metadata(metadata_1: Metadata, metadata_2: Metadata) -> tuple[list[str], list[NDArray], list[NDArray]]:
56
17
  """
57
18
  Combines the factor names and data arrays of metadata_1 and metadata_2 when the names
@@ -93,7 +54,7 @@ def _combine_metadata(metadata_1: Metadata, metadata_2: Metadata) -> tuple[list[
93
54
 
94
55
  # Validate and attach discrete data
95
56
  if metadata_1.discrete_factor_names:
96
- _validate_keys(metadata_1.discrete_factor_names, metadata_2.discrete_factor_names)
57
+ _compare_keys(metadata_1.discrete_factor_names, metadata_2.discrete_factor_names)
97
58
  _validate_factors_and_data(metadata_1.discrete_factor_names, metadata_1.discrete_data)
98
59
 
99
60
  factor_names.extend(metadata_1.discrete_factor_names)
@@ -102,7 +63,7 @@ def _combine_metadata(metadata_1: Metadata, metadata_2: Metadata) -> tuple[list[
102
63
 
103
64
  # Validate and attach continuous data
104
65
  if metadata_1.continuous_factor_names:
105
- _validate_keys(metadata_1.continuous_factor_names, metadata_2.continuous_factor_names)
66
+ _compare_keys(metadata_1.continuous_factor_names, metadata_2.continuous_factor_names)
106
67
  _validate_factors_and_data(metadata_1.continuous_factor_names, metadata_1.continuous_data)
107
68
 
108
69
  factor_names.extend(metadata_1.continuous_factor_names)
@@ -159,11 +120,12 @@ def _calc_median_deviations(reference: NDArray, test: NDArray) -> NDArray:
159
120
  return np.abs(np.where(test_dev >= 0, test_dev / pscale, test_dev / nscale)) # (S_t, F)
160
121
 
161
122
 
123
+ @set_metadata
162
124
  def most_deviated_factors(
163
125
  metadata_1: Metadata,
164
126
  metadata_2: Metadata,
165
127
  ood: OODOutput,
166
- ) -> list[tuple[str, float]]:
128
+ ) -> MostDeviatedFactorsOutput:
167
129
  """
168
130
  Determines greatest deviation in metadata features per out of distribution sample in metadata_2.
169
131
 
@@ -189,13 +151,30 @@ def most_deviated_factors(
189
151
  and have equivalent factor names and lengths
190
152
  2. The flag at index `i` in :attr:`.OODOutput.is_ood` must correspond
191
153
  directly to sample `i` of `metadata_2` being out-of-distribution from `metadata_1`
154
+
155
+ Examples
156
+ --------
157
+
158
+ >>> from dataeval.detectors.ood import OODOutput
159
+
160
+ All samples are out-of-distribution
161
+
162
+ >>> is_ood = OODOutput(np.array([True, True, True]), np.array([]), np.array([]))
163
+ >>> most_deviated_factors(metadata1, metadata2, is_ood)
164
+ MostDeviatedFactorsOutput([('time', 2.0), ('time', 2.592), ('time', 3.51)])
165
+
166
+ If there are no out-of-distribution samples, a list is returned
167
+
168
+ >>> is_ood = OODOutput(np.array([False, False, False]), np.array([]), np.array([]))
169
+ >>> most_deviated_factors(metadata1, metadata2, is_ood)
170
+ MostDeviatedFactorsOutput([])
192
171
  """
193
172
 
194
173
  ood_mask: NDArray[np.bool] = ood.is_ood
195
174
 
196
175
  # No metadata correlated with out of distribution data
197
176
  if not any(ood_mask):
198
- return []
177
+ return MostDeviatedFactorsOutput([])
199
178
 
200
179
  # Combines reference and test factor names and data if exists and match exactly
201
180
  # shape -> (samples, factors)
@@ -204,6 +183,7 @@ def most_deviated_factors(
204
183
  metadata_2=metadata_2,
205
184
  )
206
185
 
186
+ # Stack discrete and continuous factors as separate factors. Must have equal sample counts
207
187
  metadata_ref = np.hstack(md_1) if md_1 else np.array([])
208
188
  metadata_tst = np.hstack(md_2) if md_2 else np.array([])
209
189
 
@@ -212,7 +192,7 @@ def most_deviated_factors(
212
192
  f"At least 3 reference metadata samples are needed, got {len(metadata_ref)}",
213
193
  UserWarning,
214
194
  )
215
- return []
195
+ return MostDeviatedFactorsOutput([])
216
196
 
217
197
  if len(metadata_tst) != len(ood_mask):
218
198
  raise ValueError(
@@ -226,7 +206,7 @@ def most_deviated_factors(
226
206
  deviations = _calc_median_deviations(metadata_ref, metadata_tst)
227
207
 
228
208
  # Get most impactful factor deviation of each sample for ood samples only
229
- deviation = np.max(deviations, axis=1)[ood_mask]
209
+ deviation = np.max(deviations, axis=1)[ood_mask].astype(np.float16)
230
210
 
231
211
  # Get indices of most impactful factors for ood samples only
232
212
  max_factors = np.argmax(deviations, axis=1)[ood_mask]
@@ -235,4 +215,5 @@ def most_deviated_factors(
235
215
  most_ood_factors = np.array(factor_names)[max_factors].tolist()
236
216
 
237
217
  # List of tuples matching the factor name with its deviation
238
- return [(factor, dev.item()) for factor, dev in zip(most_ood_factors, deviation)]
218
+
219
+ return MostDeviatedFactorsOutput([(factor, dev) for factor, dev in zip(most_ood_factors, deviation)])
@@ -0,0 +1,44 @@
1
+ __all__ = []
2
+
3
+ from numpy.typing import NDArray
4
+
5
+
6
+ def _compare_keys(keys1: list[str], keys2: list[str]) -> None:
7
+ """
8
+ Raises error when two lists are not equivalent including ordering
9
+
10
+ Parameters
11
+ ----------
12
+ keys1 : list of strings
13
+ List of strings to compare
14
+ keys2 : list of strings
15
+ List of strings to compare
16
+
17
+ Raises
18
+ ------
19
+ ValueError
20
+ If lists do not have the same values, value counts, or ordering
21
+ """
22
+
23
+ if keys1 != keys2:
24
+ raise ValueError(f"Metadata keys must be identical, got {keys1} and {keys2}")
25
+
26
+
27
+ def _validate_factors_and_data(factors: list[str], data: NDArray) -> None:
28
+ """
29
+ Raises error when the number of factors and number of rows do not match
30
+
31
+ Parameters
32
+ ----------
33
+ factors : list of strings
34
+ List of factor names of size N
35
+ data : NDArray
36
+ Array of values with shape (M, N)
37
+
38
+ Raises
39
+ ------
40
+ ValueError
41
+ If the length of factors does not equal the length of the transposed data
42
+ """
43
+ if len(factors) != len(data.T):
44
+ raise ValueError(f"Factors and data have mismatched lengths. Got {len(factors)} and {len(data.T)}")
@@ -16,7 +16,8 @@ __all__ = [
16
16
  "parity",
17
17
  ]
18
18
 
19
- from dataeval.metrics.bias._balance import BalanceOutput, balance
20
- from dataeval.metrics.bias._coverage import CoverageOutput, coverage
21
- from dataeval.metrics.bias._diversity import DiversityOutput, diversity
22
- from dataeval.metrics.bias._parity import LabelParityOutput, ParityOutput, label_parity, parity
19
+ from dataeval.metrics.bias._balance import balance
20
+ from dataeval.metrics.bias._coverage import coverage
21
+ from dataeval.metrics.bias._diversity import diversity
22
+ from dataeval.metrics.bias._parity import label_parity, parity
23
+ from dataeval.outputs._bias import BalanceOutput, CoverageOutput, DiversityOutput, LabelParityOutput, ParityOutput
@@ -2,150 +2,18 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- import contextlib
6
5
  import warnings
7
- from dataclasses import dataclass
8
- from typing import Any, Literal, overload
9
6
 
10
7
  import numpy as np
11
8
  import scipy as sp
12
- from numpy.typing import NDArray
13
9
  from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
14
10
 
15
- from dataeval._output import Output, set_metadata
11
+ from dataeval.config import get_seed
12
+ from dataeval.outputs import BalanceOutput
13
+ from dataeval.outputs._base import set_metadata
16
14
  from dataeval.utils._bin import get_counts
17
- from dataeval.utils._plot import heatmap
18
15
  from dataeval.utils.data import Metadata
19
16
 
20
- with contextlib.suppress(ImportError):
21
- from matplotlib.figure import Figure
22
-
23
-
24
- @dataclass(frozen=True)
25
- class BalanceOutput(Output):
26
- """
27
- Output class for :func:`.balance` :term:`bias<Bias>` metric.
28
-
29
- Attributes
30
- ----------
31
- balance : NDArray[np.float64]
32
- Estimate of mutual information between metadata factors and class label
33
- factors : NDArray[np.float64]
34
- Estimate of inter/intra-factor mutual information
35
- classwise : NDArray[np.float64]
36
- Estimate of mutual information between metadata factors and individual class labels
37
- factor_names : list[str]
38
- Names of each metadata factor
39
- class_names : list[str]
40
- List of the class labels present in the dataset
41
- """
42
-
43
- balance: NDArray[np.float64]
44
- factors: NDArray[np.float64]
45
- classwise: NDArray[np.float64]
46
- factor_names: list[str]
47
- class_names: list[str]
48
-
49
- @overload
50
- def _by_factor_type(
51
- self,
52
- attr: Literal["factor_names"],
53
- factor_type: Literal["discrete", "continuous", "both"],
54
- ) -> list[str]: ...
55
-
56
- @overload
57
- def _by_factor_type(
58
- self,
59
- attr: Literal["balance", "factors", "classwise"],
60
- factor_type: Literal["discrete", "continuous", "both"],
61
- ) -> NDArray[np.float64]: ...
62
-
63
- def _by_factor_type(
64
- self,
65
- attr: Literal["balance", "factors", "classwise", "factor_names"],
66
- factor_type: Literal["discrete", "continuous", "both"],
67
- ) -> NDArray[np.float64] | list[str]:
68
- # if not filtering by factor_type then just return the requested attribute without mask
69
- if factor_type == "both":
70
- return getattr(self, attr)
71
-
72
- # create the mask for the selected factor_type
73
- mask_lambda = (
74
- (lambda x: "-continuous" not in x) if factor_type == "discrete" else (lambda x: "-discrete" not in x)
75
- )
76
-
77
- # return the masked attribute
78
- if attr == "factor_names":
79
- return [x.replace(f"-{factor_type}", "") for x in self.factor_names if mask_lambda(x)]
80
- else:
81
- factor_type_mask = [mask_lambda(x) for x in self.factor_names]
82
- if attr == "factors":
83
- return self.factors[factor_type_mask[1:]][:, factor_type_mask[1:]]
84
- elif attr == "balance":
85
- return self.balance[factor_type_mask]
86
- elif attr == "classwise":
87
- return self.classwise[:, factor_type_mask]
88
-
89
- def plot(
90
- self,
91
- row_labels: list[Any] | NDArray[Any] | None = None,
92
- col_labels: list[Any] | NDArray[Any] | None = None,
93
- plot_classwise: bool = False,
94
- factor_type: Literal["discrete", "continuous", "both"] = "discrete",
95
- ) -> Figure:
96
- """
97
- Plot a heatmap of balance information
98
-
99
- Parameters
100
- ----------
101
- row_labels : ArrayLike or None, default None
102
- List/Array containing the labels for rows in the histogram
103
- col_labels : ArrayLike or None, default None
104
- List/Array containing the labels for columns in the histogram
105
- plot_classwise : bool, default False
106
- Whether to plot per-class balance instead of global balance
107
- factor_type : "discrete", "continuous", or "both", default "discrete"
108
- Whether to plot discretized values, continuous values, or to include both
109
- """
110
- if plot_classwise:
111
- if row_labels is None:
112
- row_labels = self.class_names
113
- if col_labels is None:
114
- col_labels = self._by_factor_type("factor_names", factor_type)
115
-
116
- fig = heatmap(
117
- self._by_factor_type("classwise", factor_type),
118
- row_labels,
119
- col_labels,
120
- xlabel="Factors",
121
- ylabel="Class",
122
- cbarlabel="Normalized Mutual Information",
123
- )
124
- else:
125
- # Combine balance and factors results
126
- data = np.concatenate(
127
- [
128
- self._by_factor_type("balance", factor_type)[np.newaxis, 1:],
129
- self._by_factor_type("factors", factor_type),
130
- ],
131
- axis=0,
132
- )
133
- # Create a mask for the upper triangle of the symmetrical array, ignoring the diagonal
134
- mask = np.triu(data + 1, k=0) < 1
135
- # Finalize the data for the plot, last row is last factor x last factor so it gets dropped
136
- heat_data = np.where(mask, np.nan, data)[:-1]
137
- # Creating label array for heat map axes
138
- heat_labels = self._by_factor_type("factor_names", factor_type)
139
-
140
- if row_labels is None:
141
- row_labels = heat_labels[:-1]
142
- if col_labels is None:
143
- col_labels = heat_labels[1:]
144
-
145
- fig = heatmap(heat_data, row_labels, col_labels, cbarlabel="Normalized Mutual Information")
146
-
147
- return fig
148
-
149
17
 
150
18
  def _validate_num_neighbors(num_neighbors: int) -> int:
151
19
  if not isinstance(num_neighbors, (int, float)):
@@ -200,25 +68,22 @@ def balance(
200
68
 
201
69
  >>> bal = balance(metadata)
202
70
  >>> bal.balance
203
- array([0.9999982 , 0.2494567 , 0.02994455, 0.13363788, 0. ,
204
- 0. ])
71
+ array([1. , 0.249, 0.03 , 0.134, 0. , 0. ])
205
72
 
206
73
  Return intra/interfactor balance (mutual information)
207
74
 
208
75
  >>> bal.factors
209
- array([[0.99999935, 0.31360499, 0.26925848, 0.85201924, 0.36653548],
210
- [0.31360499, 0.99999856, 0.09725766, 0.15836905, 1.98031993],
211
- [0.26925848, 0.09725766, 0.99999846, 0.03713108, 0.01544656],
212
- [0.85201924, 0.15836905, 0.03713108, 0.47450653, 0.25509664],
213
- [0.36653548, 1.98031993, 0.01544656, 0.25509664, 1.06260686]])
76
+ array([[1. , 0.314, 0.269, 0.852, 0.367],
77
+ [0.314, 1. , 0.097, 0.158, 1.98 ],
78
+ [0.269, 0.097, 1. , 0.037, 0.015],
79
+ [0.852, 0.158, 0.037, 0.475, 0.255],
80
+ [0.367, 1.98 , 0.015, 0.255, 1.063]])
214
81
 
215
82
  Return classwise balance (mutual information) of factors with individual class_labels
216
83
 
217
84
  >>> bal.classwise
218
- array([[0.9999982 , 0.2494567 , 0.02994455, 0.13363788, 0. ,
219
- 0. ],
220
- [0.9999982 , 0.2494567 , 0.02994455, 0.13363788, 0. ,
221
- 0. ]])
85
+ array([[1. , 0.249, 0.03 , 0.134, 0. , 0. ],
86
+ [1. , 0.249, 0.03 , 0.134, 0. , 0. ]])
222
87
 
223
88
 
224
89
  See Also
@@ -227,6 +92,9 @@ def balance(
227
92
  sklearn.feature_selection.mutual_info_regression
228
93
  sklearn.metrics.mutual_info_score
229
94
  """
95
+ if not metadata.discrete_factor_names and not metadata.continuous_factor_names:
96
+ raise ValueError("No factors found in provided metadata.")
97
+
230
98
  num_neighbors = _validate_num_neighbors(num_neighbors)
231
99
 
232
100
  num_factors = metadata.total_num_factors
@@ -246,7 +114,7 @@ def balance(
246
114
  data[:, idx],
247
115
  discrete_features=is_discrete, # type: ignore
248
116
  n_neighbors=num_neighbors,
249
- random_state=0,
117
+ random_state=get_seed(),
250
118
  )
251
119
  else:
252
120
  mi[idx, :] = mutual_info_classif(
@@ -254,7 +122,7 @@ def balance(
254
122
  data[:, idx],
255
123
  discrete_features=is_discrete, # type: ignore
256
124
  n_neighbors=num_neighbors,
257
- random_state=0,
125
+ random_state=get_seed(),
258
126
  )
259
127
 
260
128
  # Normalization via entropy
@@ -283,7 +151,7 @@ def balance(
283
151
  tgt_bin[:, idx],
284
152
  discrete_features=is_discrete, # type: ignore
285
153
  n_neighbors=num_neighbors,
286
- random_state=0,
154
+ random_state=get_seed(),
287
155
  )
288
156
 
289
157
  # Classwise normalization via entropy
@@ -2,118 +2,16 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- import contextlib
6
5
  import math
7
- from dataclasses import dataclass
8
- from typing import Any, Literal
6
+ from typing import Literal
9
7
 
10
8
  import numpy as np
11
- from numpy.typing import NDArray
12
9
  from scipy.spatial.distance import pdist, squareform
13
10
 
14
- from dataeval._output import Output, set_metadata
11
+ from dataeval.outputs import CoverageOutput
12
+ from dataeval.outputs._base import set_metadata
15
13
  from dataeval.typing import ArrayLike
16
- from dataeval.utils._array import ensure_embeddings, flatten, to_numpy
17
-
18
- with contextlib.suppress(ImportError):
19
- from matplotlib.figure import Figure
20
-
21
-
22
- def _plot(images: NDArray[Any], num_images: int) -> Figure:
23
- """
24
- Creates a single plot of all of the provided images
25
-
26
- Parameters
27
- ----------
28
- images : NDArray
29
- Array containing only the desired images to plot
30
-
31
- Returns
32
- -------
33
- matplotlib.figure.Figure
34
- Plot of all provided images
35
- """
36
- import matplotlib.pyplot as plt
37
-
38
- num_images = min(num_images, len(images))
39
-
40
- if images.ndim == 4:
41
- images = np.moveaxis(images, 1, -1)
42
- elif images.ndim == 3:
43
- images = np.repeat(images[:, :, :, np.newaxis], 3, axis=-1)
44
- else:
45
- raise ValueError(
46
- f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
47
- )
48
-
49
- rows = int(np.ceil(num_images / 3))
50
- fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
51
-
52
- if rows == 1:
53
- for j in range(3):
54
- if j >= len(images):
55
- continue
56
- axs[j].imshow(images[j])
57
- axs[j].axis("off")
58
- else:
59
- for i in range(rows):
60
- for j in range(3):
61
- i_j = i * 3 + j
62
- if i_j >= len(images):
63
- continue
64
- axs[i, j].imshow(images[i_j])
65
- axs[i, j].axis("off")
66
-
67
- fig.tight_layout()
68
- return fig
69
-
70
-
71
- @dataclass(frozen=True)
72
- class CoverageOutput(Output):
73
- """
74
- Output class for :func:`.coverage` :term:`bias<Bias>` metric.
75
-
76
- Attributes
77
- ----------
78
- uncovered_indices : NDArray[np.intp]
79
- Array of uncovered indices
80
- critical_value_radii : NDArray[np.float64]
81
- Array of critical value radii
82
- coverage_radius : float
83
- Radius for :term:`coverage<Coverage>`
84
- """
85
-
86
- uncovered_indices: NDArray[np.intp]
87
- critical_value_radii: NDArray[np.float64]
88
- coverage_radius: float
89
-
90
- def plot(self, images: ArrayLike, top_k: int = 6) -> Figure:
91
- """
92
- Plot the top k images together for visualization
93
-
94
- Parameters
95
- ----------
96
- images : ArrayLike
97
- Original images (not embeddings) in (N, C, H, W) or (N, H, W) format
98
- top_k : int, default 6
99
- Number of images to plot (plotting assumes groups of 3)
100
-
101
- Returns
102
- -------
103
- matplotlib.figure.Figure
104
- """
105
-
106
- # Determine which images to plot
107
- highest_uncovered_indices = self.uncovered_indices[:top_k]
108
-
109
- # Grab the images
110
- images = to_numpy(images)
111
- selected_images = images[highest_uncovered_indices]
112
-
113
- # Plot the images
114
- fig = _plot(selected_images, top_k)
115
-
116
- return fig
14
+ from dataeval.utils._array import ensure_embeddings, flatten
117
15
 
118
16
 
119
17
  @set_metadata