dataeval 0.76.1__py3-none-any.whl → 0.81.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/{output.py → _output.py} +14 -0
  3. dataeval/config.py +77 -0
  4. dataeval/detectors/__init__.py +1 -1
  5. dataeval/detectors/drift/__init__.py +6 -6
  6. dataeval/detectors/drift/{base.py → _base.py} +41 -30
  7. dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
  8. dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
  9. dataeval/detectors/drift/{mmd.py → _mmd.py} +33 -19
  10. dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
  11. dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +23 -7
  12. dataeval/detectors/drift/updates.py +1 -1
  13. dataeval/detectors/linters/__init__.py +0 -3
  14. dataeval/detectors/linters/duplicates.py +17 -8
  15. dataeval/detectors/linters/outliers.py +23 -14
  16. dataeval/detectors/ood/ae.py +29 -8
  17. dataeval/detectors/ood/base.py +5 -4
  18. dataeval/detectors/ood/metadata_ks_compare.py +1 -1
  19. dataeval/detectors/ood/mixin.py +20 -5
  20. dataeval/detectors/ood/output.py +1 -1
  21. dataeval/detectors/ood/vae.py +73 -0
  22. dataeval/metadata/__init__.py +5 -0
  23. dataeval/metadata/_ood.py +238 -0
  24. dataeval/metrics/__init__.py +1 -1
  25. dataeval/metrics/bias/__init__.py +5 -4
  26. dataeval/metrics/bias/{balance.py → _balance.py} +67 -17
  27. dataeval/metrics/bias/{coverage.py → _coverage.py} +41 -35
  28. dataeval/metrics/bias/{diversity.py → _diversity.py} +17 -12
  29. dataeval/metrics/bias/{parity.py → _parity.py} +89 -61
  30. dataeval/metrics/estimators/__init__.py +14 -4
  31. dataeval/metrics/estimators/{ber.py → _ber.py} +42 -11
  32. dataeval/metrics/estimators/_clusterer.py +104 -0
  33. dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -13
  34. dataeval/metrics/estimators/{uap.py → _uap.py} +4 -4
  35. dataeval/metrics/stats/__init__.py +7 -7
  36. dataeval/metrics/stats/{base.py → _base.py} +52 -16
  37. dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +6 -9
  38. dataeval/metrics/stats/{datasetstats.py → _datasetstats.py} +10 -14
  39. dataeval/metrics/stats/{dimensionstats.py → _dimensionstats.py} +6 -5
  40. dataeval/metrics/stats/{hashstats.py → _hashstats.py} +6 -6
  41. dataeval/metrics/stats/{labelstats.py → _labelstats.py} +4 -4
  42. dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +5 -4
  43. dataeval/metrics/stats/{visualstats.py → _visualstats.py} +9 -8
  44. dataeval/typing.py +54 -0
  45. dataeval/utils/__init__.py +2 -2
  46. dataeval/utils/_array.py +169 -0
  47. dataeval/utils/_bin.py +199 -0
  48. dataeval/utils/_clusterer.py +144 -0
  49. dataeval/utils/_fast_mst.py +189 -0
  50. dataeval/utils/{image.py → _image.py} +6 -4
  51. dataeval/utils/_method.py +18 -0
  52. dataeval/utils/{shared.py → _mst.py} +3 -65
  53. dataeval/utils/{plot.py → _plot.py} +4 -4
  54. dataeval/utils/data/__init__.py +22 -0
  55. dataeval/utils/data/_embeddings.py +105 -0
  56. dataeval/utils/data/_images.py +65 -0
  57. dataeval/utils/data/_metadata.py +352 -0
  58. dataeval/utils/data/_selection.py +119 -0
  59. dataeval/utils/{dataset/split.py → data/_split.py} +13 -14
  60. dataeval/utils/data/_targets.py +73 -0
  61. dataeval/utils/data/_types.py +58 -0
  62. dataeval/utils/data/collate.py +103 -0
  63. dataeval/utils/data/datasets/__init__.py +17 -0
  64. dataeval/utils/data/datasets/_base.py +254 -0
  65. dataeval/utils/data/datasets/_cifar10.py +134 -0
  66. dataeval/utils/data/datasets/_fileio.py +168 -0
  67. dataeval/utils/data/datasets/_milco.py +153 -0
  68. dataeval/utils/data/datasets/_mixin.py +56 -0
  69. dataeval/utils/data/datasets/_mnist.py +183 -0
  70. dataeval/utils/data/datasets/_ships.py +123 -0
  71. dataeval/utils/data/datasets/_voc.py +352 -0
  72. dataeval/utils/data/selections/__init__.py +15 -0
  73. dataeval/utils/data/selections/_classfilter.py +60 -0
  74. dataeval/utils/data/selections/_indices.py +26 -0
  75. dataeval/utils/data/selections/_limit.py +26 -0
  76. dataeval/utils/data/selections/_reverse.py +18 -0
  77. dataeval/utils/data/selections/_shuffle.py +29 -0
  78. dataeval/utils/metadata.py +51 -376
  79. dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
  80. dataeval/utils/torch/{internal.py → _internal.py} +21 -51
  81. dataeval/utils/torch/models.py +43 -2
  82. dataeval/workflows/sufficiency.py +10 -9
  83. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/METADATA +4 -1
  84. dataeval-0.81.0.dist-info/RECORD +94 -0
  85. dataeval/detectors/linters/clusterer.py +0 -512
  86. dataeval/detectors/linters/merged_stats.py +0 -49
  87. dataeval/detectors/ood/metadata_least_likely.py +0 -119
  88. dataeval/interop.py +0 -69
  89. dataeval/utils/dataset/__init__.py +0 -7
  90. dataeval/utils/dataset/datasets.py +0 -412
  91. dataeval/utils/dataset/read.py +0 -63
  92. dataeval-0.76.1.dist-info/RECORD +0 -67
  93. /dataeval/{log.py → _log.py} +0 -0
  94. /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
  95. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/LICENSE.txt +0 -0
  96. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/WHEEL +0 -0
