dataeval 0.73.1__py3-none-any.whl → 0.74.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,229 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import contextlib
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ from numpy.typing import ArrayLike, NDArray
10
+
11
+ from dataeval.interop import to_numpy
12
+
13
+ with contextlib.suppress(ImportError):
14
+ from matplotlib.figure import Figure
15
+
16
+
17
+ def get_counts(data: NDArray[np.int_], min_num_bins: int | None = None) -> NDArray[np.int_]:
18
+ """
19
+ Returns columnwise unique counts for discrete data.
20
+
21
+ Parameters
22
+ ----------
23
+ data : NDArray
24
+ Array containing integer values for metadata factors
25
+ min_num_bins : int | None, default None
26
+ Minimum number of bins for bincount, helps force consistency across runs
27
+
28
+ Returns
29
+ -------
30
+ NDArray[np.int_]
31
+ Bin counts per column of data.
32
+ """
33
+ max_value = data.max() + 1 if min_num_bins is None else min_num_bins
34
+ cnt_array = np.zeros((max_value, data.shape[1]), dtype=np.int_)
35
+ for idx in range(data.shape[1]):
36
+ cnt_array[:, idx] = np.bincount(data[:, idx], minlength=max_value)
37
+
38
+ return cnt_array
39
+
40
+
41
+ def heatmap(
42
+ data: ArrayLike,
43
+ row_labels: list[str] | ArrayLike,
44
+ col_labels: list[str] | ArrayLike,
45
+ xlabel: str = "",
46
+ ylabel: str = "",
47
+ cbarlabel: str = "",
48
+ ) -> Figure:
49
+ """
50
+ Plots a formatted heatmap
51
+
52
+ Parameters
53
+ ----------
54
+ data : NDArray
55
+ Array containing numerical values for factors to plot
56
+ row_labels : ArrayLike
57
+ List/Array containing the labels for rows in the histogram
58
+ col_labels : ArrayLike
59
+ List/Array containing the labels for columns in the histogram
60
+ xlabel : str, default ""
61
+ X-axis label
62
+ ylabel : str, default ""
63
+ Y-axis label
64
+ cbarlabel : str, default ""
65
+ Label for the colorbar
66
+
67
+ Returns
68
+ -------
69
+ matplotlib.figure.Figure
70
+ Formatted heatmap
71
+ """
72
+ import matplotlib.pyplot as plt
73
+ from matplotlib.ticker import FuncFormatter
74
+
75
+ np_data = to_numpy(data)
76
+ rows = row_labels if isinstance(row_labels, list) else to_numpy(row_labels)
77
+ cols = col_labels if isinstance(col_labels, list) else to_numpy(col_labels)
78
+
79
+ fig, ax = plt.subplots(figsize=(10, 10))
80
+
81
+ # Plot the heatmap
82
+ im = ax.imshow(np_data, vmin=0, vmax=1.0)
83
+
84
+ # Create colorbar
85
+ cbar = fig.colorbar(im, shrink=0.5)
86
+ cbar.set_ticks([0.0, 0.25, 0.5, 0.75, 1.0])
87
+ cbar.set_ticklabels(["0.0", "0.25", "0.5", "0.75", "1.0"])
88
+ cbar.set_label(cbarlabel, loc="center")
89
+
90
+ # Show all ticks and label them with the respective list entries.
91
+ ax.set_xticks(np.arange(np_data.shape[1]), labels=cols)
92
+ ax.set_yticks(np.arange(np_data.shape[0]), labels=rows)
93
+
94
+ ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
95
+ # Rotate the tick labels and set their alignment.
96
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
97
+
98
+ # Turn spines off and create white grid.
99
+ ax.spines[:].set_visible(False)
100
+
101
+ ax.set_xticks(np.arange(np_data.shape[1] + 1) - 0.5, minor=True)
102
+ ax.set_yticks(np.arange(np_data.shape[0] + 1) - 0.5, minor=True)
103
+ ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
104
+ ax.tick_params(which="minor", bottom=False, left=False)
105
+
106
+ if xlabel:
107
+ ax.set_xlabel(xlabel)
108
+ if ylabel:
109
+ ax.set_ylabel(ylabel)
110
+
111
+ valfmt = FuncFormatter(format_text)
112
+
113
+ # Normalize the threshold to the images color range.
114
+ threshold = im.norm(1.0) / 2.0
115
+
116
+ # Set default alignment to center, but allow it to be
117
+ # overwritten by textkw.
118
+ kw = {"horizontalalignment": "center", "verticalalignment": "center"}
119
+
120
+ # Loop over the data and create a `Text` for each "pixel".
121
+ # Change the text's color depending on the data.
122
+ textcolors = ("white", "black")
123
+ texts = []
124
+ for i in range(np_data.shape[0]):
125
+ for j in range(np_data.shape[1]):
126
+ kw.update(color=textcolors[int(im.norm(np_data[i, j]) > threshold)])
127
+ text = im.axes.text(j, i, valfmt(np_data[i, j], None), **kw) # type: ignore
128
+ texts.append(text)
129
+
130
+ fig.tight_layout()
131
+ return fig
132
+
133
+
134
+ # Function to define how the text is displayed in the heatmap
135
+ def format_text(*args: str) -> str:
136
+ """
137
+ Helper function to format text for heatmap()
138
+
139
+ Parameters
140
+ ----------
141
+ *args : tuple[str, str]
142
+ Text to be formatted. Second element is ignored, but is a
143
+ mandatory pass-through argument as per matplotlib.ticket.FuncFormatter
144
+
145
+ Returns
146
+ -------
147
+ str
148
+ Formatted text
149
+ """
150
+ x = args[0]
151
+ return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
152
+
153
+
154
+ def diversity_bar_plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
155
+ """
156
+ Plots a formatted bar plot
157
+
158
+ Parameters
159
+ ----------
160
+ labels : NDArray
161
+ Array containing the labels for each bar
162
+ bar_heights : NDArray
163
+ Array containing the values for each bar
164
+
165
+ Returns
166
+ -------
167
+ matplotlib.figure.Figure
168
+ Bar plot figure
169
+ """
170
+ import matplotlib.pyplot as plt
171
+
172
+ fig, ax = plt.subplots(figsize=(10, 10))
173
+
174
+ ax.bar(labels, bar_heights)
175
+ ax.set_xlabel("Factors")
176
+
177
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
178
+
179
+ fig.tight_layout()
180
+ return fig
181
+
182
+
183
+ def coverage_plot(images: NDArray[Any], num_images: int) -> Figure:
184
+ """
185
+ Creates a single plot of all of the provided images
186
+
187
+ Parameters
188
+ ----------
189
+ images : NDArray
190
+ Array containing only the desired images to plot
191
+
192
+ Returns
193
+ -------
194
+ matplotlib.figure.Figure
195
+ Plot of all provided images
196
+ """
197
+ import matplotlib.pyplot as plt
198
+
199
+ num_images = min(num_images, len(images))
200
+
201
+ if images.ndim == 4:
202
+ images = np.moveaxis(images, 1, -1)
203
+ elif images.ndim == 3:
204
+ images = np.repeat(images[:, :, :, np.newaxis], 3, axis=-1)
205
+ else:
206
+ raise ValueError(
207
+ f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
208
+ )
209
+
210
+ rows = int(np.ceil(num_images / 3))
211
+ fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
212
+
213
+ if rows == 1:
214
+ for j in range(3):
215
+ if j >= len(images):
216
+ continue
217
+ axs[j].imshow(images[j])
218
+ axs[j].axis("off")
219
+ else:
220
+ for i in range(rows):
221
+ for j in range(3):
222
+ i_j = i * 3 + j
223
+ if i_j >= len(images):
224
+ continue
225
+ axs[i, j].imshow(images[i_j])
226
+ axs[i, j].axis("off")
227
+
228
+ fig.tight_layout()
229
+ return fig
@@ -4,14 +4,15 @@ __all__ = ["ParityOutput", "parity", "label_parity"]
4
4
 
