dataeval 0.75.0__py3-none-any.whl → 0.76.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/detectors/drift/base.py +2 -2
  3. dataeval/detectors/drift/ks.py +2 -1
  4. dataeval/detectors/drift/mmd.py +3 -2
  5. dataeval/detectors/drift/uncertainty.py +2 -2
  6. dataeval/detectors/drift/updates.py +1 -1
  7. dataeval/detectors/linters/clusterer.py +3 -2
  8. dataeval/detectors/linters/duplicates.py +4 -4
  9. dataeval/detectors/linters/outliers.py +96 -3
  10. dataeval/detectors/ood/__init__.py +1 -1
  11. dataeval/detectors/ood/base.py +1 -17
  12. dataeval/detectors/ood/output.py +1 -1
  13. dataeval/interop.py +1 -1
  14. dataeval/metrics/__init__.py +1 -1
  15. dataeval/metrics/bias/__init__.py +1 -1
  16. dataeval/metrics/bias/balance.py +3 -3
  17. dataeval/metrics/bias/coverage.py +1 -1
  18. dataeval/metrics/bias/diversity.py +14 -10
  19. dataeval/metrics/bias/parity.py +7 -9
  20. dataeval/metrics/estimators/ber.py +4 -3
  21. dataeval/metrics/estimators/divergence.py +3 -3
  22. dataeval/metrics/estimators/uap.py +3 -3
  23. dataeval/metrics/stats/__init__.py +1 -1
  24. dataeval/metrics/stats/base.py +24 -8
  25. dataeval/metrics/stats/boxratiostats.py +5 -5
  26. dataeval/metrics/stats/datasetstats.py +39 -6
  27. dataeval/metrics/stats/dimensionstats.py +4 -4
  28. dataeval/metrics/stats/hashstats.py +2 -2
  29. dataeval/metrics/stats/labelstats.py +89 -6
  30. dataeval/metrics/stats/pixelstats.py +7 -5
  31. dataeval/metrics/stats/visualstats.py +6 -4
  32. dataeval/output.py +23 -14
  33. dataeval/utils/__init__.py +2 -2
  34. dataeval/utils/dataset/read.py +1 -1
  35. dataeval/utils/dataset/split.py +1 -1
  36. dataeval/utils/metadata.py +255 -110
  37. dataeval/utils/plot.py +129 -6
  38. dataeval/workflows/sufficiency.py +2 -2
  39. {dataeval-0.75.0.dist-info → dataeval-0.76.1.dist-info}/LICENSE.txt +2 -2
  40. {dataeval-0.75.0.dist-info → dataeval-0.76.1.dist-info}/METADATA +57 -30
  41. dataeval-0.76.1.dist-info/RECORD +67 -0
  42. dataeval-0.75.0.dist-info/RECORD +0 -67
  43. {dataeval-0.75.0.dist-info → dataeval-0.76.1.dist-info}/WHEEL +0 -0
@@ -1,14 +1,15 @@
1
1
  """
2
- Metadata related utility functions that help organize raw metadata into :class:`Metadata` objects
3
- for use within `DataEval`.
2
+ Metadata related utility functions that help organize raw metadata into \
3
+ :class:`Metadata` objects for use within `DataEval`.
4
4
  """
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
- __all__ = ["Metadata", "preprocess", "merge"]
8
+ __all__ = ["Metadata", "preprocess", "merge", "flatten"]
9
9
 
10
10
  import warnings
11
11
  from dataclasses import dataclass
12
+ from enum import Enum
12
13
  from typing import Any, Iterable, Literal, Mapping, TypeVar, overload
13
14
 
14
15
  import numpy as np
@@ -18,9 +19,15 @@ from scipy.stats import wasserstein_distance as wd
18
19
  from dataeval.interop import as_numpy, to_numpy
19
20
  from dataeval.output import Output, set_metadata
20
21
 
21
- TNum = TypeVar("TNum", int, float)
22
22
  DISCRETE_MIN_WD = 0.054
23
23
  CONTINUOUS_MIN_SAMPLE_SIZE = 20
24
+ DEFAULT_IMAGE_INDEX_KEY = "_image_index"
25
+
26
+
27
+ class DropReason(Enum):
28
+ INCONSISTENT_KEY = "inconsistent_key"
29
+ INCONSISTENT_SIZE = "inconsistent_size"
30
+ NESTED_LIST = "nested_list"
24
31
 