@@ -1,27 +1,19 @@
1
1
  """
2
- Metadata related utility functions that help organize raw metadata into \
3
- :class:`Metadata` objects for use within `DataEval`.
2
+ Utility functions that help organize raw metadata.
4
3
  """
5
4
 
6
5
  from __future__ import annotations
7
6
 
8
- __all__ = ["Metadata", "preprocess", "merge", "flatten"]
7
+ __all__ = ["merge", "flatten"]
9
8
 
10
9
  import warnings
11
- from dataclasses import dataclass
12
10
  from enum import Enum
13
- from typing import Any, Iterable, Literal, Mapping, TypeVar, overload
11
+ from typing import Any, Iterable, Literal, Mapping, Sequence, overload
14
12
 
15
13
  import numpy as np
16
- from numpy.typing import ArrayLike, NDArray
17
- from scipy.stats import wasserstein_distance as wd
14
+ from numpy.typing import NDArray
18
15
 
19
- from dataeval.interop import as_numpy, to_numpy
20
- from dataeval.output import Output, set_metadata
21
-
22
- DISCRETE_MIN_WD = 0.054
23
- CONTINUOUS_MIN_SAMPLE_SIZE = 20
24
- DEFAULT_IMAGE_INDEX_KEY = "_image_index"
16
+ _TYPE_MAP = {int: 0, float: 1, str: 2}
25
17
 
26
18
 
27
19
  class DropReason(Enum):
@@ -30,26 +22,15 @@ class DropReason(Enum):
30
22
  NESTED_LIST = "nested_list"
31
23
 
32
24
 
33
- T = TypeVar("T")
34
-
35
-
36
- def _try_cast(v: Any, t: type[T]) -> T | None:
37
- """Casts a value to a type or returns None if unable"""
38
- try:
39
- return t(v) # type: ignore
40
- except (TypeError, ValueError):
41
- return None
42
-
43
-
44
25
  @overload