5
5
  import warnings
6
6
  from dataclasses import dataclass
7
- from typing import Any, Generic, Mapping, TypeVar
7
+ from typing import Any, Generic, TypeVar
8
8
 
9
9
  import numpy as np
10
10
  from numpy.typing import ArrayLike, NDArray
11
- from scipy.stats import chi2_contingency, chisquare
11
+ from scipy.stats import chisquare
12
+ from scipy.stats.contingency import chi2_contingency, crosstab
12
13
 
13
- from dataeval.interop import to_numpy
14
- from dataeval.metrics.bias.metadata import CLASS_LABEL, preprocess_metadata
14
+ from dataeval.interop import as_numpy, to_numpy
15
+ from dataeval.metrics.bias.metadata_preprocessing import MetadataOutput
15
16
  from dataeval.output import OutputMetadata, set_metadata
16
17
 
17
18
  TData = TypeVar("TData", np.float64, NDArray[np.float64])
@@ -37,97 +38,6 @@ class ParityOutput(Generic[TData], OutputMetadata):
37
38
  metadata_names: list[str] | None
38
39
 
39
40
 
40
- def digitize_factor_bins(continuous_values: NDArray[Any], bins: int, factor_name: str) -> NDArray[np.intp]:
41
- """
42
- Digitizes a list of values into a given number of bins.
43
-
44
- Parameters
45
- ----------
46
- continuous_values : NDArray
47
- The values to be digitized.
48
- bins : int
49
- The number of bins for the discrete values that continuous_values will be digitized into.
50
- factor_name : str
51
- The name of the factor to be digitized.
52
-
53
- Returns
54
- -------
55
- NDArray[np.intp]
56
- The digitized values
57
- """
58
-
59
- if not np.all([np.issubdtype(type(n), np.number) for n in continuous_values]):
60
- raise TypeError(
61
- f"Encountered a non-numeric value for factor {factor_name}, but the factor"
62
- " was specified to be continuous. Ensure all occurrences of this factor are numeric types,"
63
- f" or do not specify {factor_name} as a continuous factor."
64
- )
65
-
66
- _, bin_edges = np.histogram(continuous_values, bins=bins)
67
- bin_edges[-1] = np.inf
68
- bin_edges[0] = -np.inf
69
- return np.digitize(continuous_values, bin_edges)
70
-
71
-
72
- def format_discretize_factors(
73
- data: NDArray[Any],
74
- names: list[str],
75
- is_categorical: list[bool],
76
- continuous_factor_bincounts: Mapping[str, int] | None,
77
- ) -> dict[str, NDArray[Any]]:
78
- """
79
- Sets up the internal list of metadata factors.
80
-
81
- Parameters
82
- ----------
83
- data : NDArray
84
- The dataset factors, which are per-image attributes including class label and metadata.
85
- names : list[str]
86
- The class label
87
- continuous_factor_bincounts : Mapping[str, int] or None
88
- The factors in data_factors that have continuous values and the array of bin counts to
89
- discretize values into. All factors are treated as having discrete values unless they
90
- are specified as keys in this dictionary. Each element of this array must occur as a key
91
- in data_factors.
92
-
93
- Returns
94
- -------
95
- Dict[str, NDArray]
96
- - Intrinsic per-image metadata information with the formatting that input data_factors uses.
97
- Each key is a metadata factor, whose value is the discrete per-image factor values.
98
- """
99
-
100
- if continuous_factor_bincounts:
101
- invalid_keys = set(continuous_factor_bincounts.keys()) - set(names)
102
- if invalid_keys:
103
- raise KeyError(
104
- f"The continuous factor(s) {invalid_keys} do not exist in data_factors. Delete these "
105
- "keys from `continuous_factor_names` or add corresponding entries to `data_factors`."
106
- )
107
-
108
- warn = []
109
- metadata_factors = {}
110
- for i, name in enumerate(names):
111
- if name == CLASS_LABEL:
112
- continue
113
- if continuous_factor_bincounts and name in continuous_factor_bincounts:
114
- metadata_factors[name] = digitize_factor_bins(data[:, i], continuous_factor_bincounts[name], name)
115
- elif not is_categorical[i]:
116
- warn.append(name)
117
- metadata_factors[name] = data[:, i]
118
- else:
119
- metadata_factors[name] = data[:, i]
120
-
121
- if warn:
122
- warnings.warn(
123
- f"The following factors appear to be continuous but did not have the desired number of bins specified: \n\
124
- {warn}",
125
- UserWarning,
126
- )
127
-
128
- return metadata_factors
129
-
130
-
131
41
  def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[Any]) -> NDArray[Any]:
132
42
  """
133
43
  Normalize the expected label distribution to match the total number of labels in the observed distribution.
@@ -295,31 +205,19 @@ def label_parity(
295
205
 
296
206
 
297
207
  @set_metadata()
298
- def parity(
299
- class_labels: ArrayLike,
300
- metadata: Mapping[str, ArrayLike],
301
- continuous_factor_bincounts: Mapping[str, int] | None = None,
302
- ) -> ParityOutput[NDArray[np.float64]]:
208
+ def parity(metadata: MetadataOutput) -> ParityOutput[NDArray[np.float64]]:
303
209
  """
304
- Calculate chi-square statistics to assess the relationship between multiple factors
210
+ Calculate chi-square statistics to assess the linear relationship between multiple factors
305
211
  and class labels.
306
212
 
307
213
  This function computes the chi-square statistic for each metadata factor to determine if there is
308
- a significant relationship between the factor values and class labels. The function handles both categorical
309
- and discretized continuous factors.
214
+ a significant relationship between the factor values and class labels. The chi-square statistic is
215
+ only valid for linear relationships. If non-linear relationships exist, use `balance`.
310
216
 
311
217
  Parameters
312
218
  ----------
313
- class_labels : ArrayLike
314
- List of class labels for each image
315
- metadata : Mapping[str, ArrayLike]
316
- The dataset factors, which are per-image metadata attributes.
317
- Each key of dataset_factors is a factor, whose value is the per-image factor values.
318
- continuous_factor_bincounts : Mapping[str, int] or None, default None
319
- A dictionary specifying the number of bins for discretizing the continuous factors.
320
- The keys should correspond to the names of continuous factors in `metadata`,
321
- and the values should be the number of bins to use for discretization.
322
- If not provided, no discretization is applied.
219
+ metadata : MetadataOutput
220
+ Output after running `metadata_preprocessing`
323
221
 