25
32
 
26
33
  T = TypeVar("T")
@@ -42,8 +49,8 @@ def _convert_type(data: str) -> int | float | str: ...
42
49
 
43
50
  def _convert_type(data: list[str] | str) -> list[int] | list[float] | list[str] | int | float | str:
44
51
  """
45
- Converts a value or a list of values to the simplest form possible, in preferred order of `int`,
46
- `float`, or `string`.
52
+ Converts a value or a list of values to the simplest form possible,
53
+ in preferred order of `int`, `float`, or `string`.
47
54
 
48
55
  Parameters
49
56
  ----------
@@ -100,8 +107,16 @@ def _get_key_indices(keys: Iterable[tuple[str, ...]]) -> dict[tuple[str, ...], i
100
107
  return indices
101
108
 
102
109
 
110
+ def _sorted_drop_reasons(d: dict[str, set[DropReason]]) -> dict[str, list[str]]:
111
+ return {k: sorted({vv.value for vv in v}) for k, v in sorted(d.items(), key=lambda item: item[1])}
112
+
113
+
103
114
  def _flatten_dict_inner(
104
- d: Mapping[str, Any], parent_keys: tuple[str, ...], size: int | None = None, nested: bool = False
115
+ d: Mapping[str, Any],
116
+ dropped: dict[tuple[str, ...], set[DropReason]],
117
+ parent_keys: tuple[str, ...],
118
+ size: int | None = None,
119
+ nested: bool = False,
105
120
  ) -> tuple[dict[tuple[str, ...], Any], int | None]:
106
121
  """
107
122
  Recursive internal function for flattening a dictionary.