45
- def _convert_type(data: list[str]) -> list[int] | list[float] | list[str]: ...
26
+ def _simplify_type(data: list[str]) -> list[int] | list[float] | list[str]: ...
46
27
  @overload
47
- def _convert_type(data: str) -> int | float | str: ...
28
+ def _simplify_type(data: str) -> int | float | str: ...
48
29
 
49
30
 
50
- def _convert_type(data: list[str] | str) -> list[int] | list[float] | list[str] | int | float | str:
31
+ def _simplify_type(data: list[str] | str) -> list[int] | list[float] | list[str] | int | float | str:
51
32
  """
52
- Converts a value or a list of values to the simplest form possible,
33
+ Simplifies a value or a list of values to the simplest form possible,
53
34
  in preferred order of `int`, `float`, or `string`.
54
35
 
55
36
  Parameters
@@ -63,18 +44,20 @@ def _convert_type(data: list[str] | str) -> list[int] | list[float] | list[str]
63
44
  The same values converted to the numerical type if possible
64
45
  """
65
46
  if not isinstance(data, list):
66
- value = _try_cast(data, float)
47
+ try:
48
+ value = float(data)
49
+ except (TypeError, ValueError):
50
+ value = None
67
51
  return str(data) if value is None else int(value) if value.is_integer() else value
68
52
 
69
53
  converted = []
70
- TYPE_MAP = {int: 0, float: 1, str: 2}
71
54
  max_type = 0
72
55
  for value in data:
73
- value = _convert_type(value)
74
- max_type = max(max_type, TYPE_MAP.get(type(value), 2))
56
+ value = _simplify_type(value)
57
+ max_type = max(max_type, _TYPE_MAP.get(type(value), 2))
75
58
  converted.append(value)
76
59
  for i in range(len(converted)):
77
- converted[i] = list(TYPE_MAP)[max_type](converted[i])
60
+ converted[i] = list(_TYPE_MAP)[max_type](converted[i])
78
61
  return converted
79
62
 
80
63
 