324
222
  Returns
325
223
  -------
@@ -333,74 +231,66 @@ def parity(
333
231
  Warning
334
232
  If any cell in the contingency matrix has a value between 0 and 5, a warning is issued because this can
335
233
  lead to inaccurate chi-square calculations. It is recommended to ensure that each label co-occurs with
336
- factor values either 0 times or at least 5 times. Alternatively, continuous-valued factors can be digitized
337
- into fewer bins.
234
+ factor values either 0 times or at least 5 times.
338
235
 
339
236
  Note
340
237
  ----
341
- - Each key of the ``continuous_factor_bincounts`` dictionary must occur as a key in data_factors.
342
238
  - A high score with a low p-value suggests that a metadata factor is strongly correlated with a class label.
343
239
  - The function creates a contingency matrix for each factor, where each entry represents the frequency of a
344
240
  specific factor value co-occurring with a particular class label.
345
241
  - Rows containing only zeros in the contingency matrix are removed before performing the chi-square test
346
242
  to prevent errors in the calculation.
347
243
 
244
+ See Also
245
+ --------
246
+ balance
247
+
348
248
  Examples
349
249
  --------
350
250
  Randomly creating some "continuous" and categorical variables using ``np.random.default_rng``
351
251
 
352
252
  >>> labels = np_random_gen.choice([0, 1, 2], (100))
353
- >>> metadata = {
354
- ... "age": np_random_gen.choice([25, 30, 35, 45], (100)),
355
- ... "income": np_random_gen.choice([50000, 65000, 80000], (100)),
356
- ... "gender": np_random_gen.choice(["M", "F"], (100)),
357
- ... }
253
+ >>> metadata_dict = [
254
+ ... {
255
+ ... "age": list(np_random_gen.choice([25, 30, 35, 45], (100))),
256
+ ... "income": list(np_random_gen.choice([50000, 65000, 80000], (100))),
257
+ ... "gender": list(np_random_gen.choice(["M", "F"], (100))),
258
+ ... }
259
+ ... ]
358
260
  >>> continuous_factor_bincounts = {"age": 4, "income": 3}
359
- >>> parity(labels, metadata, continuous_factor_bincounts)
261
+ >>> metadata = metadata_preprocessing(metadata_dict, labels, continuous_factor_bincounts)
262
+ >>> parity(metadata)
360
263
  ParityOutput(score=array([7.35731943, 5.46711299, 0.51506212]), p_value=array([0.28906231, 0.24263543, 0.77295762]), metadata_names=['age', 'income', 'gender'])
361
264
  """ # noqa: E501
362
- if len(np.shape(class_labels)) > 1:
363
- raise ValueError(
364
- f"Got class labels with {len(np.shape(class_labels))}-dimensional",
365
- f" shape {np.shape(class_labels)}, but expected a 1-dimensional array.",
366
- )
367
-
368
- data, names, is_categorical, _ = preprocess_metadata(class_labels, metadata)
369
-
370
- factors = format_discretize_factors(data, names, is_categorical, continuous_factor_bincounts)
371
-
372
- # unique class labels
373
- class_idx = names.index(CLASS_LABEL)
374
- u_cls = np.unique(data[:, class_idx])
375
-
376
- chi_scores = np.zeros(len(factors))
377
- p_values = np.zeros(len(factors))
265
+ chi_scores = np.zeros(metadata.discrete_data.shape[1])
266
+ p_values = np.zeros_like(chi_scores)
378
267
  not_enough_data = {}
379
- for i, (current_factor_name, factor_values) in enumerate(factors.items()):
380
- unique_factor_values = np.unique(factor_values)
381
- contingency_matrix = np.zeros((len(unique_factor_values), u_cls.size))
268
+ for i, col_data in enumerate(metadata.discrete_data.T):
382
269
  # Builds a contingency matrix where entry at index (r,c) represents
383
270
  # the frequency of current_factor_name achieving value unique_factor_values[r]
384
271
  # at a data point with class c.
385
-
386
- # TODO: Vectorize this nested for loop
387
- for fi, factor_value in enumerate(unique_factor_values):
388
- for label in u_cls:
389
- with_both = np.bitwise_and((data[:, class_idx] == label), factor_values == factor_value)
390
- contingency_matrix[fi, label] = np.sum(with_both)
391
- if 0 < contingency_matrix[fi, label] < 5:
392
- if current_factor_name not in not_enough_data:
393
- not_enough_data[current_factor_name] = {}
394
- if factor_value not in not_enough_data[current_factor_name]:
395
- not_enough_data[current_factor_name][factor_value] = []
396
- not_enough_data[current_factor_name][factor_value].append(
397
- (label, int(contingency_matrix[fi, label]))
398
- )
272
+ results = crosstab(col_data, metadata.class_labels)
273
+ contingency_matrix = as_numpy(results.count) # type: ignore
274
+
275
+ # Determines if any frequencies are too low
276
+ counts = np.nonzero(contingency_matrix < 5)
277
+ unique_factor_values = np.unique(col_data)
278
+ current_factor_name = metadata.discrete_factor_names[i]
279
+ for int_factor, int_class in zip(counts[0], counts[1]):
280
+ if contingency_matrix[int_factor, int_class] > 0:
281
+ factor_category = unique_factor_values[int_factor]
282
+ if current_factor_name not in not_enough_data:
283
+ not_enough_data[current_factor_name] = {}
284
+ if factor_category not in not_enough_data[current_factor_name]:
285
+ not_enough_data[current_factor_name][factor_category] = []
286
+ not_enough_data[current_factor_name][factor_category].append(
287
+ (metadata.class_names[int_class], int(contingency_matrix[int_factor, int_class]))
288
+ )
399
289
 
400
290
  # This deletes rows containing only zeros,
401
291
  # because scipy.stats.chi2_contingency fails when there are rows containing only zeros.
402
292
  rowsums = np.sum(contingency_matrix, axis=1)
403
- rowmask = np.where(rowsums)
293
+ rowmask = np.nonzero(rowsums)[0]
404
294
  contingency_matrix = contingency_matrix[rowmask]
405
295
 
406
296
  chi2, p, _, _ = chi2_contingency(contingency_matrix)
@@ -428,4 +318,4 @@ def parity(
428
318
  UserWarning,
429
319
  )
430
320
 
431
- return ParityOutput(chi_scores, p_values, list(metadata.keys()))
321
+ return ParityOutput(chi_scores, p_values, metadata.discrete_factor_names)
dataeval/utils/gmm.py ADDED
@@ -0,0 +1,26 @@
1
+ from dataclasses import dataclass
2
+ from typing import Generic, TypeVar
3
+
4
+ TGMMData = TypeVar("TGMMData")
5
+
6
+
7
+ @dataclass
8
+ class GaussianMixtureModelParams(Generic[TGMMData]):
9
+ """
10
+ phi : TGMMData
11
+ Mixture component distribution weights.
12
+ mu : TGMMData
13
+ Mixture means.
14
+ cov : TGMMData
15
+ Mixture covariance.
16
+ L : TGMMData
17
+ Cholesky decomposition of `cov`.
18
+ log_det_cov : TGMMData
19
+ Log of the determinant of `cov`.
20
+ """
21
+
22
+ phi: TGMMData
23
+ mu: TGMMData
24
+ cov: TGMMData
25
+ L: TGMMData
26
+ log_det_cov: TGMMData
@@ -131,7 +131,9 @@ def _flatten_dict_inner(
131
131
  return items, size
132
132
 
133
133
 
134
- def _flatten_dict(d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qualified: bool) -> dict[str, Any]:
134
+ def _flatten_dict(
135
+ d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qualified: bool
136
+ ) -> tuple[dict[str, Any], int]:
135
137
  """
136
138
  Flattens a dictionary and converts values to numeric values when possible.
137
139
 
@@ -165,7 +167,7 @@ def _flatten_dict(d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qual
165
167
  output[k] = cv
166
168
  elif not isinstance(cv, list):
167
169
  output[k] = cv if not size else [cv] * size
168
- return output
170
+ return output, size if size is not None else 1
169
171
 
170
172
 
171
173
  def _is_metadata_dict_of_dicts(metadata: Mapping) -> bool:
@@ -188,7 +190,7 @@ def merge_metadata(
188
190
  ignore_lists: bool = False,
189
191
  fully_qualified: bool = False,
190
192
  as_numpy: bool = False,
191
- ) -> dict[str, list[Any]] | dict[str, NDArray[Any]]:
193
+ ) -> tuple[dict[str, list[Any]] | dict[str, NDArray[Any]], NDArray[np.int_]]:
192
194
  """
193
195
  Merges a collection of metadata dictionaries into a single flattened dictionary of keys and values.
194
196
 
@@ -208,8 +210,10 @@ def merge_metadata(
208
210
 
209
211
  Returns
210
212
  -------
211
- dict[str, list[Any]] | dict[str, NDArray[Any]]
213
+ dict[str, list[Any]] or dict[str, NDArray[Any]]
212
214
  A single dictionary containing the flattened data as lists or NumPy arrays
215
+ NDArray[np.int_]
216
+ Array defining where individual images start, helpful when working with object detection metadata
213
217
 
214
218
  Note
215
219
  ----
@@ -217,9 +221,12 @@ def merge_metadata(
217
221
 
218
222
  Example
219
223
  -------
220
- >>> list_metadata = [{"common": 1, "target": [{"a": 1, "b": 3}, {"a": 2, "b": 4}], "source": "example"}]
221
- >>> merge_metadata(list_metadata)
224
+ >>> list_metadata = [{"common": 1, "target": [{"a": 1, "b": 3, "c": 5}, {"a": 2, "b": 4}], "source": "example"}]
225
+ >>> reorganized_metadata, image_indicies = merge_metadata(list_metadata)
226
+ >>> reorganized_metadata
222
227
  {'common': [1, 1], 'a': [1, 2], 'b': [3, 4], 'source': ['example', 'example']}
228
+ >>> image_indicies
229
+ array([0])
223
230
  """
224
231
  merged: dict[str, list[Any]] = {}
225
232
  isect: set[str] = set()
@@ -236,8 +243,11 @@ def merge_metadata(
236
243
  else:
237
244
  dicts = list(metadata)
238
245
 
239
- for d in dicts:
240
- flattened = _flatten_dict(d, sep="_", ignore_lists=ignore_lists, fully_qualified=fully_qualified)
246
+ image_repeats = np.zeros(len(dicts))
247
+ for i, d in enumerate(dicts):
248
+ flattened, image_repeats[i] = _flatten_dict(
249
+ d, sep="_", ignore_lists=ignore_lists, fully_qualified=fully_qualified
250
+ )
241
251
  isect = isect.intersection(flattened.keys()) if isect else set(flattened.keys())
242
252
  union = union.union(flattened.keys())
243
253
  for k, v in flattened.items():
@@ -248,6 +258,16 @@ def merge_metadata(
248
258
 
249
259
  output: dict[str, Any] = {}
250
260
 
261
+ if image_repeats.sum() == image_repeats.size:
262
+ image_indicies = np.arange(image_repeats.size)
263
+ else:
264
+ image_ids = np.arange(image_repeats.size)
265
+ image_data = np.concatenate(
266
+ [np.repeat(image_ids[i], image_repeats[i]) for i in range(image_ids.size)], dtype=np.int_
267
+ )
268
+ _, image_unsorted = np.unique(image_data, return_index=True)
269
+ image_indicies = np.sort(image_unsorted)
270
+
251
271
  if keys:
252
272
  output["keys"] = np.array(keys) if as_numpy else keys
253
273
 
@@ -255,4 +275,4 @@ def merge_metadata(
255
275
  cv = _convert_type(merged[k])
256
276
  output[k] = np.array(cv) if as_numpy else cv
257
277
 
258
- return output
278
+ return output, image_indicies
@@ -8,10 +8,11 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- from typing import TYPE_CHECKING, NamedTuple
11
+ from typing import TYPE_CHECKING
12
12
 
13
13
  import numpy as np
14
14
 
15
+ from dataeval.utils.gmm import GaussianMixtureModelParams
15
16
  from dataeval.utils.lazy import lazyload
16
17
 
17
18
  if TYPE_CHECKING:
@@ -20,28 +21,7 @@ else:
20
21
  tf = lazyload("tensorflow")
21
22
 
22
23
 
23
- class GaussianMixtureModelParams(NamedTuple):
24
- """
25
- phi : tf.Tensor
26
- Mixture component distribution weights.
27
- mu : tf.Tensor
28
- Mixture means.
29
- cov : tf.Tensor
30
- Mixture covariance.
31
- L : tf.Tensor
32
- Cholesky decomposition of `cov`.
33
- log_det_cov : tf.Tensor
34
- Log of the determinant of `cov`.
35
- """
36
-
37
- phi: tf.Tensor
38
- mu: tf.Tensor
39
- cov: tf.Tensor
40
- L: tf.Tensor
41
- log_det_cov: tf.Tensor
42
-
43
-
44
- def gmm_params(z: tf.Tensor, gamma: tf.Tensor) -> GaussianMixtureModelParams:
24
+ def gmm_params(z: tf.Tensor, gamma: tf.Tensor) -> GaussianMixtureModelParams[tf.Tensor]:
45
25
  """
46
26
  Compute parameters of Gaussian Mixture Model.
47
27
 
@@ -81,7 +61,7 @@ def gmm_params(z: tf.Tensor, gamma: tf.Tensor) -> GaussianMixtureModelParams:
81
61
 
82
62
  def gmm_energy(
83
63
  z: tf.Tensor,
84
- params: GaussianMixtureModelParams,
64
+ params: GaussianMixtureModelParams[tf.Tensor],
85
65
  return_mean: bool = True,
86
66
  ) -> tuple[tf.Tensor, tf.Tensor]:
87
67
  """