@@ -110,6 +125,8 @@ def _flatten_dict_inner(
110
125
  ----------
111
126
  d : dict[str, Any]
112
127
  Dictionary to flatten
128
+ dropped: set[tuple[str, ...]]
129
+ Reference to set of dropped keys from the dictionary
113
130
  parent_keys : tuple[str, ...]
114
131
  Parent keys to the current dictionary being flattened
115
132
  size : int or None, default None
@@ -120,35 +137,62 @@ def _flatten_dict_inner(
120
137
  Returns
121
138
  -------
122
139
  tuple[dict[tuple[str, ...], Any], int | None]
123
- - [0]: Dictionary of flattened values with the keys reformatted as a hierarchical tuple of strings
140
+ - [0]: Dictionary of flattened values with the keys reformatted as a
141
+ hierarchical tuple of strings
124
142
  - [1]: Size, if any, of the current list of values
125
143
  """
126
144
  items: dict[tuple[str, ...], Any] = {}
127
145
  for k, v in d.items():
128
146
  new_keys: tuple[str, ...] = parent_keys + (k,)
129
147
  if isinstance(v, dict):
130
- fd, size = _flatten_dict_inner(v, new_keys, size=size, nested=nested)
148
+ fd, size = _flatten_dict_inner(v, dropped, new_keys, size=size, nested=nested)
131
149
  items.update(fd)
132
150
  elif isinstance(v, (list, tuple)):
133
- if not nested and (size is None or size == len(v)):
151
+ if nested:
152
+ dropped.setdefault(parent_keys + (k,), set()).add(DropReason.NESTED_LIST)
153
+ elif size is not None and size != len(v):
154
+ dropped.setdefault(parent_keys + (k,), set()).add(DropReason.INCONSISTENT_SIZE)
155
+ else:
134
156
  size = len(v)
135
157
  if all(isinstance(i, dict) for i in v):
136
158
  for sub_dict in v:
137
- fd, size = _flatten_dict_inner(sub_dict, new_keys, size=size, nested=True)
159
+ fd, size = _flatten_dict_inner(sub_dict, dropped, new_keys, size=size, nested=True)
138
160
  for fk, fv in fd.items():
139
161
  items.setdefault(fk, []).append(fv)
140
162
  else:
141
163
  items[new_keys] = v
142
- else:
143
- warnings.warn(f"Dropping nested list found in '{parent_keys + (k, )}'.")
144
164
  else:
145
165
  items[new_keys] = v
146
166
  return items, size
147
167
 
148
168
 
149
- def _flatten_dict(
150
- d: Mapping[str, Any], sep: str, ignore_lists: bool, fully_qualified: bool
151
- ) -> tuple[dict[str, Any], int]:
169
+ @overload
170
+ def flatten(
171
+ d: Mapping[str, Any],
172
+ return_dropped: Literal[True],
173
+ sep: str = "_",
174
+ ignore_lists: bool = False,
175
+ fully_qualified: bool = False,
176
+ ) -> tuple[dict[str, Any], int, dict[str, list[str]]]: ...
177
+
178
+
179
+ @overload
180
+ def flatten(
181
+ d: Mapping[str, Any],
182
+ return_dropped: Literal[False] = False,
183
+ sep: str = "_",
184
+ ignore_lists: bool = False,
185
+ fully_qualified: bool = False,
186
+ ) -> tuple[dict[str, Any], int]: ...
187
+
188
+
189
+ def flatten(
190
+ d: Mapping[str, Any],
191
+ return_dropped: bool = False,
192
+ sep: str = "_",
193
+ ignore_lists: bool = False,
194
+ fully_qualified: bool = False,
195
+ ):
152
196
  """
153
197
  Flattens a dictionary and converts values to numeric values when possible.
154
198
 
@@ -156,33 +200,53 @@ def _flatten_dict(
156
200
  ----------
157
201
  d : dict[str, Any]
158
202
  Dictionary to flatten
159
- sep : str
203
+ return_dropped: bool, default False
204
+ Option to return a dictionary of dropped keys and the reason(s) for dropping
205
+ sep : str, default "_"
160
206
  String separator to use when concatenating key names
161
- ignore_lists : bool
207
+ ignore_lists : bool, default False
162
208
  Option to skip expanding lists within metadata
163
- fully_qualified : bool
164
- Option to return dictionary keys full qualified instead of minimized
209
+ fully_qualified : bool, default False
210
+ Option to return dictionary keys fully qualified instead of reduced
165
211
 
166
212
  Returns
167
213
  -------
168
214
  dict[str, Any]
169
- A flattened dictionary
215
+ Dictionary of flattened values with the keys reformatted as a hierarchical tuple of strings
216
+ int
217
+ Size of the values in the flattened dictionary
218
+ dict[str, list[str]], Optional
219
+ Dictionary containing dropped keys and reason(s) for dropping
170
220
  """
171
- expanded, size = _flatten_dict_inner(d, parent_keys=(), nested=ignore_lists)
221
+ dropped_inner: dict[tuple[str, ...], set[DropReason]] = {}
222
+ expanded, size = _flatten_dict_inner(d, dropped=dropped_inner, parent_keys=(), nested=ignore_lists)
172
223
 
173
224
  output = {}
174
- if fully_qualified:
175
- expanded = {sep.join(k): v for k, v in expanded.items()}
176
- else:
177
- keys = _get_key_indices(expanded)
178
- expanded = {sep.join(k[keys[k] :]): v for k, v in expanded.items()}
179
225
  for k, v in expanded.items():
180
226
  cv = _convert_type(v)
181
- if isinstance(cv, list) and len(cv) == size:
182
- output[k] = cv
227
+ if isinstance(cv, list):
228
+ if len(cv) == size:
229
+ output[k] = cv
230
+ else:
231
+ dropped_inner.setdefault(k, set()).add(DropReason.INCONSISTENT_KEY)
183
232
  elif not isinstance(cv, list):
184
233
  output[k] = cv if not size else [cv] * size
185
- return output, size if size is not None else 1
234
+
235
+ if fully_qualified:
236
+ output = {sep.join(k): v for k, v in output.items()}
237
+ else:
238
+ keys = _get_key_indices(output)
239
+ output = {sep.join(k[keys[k] :]): v for k, v in output.items()}
240
+
241
+ size = size if size is not None else 1
242
+ dropped = {sep.join(k): v for k, v in dropped_inner.items()}
243
+
244
+ if return_dropped:
245
+ return output, size, _sorted_drop_reasons(dropped)
246
+ else:
247
+ if dropped:
248
+ warnings.warn(f"Metadata keys {list(dropped)} were dropped.")
249
+ return output, size
186
250
 
187
251
 
188
252
  def _is_metadata_dict_of_dicts(metadata: Mapping) -> bool:
@@ -200,48 +264,75 @@ def _is_metadata_dict_of_dicts(metadata: Mapping) -> bool:
200
264
  return set(metadata[keys[0]]) == set(metadata[keys[1]])
201
265
 
202
266
 
267
+ @overload
268
+ def merge(
269
+ metadata: Iterable[Mapping[str, Any]],
270
+ return_dropped: Literal[True],
271
+ ignore_lists: bool = False,
272
+ fully_qualified: bool = False,
273
+ return_numpy: bool = False,
274
+ ) -> tuple[dict[str, list[Any]] | dict[str, NDArray[Any]], dict[str, list[str]]]: ...
275
+
276
+
277
+ @overload
278
+ def merge(
279
+ metadata: Iterable[Mapping[str, Any]],
280
+ return_dropped: Literal[False] = False,
281
+ ignore_lists: bool = False,
282
+ fully_qualified: bool = False,
283
+ return_numpy: bool = False,
284
+ ) -> dict[str, list[Any]] | dict[str, NDArray[Any]]: ...
285
+
286
+
203
287
  def merge(
204
288
  metadata: Iterable[Mapping[str, Any]],
289
+ return_dropped: bool = False,
205
290
  ignore_lists: bool = False,
206
291
  fully_qualified: bool = False,
207
- as_numpy: bool = False,
208
- ) -> tuple[dict[str, list[Any]] | dict[str, NDArray[Any]], NDArray[np.int_]]:
292
+ return_numpy: bool = False,
293
+ ):
209
294
  """
210
- Merges a collection of metadata dictionaries into a single flattened dictionary of keys and values.
295
+ Merges a collection of metadata dictionaries into a single flattened
296
+ dictionary of keys and values.
211
297
 
212
- Nested dictionaries are flattened, and lists are expanded. Nested lists are dropped as the
213
- expanding into multiple hierarchical trees is not supported.
298
+ Nested dictionaries are flattened, and lists are expanded. Nested lists are
299
+ dropped as the expanding into multiple hierarchical trees is not supported.
300
+ The function adds an internal "_image_index" key to the metadata dictionary
301
+ for consumption by the preprocess function.
214
302
 
215
303
  Parameters
216
304
  ----------
217
305
  metadata : Iterable[Mapping[str, Any]]
218
306
  Iterable collection of metadata dictionaries to flatten and merge
307
+ return_dropped: bool, default False
308
+ Option to return a dictionary of dropped keys and the reason(s) for dropping
219
309
  ignore_lists : bool, default False
220
310
  Option to skip expanding lists within metadata
221
311
  fully_qualified : bool, default False
222
312
  Option to return dictionary keys full qualified instead of minimized
223
- as_numpy : bool, default False
313
+ return_numpy : bool, default False
224
314
  Option to return results as lists or NumPy arrays
225
315
 
226
316
  Returns
227
317
  -------
228
- dict[str, list[Any]] or dict[str, NDArray[Any]]
318
+ dict[str, list[Any]] | dict[str, NDArray[Any]]
229
319
  A single dictionary containing the flattened data as lists or NumPy arrays
230
- NDArray[np.int_]
231
- Array defining where individual images start, helpful when working with object detection metadata
320
+ dict[str, list[str]], Optional
321
+ Dictionary containing dropped keys and reason(s) for dropping
232
322
 
233
323
  Note
234
324
  ----
235
- Nested lists of values and inconsistent keys are dropped in the merged metadata dictionary
325
+ Nested lists of values and inconsistent keys are dropped in the merged
326
+ metadata dictionary
236
327
 
237
328
  Example
238
329
  -------
239
330
  >>> list_metadata = [{"common": 1, "target": [{"a": 1, "b": 3, "c": 5}, {"a": 2, "b": 4}], "source": "example"}]
240
- >>> reorganized_metadata, image_indicies = merge(list_metadata)
331
+ >>> reorganized_metadata, dropped_keys = merge(list_metadata, return_dropped=True)
241
332
  >>> reorganized_metadata
242
- {'common': [1, 1], 'a': [1, 2], 'b': [3, 4], 'source': ['example', 'example']}
243
- >>> image_indicies
244
- array([0])
333
+ {'common': [1, 1], 'a': [1, 2], 'b': [3, 4], 'source': ['example', 'example'], '_image_index': [0, 0]}
334
+ >>> dropped_keys
335
+ {'target_c': ['inconsistent_key']}
245
336
  """
246
337
  merged: dict[str, list[Any]] = {}
247
338
  isect: set[str] = set()
@@ -258,56 +349,71 @@ def merge(
258
349
  else:
259
350
  dicts = list(metadata)
260
351
 
261
- image_repeats = np.zeros(len(dicts))
352
+ image_repeats = np.zeros(len(dicts), dtype=np.int_)
353
+ dropped: dict[str, set[DropReason]] = {}
262
354
  for i, d in enumerate(dicts):
263
- flattened, image_repeats[i] = _flatten_dict(
264
- d, sep="_", ignore_lists=ignore_lists, fully_qualified=fully_qualified
355
+ flattened, image_repeats[i], dropped_inner = flatten(
356
+ d,
357
+ return_dropped=True,
358
+ ignore_lists=ignore_lists,
359
+ fully_qualified=fully_qualified,
265
360
  )
266
361
  isect = isect.intersection(flattened.keys()) if isect else set(flattened.keys())
267
- union = union.union(flattened.keys())
362
+ union.update(flattened.keys())
363
+ for k, v in dropped_inner.items():
364
+ dropped.setdefault(k, set()).update({DropReason(vv) for vv in v})
268
365
  for k, v in flattened.items():
269
366
  merged.setdefault(k, []).extend(flattened[k]) if isinstance(v, list) else merged.setdefault(k, []).append(v)
270
367
 
271
- if len(union) > len(isect):
272
- warnings.warn(f"Inconsistent metadata keys found. Dropping {union - isect} from metadata.")
273
-
274
- output: dict[str, Any] = {}
368
+ for k in union - isect:
369
+ dropped.setdefault(k, set()).add(DropReason.INCONSISTENT_KEY)
275
370
 
276
371
  if image_repeats.sum() == image_repeats.size:
277
- image_indicies = np.arange(image_repeats.size)
372
+ image_indices = np.arange(image_repeats.size)
278
373
  else:
279
374
  image_ids = np.arange(image_repeats.size)
280
375
  image_data = np.concatenate(
281
376
  [np.repeat(image_ids[i], image_repeats[i]) for i in range(image_ids.size)], dtype=np.int_
282
377
  )
283
- _, image_unsorted = np.unique(image_data, return_index=True)
284
- image_indicies = np.sort(image_unsorted)
378
+ _, image_unsorted = np.unique(image_data, return_inverse=True)
379
+ image_indices = np.sort(image_unsorted)
380
+
381
+ output: dict[str, Any] = {}
285
382
 
286
383
  if keys:
287
- output["keys"] = np.array(keys) if as_numpy else keys
384
+ output["keys"] = np.array(keys) if return_numpy else keys
288
385
 
289
386
  for k in (key for key in merged if key in isect):
290
387
  cv = _convert_type(merged[k])
291
- output[k] = np.array(cv) if as_numpy else cv
388
+ output[k] = np.array(cv) if return_numpy else cv
389
+ output[DEFAULT_IMAGE_INDEX_KEY] = np.array(image_indices) if return_numpy else list(image_indices)
292
390
 
293
- return output, image_indicies
391
+ if return_dropped:
392
+ return output, _sorted_drop_reasons(dropped)
393
+ else:
394
+ if dropped:
395
+ warnings.warn(f"Metadata keys {list(dropped)} were dropped.")
396
+ return output
294
397
 
295
398
 
296
399
  @dataclass(frozen=True)
297
400
  class Metadata(Output):
298
401
  """
299
- Dataclass containing binned metadata from the :func:`preprocess` function
402
+ Dataclass containing binned metadata from the :func:`preprocess` function.
300
403
 
301
404
  Attributes
302
405
  ----------
303
406
  discrete_factor_names : list[str]
304
- List containing factor names for the original data that was discrete and the binned continuous data
407
+ List containing factor names for the original data that was discrete and
408
+ the binned continuous data
305
409
  discrete_data : NDArray[np.int]
306
- Array containing values for the original data that was discrete and the binned continuous data
410
+ Array containing values for the original data that was discrete and the
411
+ binned continuous data
307
412
  continuous_factor_names : list[str]
308
413
  List containing factor names for the original continuous data
309
414
  continuous_data : NDArray[np.int or np.double] | None
310
- Array containing values for the original continuous data or None if there was no continuous data
415
+ Array containing values for the original continuous data or None if there
416
+ was no continuous data
311
417
  class_labels : NDArray[np.int]
312
418
  Numerical class labels for the images/objects
313
419
  class_names : NDArray[Any]
@@ -327,11 +433,12 @@ class Metadata(Output):
327
433
 
328
434
  @set_metadata
329
435
  def preprocess(
330
- raw_metadata: Iterable[Mapping[str, Any]],
436
+ metadata: dict[str, list[Any]] | dict[str, NDArray[Any]],
331
437
  class_labels: ArrayLike | str,
332
- continuous_factor_bins: Mapping[str, int | list[tuple[TNum, TNum]]] | None = None,
438
+ continuous_factor_bins: Mapping[str, int | Iterable[float]] | None = None,
333
439
  auto_bin_method: Literal["uniform_width", "uniform_count", "clusters"] = "uniform_width",
334
440
  exclude: Iterable[str] | None = None,
441
+ image_index_key: str = "_image_index",
335
442
  ) -> Metadata:
336
443
  """
337
444
  Restructures the metadata to be in the correct format for the bias functions.
@@ -343,46 +450,84 @@ def preprocess(
343
450
 
344
451
  Parameters
345
452
  ----------
346
- raw_metadata : Iterable[Mapping[str, Any]]
347
- Iterable collection of metadata dictionaries to flatten and merge.
453
+ metadata : dict[str, list[Any] | NDArray[Any]]
454
+ A flat dictionary which contains all of the metadata on a per image (classification)
455
+ or per object (object detection) basis. Length of lists/array should match the length
456
+ of the label list/array.
348
457
  class_labels : ArrayLike or string
349
- If arraylike, expects the labels for each image (image classification) or each object (object detection).
350
- If the labels are included in the metadata dictionary, pass in the key value.
351
- continuous_factor_bins : Mapping[str, int] or Mapping[str, list[tuple[TNum, TNum]]] or None, default None
352
- User provided dictionary specifying how to bin the continuous metadata factors
458
+ If arraylike, expects the labels for each image (image classification)
459
+ or each object (object detection). If the labels are included in the
460
+ metadata dictionary, pass in the key value.
461
+ continuous_factor_bins : Mapping[str, int or Iterable[float]] or None, default None
462
+ User provided dictionary specifying how to bin the continuous metadata
463
+ factors where the value is either an int to represent the number of bins,
464
+ or a list of floats representing the edges for each bin.
353
465
  auto_bin_method : "uniform_width" or "uniform_count" or "clusters", default "uniform_width"
354
- Method by which the function will automatically bin continuous metadata factors. It is recommended
355
- that the user provide the bins through the `continuous_factor_bins`.
466
+ Method by which the function will automatically bin continuous metadata factors.
467
+ It is recommended that the user provide the bins through the `continuous_factor_bins`.
356
468
  exclude : Iterable[str] or None, default None
357
469
  User provided collection of metadata keys to exclude when processing metadata.
470
+ image_index_key : str, default "_image_index"
471
+ User provided metadata key which maps the metadata entry to the source image.
358
472
 
359
473
  Returns
360
474
  -------
361
475
  Metadata
362
476
  Output class containing the binned metadata
477
+
478
+ See Also
479
+ --------
480
+ merge
363
481
  """
364
- # Transform metadata into single, flattened dictionary
365
- metadata, image_repeats = merge(raw_metadata)
482
+ # Check that metadata is a single, flattened dictionary with uniform array lengths
483
+ check_length = -1
484
+ for k, v in metadata.items():
485
+ if not isinstance(v, (list, tuple, np.ndarray)):
486
+ raise TypeError(
487
+ "Metadata dictionary needs to be a single dictionary whose values "
488
+ "are arraylike containing the metadata on a per image or per object basis."
489
+ )
490
+ else:
491
+ if check_length == -1:
492
+ check_length = len(v)
493
+ else:
494
+ if check_length != len(v):
495
+ raise ValueError(
496
+ "The lists/arrays in the metadata dict have varying lengths. "
497
+ "Preprocess needs them to be uniform in length."
498
+ )
499
+
500
+ # Grab continuous factors if supplied
501
+ continuous_factor_bins = dict(continuous_factor_bins) if continuous_factor_bins else None
366
502
 
367
503
  # Drop any excluded metadata keys
368
- if exclude:
369
- for k in list(metadata):
370
- if k in exclude:
371
- metadata.pop(k)
504
+ for k in exclude or ():
505
+ metadata.pop(k, None)
506
+ if continuous_factor_bins:
507
+ continuous_factor_bins.pop(k, None)
372
508
 
373
- # Get the class label array in numeric form
509
+ # Get the class label array in numeric form and check its dimensions
374
510
  class_array = as_numpy(metadata.pop(class_labels)) if isinstance(class_labels, str) else as_numpy(class_labels)
375
511
  if class_array.ndim > 1:
376
512
  raise ValueError(
377
513
  f"Got class labels with {class_array.ndim}-dimensional "
378
514
  f"shape {class_array.shape}, but expected a 1-dimensional array."
379
515
  )
516
+ # Check if the label array is the same length as the metadata arrays
517
+ elif len(class_array) != check_length:
518
+ raise ValueError(
519
+ f"The length of the label array {len(class_array)} is not the same as "
520
+ f"the length of the metadata arrays {check_length}."
521
+ )
380
522
  if not np.issubdtype(class_array.dtype, np.int_):
381
523
  unique_classes, numerical_labels = np.unique(class_array, return_inverse=True)
382
524
  else:
383
525
  numerical_labels = class_array
384
526
  unique_classes = np.unique(class_array)
385
527
 
528
+ # Determine if _image_index is given
529
+ image_indices = as_numpy(metadata[image_index_key]) if image_index_key in metadata else np.arange(check_length)
530
+
386
531
  # Bin according to user supplied bins
387
532
  continuous_metadata = {}
388
533
  discrete_metadata = {}
@@ -394,8 +539,8 @@ def preprocess(
394
539
  "but are not keys in the `metadata` dictionary. Delete these keys from `continuous_factor_bins` "
395
540
  "or add corresponding entries to the `metadata` dictionary."
396
541
  )
397
- for factor, grouping in continuous_factor_bins.items():
398
- discrete_metadata[factor] = _user_defined_bin(metadata[factor], grouping)
542
+ for factor, bins in continuous_factor_bins.items():
543
+ discrete_metadata[factor] = _digitize_data(metadata[factor], bins)
399
544
  continuous_metadata[factor] = metadata[factor]
400
545
 
401
546
  # Determine category of the rest of the keys
@@ -403,7 +548,7 @@ def preprocess(
403
548
  for key in remaining_keys:
404
549
  data = to_numpy(metadata[key])
405
550
  if np.issubdtype(data.dtype, np.number):
406
- result = _is_continuous(data, image_repeats)
551
+ result = _is_continuous(data, image_indices)
407
552
  if result:
408
553
  continuous_metadata[key] = data
409
554
  unique_samples, ordinal_data = np.unique(data, return_inverse=True)
@@ -417,11 +562,11 @@ def preprocess(
417
562
  "bins using the continuous_factor_bins parameter.",
418
563
  UserWarning,
419
564
  )
420
- discrete_metadata[key] = _binning_function(data, auto_bin_method)
565
+ discrete_metadata[key] = _bin_data(data, auto_bin_method)
421
566
  else:
422
567
  _, discrete_metadata[key] = np.unique(data, return_inverse=True)
423
568
 
424
- # splitting out the dictionaries into the keys and values
569
+ # Split out the dictionaries into the keys and values
425
570
  discrete_factor_names = list(discrete_metadata.keys())
426
571
  discrete_data = np.stack(list(discrete_metadata.values()), axis=-1)
427
572
  continuous_factor_names = list(continuous_metadata.keys())
@@ -439,7 +584,7 @@ def preprocess(
439
584
  )
440
585
 
441
586
 
442
- def _user_defined_bin(data: list[Any] | NDArray[Any], binning: int | list[tuple[TNum, TNum]]) -> NDArray[np.intp]:
587
+ def _digitize_data(data: list[Any] | NDArray[Any], bins: int | Iterable[float]) -> NDArray[np.intp]:
443
588
  """
444
589
  Digitizes a list of values into a given number of bins.
445
590
 
@@ -447,8 +592,8 @@ def _user_defined_bin(data: list[Any] | NDArray[Any], binning: int | list[tuple[
447
592
  ----------
448
593
  data : list | NDArray
449
594
  The values to be digitized.
450
- binning : int | list[tuple[TNum, TNum]]
451
- The number of bins for the discrete values that data will be digitized into.
595
+ bins : int | Iterable[float]
596
+ The number of bins or list of bin edges for the discrete values that data will be digitized into.
452
597
 
453
598
  Returns
454
599
  -------
@@ -461,16 +606,16 @@ def _user_defined_bin(data: list[Any] | NDArray[Any], binning: int | list[tuple[
461
606
  "Encountered a data value with non-numeric type when digitizing a factor. "
462
607
  "Ensure all occurrences of continuous factors are numeric types."
463
608
  )
464
- if type(binning) is int:
465
- _, bin_edges = np.histogram(data, bins=binning)
609
+ if isinstance(bins, int):
610
+ _, bin_edges = np.histogram(data, bins=bins)
466
611
  bin_edges[-1] = np.inf
467
612
  bin_edges[0] = -np.inf
468
613
  else:
469
- bin_edges = binning
614
+ bin_edges = list(bins)
470
615
  return np.digitize(data, bin_edges)
471
616
 
472
617
 
473
- def _binning_function(data: NDArray[Any], bin_method: str) -> NDArray[np.int_]:
618
+ def _bin_data(data: NDArray[Any], bin_method: str) -> NDArray[np.int_]:
474
619
  """
475
620
  Bins continuous data through either equal width bins, equal amounts in each bin, or by clusters.
476
621
  """
@@ -482,26 +627,26 @@ def _binning_function(data: NDArray[Any], bin_method: str) -> NDArray[np.int_]:
482
627
  )
483
628
  bin_method = "uniform_width"
484
629
 
485
- if bin_method != "clusters":
486
- counts, bin_edges = np.histogram(data, bins="auto")
487
- n_bins = counts.size
488
- if counts[counts > 0].min() < 10:
489
- for _ in range(20):
490
- n_bins -= 1
491
- counts, bin_edges = np.histogram(data, bins=n_bins)
492
- if counts[counts > 0].min() >= 10 or n_bins < 2:
493
- break
630
+ # if bin_method != "clusters": # restore this when clusters bin_method is available
631
+ counts, bin_edges = np.histogram(data, bins="auto")
632
+ n_bins = counts.size
633
+ if counts[counts > 0].min() < 10:
634
+ counter = 20
635
+ while counts[counts > 0].min() < 10 and n_bins >= 2 and counter > 0:
636
+ counter -= 1
637
+ n_bins -= 1
638
+ counts, bin_edges = np.histogram(data, bins=n_bins)
494
639
 
495
- if bin_method == "uniform_count":
496
- quantiles = np.linspace(0, 100, n_bins + 1)
497
- bin_edges = np.asarray(np.percentile(data, quantiles))
640
+ if bin_method == "uniform_count":
641
+ quantiles = np.linspace(0, 100, n_bins + 1)
642
+ bin_edges = np.asarray(np.percentile(data, quantiles))
498
643
 
499
644
  bin_edges[0] = -np.inf # type: ignore # until the clusters speed up is merged
500
645
  bin_edges[-1] = np.inf # type: ignore # and the _binning_by_clusters can be uncommented
501
646
  return np.digitize(data, bin_edges) # type: ignore
502
647
 
503
648
 
504
- def _is_continuous(data: NDArray[np.number], image_indicies: NDArray[np.number]) -> bool:
649
+ def _is_continuous(data: NDArray[np.number], image_indices: NDArray[np.number]) -> bool:
505
650
  """
506
651
  Determines whether the data is continuous or discrete using the Wasserstein distance.
507
652
 
@@ -520,11 +665,11 @@ def _is_continuous(data: NDArray[np.number], image_indicies: NDArray[np.number])
520
665
  measured from a uniform distribution is greater or less than 0.054, respectively.
521
666
  """
522
667
  # Check if the metadata is image specific
523
- _, data_indicies_unsorted = np.unique(data, return_index=True)
524
- if data_indicies_unsorted.size == image_indicies.size:
525
- data_indicies = np.sort(data_indicies_unsorted)
526
- if (data_indicies == image_indicies).all():
527
- data = data[data_indicies]
668
+ _, data_indices_unsorted = np.unique(data, return_index=True)
669
+ if data_indices_unsorted.size == image_indices.size:
670
+ data_indices = np.sort(data_indices_unsorted)
671
+ if (data_indices == image_indices).all():
672
+ data = data[data_indices]
528
673
 
529
674
  # OLD METHOD
530
675
  # uvals = np.unique(data)
@@ -572,7 +717,7 @@ def get_counts(data: NDArray[np.int_], min_num_bins: int | None = None) -> NDArr
572
717
 
573
718
  Returns
574
719
  -------
575
- NDArray[np.int_]
720
+ NDArray[np.int]
576
721
  Bin counts per column of data.
577
722
  """
578
723
  max_value = data.max() + 1 if min_num_bins is None else min_num_bins