dataeval 0.73.0__py3-none-any.whl → 0.74.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/detectors/__init__.py +1 -1
  3. dataeval/detectors/drift/__init__.py +1 -1
  4. dataeval/detectors/drift/base.py +2 -2
  5. dataeval/detectors/drift/torch.py +1 -101
  6. dataeval/detectors/linters/clusterer.py +1 -1
  7. dataeval/detectors/ood/__init__.py +11 -4
  8. dataeval/detectors/ood/ae.py +2 -1
  9. dataeval/detectors/ood/ae_torch.py +70 -0
  10. dataeval/detectors/ood/aegmm.py +4 -3
  11. dataeval/detectors/ood/base.py +58 -108
  12. dataeval/detectors/ood/base_tf.py +109 -0
  13. dataeval/detectors/ood/base_torch.py +109 -0
  14. dataeval/detectors/ood/llr.py +2 -2
  15. dataeval/detectors/ood/metadata_ks_compare.py +53 -14
  16. dataeval/detectors/ood/vae.py +3 -2
  17. dataeval/detectors/ood/vaegmm.py +5 -4
  18. dataeval/metrics/bias/__init__.py +3 -0
  19. dataeval/metrics/bias/balance.py +77 -64
  20. dataeval/metrics/bias/coverage.py +12 -12
  21. dataeval/metrics/bias/diversity.py +74 -114
  22. dataeval/metrics/bias/metadata_preprocessing.py +285 -0
  23. dataeval/metrics/bias/metadata_utils.py +229 -0
  24. dataeval/metrics/bias/parity.py +54 -158
  25. dataeval/utils/__init__.py +2 -2
  26. dataeval/utils/gmm.py +26 -0
  27. dataeval/utils/metadata.py +29 -9
  28. dataeval/utils/shared.py +1 -1
  29. dataeval/utils/split_dataset.py +12 -6
  30. dataeval/utils/tensorflow/_internal/gmm.py +4 -24
  31. dataeval/utils/torch/datasets.py +2 -2
  32. dataeval/utils/torch/gmm.py +98 -0
  33. dataeval/utils/torch/models.py +192 -0
  34. dataeval/utils/torch/trainer.py +84 -5
  35. dataeval/utils/torch/utils.py +107 -1
  36. dataeval/workflows/__init__.py +1 -1
  37. {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/METADATA +1 -2
  38. {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/RECORD +40 -34
  39. dataeval/metrics/bias/metadata.py +0 -358
  40. {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/LICENSE.txt +0 -0
  41. {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,229 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import contextlib
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ from numpy.typing import ArrayLike, NDArray
10
+
11
+ from dataeval.interop import to_numpy
12
+
13
+ with contextlib.suppress(ImportError):
14
+ from matplotlib.figure import Figure
15
+
16
+
17
+ def get_counts(data: NDArray[np.int_], min_num_bins: int | None = None) -> NDArray[np.int_]:
18
+ """
19
+ Returns columnwise unique counts for discrete data.
20
+
21
+ Parameters
22
+ ----------
23
+ data : NDArray
24
+ Array containing integer values for metadata factors
25
+ min_num_bins : int | None, default None
26
+ Minimum number of bins for bincount, helps force consistency across runs
27
+
28
+ Returns
29
+ -------
30
+ NDArray[np.int_]
31
+ Bin counts per column of data.
32
+ """
33
+ max_value = data.max() + 1 if min_num_bins is None else min_num_bins
34
+ cnt_array = np.zeros((max_value, data.shape[1]), dtype=np.int_)
35
+ for idx in range(data.shape[1]):
36
+ cnt_array[:, idx] = np.bincount(data[:, idx], minlength=max_value)
37
+
38
+ return cnt_array
39
+
40
+
41
+ def heatmap(
42
+ data: ArrayLike,
43
+ row_labels: list[str] | ArrayLike,
44
+ col_labels: list[str] | ArrayLike,
45
+ xlabel: str = "",
46
+ ylabel: str = "",
47
+ cbarlabel: str = "",
48
+ ) -> Figure:
49
+ """
50
+ Plots a formatted heatmap
51
+
52
+ Parameters
53
+ ----------
54
+ data : NDArray
55
+ Array containing numerical values for factors to plot
56
+ row_labels : ArrayLike
57
+ List/Array containing the labels for rows in the histogram
58
+ col_labels : ArrayLike
59
+ List/Array containing the labels for columns in the histogram
60
+ xlabel : str, default ""
61
+ X-axis label
62
+ ylabel : str, default ""
63
+ Y-axis label
64
+ cbarlabel : str, default ""
65
+ Label for the colorbar
66
+
67
+ Returns
68
+ -------
69
+ matplotlib.figure.Figure
70
+ Formatted heatmap
71
+ """
72
+ import matplotlib.pyplot as plt
73
+ from matplotlib.ticker import FuncFormatter
74
+
75
+ np_data = to_numpy(data)
76
+ rows = row_labels if isinstance(row_labels, list) else to_numpy(row_labels)
77
+ cols = col_labels if isinstance(col_labels, list) else to_numpy(col_labels)
78
+
79
+ fig, ax = plt.subplots(figsize=(10, 10))
80
+
81
+ # Plot the heatmap
82
+ im = ax.imshow(np_data, vmin=0, vmax=1.0)
83
+
84
+ # Create colorbar
85
+ cbar = fig.colorbar(im, shrink=0.5)
86
+ cbar.set_ticks([0.0, 0.25, 0.5, 0.75, 1.0])
87
+ cbar.set_ticklabels(["0.0", "0.25", "0.5", "0.75", "1.0"])
88
+ cbar.set_label(cbarlabel, loc="center")
89
+
90
+ # Show all ticks and label them with the respective list entries.
91
+ ax.set_xticks(np.arange(np_data.shape[1]), labels=cols)
92
+ ax.set_yticks(np.arange(np_data.shape[0]), labels=rows)
93
+
94
+ ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
95
+ # Rotate the tick labels and set their alignment.
96
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
97
+
98
+ # Turn spines off and create white grid.
99
+ ax.spines[:].set_visible(False)
100
+
101
+ ax.set_xticks(np.arange(np_data.shape[1] + 1) - 0.5, minor=True)
102
+ ax.set_yticks(np.arange(np_data.shape[0] + 1) - 0.5, minor=True)
103
+ ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
104
+ ax.tick_params(which="minor", bottom=False, left=False)
105
+
106
+ if xlabel:
107
+ ax.set_xlabel(xlabel)
108
+ if ylabel:
109
+ ax.set_ylabel(ylabel)
110
+
111
+ valfmt = FuncFormatter(format_text)
112
+
113
+ # Normalize the threshold to the images color range.
114
+ threshold = im.norm(1.0) / 2.0
115
+
116
+ # Set default alignment to center, but allow it to be
117
+ # overwritten by textkw.
118
+ kw = {"horizontalalignment": "center", "verticalalignment": "center"}
119
+
120
+ # Loop over the data and create a `Text` for each "pixel".
121
+ # Change the text's color depending on the data.
122
+ textcolors = ("white", "black")
123
+ texts = []
124
+ for i in range(np_data.shape[0]):
125
+ for j in range(np_data.shape[1]):
126
+ kw.update(color=textcolors[int(im.norm(np_data[i, j]) > threshold)])
127
+ text = im.axes.text(j, i, valfmt(np_data[i, j], None), **kw) # type: ignore
128
+ texts.append(text)
129
+
130
+ fig.tight_layout()
131
+ return fig
132
+
133
+
134
+ # Function to define how the text is displayed in the heatmap
135
+ def format_text(*args: str) -> str:
136
+ """
137
+ Helper function to format text for heatmap()
138
+
139
+ Parameters
140
+ ----------
141
+ *args : tuple[str, str]
142
+ Text to be formatted. Second element is ignored, but is a
143
+ mandatory pass-through argument as per matplotlib.ticket.FuncFormatter
144
+
145
+ Returns
146
+ -------
147
+ str
148
+ Formatted text
149
+ """
150
+ x = args[0]
151
+ return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
152
+
153
+
154
+ def diversity_bar_plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
155
+ """
156
+ Plots a formatted bar plot
157
+
158
+ Parameters
159
+ ----------
160
+ labels : NDArray
161
+ Array containing the labels for each bar
162
+ bar_heights : NDArray
163
+ Array containing the values for each bar
164
+
165
+ Returns
166
+ -------
167
+ matplotlib.figure.Figure
168
+ Bar plot figure
169
+ """
170
+ import matplotlib.pyplot as plt
171
+
172
+ fig, ax = plt.subplots(figsize=(10, 10))
173
+
174
+ ax.bar(labels, bar_heights)
175
+ ax.set_xlabel("Factors")
176
+
177
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
178
+
179
+ fig.tight_layout()
180
+ return fig
181
+
182
+
183
+ def coverage_plot(images: NDArray[Any], num_images: int) -> Figure:
184
+ """
185
+ Creates a single plot of all of the provided images
186
+
187
+ Parameters
188
+ ----------
189
+ images : NDArray
190
+ Array containing only the desired images to plot
191
+
192
+ Returns
193
+ -------
194
+ matplotlib.figure.Figure
195
+ Plot of all provided images
196
+ """
197
+ import matplotlib.pyplot as plt
198
+
199
+ num_images = min(num_images, len(images))
200
+
201
+ if images.ndim == 4:
202
+ images = np.moveaxis(images, 1, -1)
203
+ elif images.ndim == 3:
204
+ images = np.repeat(images[:, :, :, np.newaxis], 3, axis=-1)
205
+ else:
206
+ raise ValueError(
207
+ f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
208
+ )
209
+
210
+ rows = int(np.ceil(num_images / 3))
211
+ fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
212
+
213
+ if rows == 1:
214
+ for j in range(3):
215
+ if j >= len(images):
216
+ continue
217
+ axs[j].imshow(images[j])
218
+ axs[j].axis("off")
219
+ else:
220
+ for i in range(rows):
221
+ for j in range(3):
222
+ i_j = i * 3 + j
223
+ if i_j >= len(images):
224
+ continue
225
+ axs[i, j].imshow(images[i_j])
226
+ axs[i, j].axis("off")
227
+
228
+ fig.tight_layout()
229
+ return fig
@@ -4,14 +4,15 @@ __all__ = ["ParityOutput", "parity", "label_parity"]
4
4
 
5
5
  import warnings
6
6
  from dataclasses import dataclass
7
- from typing import Any, Generic, Mapping, TypeVar
7
+ from typing import Any, Generic, TypeVar
8
8
 
9
9
  import numpy as np
10
10
  from numpy.typing import ArrayLike, NDArray
11
- from scipy.stats import chi2_contingency, chisquare
11
+ from scipy.stats import chisquare
12
+ from scipy.stats.contingency import chi2_contingency, crosstab
12
13
 
13
- from dataeval.interop import to_numpy
14
- from dataeval.metrics.bias.metadata import CLASS_LABEL, preprocess_metadata
14
+ from dataeval.interop import as_numpy, to_numpy
15
+ from dataeval.metrics.bias.metadata_preprocessing import MetadataOutput
15
16
  from dataeval.output import OutputMetadata, set_metadata
16
17
 
17
18
  TData = TypeVar("TData", np.float64, NDArray[np.float64])
@@ -28,7 +29,7 @@ class ParityOutput(Generic[TData], OutputMetadata):
28
29
  chi-squared score(s) of the test
29
30
  p_value : np.float64 | NDArray[np.float64]
30
31
  p-value(s) of the test
31
- metadata_names: list[str] | None
32
+ metadata_names : list[str] | None
32
33
  Names of each metadata factor
33
34
  """
34
35
 
@@ -37,92 +38,6 @@ class ParityOutput(Generic[TData], OutputMetadata):
37
38
  metadata_names: list[str] | None
38
39
 
39
40
 
40
- def digitize_factor_bins(continuous_values: NDArray[Any], bins: int, factor_name: str) -> NDArray[np.intp]:
41
- """
42
- Digitizes a list of values into a given number of bins.
43
-
44
- Parameters
45
- ----------
46
- continuous_values: NDArray
47
- The values to be digitized.
48
- bins: int
49
- The number of bins for the discrete values that continuous_values will be digitized into.
50
- factor_name: str
51
- The name of the factor to be digitized.
52
-
53
- Returns
54
- -------
55
- NDArray
56
- The digitized values
57
- """
58
-
59
- if not np.all([np.issubdtype(type(n), np.number) for n in continuous_values]):
60
- raise TypeError(
61
- f"Encountered a non-numeric value for factor {factor_name}, but the factor"
62
- " was specified to be continuous. Ensure all occurrences of this factor are numeric types,"
63
- f" or do not specify {factor_name} as a continuous factor."
64
- )
65
-
66
- _, bin_edges = np.histogram(continuous_values, bins=bins)
67
- bin_edges[-1] = np.inf
68
- bin_edges[0] = -np.inf
69
- return np.digitize(continuous_values, bin_edges)
70
-
71
-
72
- def format_discretize_factors(
73
- data: NDArray[Any], names: list[str], is_categorical: list[bool], continuous_factor_bincounts: Mapping[str, int]
74
- ) -> dict[str, NDArray[Any]]:
75
- """
76
- Sets up the internal list of metadata factors.
77
-
78
- Parameters
79
- ----------
80
- data_factors: Dict[str, NDArray]
81
- The dataset factors, which are per-image attributes including class label and metadata.
82
- Each key of dataset_factors is a factor, whose value is the per-image factor values.
83
- continuous_factor_bincounts : Dict[str, int]
84
- The factors in data_factors that have continuous values and the array of bin counts to
85
- discretize values into. All factors are treated as having discrete values unless they
86
- are specified as keys in this dictionary. Each element of this array must occur as a key
87
- in data_factors.
88
-
89
- Returns
90
- -------
91
- Dict[str, NDArray]
92
- - Intrinsic per-image metadata information with the formatting that input data_factors uses.
93
- Each key is a metadata factor, whose value is the discrete per-image factor values.
94
- """
95
-
96
- invalid_keys = set(continuous_factor_bincounts.keys()) - set(names)
97
- if invalid_keys:
98
- raise KeyError(
99
- f"The continuous factor(s) {invalid_keys} do not exist in data_factors. Delete these "
100
- "keys from `continuous_factor_names` or add corresponding entries to `data_factors`."
101
- )
102
-
103
- warn = []
104
- metadata_factors = {}
105
- for i, name in enumerate(names):
106
- if name == CLASS_LABEL:
107
- continue
108
- if name in continuous_factor_bincounts:
109
- metadata_factors[name] = digitize_factor_bins(data[:, i], continuous_factor_bincounts[name], name)
110
- elif not is_categorical[i]:
111
- warn.append(name)
112
- metadata_factors[name] = data[:, i]
113
- else:
114
- metadata_factors[name] = data[:, i]
115
-
116
- if warn:
117
- warnings.warn(
118
- f"The following factors appear to be continuous but did not have the desired number of bins specified: \n\
119
- {warn}",
120
- UserWarning,
121
- )
122
-
123
- return metadata_factors
124
-
125
-
126
41
  def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[Any]) -> NDArray[Any]:
127
42
  """
128
43
  Normalize the expected label distribution to match the total number of labels in the observed distribution.
@@ -132,14 +47,14 @@ def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[
132
47
 
133
48
  Parameters
134
49
  ----------
135
- expected_dist : np.ndarray
50
+ expected_dist : NDArray
136
51
  The expected label distribution. This array represents the anticipated distribution of labels.
137
- observed_dist : np.ndarray
52
+ observed_dist : NDArray
138
53
  The observed label distribution. This array represents the actual distribution of labels in the dataset.
139
54
 
140
55
  Returns
141
56
  -------
142
- np.ndarray
57
+ NDArray
143
58
  The normalized expected distribution, scaled to have the same sum as the observed distribution.
144
59
 
145
60
  Raises
@@ -179,6 +94,8 @@ def validate_dist(label_dist: NDArray[Any], label_name: str) -> None:
179
94
  ----------
180
95
  label_dist : NDArray
181
96
  Array representing label distributions
97
+ label_name : str
98
+ String representing label name
182
99
 
183
100
  Raises
184
101
  ------
@@ -219,7 +136,7 @@ def label_parity(
219
136
  List of class labels in the expected dataset
220
137
  observed_labels : ArrayLike
221
138
  List of class labels in the observed dataset
222
- num_classes : int | None, default None
139
+ num_classes : int or None, default None
223
140
  The number of unique classes in the datasets. If not provided, the function will infer it
224
141
  from the set of unique labels in expected_labels and observed_labels
225
142
 
@@ -288,31 +205,19 @@ def label_parity(
288
205
 
289
206
 
290
207
  @set_metadata()
291
- def parity(
292
- class_labels: ArrayLike,
293
- metadata: Mapping[str, ArrayLike],
294
- continuous_factor_bincounts: Mapping[str, int] | None = None,
295
- ) -> ParityOutput[NDArray[np.float64]]:
208
+ def parity(metadata: MetadataOutput) -> ParityOutput[NDArray[np.float64]]:
296
209
  """
297
- Calculate chi-square statistics to assess the relationship between multiple factors
210
+ Calculate chi-square statistics to assess the linear relationship between multiple factors
298
211
  and class labels.
299
212
 
300
213
  This function computes the chi-square statistic for each metadata factor to determine if there is
301
- a significant relationship between the factor values and class labels. The function handles both categorical
302
- and discretized continuous factors.
214
+ a significant relationship between the factor values and class labels. The chi-square statistic is
215
+ only valid for linear relationships. If non-linear relationships exist, use `balance`.
303
216
 
304
217
  Parameters
305
218
  ----------
306
- class_labels: ArrayLike
307
- List of class labels for each image
308
- metadata: Mapping[str, ArrayLike]
309
- The dataset factors, which are per-image metadata attributes.
310
- Each key of dataset_factors is a factor, whose value is the per-image factor values.
311
- continuous_factor_bincounts : Mapping[str, int] | None, default None
312
- A dictionary specifying the number of bins for discretizing the continuous factors.
313
- The keys should correspond to the names of continuous factors in `metadata`,
314
- and the values should be the number of bins to use for discretization.
315
- If not provided, no discretization is applied.
219
+ metadata : MetadataOutput
220
+ Output after running `metadata_preprocessing`
316
221
 
317
222
  Returns
318
223
  -------
@@ -326,75 +231,66 @@ def parity(
326
231
  Warning
327
232
  If any cell in the contingency matrix has a value between 0 and 5, a warning is issued because this can
328
233
  lead to inaccurate chi-square calculations. It is recommended to ensure that each label co-occurs with
329
- factor values either 0 times or at least 5 times. Alternatively, continuous-valued factors can be digitized
330
- into fewer bins.
234
+ factor values either 0 times or at least 5 times.
331
235
 
332
236
  Note
333
237
  ----
334
- - Each key of the ``continuous_factor_bincounts`` dictionary must occur as a key in data_factors.
335
238
  - A high score with a low p-value suggests that a metadata factor is strongly correlated with a class label.
336
239
  - The function creates a contingency matrix for each factor, where each entry represents the frequency of a
337
240
  specific factor value co-occurring with a particular class label.
338
241
  - Rows containing only zeros in the contingency matrix are removed before performing the chi-square test
339
242
  to prevent errors in the calculation.
340
243
 
244
+ See Also
245
+ --------
246
+ balance
247
+
341
248
  Examples
342
249
  --------
343
250
  Randomly creating some "continuous" and categorical variables using ``np.random.default_rng``
344
251
 
345
252
  >>> labels = np_random_gen.choice([0, 1, 2], (100))
346
- >>> metadata = {
347
- ... "age": np_random_gen.choice([25, 30, 35, 45], (100)),
348
- ... "income": np_random_gen.choice([50000, 65000, 80000], (100)),
349
- ... "gender": np_random_gen.choice(["M", "F"], (100)),
350
- ... }
253
+ >>> metadata_dict = [
254
+ ... {
255
+ ... "age": list(np_random_gen.choice([25, 30, 35, 45], (100))),
256
+ ... "income": list(np_random_gen.choice([50000, 65000, 80000], (100))),
257
+ ... "gender": list(np_random_gen.choice(["M", "F"], (100))),
258
+ ... }
259
+ ... ]
351
260
  >>> continuous_factor_bincounts = {"age": 4, "income": 3}
352
- >>> parity(labels, metadata, continuous_factor_bincounts)
261
+ >>> metadata = metadata_preprocessing(metadata_dict, labels, continuous_factor_bincounts)
262
+ >>> parity(metadata)
353
263
  ParityOutput(score=array([7.35731943, 5.46711299, 0.51506212]), p_value=array([0.28906231, 0.24263543, 0.77295762]), metadata_names=['age', 'income', 'gender'])
354
264
  """ # noqa: E501
355
- if len(np.shape(class_labels)) > 1:
356
- raise ValueError(
357
- f"Got class labels with {len(np.shape(class_labels))}-dimensional",
358
- f" shape {np.shape(class_labels)}, but expected a 1-dimensional array.",
359
- )
360
-
361
- data, names, is_categorical, _ = preprocess_metadata(class_labels, metadata)
362
- continuous_factor_bincounts = continuous_factor_bincounts if continuous_factor_bincounts else {}
363
-
364
- factors = format_discretize_factors(data, names, is_categorical, continuous_factor_bincounts)
365
-
366
- # unique class labels
367
- class_idx = names.index(CLASS_LABEL)
368
- u_cls = np.unique(data[:, class_idx])
369
-
370
- chi_scores = np.zeros(len(factors))
371
- p_values = np.zeros(len(factors))
265
+ chi_scores = np.zeros(metadata.discrete_data.shape[1])
266
+ p_values = np.zeros_like(chi_scores)
372
267
  not_enough_data = {}
373
- for i, (current_factor_name, factor_values) in enumerate(factors.items()):
374
- unique_factor_values = np.unique(factor_values)
375
- contingency_matrix = np.zeros((len(unique_factor_values), u_cls.size))
268
+ for i, col_data in enumerate(metadata.discrete_data.T):
376
269
  # Builds a contingency matrix where entry at index (r,c) represents
377
270
  # the frequency of current_factor_name achieving value unique_factor_values[r]
378
271
  # at a data point with class c.
379
-
380
- # TODO: Vectorize this nested for loop
381
- for fi, factor_value in enumerate(unique_factor_values):
382
- for label in u_cls:
383
- with_both = np.bitwise_and((data[:, class_idx] == label), factor_values == factor_value)
384
- contingency_matrix[fi, label] = np.sum(with_both)
385
- if 0 < contingency_matrix[fi, label] < 5:
386
- if current_factor_name not in not_enough_data:
387
- not_enough_data[current_factor_name] = {}
388
- if factor_value not in not_enough_data[current_factor_name]:
389
- not_enough_data[current_factor_name][factor_value] = []
390
- not_enough_data[current_factor_name][factor_value].append(
391
- (label, int(contingency_matrix[fi, label]))
392
- )
272
+ results = crosstab(col_data, metadata.class_labels)
273
+ contingency_matrix = as_numpy(results.count) # type: ignore
274
+
275
+ # Determines if any frequencies are too low
276
+ counts = np.nonzero(contingency_matrix < 5)
277
+ unique_factor_values = np.unique(col_data)
278
+ current_factor_name = metadata.discrete_factor_names[i]
279
+ for int_factor, int_class in zip(counts[0], counts[1]):
280
+ if contingency_matrix[int_factor, int_class] > 0:
281
+ factor_category = unique_factor_values[int_factor]
282
+ if current_factor_name not in not_enough_data:
283
+ not_enough_data[current_factor_name] = {}
284
+ if factor_category not in not_enough_data[current_factor_name]:
285
+ not_enough_data[current_factor_name][factor_category] = []
286
+ not_enough_data[current_factor_name][factor_category].append(
287
+ (metadata.class_names[int_class], int(contingency_matrix[int_factor, int_class]))
288
+ )
393
289
 
394
290
  # This deletes rows containing only zeros,
395
291
  # because scipy.stats.chi2_contingency fails when there are rows containing only zeros.
396
292
  rowsums = np.sum(contingency_matrix, axis=1)
397
- rowmask = np.where(rowsums)
293
+ rowmask = np.nonzero(rowsums)[0]
398
294
  contingency_matrix = contingency_matrix[rowmask]
399
295
 
400
296
  chi2, p, _, _ = chi2_contingency(contingency_matrix)
@@ -422,4 +318,4 @@ def parity(
422
318
  UserWarning,
423
319
  )
424
320
 
425
- return ParityOutput(chi_scores, p_values, list(metadata.keys()))
321
+ return ParityOutput(chi_scores, p_values, metadata.discrete_factor_names)
@@ -10,12 +10,12 @@ from dataeval.utils.split_dataset import split_dataset
10
10
 
11
11
  __all__ = ["split_dataset", "merge_metadata"]
12
12
 
13
- if _IS_TORCH_AVAILABLE: # pragma: no cover
13
+ if _IS_TORCH_AVAILABLE:
14
14
  from dataeval.utils import torch
15
15
 
16
16
  __all__ += ["torch"]
17
17
 
18
- if _IS_TENSORFLOW_AVAILABLE: # pragma: no cover
18
+ if _IS_TENSORFLOW_AVAILABLE:
19
19
  from dataeval.utils import tensorflow
20
20
 
21
21
  __all__ += ["tensorflow"]
dataeval/utils/gmm.py ADDED
@@ -0,0 +1,26 @@
1
+ from dataclasses import dataclass
2
+ from typing import Generic, TypeVar
3
+
4
+ TGMMData = TypeVar("TGMMData")
5
+
6
+
7
+ @dataclass
8
+ class GaussianMixtureModelParams(Generic[TGMMData]):
9
+ """
10
+ phi : TGMMData
11
+ Mixture component distribution weights.
12
+ mu : TGMMData
13
+ Mixture means.
14
+ cov : TGMMData
15
+ Mixture covariance.
16
+ L : TGMMData
17
+ Cholesky decomposition of `cov`.
18
+ log_det_cov : TGMMData
19
+ Log of the determinant of `cov`.
20
+ """
21
+
22
+ phi: TGMMData
23
+ mu: TGMMData
24
+ cov: TGMMData
25
+ L: TGMMData
26
+ log_det_cov: TGMMData
@@ -131,7 +131,9 @@ def _flatten_dict_inner(
131
131
  return items, size
132
132
 
133
133
 
134
- def _flatten_dict(d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qualified: bool) -> dict[str, Any]:
134
+ def _flatten_dict(
135
+ d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qualified: bool
136
+ ) -> tuple[dict[str, Any], int]:
135
137
  """
136
138
  Flattens a dictionary and converts values to numeric values when possible.
137
139
 
@@ -165,7 +167,7 @@ def _flatten_dict(d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qual
165
167
  output[k] = cv
166
168
  elif not isinstance(cv, list):
167
169
  output[k] = cv if not size else [cv] * size
168
- return output
170
+ return output, size if size is not None else 1
169
171
 
170
172
 
171
173
  def _is_metadata_dict_of_dicts(metadata: Mapping) -> bool:
@@ -188,7 +190,7 @@ def merge_metadata(
188
190
  ignore_lists: bool = False,
189
191
  fully_qualified: bool = False,
190
192
  as_numpy: bool = False,
191
- ) -> dict[str, list[Any]] | dict[str, NDArray[Any]]:
193
+ ) -> tuple[dict[str, list[Any]] | dict[str, NDArray[Any]], NDArray[np.int_]]:
192
194
  """
193
195
  Merges a collection of metadata dictionaries into a single flattened dictionary of keys and values.
194
196
 
@@ -208,8 +210,10 @@ def merge_metadata(
208
210
 
209
211
  Returns
210
212
  -------
211
- dict[str, list[Any]] | dict[str, NDArray[Any]]
213
+ dict[str, list[Any]] or dict[str, NDArray[Any]]
212
214
  A single dictionary containing the flattened data as lists or NumPy arrays
215
+ NDArray[np.int_]
216
+ Array defining where individual images start, helpful when working with object detection metadata
213
217
 
214
218
  Note
215
219
  ----
@@ -217,9 +221,12 @@ def merge_metadata(
217
221
 
218
222
  Example
219
223
  -------
220
- >>> list_metadata = [{"common": 1, "target": [{"a": 1, "b": 3}, {"a": 2, "b": 4}], "source": "example"}]
221
- >>> merge_metadata(list_metadata)
224
+ >>> list_metadata = [{"common": 1, "target": [{"a": 1, "b": 3, "c": 5}, {"a": 2, "b": 4}], "source": "example"}]
225
+ >>> reorganized_metadata, image_indicies = merge_metadata(list_metadata)
226
+ >>> reorganized_metadata
222
227
  {'common': [1, 1], 'a': [1, 2], 'b': [3, 4], 'source': ['example', 'example']}
228
+ >>> image_indicies
229
+ array([0])
223
230
  """
224
231
  merged: dict[str, list[Any]] = {}
225
232
  isect: set[str] = set()
@@ -236,8 +243,11 @@ def merge_metadata(
236
243
  else:
237
244
  dicts = list(metadata)
238
245
 
239
- for d in dicts:
240
- flattened = _flatten_dict(d, sep="_", ignore_lists=ignore_lists, fully_qualified=fully_qualified)
246
+ image_repeats = np.zeros(len(dicts))
247
+ for i, d in enumerate(dicts):
248
+ flattened, image_repeats[i] = _flatten_dict(
249
+ d, sep="_", ignore_lists=ignore_lists, fully_qualified=fully_qualified
250
+ )
241
251
  isect = isect.intersection(flattened.keys()) if isect else set(flattened.keys())
242
252
  union = union.union(flattened.keys())
243
253
  for k, v in flattened.items():
@@ -248,6 +258,16 @@ def merge_metadata(
248
258
 
249
259
  output: dict[str, Any] = {}
250
260
 
261
+ if image_repeats.sum() == image_repeats.size:
262
+ image_indicies = np.arange(image_repeats.size)
263
+ else:
264
+ image_ids = np.arange(image_repeats.size)
265
+ image_data = np.concatenate(
266
+ [np.repeat(image_ids[i], image_repeats[i]) for i in range(image_ids.size)], dtype=np.int_
267
+ )
268
+ _, image_unsorted = np.unique(image_data, return_index=True)
269
+ image_indicies = np.sort(image_unsorted)
270
+
251
271
  if keys:
252
272
  output["keys"] = np.array(keys) if as_numpy else keys
253
273
 
@@ -255,4 +275,4 @@ def merge_metadata(
255
275
  cv = _convert_type(merged[k])
256
276
  output[k] = np.array(cv) if as_numpy else cv
257
277
 
258
- return output
278
+ return output, image_indicies
dataeval/utils/shared.py CHANGED
@@ -95,7 +95,7 @@ def get_classes_counts(labels: NDArray[np.int_]) -> tuple[int, int]:
95
95
  M = len(classes)
96
96
  if M < 2:
97
97
  raise ValueError("Label vector contains less than 2 classes!")
98
- N = np.sum(counts).astype(int)
98
+ N = int(np.sum(counts))
99
99
  return M, N
100
100
 
101
101