dataeval 0.73.1__py3-none-any.whl → 0.74.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. dataeval/__init__.py +3 -9
  2. dataeval/detectors/__init__.py +2 -10
  3. dataeval/detectors/drift/base.py +3 -3
  4. dataeval/detectors/drift/mmd.py +1 -1
  5. dataeval/detectors/drift/torch.py +1 -101
  6. dataeval/detectors/linters/clusterer.py +3 -3
  7. dataeval/detectors/linters/duplicates.py +4 -4
  8. dataeval/detectors/linters/outliers.py +4 -4
  9. dataeval/detectors/ood/__init__.py +9 -9
  10. dataeval/detectors/ood/{ae.py → ae_torch.py} +22 -27
  11. dataeval/detectors/ood/base.py +63 -113
  12. dataeval/detectors/ood/base_torch.py +109 -0
  13. dataeval/detectors/ood/metadata_ks_compare.py +52 -14
  14. dataeval/interop.py +1 -1
  15. dataeval/metrics/bias/__init__.py +3 -0
  16. dataeval/metrics/bias/balance.py +73 -70
  17. dataeval/metrics/bias/coverage.py +4 -4
  18. dataeval/metrics/bias/diversity.py +67 -136
  19. dataeval/metrics/bias/metadata_preprocessing.py +285 -0
  20. dataeval/metrics/bias/metadata_utils.py +229 -0
  21. dataeval/metrics/bias/parity.py +51 -161
  22. dataeval/metrics/estimators/ber.py +3 -3
  23. dataeval/metrics/estimators/divergence.py +3 -3
  24. dataeval/metrics/estimators/uap.py +3 -3
  25. dataeval/metrics/stats/base.py +2 -2
  26. dataeval/metrics/stats/boxratiostats.py +1 -1
  27. dataeval/metrics/stats/datasetstats.py +6 -6
  28. dataeval/metrics/stats/dimensionstats.py +1 -1
  29. dataeval/metrics/stats/hashstats.py +1 -1
  30. dataeval/metrics/stats/labelstats.py +3 -3
  31. dataeval/metrics/stats/pixelstats.py +1 -1
  32. dataeval/metrics/stats/visualstats.py +1 -1
  33. dataeval/output.py +77 -53
  34. dataeval/utils/__init__.py +1 -7
  35. dataeval/utils/gmm.py +26 -0
  36. dataeval/utils/metadata.py +29 -9
  37. dataeval/utils/torch/gmm.py +98 -0
  38. dataeval/utils/torch/models.py +192 -0
  39. dataeval/utils/torch/trainer.py +84 -5
  40. dataeval/utils/torch/utils.py +107 -1
  41. dataeval/workflows/sufficiency.py +4 -4
  42. {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/METADATA +3 -9
  43. dataeval-0.74.1.dist-info/RECORD +65 -0
  44. dataeval/detectors/ood/aegmm.py +0 -66
  45. dataeval/detectors/ood/llr.py +0 -302
  46. dataeval/detectors/ood/vae.py +0 -97
  47. dataeval/detectors/ood/vaegmm.py +0 -75
  48. dataeval/metrics/bias/metadata.py +0 -440
  49. dataeval/utils/lazy.py +0 -26
  50. dataeval/utils/tensorflow/__init__.py +0 -19
  51. dataeval/utils/tensorflow/_internal/gmm.py +0 -123
  52. dataeval/utils/tensorflow/_internal/loss.py +0 -121
  53. dataeval/utils/tensorflow/_internal/models.py +0 -1394
  54. dataeval/utils/tensorflow/_internal/trainer.py +0 -114
  55. dataeval/utils/tensorflow/_internal/utils.py +0 -256
  56. dataeval/utils/tensorflow/loss/__init__.py +0 -11
  57. dataeval-0.73.1.dist-info/RECORD +0 -73
  58. {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/LICENSE.txt +0 -0
  59. {dataeval-0.73.1.dist-info → dataeval-0.74.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,229 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import contextlib
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ from numpy.typing import ArrayLike, NDArray
10
+
11
+ from dataeval.interop import to_numpy
12
+
13
+ with contextlib.suppress(ImportError):
14
+ from matplotlib.figure import Figure
15
+
16
+
17
+ def get_counts(data: NDArray[np.int_], min_num_bins: int | None = None) -> NDArray[np.int_]:
18
+ """
19
+ Returns columnwise unique counts for discrete data.
20
+
21
+ Parameters
22
+ ----------
23
+ data : NDArray
24
+ Array containing integer values for metadata factors
25
+ min_num_bins : int | None, default None
26
+ Minimum number of bins for bincount, helps force consistency across runs
27
+
28
+ Returns
29
+ -------
30
+ NDArray[np.int_]
31
+ Bin counts per column of data.
32
+ """
33
+ max_value = data.max() + 1 if min_num_bins is None else min_num_bins
34
+ cnt_array = np.zeros((max_value, data.shape[1]), dtype=np.int_)
35
+ for idx in range(data.shape[1]):
36
+ cnt_array[:, idx] = np.bincount(data[:, idx], minlength=max_value)
37
+
38
+ return cnt_array
39
+
40
+
41
+ def heatmap(
42
+ data: ArrayLike,
43
+ row_labels: list[str] | ArrayLike,
44
+ col_labels: list[str] | ArrayLike,
45
+ xlabel: str = "",
46
+ ylabel: str = "",
47
+ cbarlabel: str = "",
48
+ ) -> Figure:
49
+ """
50
+ Plots a formatted heatmap
51
+
52
+ Parameters
53
+ ----------
54
+ data : NDArray
55
+ Array containing numerical values for factors to plot
56
+ row_labels : ArrayLike
57
+ List/Array containing the labels for rows in the histogram
58
+ col_labels : ArrayLike
59
+ List/Array containing the labels for columns in the histogram
60
+ xlabel : str, default ""
61
+ X-axis label
62
+ ylabel : str, default ""
63
+ Y-axis label
64
+ cbarlabel : str, default ""
65
+ Label for the colorbar
66
+
67
+ Returns
68
+ -------
69
+ matplotlib.figure.Figure
70
+ Formatted heatmap
71
+ """
72
+ import matplotlib.pyplot as plt
73
+ from matplotlib.ticker import FuncFormatter
74
+
75
+ np_data = to_numpy(data)
76
+ rows = row_labels if isinstance(row_labels, list) else to_numpy(row_labels)
77
+ cols = col_labels if isinstance(col_labels, list) else to_numpy(col_labels)
78
+
79
+ fig, ax = plt.subplots(figsize=(10, 10))
80
+
81
+ # Plot the heatmap
82
+ im = ax.imshow(np_data, vmin=0, vmax=1.0)
83
+
84
+ # Create colorbar
85
+ cbar = fig.colorbar(im, shrink=0.5)
86
+ cbar.set_ticks([0.0, 0.25, 0.5, 0.75, 1.0])
87
+ cbar.set_ticklabels(["0.0", "0.25", "0.5", "0.75", "1.0"])
88
+ cbar.set_label(cbarlabel, loc="center")
89
+
90
+ # Show all ticks and label them with the respective list entries.
91
+ ax.set_xticks(np.arange(np_data.shape[1]), labels=cols)
92
+ ax.set_yticks(np.arange(np_data.shape[0]), labels=rows)
93
+
94
+ ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
95
+ # Rotate the tick labels and set their alignment.
96
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
97
+
98
+ # Turn spines off and create white grid.
99
+ ax.spines[:].set_visible(False)
100
+
101
+ ax.set_xticks(np.arange(np_data.shape[1] + 1) - 0.5, minor=True)
102
+ ax.set_yticks(np.arange(np_data.shape[0] + 1) - 0.5, minor=True)
103
+ ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
104
+ ax.tick_params(which="minor", bottom=False, left=False)
105
+
106
+ if xlabel:
107
+ ax.set_xlabel(xlabel)
108
+ if ylabel:
109
+ ax.set_ylabel(ylabel)
110
+
111
+ valfmt = FuncFormatter(format_text)
112
+
113
+ # Normalize the threshold to the images color range.
114
+ threshold = im.norm(1.0) / 2.0
115
+
116
+ # Set default alignment to center, but allow it to be
117
+ # overwritten by textkw.
118
+ kw = {"horizontalalignment": "center", "verticalalignment": "center"}
119
+
120
+ # Loop over the data and create a `Text` for each "pixel".
121
+ # Change the text's color depending on the data.
122
+ textcolors = ("white", "black")
123
+ texts = []
124
+ for i in range(np_data.shape[0]):
125
+ for j in range(np_data.shape[1]):
126
+ kw.update(color=textcolors[int(im.norm(np_data[i, j]) > threshold)])
127
+ text = im.axes.text(j, i, valfmt(np_data[i, j], None), **kw) # type: ignore
128
+ texts.append(text)
129
+
130
+ fig.tight_layout()
131
+ return fig
132
+
133
+
134
+ # Function to define how the text is displayed in the heatmap
135
+ def format_text(*args: str) -> str:
136
+ """
137
+ Helper function to format text for heatmap()
138
+
139
+ Parameters
140
+ ----------
141
+ *args : tuple[str, str]
142
+ Text to be formatted. Second element is ignored, but is a
143
+ mandatory pass-through argument as per matplotlib.ticket.FuncFormatter
144
+
145
+ Returns
146
+ -------
147
+ str
148
+ Formatted text
149
+ """
150
+ x = args[0]
151
+ return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
152
+
153
+
154
+ def diversity_bar_plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
155
+ """
156
+ Plots a formatted bar plot
157
+
158
+ Parameters
159
+ ----------
160
+ labels : NDArray
161
+ Array containing the labels for each bar
162
+ bar_heights : NDArray
163
+ Array containing the values for each bar
164
+
165
+ Returns
166
+ -------
167
+ matplotlib.figure.Figure
168
+ Bar plot figure
169
+ """
170
+ import matplotlib.pyplot as plt
171
+
172
+ fig, ax = plt.subplots(figsize=(10, 10))
173
+
174
+ ax.bar(labels, bar_heights)
175
+ ax.set_xlabel("Factors")
176
+
177
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
178
+
179
+ fig.tight_layout()
180
+ return fig
181
+
182
+
183
+ def coverage_plot(images: NDArray[Any], num_images: int) -> Figure:
184
+ """
185
+ Creates a single plot of all of the provided images
186
+
187
+ Parameters
188
+ ----------
189
+ images : NDArray
190
+ Array containing only the desired images to plot
191
+
192
+ Returns
193
+ -------
194
+ matplotlib.figure.Figure
195
+ Plot of all provided images
196
+ """
197
+ import matplotlib.pyplot as plt
198
+
199
+ num_images = min(num_images, len(images))
200
+
201
+ if images.ndim == 4:
202
+ images = np.moveaxis(images, 1, -1)
203
+ elif images.ndim == 3:
204
+ images = np.repeat(images[:, :, :, np.newaxis], 3, axis=-1)
205
+ else:
206
+ raise ValueError(
207
+ f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
208
+ )
209
+
210
+ rows = int(np.ceil(num_images / 3))
211
+ fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
212
+
213
+ if rows == 1:
214
+ for j in range(3):
215
+ if j >= len(images):
216
+ continue
217
+ axs[j].imshow(images[j])
218
+ axs[j].axis("off")
219
+ else:
220
+ for i in range(rows):
221
+ for j in range(3):
222
+ i_j = i * 3 + j
223
+ if i_j >= len(images):
224
+ continue
225
+ axs[i, j].imshow(images[i_j])
226
+ axs[i, j].axis("off")
227
+
228
+ fig.tight_layout()
229
+ return fig
@@ -4,21 +4,22 @@ __all__ = ["ParityOutput", "parity", "label_parity"]
4
4
 
5
5
  import warnings
6
6
  from dataclasses import dataclass
7
- from typing import Any, Generic, Mapping, TypeVar
7
+ from typing import Any, Generic, TypeVar
8
8
 
9
9
  import numpy as np
10
10
  from numpy.typing import ArrayLike, NDArray
11
- from scipy.stats import chi2_contingency, chisquare
11
+ from scipy.stats import chisquare
12
+ from scipy.stats.contingency import chi2_contingency, crosstab
12
13
 
13
- from dataeval.interop import to_numpy
14
- from dataeval.metrics.bias.metadata import CLASS_LABEL, preprocess_metadata
15
- from dataeval.output import OutputMetadata, set_metadata
14
+ from dataeval.interop import as_numpy, to_numpy
15
+ from dataeval.metrics.bias.metadata_preprocessing import MetadataOutput
16
+ from dataeval.output import Output, set_metadata
16
17
 
17
18
  TData = TypeVar("TData", np.float64, NDArray[np.float64])
18
19
 
19
20
 
20
21
  @dataclass(frozen=True)
21
- class ParityOutput(Generic[TData], OutputMetadata):
22
+ class ParityOutput(Generic[TData], Output):
22
23
  """
23
24
  Output class for :func:`parity` and :func:`label_parity` :term:`bias<Bias>` metrics
24
25
 
@@ -37,97 +38,6 @@ class ParityOutput(Generic[TData], OutputMetadata):
37
38
  metadata_names: list[str] | None
38
39
 
39
40
 
40
- def digitize_factor_bins(continuous_values: NDArray[Any], bins: int, factor_name: str) -> NDArray[np.intp]:
41
- """
42
- Digitizes a list of values into a given number of bins.
43
-
44
- Parameters
45
- ----------
46
- continuous_values : NDArray
47
- The values to be digitized.
48
- bins : int
49
- The number of bins for the discrete values that continuous_values will be digitized into.
50
- factor_name : str
51
- The name of the factor to be digitized.
52
-
53
- Returns
54
- -------
55
- NDArray[np.intp]
56
- The digitized values
57
- """
58
-
59
- if not np.all([np.issubdtype(type(n), np.number) for n in continuous_values]):
60
- raise TypeError(
61
- f"Encountered a non-numeric value for factor {factor_name}, but the factor"
62
- " was specified to be continuous. Ensure all occurrences of this factor are numeric types,"
63
- f" or do not specify {factor_name} as a continuous factor."
64
- )
65
-
66
- _, bin_edges = np.histogram(continuous_values, bins=bins)
67
- bin_edges[-1] = np.inf
68
- bin_edges[0] = -np.inf
69
- return np.digitize(continuous_values, bin_edges)
70
-
71
-
72
- def format_discretize_factors(
73
- data: NDArray[Any],
74
- names: list[str],
75
- is_categorical: list[bool],
76
- continuous_factor_bincounts: Mapping[str, int] | None,
77
- ) -> dict[str, NDArray[Any]]:
78
- """
79
- Sets up the internal list of metadata factors.
80
-
81
- Parameters
82
- ----------
83
- data : NDArray
84
- The dataset factors, which are per-image attributes including class label and metadata.
85
- names : list[str]
86
- The class label
87
- continuous_factor_bincounts : Mapping[str, int] or None
88
- The factors in data_factors that have continuous values and the array of bin counts to
89
- discretize values into. All factors are treated as having discrete values unless they
90
- are specified as keys in this dictionary. Each element of this array must occur as a key
91
- in data_factors.
92
-
93
- Returns
94
- -------
95
- Dict[str, NDArray]
96
- - Intrinsic per-image metadata information with the formatting that input data_factors uses.
97
- Each key is a metadata factor, whose value is the discrete per-image factor values.
98
- """
99
-
100
- if continuous_factor_bincounts:
101
- invalid_keys = set(continuous_factor_bincounts.keys()) - set(names)
102
- if invalid_keys:
103
- raise KeyError(
104
- f"The continuous factor(s) {invalid_keys} do not exist in data_factors. Delete these "
105
- "keys from `continuous_factor_names` or add corresponding entries to `data_factors`."
106
- )
107
-
108
- warn = []
109
- metadata_factors = {}
110
- for i, name in enumerate(names):
111
- if name == CLASS_LABEL:
112
- continue
113
- if continuous_factor_bincounts and name in continuous_factor_bincounts:
114
- metadata_factors[name] = digitize_factor_bins(data[:, i], continuous_factor_bincounts[name], name)
115
- elif not is_categorical[i]:
116
- warn.append(name)
117
- metadata_factors[name] = data[:, i]
118
- else:
119
- metadata_factors[name] = data[:, i]
120
-
121
- if warn:
122
- warnings.warn(
123
- f"The following factors appear to be continuous but did not have the desired number of bins specified: \n\
124
- {warn}",
125
- UserWarning,
126
- )
127
-
128
- return metadata_factors
129
-
130
-
131
41
  def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[Any]) -> NDArray[Any]:
132
42
  """
133
43
  Normalize the expected label distribution to match the total number of labels in the observed distribution.
@@ -206,7 +116,7 @@ def validate_dist(label_dist: NDArray[Any], label_name: str) -> None:
206
116
  )
207
117
 
208
118
 
209
- @set_metadata()
119
+ @set_metadata
210
120
  def label_parity(
211
121
  expected_labels: ArrayLike,
212
122
  observed_labels: ArrayLike,
@@ -294,32 +204,20 @@ def label_parity(
294
204
  return ParityOutput(cs, p, None)
295
205
 
296
206
 
297
- @set_metadata()
298
- def parity(
299
- class_labels: ArrayLike,
300
- metadata: Mapping[str, ArrayLike],
301
- continuous_factor_bincounts: Mapping[str, int] | None = None,
302
- ) -> ParityOutput[NDArray[np.float64]]:
207
+ @set_metadata
208
+ def parity(metadata: MetadataOutput) -> ParityOutput[NDArray[np.float64]]:
303
209
  """
304
- Calculate chi-square statistics to assess the relationship between multiple factors
210
+ Calculate chi-square statistics to assess the linear relationship between multiple factors
305
211
  and class labels.
306
212
 
307
213
  This function computes the chi-square statistic for each metadata factor to determine if there is
308
- a significant relationship between the factor values and class labels. The function handles both categorical
309
- and discretized continuous factors.
214
+ a significant relationship between the factor values and class labels. The chi-square statistic is
215
+ only valid for linear relationships. If non-linear relationships exist, use `balance`.
310
216
 
311
217
  Parameters
312
218
  ----------
313
- class_labels : ArrayLike
314
- List of class labels for each image
315
- metadata : Mapping[str, ArrayLike]
316
- The dataset factors, which are per-image metadata attributes.
317
- Each key of dataset_factors is a factor, whose value is the per-image factor values.
318
- continuous_factor_bincounts : Mapping[str, int] or None, default None
319
- A dictionary specifying the number of bins for discretizing the continuous factors.
320
- The keys should correspond to the names of continuous factors in `metadata`,
321
- and the values should be the number of bins to use for discretization.
322
- If not provided, no discretization is applied.
219
+ metadata : MetadataOutput
220
+ Output after running `metadata_preprocessing`
323
221
 
324
222
  Returns
325
223
  -------
@@ -333,74 +231,66 @@ def parity(
333
231
  Warning
334
232
  If any cell in the contingency matrix has a value between 0 and 5, a warning is issued because this can
335
233
  lead to inaccurate chi-square calculations. It is recommended to ensure that each label co-occurs with
336
- factor values either 0 times or at least 5 times. Alternatively, continuous-valued factors can be digitized
337
- into fewer bins.
234
+ factor values either 0 times or at least 5 times.
338
235
 
339
236
  Note
340
237
  ----
341
- - Each key of the ``continuous_factor_bincounts`` dictionary must occur as a key in data_factors.
342
238
  - A high score with a low p-value suggests that a metadata factor is strongly correlated with a class label.
343
239
  - The function creates a contingency matrix for each factor, where each entry represents the frequency of a
344
240
  specific factor value co-occurring with a particular class label.
345
241
  - Rows containing only zeros in the contingency matrix are removed before performing the chi-square test
346
242
  to prevent errors in the calculation.
347
243
 
244
+ See Also
245
+ --------
246
+ balance
247
+
348
248
  Examples
349
249
  --------
350
250
  Randomly creating some "continuous" and categorical variables using ``np.random.default_rng``
351
251
 
352
252
  >>> labels = np_random_gen.choice([0, 1, 2], (100))
353
- >>> metadata = {
354
- ... "age": np_random_gen.choice([25, 30, 35, 45], (100)),
355
- ... "income": np_random_gen.choice([50000, 65000, 80000], (100)),
356
- ... "gender": np_random_gen.choice(["M", "F"], (100)),
357
- ... }
253
+ >>> metadata_dict = [
254
+ ... {
255
+ ... "age": list(np_random_gen.choice([25, 30, 35, 45], (100))),
256
+ ... "income": list(np_random_gen.choice([50000, 65000, 80000], (100))),
257
+ ... "gender": list(np_random_gen.choice(["M", "F"], (100))),
258
+ ... }
259
+ ... ]
358
260
  >>> continuous_factor_bincounts = {"age": 4, "income": 3}
359
- >>> parity(labels, metadata, continuous_factor_bincounts)
261
+ >>> metadata = metadata_preprocessing(metadata_dict, labels, continuous_factor_bincounts)
262
+ >>> parity(metadata)
360
263
  ParityOutput(score=array([7.35731943, 5.46711299, 0.51506212]), p_value=array([0.28906231, 0.24263543, 0.77295762]), metadata_names=['age', 'income', 'gender'])
361
264
  """ # noqa: E501
362
- if len(np.shape(class_labels)) > 1:
363
- raise ValueError(
364
- f"Got class labels with {len(np.shape(class_labels))}-dimensional",
365
- f" shape {np.shape(class_labels)}, but expected a 1-dimensional array.",
366
- )
367
-
368
- data, names, is_categorical, _ = preprocess_metadata(class_labels, metadata)
369
-
370
- factors = format_discretize_factors(data, names, is_categorical, continuous_factor_bincounts)
371
-
372
- # unique class labels
373
- class_idx = names.index(CLASS_LABEL)
374
- u_cls = np.unique(data[:, class_idx])
375
-
376
- chi_scores = np.zeros(len(factors))
377
- p_values = np.zeros(len(factors))
265
+ chi_scores = np.zeros(metadata.discrete_data.shape[1])
266
+ p_values = np.zeros_like(chi_scores)
378
267
  not_enough_data = {}
379
- for i, (current_factor_name, factor_values) in enumerate(factors.items()):
380
- unique_factor_values = np.unique(factor_values)
381
- contingency_matrix = np.zeros((len(unique_factor_values), u_cls.size))
268
+ for i, col_data in enumerate(metadata.discrete_data.T):
382
269
  # Builds a contingency matrix where entry at index (r,c) represents
383
270
  # the frequency of current_factor_name achieving value unique_factor_values[r]
384
271
  # at a data point with class c.
385
-
386
- # TODO: Vectorize this nested for loop
387
- for fi, factor_value in enumerate(unique_factor_values):
388
- for label in u_cls:
389
- with_both = np.bitwise_and((data[:, class_idx] == label), factor_values == factor_value)
390
- contingency_matrix[fi, label] = np.sum(with_both)
391
- if 0 < contingency_matrix[fi, label] < 5:
392
- if current_factor_name not in not_enough_data:
393
- not_enough_data[current_factor_name] = {}
394
- if factor_value not in not_enough_data[current_factor_name]:
395
- not_enough_data[current_factor_name][factor_value] = []
396
- not_enough_data[current_factor_name][factor_value].append(
397
- (label, int(contingency_matrix[fi, label]))
398
- )
272
+ results = crosstab(col_data, metadata.class_labels)
273
+ contingency_matrix = as_numpy(results.count) # type: ignore
274
+
275
+ # Determines if any frequencies are too low
276
+ counts = np.nonzero(contingency_matrix < 5)
277
+ unique_factor_values = np.unique(col_data)
278
+ current_factor_name = metadata.discrete_factor_names[i]
279
+ for int_factor, int_class in zip(counts[0], counts[1]):
280
+ if contingency_matrix[int_factor, int_class] > 0:
281
+ factor_category = unique_factor_values[int_factor]
282
+ if current_factor_name not in not_enough_data:
283
+ not_enough_data[current_factor_name] = {}
284
+ if factor_category not in not_enough_data[current_factor_name]:
285
+ not_enough_data[current_factor_name][factor_category] = []
286
+ not_enough_data[current_factor_name][factor_category].append(
287
+ (metadata.class_names[int_class], int(contingency_matrix[int_factor, int_class]))
288
+ )
399
289
 
400
290
  # This deletes rows containing only zeros,
401
291
  # because scipy.stats.chi2_contingency fails when there are rows containing only zeros.
402
292
  rowsums = np.sum(contingency_matrix, axis=1)
403
- rowmask = np.where(rowsums)
293
+ rowmask = np.nonzero(rowsums)[0]
404
294
  contingency_matrix = contingency_matrix[rowmask]
405
295
 
406
296
  chi2, p, _, _ = chi2_contingency(contingency_matrix)
@@ -428,4 +318,4 @@ def parity(
428
318
  UserWarning,
429
319
  )
430
320
 
431
- return ParityOutput(chi_scores, p_values, list(metadata.keys()))
321
+ return ParityOutput(chi_scores, p_values, metadata.discrete_factor_names)
@@ -20,12 +20,12 @@ from scipy.sparse import coo_matrix
20
20
  from scipy.stats import mode
21
21
 
22
22
  from dataeval.interop import as_numpy
23
- from dataeval.output import OutputMetadata, set_metadata
23
+ from dataeval.output import Output, set_metadata
24
24
  from dataeval.utils.shared import compute_neighbors, get_classes_counts, get_method, minimum_spanning_tree
25
25
 
26
26
 
27
27
  @dataclass(frozen=True)
28
- class BEROutput(OutputMetadata):
28
+ class BEROutput(Output):
29
29
  """
30
30
  Output class for :func:`ber` estimator metric
31
31
 
@@ -114,7 +114,7 @@ def knn_lowerbound(value: float, classes: int, k: int) -> float:
114
114
  return ((classes - 1) / classes) * (1 - np.sqrt(max(0, 1 - ((classes / (classes - 1)) * value))))
115
115
 
116
116
 
117
- @set_metadata()
117
+ @set_metadata
118
118
  def ber(images: ArrayLike, labels: ArrayLike, k: int = 1, method: Literal["KNN", "MST"] = "KNN") -> BEROutput:
119
119
  """
120
120
  An estimator for Multi-class :term:`Bayes error rate<Bayes Error Rate (BER)>` using FR or KNN test statistic basis
@@ -14,12 +14,12 @@ import numpy as np
14
14
  from numpy.typing import ArrayLike, NDArray
15
15
 
16
16
  from dataeval.interop import as_numpy
17
- from dataeval.output import OutputMetadata, set_metadata
17
+ from dataeval.output import Output, set_metadata
18
18
  from dataeval.utils.shared import compute_neighbors, get_method, minimum_spanning_tree
19
19
 
20
20
 
21
21
  @dataclass(frozen=True)
22
- class DivergenceOutput(OutputMetadata):
22
+ class DivergenceOutput(Output):
23
23
  """
24
24
  Output class for :func:`divergence` estimator metric
25
25
 
@@ -78,7 +78,7 @@ def divergence_fnn(data: NDArray[np.float64], labels: NDArray[np.int_]) -> int:
78
78
  return errors
79
79
 
80
80
 
81
- @set_metadata()
81
+ @set_metadata
82
82
  def divergence(data_a: ArrayLike, data_b: ArrayLike, method: Literal["FNN", "MST"] = "FNN") -> DivergenceOutput:
83
83
  """
84
84
  Calculates the :term`divergence` and any errors between the datasets
@@ -14,11 +14,11 @@ from numpy.typing import ArrayLike
14
14
  from sklearn.metrics import average_precision_score
15
15
 
16
16
  from dataeval.interop import as_numpy
17
- from dataeval.output import OutputMetadata, set_metadata
17
+ from dataeval.output import Output, set_metadata
18
18
 
19
19
 
20
20
  @dataclass(frozen=True)
21
- class UAPOutput(OutputMetadata):
21
+ class UAPOutput(Output):
22
22
  """
23
23
  Output class for :func:`uap` estimator metric
24
24
 
@@ -31,7 +31,7 @@ class UAPOutput(OutputMetadata):
31
31
  uap: float
32
32
 
33
33
 
34
- @set_metadata()
34
+ @set_metadata
35
35
  def uap(labels: ArrayLike, scores: ArrayLike) -> UAPOutput:
36
36
  """
37
37
  FR Test Statistic based estimate of the empirical mean precision for
@@ -15,7 +15,7 @@ import tqdm
15
15
  from numpy.typing import ArrayLike, NDArray
16
16
 
17
17
  from dataeval.interop import to_numpy_iter
18
- from dataeval.output import OutputMetadata
18
+ from dataeval.output import Output
19
19
  from dataeval.utils.image import normalize_image_shape, rescale
20
20
 
21
21
  DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
@@ -65,7 +65,7 @@ class SourceIndex(NamedTuple):
65
65
 
66
66
 
67
67
  @dataclass(frozen=True)
68
- class BaseStatsOutput(OutputMetadata):
68
+ class BaseStatsOutput(Output):
69
69
  """
70
70
  Attributes
71
71
  ----------
@@ -96,7 +96,7 @@ def calculate_ratios(key: str, box_stats: BaseStatsOutput, img_stats: BaseStatsO
96
96
  return out_stats
97
97
 
98
98
 
99
- @set_metadata()
99
+ @set_metadata
100
100
  def boxratiostats(
101
101
  boxstats: TStatOutput,
102
102
  imgstats: TStatOutput,
@@ -15,11 +15,11 @@ from dataeval.metrics.stats.dimensionstats import (
15
15
  from dataeval.metrics.stats.labelstats import LabelStatsOutput, labelstats
16
16
  from dataeval.metrics.stats.pixelstats import PixelStatsOutput, PixelStatsProcessor
17
17
  from dataeval.metrics.stats.visualstats import VisualStatsOutput, VisualStatsProcessor
18
- from dataeval.output import OutputMetadata, set_metadata
18
+ from dataeval.output import Output, set_metadata
19
19
 
20
20
 
21
21
  @dataclass(frozen=True)
22
- class DatasetStatsOutput(OutputMetadata):
22
+ class DatasetStatsOutput(Output):
23
23
  """
24
24
  Output class for :func:`datasetstats` stats metric
25
25
 
@@ -41,7 +41,7 @@ class DatasetStatsOutput(OutputMetadata):
41
41
  visualstats: VisualStatsOutput
42
42
  labelstats: LabelStatsOutput | None = None
43
43
 
44
- def _outputs(self) -> list[OutputMetadata]:
44
+ def _outputs(self) -> list[Output]:
45
45
  return [s for s in (self.dimensionstats, self.pixelstats, self.visualstats, self.labelstats) if s is not None]
46
46
 
47
47
  def dict(self) -> dict[str, Any]:
@@ -54,7 +54,7 @@ class DatasetStatsOutput(OutputMetadata):
54
54
 
55
55
 
56
56
  @dataclass(frozen=True)
57
- class ChannelStatsOutput(OutputMetadata):
57
+ class ChannelStatsOutput(Output):
58
58
  """
59
59
  Output class for :func:`channelstats` stats metric
60
60
 
@@ -84,7 +84,7 @@ class ChannelStatsOutput(OutputMetadata):
84
84
  raise ValueError("All StatsOutput classes must contain the same number of image sources.")
85
85
 
86
86
 
87
- @set_metadata()
87
+ @set_metadata
88
88
  def datasetstats(
89
89
  images: Iterable[ArrayLike],
90
90
  bboxes: Iterable[ArrayLike] | None = None,
@@ -131,7 +131,7 @@ def datasetstats(
131
131
  return DatasetStatsOutput(*outputs, labelstats=labelstats(labels) if labels else None) # type: ignore
132
132
 
133
133
 
134
- @set_metadata()
134
+ @set_metadata
135
135
  def channelstats(
136
136
  images: Iterable[ArrayLike],
137
137
  bboxes: Iterable[ArrayLike] | None = None,
@@ -73,7 +73,7 @@ class DimensionStatsProcessor(StatsProcessor[DimensionStatsOutput]):
73
73
  }
74
74
 
75
75
 
76
- @set_metadata()
76
+ @set_metadata
77
77
  def dimensionstats(
78
78
  images: Iterable[ArrayLike],
79
79
  bboxes: Iterable[ArrayLike] | None = None,