@@ -92,7 +75,7 @@ def _get_key_indices(keys: Iterable[tuple[str, ...]]) -> dict[tuple[str, ...], i
92
75
  dict[tuple[str, ...], int]
93
76
  Mapping of tuple keys to starting index
94
77
  """
95
- indices = {k: -1 for k in keys}
78
+ indices = dict.fromkeys(keys, -1)
96
79
  ks = list(keys)
97
80
  while len(ks) > 0:
98
81
  seen: dict[tuple[str, ...], list[tuple[str, ...]]] = {}
@@ -144,6 +127,8 @@ def _flatten_dict_inner(
144
127
  items: dict[tuple[str, ...], Any] = {}
145
128
  for k, v in d.items():
146
129
  new_keys: tuple[str, ...] = parent_keys + (k,)
130
+ if isinstance(v, np.ndarray):
131
+ v = v.tolist()
147
132
  if isinstance(v, dict):
148
133
  fd, size = _flatten_dict_inner(v, dropped, new_keys, size=size, nested=nested)
149
134
  items.update(fd)
@@ -223,7 +208,7 @@ def flatten(
223
208
 
224
209
  output = {}
225
210
  for k, v in expanded.items():
226
- cv = _convert_type(v)
211
+ cv = _simplify_type(v)
227
212
  if isinstance(cv, list):
228
213
  if len(cv) == size:
229
214
  output[k] = cv
@@ -245,7 +230,8 @@ def flatten(
245
230
  return output, size, _sorted_drop_reasons(dropped)
246
231
  else:
247
232
  if dropped:
248
- warnings.warn(f"Metadata keys {list(dropped)} were dropped.")
233
+ dropped_items = "\n".join([f" {k}: {v}" for k, v in _sorted_drop_reasons(dropped).items()])
234
+ warnings.warn(f"Metadata entries were dropped:\n{dropped_items}")
249
235
  return output, size
250
236
 
251
237
 
@@ -271,6 +257,8 @@ def merge(
271
257
  ignore_lists: bool = False,
272
258
  fully_qualified: bool = False,
273
259
  return_numpy: bool = False,
260
+ targets_per_image: Sequence[int] | None = None,
261
+ image_index_key: str = "_image_index",
274
262
  ) -> tuple[dict[str, list[Any]] | dict[str, NDArray[Any]], dict[str, list[str]]]: ...
275
263
 
276
264
 
@@ -281,6 +269,8 @@ def merge(
281
269
  ignore_lists: bool = False,
282
270
  fully_qualified: bool = False,
283
271
  return_numpy: bool = False,
272
+ targets_per_image: Sequence[int] | None = None,
273
+ image_index_key: str = "_image_index",
284
274
  ) -> dict[str, list[Any]] | dict[str, NDArray[Any]]: ...
285
275
 
286
276
 
@@ -290,6 +280,8 @@ def merge(
290
280
  ignore_lists: bool = False,
291
281
  fully_qualified: bool = False,
292
282
  return_numpy: bool = False,
283
+ targets_per_image: Sequence[int] | None = None,
284
+ image_index_key: str = "_image_index",
293
285
  ):
294
286
  """
295
287
  Merges a collection of metadata dictionaries into a single flattened
@@ -298,7 +290,7 @@ def merge(
298
290
  Nested dictionaries are flattened, and lists are expanded. Nested lists are
299
291
  dropped as the expanding into multiple hierarchical trees is not supported.
300
292
  The function adds an internal "_image_index" key to the metadata dictionary
301
- for consumption by the preprocess function.
293
+ used by the `Metadata` class.
302
294
 
303
295
  Parameters
304
296
  ----------
@@ -312,6 +304,10 @@ def merge(
312
304
  Option to return dictionary keys full qualified instead of minimized
313
305
  return_numpy : bool, default False
314
306
  Option to return results as lists or NumPy arrays
307
+ targets_per_image : Sequence[int] or None, default None
308
+ Number of targets for each image metadata entry
309
+ image_index_key : str, default "_image_index"
310
+ User provided metadata key which maps the metadata entry to the source image.
315
311
 
316
312
  Returns
317
313
  -------
@@ -349,15 +345,24 @@ def merge(
349
345
  else:
350
346
  dicts = list(metadata)
351
347
 
348
+ if targets_per_image is not None and len(dicts) != len(targets_per_image):
349
+ raise ValueError("Number of targets per image must be equal to number of metadata entries.")
350
+
352
351
  image_repeats = np.zeros(len(dicts), dtype=np.int_)
353
352
  dropped: dict[str, set[DropReason]] = {}
354
353
  for i, d in enumerate(dicts):
355
354
  flattened, image_repeats[i], dropped_inner = flatten(
356
- d,
357
- return_dropped=True,
358
- ignore_lists=ignore_lists,
359
- fully_qualified=fully_qualified,
355
+ d, return_dropped=True, ignore_lists=ignore_lists, fully_qualified=fully_qualified
360
356
  )
357
+ if targets_per_image is not None:
358
+ # check for mismatch in targets per image and force ignore_lists
359
+ if not ignore_lists and targets_per_image[i] != image_repeats[i]:
360
+ flattened, image_repeats[i], dropped_inner = flatten(
361
+ d, return_dropped=True, ignore_lists=True, fully_qualified=fully_qualified
362
+ )
363
+ if targets_per_image[i] != image_repeats[i]:
364
+ flattened = {k: [v] * targets_per_image[i] for k, v in flattened.items()}
365
+ image_repeats[i] = targets_per_image[i]
361
366
  isect = isect.intersection(flattened.keys()) if isect else set(flattened.keys())
362
367
  union.update(flattened.keys())
363
368
  for k, v in dropped_inner.items():
@@ -384,345 +389,15 @@ def merge(
384
389
  output["keys"] = np.array(keys) if return_numpy else keys
385
390
 
386
391
  for k in (key for key in merged if key in isect):
387
- cv = _convert_type(merged[k])
392
+ cv = _simplify_type(merged[k])
388
393
  output[k] = np.array(cv) if return_numpy else cv
389
- output[DEFAULT_IMAGE_INDEX_KEY] = np.array(image_indices) if return_numpy else list(image_indices)
394
+ if image_index_key not in output:
395
+ output[image_index_key] = image_indices if return_numpy else image_indices.tolist()
390
396
 
391
397
  if return_dropped:
392
398
  return output, _sorted_drop_reasons(dropped)
393
399
  else:
394
400
  if dropped:
395
- warnings.warn(f"Metadata keys {list(dropped)} were dropped.")
401
+ dropped_items = "\n".join([f" {k}: {v}" for k, v in _sorted_drop_reasons(dropped).items()])
402
+ warnings.warn(f"Metadata entries were dropped:\n{dropped_items}")
396
403
  return output
397
-
398
-
399
- @dataclass(frozen=True)
400
- class Metadata(Output):
401
- """
402
- Dataclass containing binned metadata from the :func:`preprocess` function.
403
-
404
- Attributes
405
- ----------
406
- discrete_factor_names : list[str]
407
- List containing factor names for the original data that was discrete and
408
- the binned continuous data
409
- discrete_data : NDArray[np.int]
410
- Array containing values for the original data that was discrete and the
411
- binned continuous data
412
- continuous_factor_names : list[str]
413
- List containing factor names for the original continuous data
414
- continuous_data : NDArray[np.int or np.double] | None
415
- Array containing values for the original continuous data or None if there
416
- was no continuous data
417
- class_labels : NDArray[np.int]
418
- Numerical class labels for the images/objects
419
- class_names : NDArray[Any]
420
- Array of unique class names (for use with plotting)
421
- total_num_factors : int
422
- Sum of discrete_factor_names and continuous_factor_names plus 1 for class
423
- """
424
-
425
- discrete_factor_names: list[str]
426
- discrete_data: NDArray[np.int_]
427
- continuous_factor_names: list[str]
428
- continuous_data: NDArray[np.int_ | np.double] | None
429
- class_labels: NDArray[np.int_]
430
- class_names: NDArray[Any]
431
- total_num_factors: int
432
-
433
-
434
- @set_metadata
435
- def preprocess(
436
- metadata: dict[str, list[Any]] | dict[str, NDArray[Any]],
437
- class_labels: ArrayLike | str,
438
- continuous_factor_bins: Mapping[str, int | Iterable[float]] | None = None,
439
- auto_bin_method: Literal["uniform_width", "uniform_count", "clusters"] = "uniform_width",
440
- exclude: Iterable[str] | None = None,
441
- image_index_key: str = "_image_index",
442
- ) -> Metadata:
443
- """
444
- Restructures the metadata to be in the correct format for the bias functions.
445
-
446
- This identifies whether the incoming metadata is discrete or continuous,
447
- and whether the data is already binned or still needs binning.
448
- It accepts a list of dictionaries containing the per image metadata and
449
- automatically adjusts for multiple targets in an image.
450
-
451
- Parameters
452
- ----------
453
- metadata : dict[str, list[Any] | NDArray[Any]]
454
- A flat dictionary which contains all of the metadata on a per image (classification)
455
- or per object (object detection) basis. Length of lists/array should match the length
456
- of the label list/array.
457
- class_labels : ArrayLike or string
458
- If arraylike, expects the labels for each image (image classification)
459
- or each object (object detection). If the labels are included in the
460
- metadata dictionary, pass in the key value.
461
- continuous_factor_bins : Mapping[str, int or Iterable[float]] or None, default None
462
- User provided dictionary specifying how to bin the continuous metadata
463
- factors where the value is either an int to represent the number of bins,
464
- or a list of floats representing the edges for each bin.
465
- auto_bin_method : "uniform_width" or "uniform_count" or "clusters", default "uniform_width"
466
- Method by which the function will automatically bin continuous metadata factors.
467
- It is recommended that the user provide the bins through the `continuous_factor_bins`.
468
- exclude : Iterable[str] or None, default None
469
- User provided collection of metadata keys to exclude when processing metadata.
470
- image_index_key : str, default "_image_index"
471
- User provided metadata key which maps the metadata entry to the source image.
472
-
473
- Returns
474
- -------
475
- Metadata
476
- Output class containing the binned metadata
477
-
478
- See Also
479
- --------
480
- merge
481
- """
482
- # Check that metadata is a single, flattened dictionary with uniform array lengths
483
- check_length = -1
484
- for k, v in metadata.items():
485
- if not isinstance(v, (list, tuple, np.ndarray)):
486
- raise TypeError(
487
- "Metadata dictionary needs to be a single dictionary whose values "
488
- "are arraylike containing the metadata on a per image or per object basis."
489
- )
490
- else:
491
- if check_length == -1:
492
- check_length = len(v)
493
- else:
494
- if check_length != len(v):
495
- raise ValueError(
496
- "The lists/arrays in the metadata dict have varying lengths. "
497
- "Preprocess needs them to be uniform in length."
498
- )
499
-
500
- # Grab continuous factors if supplied
501
- continuous_factor_bins = dict(continuous_factor_bins) if continuous_factor_bins else None
502
-
503
- # Drop any excluded metadata keys
504
- for k in exclude or ():
505
- metadata.pop(k, None)
506
- if continuous_factor_bins:
507
- continuous_factor_bins.pop(k, None)
508
-
509
- # Get the class label array in numeric form and check its dimensions
510
- class_array = as_numpy(metadata.pop(class_labels)) if isinstance(class_labels, str) else as_numpy(class_labels)
511
- if class_array.ndim > 1:
512
- raise ValueError(
513
- f"Got class labels with {class_array.ndim}-dimensional "
514
- f"shape {class_array.shape}, but expected a 1-dimensional array."
515
- )
516
- # Check if the label array is the same length as the metadata arrays
517
- elif len(class_array) != check_length:
518
- raise ValueError(
519
- f"The length of the label array {len(class_array)} is not the same as "
520
- f"the length of the metadata arrays {check_length}."
521
- )
522
- if not np.issubdtype(class_array.dtype, np.int_):
523
- unique_classes, numerical_labels = np.unique(class_array, return_inverse=True)
524
- else:
525
- numerical_labels = class_array
526
- unique_classes = np.unique(class_array)
527
-
528
- # Determine if _image_index is given
529
- image_indices = as_numpy(metadata[image_index_key]) if image_index_key in metadata else np.arange(check_length)
530
-
531
- # Bin according to user supplied bins
532
- continuous_metadata = {}
533
- discrete_metadata = {}
534
- if continuous_factor_bins is not None and continuous_factor_bins != {}:
535
- invalid_keys = set(continuous_factor_bins.keys()) - set(metadata.keys())
536
- if invalid_keys:
537
- raise KeyError(
538
- f"The keys - {invalid_keys} - are present in the `continuous_factor_bins` dictionary "
539
- "but are not keys in the `metadata` dictionary. Delete these keys from `continuous_factor_bins` "
540
- "or add corresponding entries to the `metadata` dictionary."
541
- )
542
- for factor, bins in continuous_factor_bins.items():
543
- discrete_metadata[factor] = _digitize_data(metadata[factor], bins)
544
- continuous_metadata[factor] = metadata[factor]
545
-
546
- # Determine category of the rest of the keys
547
- remaining_keys = set(metadata.keys()) - set(continuous_metadata.keys())
548
- for key in remaining_keys:
549
- data = to_numpy(metadata[key])
550
- if np.issubdtype(data.dtype, np.number):
551
- result = _is_continuous(data, image_indices)
552
- if result:
553
- continuous_metadata[key] = data
554
- unique_samples, ordinal_data = np.unique(data, return_inverse=True)
555
- if unique_samples.size <= np.max([20, data.size * 0.01]):
556
- discrete_metadata[key] = ordinal_data
557
- else:
558
- warnings.warn(
559
- f"A user defined binning was not provided for {key}. "
560
- f"Using the {auto_bin_method} method to discretize the data. "
561
- "It is recommended that the user rerun and supply the desired "
562
- "bins using the continuous_factor_bins parameter.",
563
- UserWarning,
564
- )
565
- discrete_metadata[key] = _bin_data(data, auto_bin_method)
566
- else:
567
- _, discrete_metadata[key] = np.unique(data, return_inverse=True)
568
-
569
- # Split out the dictionaries into the keys and values
570
- discrete_factor_names = list(discrete_metadata.keys())
571
- discrete_data = np.stack(list(discrete_metadata.values()), axis=-1)
572
- continuous_factor_names = list(continuous_metadata.keys())
573
- continuous_data = np.stack(list(continuous_metadata.values()), axis=-1) if continuous_metadata else None
574
- total_num_factors = len(discrete_factor_names + continuous_factor_names) + 1
575
-
576
- return Metadata(
577
- discrete_factor_names,
578
- discrete_data,
579
- continuous_factor_names,
580
- continuous_data,
581
- numerical_labels,
582
- unique_classes,
583
- total_num_factors,
584
- )
585
-
586
-
587
- def _digitize_data(data: list[Any] | NDArray[Any], bins: int | Iterable[float]) -> NDArray[np.intp]:
588
- """
589
- Digitizes a list of values into a given number of bins.
590
-
591
- Parameters
592
- ----------
593
- data : list | NDArray
594
- The values to be digitized.
595
- bins : int | Iterable[float]
596
- The number of bins or list of bin edges for the discrete values that data will be digitized into.
597
-
598
- Returns
599
- -------
600
- NDArray[np.intp]
601
- The digitized values
602
- """
603
-
604
- if not np.all([np.issubdtype(type(n), np.number) for n in data]):
605
- raise TypeError(
606
- "Encountered a data value with non-numeric type when digitizing a factor. "
607
- "Ensure all occurrences of continuous factors are numeric types."
608
- )
609
- if isinstance(bins, int):
610
- _, bin_edges = np.histogram(data, bins=bins)
611
- bin_edges[-1] = np.inf
612
- bin_edges[0] = -np.inf
613
- else:
614
- bin_edges = list(bins)
615
- return np.digitize(data, bin_edges)
616
-
617
-
618
- def _bin_data(data: NDArray[Any], bin_method: str) -> NDArray[np.int_]:
619
- """
620
- Bins continuous data through either equal width bins, equal amounts in each bin, or by clusters.
621
- """
622
- if bin_method == "clusters":
623
- # bin_edges = _binning_by_clusters(data)
624
- warnings.warn(
625
- "Binning by clusters is currently unavailable until changes to the clustering function go through.",
626
- UserWarning,
627
- )
628
- bin_method = "uniform_width"
629
-
630
- # if bin_method != "clusters": # restore this when clusters bin_method is available
631
- counts, bin_edges = np.histogram(data, bins="auto")
632
- n_bins = counts.size
633
- if counts[counts > 0].min() < 10:
634
- counter = 20
635
- while counts[counts > 0].min() < 10 and n_bins >= 2 and counter > 0:
636
- counter -= 1
637
- n_bins -= 1
638
- counts, bin_edges = np.histogram(data, bins=n_bins)
639
-
640
- if bin_method == "uniform_count":
641
- quantiles = np.linspace(0, 100, n_bins + 1)
642
- bin_edges = np.asarray(np.percentile(data, quantiles))
643
-
644
- bin_edges[0] = -np.inf # type: ignore # until the clusters speed up is merged
645
- bin_edges[-1] = np.inf # type: ignore # and the _binning_by_clusters can be uncommented
646
- return np.digitize(data, bin_edges) # type: ignore
647
-
648
-
649
- def _is_continuous(data: NDArray[np.number], image_indices: NDArray[np.number]) -> bool:
650
- """
651
- Determines whether the data is continuous or discrete using the Wasserstein distance.
652
-
653
- Given a 1D sample, we consider the intervals between adjacent points. For a continuous distribution,
654
- a point is equally likely to lie anywhere in the interval bounded by its two neighbors. Furthermore,
655
- we can put all "between neighbor" locations on the same scale of 0 to 1 by subtracting the smaller
656
- neighbor and dividing out the length of the interval. (Duplicates are either assigned to zero or
657
- ignored, depending on context). These normalized locations will be much more uniformly distributed
658
- for continuous data than for discrete, and this gives us a way to distinguish them. Call this the
659
- Normalized Near Neighbor distribution (NNN), defined on the interval [0,1].
660
-
661
- The Wasserstein distance is available in scipy.stats.wasserstein_distance. We can use it to measure
662
- how close the NNN is to a uniform distribution over [0,1]. We found that as long as a sample has at
663
- least 20 points, and furthermore at least half as many points as there are discrete values, we can
664
- reliably distinguish discrete from continuous samples by testing that the Wasserstein distance
665
- measured from a uniform distribution is greater or less than 0.054, respectively.
666
- """
667
- # Check if the metadata is image specific
668
- _, data_indices_unsorted = np.unique(data, return_index=True)
669
- if data_indices_unsorted.size == image_indices.size:
670
- data_indices = np.sort(data_indices_unsorted)
671
- if (data_indices == image_indices).all():
672
- data = data[data_indices]
673
-
674
- # OLD METHOD
675
- # uvals = np.unique(data)
676
- # pct_unique = uvals.size / data.size
677
- # return pct_unique < threshold
678
-
679
- n_examples = len(data)
680
-
681
- if n_examples < CONTINUOUS_MIN_SAMPLE_SIZE:
682
- warnings.warn(
683
- f"All samples look discrete with so few data points (< {CONTINUOUS_MIN_SAMPLE_SIZE})", UserWarning
684
- )
685
- return False
686
-
687
- # Require at least 3 unique values before bothering with NNN
688
- xu = np.unique(data, axis=None)
689
- if xu.size < 3:
690
- return False
691
-
692
- Xs = np.sort(data)
693
-
694
- X0, X1 = Xs[0:-2], Xs[2:] # left and right neighbors
695
-
696
- dx = np.zeros(n_examples - 2) # no dx at end points
697
- gtz = (X1 - X0) > 0 # check for dups; dx will be zero for them
698
- dx[np.logical_not(gtz)] = 0.0
699
-
700
- dx[gtz] = (Xs[1:-1] - X0)[gtz] / (X1 - X0)[gtz] # the core idea: dx is NNN samples.
701
-
702
- shift = wd(dx, np.linspace(0, 1, dx.size)) # how far is dx from uniform, for this feature?
703
-
704
- return shift < DISCRETE_MIN_WD # if NNN is close enough to uniform, consider the sample continuous.
705
-
706
-
707
- def get_counts(data: NDArray[np.int_], min_num_bins: int | None = None) -> NDArray[np.int_]:
708
- """
709
- Returns columnwise unique counts for discrete data.
710
-
711
- Parameters
712
- ----------
713
- data : NDArray
714
- Array containing integer values for metadata factors
715
- min_num_bins : int | None, default None
716
- Minimum number of bins for bincount, helps force consistency across runs
717
-
718
- Returns
719
- -------
720
- NDArray[np.int]
721
- Bin counts per column of data.
722
- """
723
- max_value = data.max() + 1 if min_num_bins is None else min_num_bins
724
- cnt_array = np.zeros((max_value, data.shape[1]), dtype=np.int_)
725
- for idx in range(data.shape[1]):
726
- cnt_array[:, idx] = np.bincount(data[:, idx], minlength=max_value)
727
-
728
- return cnt_array
@@ -10,7 +10,8 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from typing import NamedTuple, TypeVar
13
+ from dataclasses import dataclass
14
+ from typing import TypeVar
14
15
 
15
16
  import numpy as np
16
17
  import torch
@@ -18,7 +19,8 @@ import torch
18
19
  TGMMData = TypeVar("TGMMData")
19
20
 
20
21
 
21
- class GaussianMixtureModelParams(NamedTuple):
22
+ @dataclass
23
+ class GaussianMixtureModelParams:
22
24
  """
23
25
  phi : torch.Tensor
24
26
  Mixture component distribution weights.