dataeval 0.76.1__py3-none-any.whl → 0.81.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/{output.py → _output.py} +14 -0
- dataeval/config.py +77 -0
- dataeval/detectors/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +6 -6
- dataeval/detectors/drift/{base.py → _base.py} +41 -30
- dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
- dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
- dataeval/detectors/drift/{mmd.py → _mmd.py} +33 -19
- dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
- dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +23 -7
- dataeval/detectors/drift/updates.py +1 -1
- dataeval/detectors/linters/__init__.py +0 -3
- dataeval/detectors/linters/duplicates.py +17 -8
- dataeval/detectors/linters/outliers.py +23 -14
- dataeval/detectors/ood/ae.py +29 -8
- dataeval/detectors/ood/base.py +5 -4
- dataeval/detectors/ood/metadata_ks_compare.py +1 -1
- dataeval/detectors/ood/mixin.py +20 -5
- dataeval/detectors/ood/output.py +1 -1
- dataeval/detectors/ood/vae.py +73 -0
- dataeval/metadata/__init__.py +5 -0
- dataeval/metadata/_ood.py +238 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +5 -4
- dataeval/metrics/bias/{balance.py → _balance.py} +67 -17
- dataeval/metrics/bias/{coverage.py → _coverage.py} +41 -35
- dataeval/metrics/bias/{diversity.py → _diversity.py} +17 -12
- dataeval/metrics/bias/{parity.py → _parity.py} +89 -61
- dataeval/metrics/estimators/__init__.py +14 -4
- dataeval/metrics/estimators/{ber.py → _ber.py} +42 -11
- dataeval/metrics/estimators/_clusterer.py +104 -0
- dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -13
- dataeval/metrics/estimators/{uap.py → _uap.py} +4 -4
- dataeval/metrics/stats/__init__.py +7 -7
- dataeval/metrics/stats/{base.py → _base.py} +52 -16
- dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +6 -9
- dataeval/metrics/stats/{datasetstats.py → _datasetstats.py} +10 -14
- dataeval/metrics/stats/{dimensionstats.py → _dimensionstats.py} +6 -5
- dataeval/metrics/stats/{hashstats.py → _hashstats.py} +6 -6
- dataeval/metrics/stats/{labelstats.py → _labelstats.py} +4 -4
- dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +5 -4
- dataeval/metrics/stats/{visualstats.py → _visualstats.py} +9 -8
- dataeval/typing.py +54 -0
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +169 -0
- dataeval/utils/_bin.py +199 -0
- dataeval/utils/_clusterer.py +144 -0
- dataeval/utils/_fast_mst.py +189 -0
- dataeval/utils/{image.py → _image.py} +6 -4
- dataeval/utils/_method.py +18 -0
- dataeval/utils/{shared.py → _mst.py} +3 -65
- dataeval/utils/{plot.py → _plot.py} +4 -4
- dataeval/utils/data/__init__.py +22 -0
- dataeval/utils/data/_embeddings.py +105 -0
- dataeval/utils/data/_images.py +65 -0
- dataeval/utils/data/_metadata.py +352 -0
- dataeval/utils/data/_selection.py +119 -0
- dataeval/utils/{dataset/split.py → data/_split.py} +13 -14
- dataeval/utils/data/_targets.py +73 -0
- dataeval/utils/data/_types.py +58 -0
- dataeval/utils/data/collate.py +103 -0
- dataeval/utils/data/datasets/__init__.py +17 -0
- dataeval/utils/data/datasets/_base.py +254 -0
- dataeval/utils/data/datasets/_cifar10.py +134 -0
- dataeval/utils/data/datasets/_fileio.py +168 -0
- dataeval/utils/data/datasets/_milco.py +153 -0
- dataeval/utils/data/datasets/_mixin.py +56 -0
- dataeval/utils/data/datasets/_mnist.py +183 -0
- dataeval/utils/data/datasets/_ships.py +123 -0
- dataeval/utils/data/datasets/_voc.py +352 -0
- dataeval/utils/data/selections/__init__.py +15 -0
- dataeval/utils/data/selections/_classfilter.py +60 -0
- dataeval/utils/data/selections/_indices.py +26 -0
- dataeval/utils/data/selections/_limit.py +26 -0
- dataeval/utils/data/selections/_reverse.py +18 -0
- dataeval/utils/data/selections/_shuffle.py +29 -0
- dataeval/utils/metadata.py +51 -376
- dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
- dataeval/utils/torch/{internal.py → _internal.py} +21 -51
- dataeval/utils/torch/models.py +43 -2
- dataeval/workflows/sufficiency.py +10 -9
- {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/METADATA +4 -1
- dataeval-0.81.0.dist-info/RECORD +94 -0
- dataeval/detectors/linters/clusterer.py +0 -512
- dataeval/detectors/linters/merged_stats.py +0 -49
- dataeval/detectors/ood/metadata_least_likely.py +0 -119
- dataeval/interop.py +0 -69
- dataeval/utils/dataset/__init__.py +0 -7
- dataeval/utils/dataset/datasets.py +0 -412
- dataeval/utils/dataset/read.py +0 -63
- dataeval-0.76.1.dist-info/RECORD +0 -67
- /dataeval/{log.py → _log.py} +0 -0
- /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/WHEEL +0 -0
dataeval/utils/metadata.py
CHANGED
@@ -1,27 +1,19 @@
|
|
1
1
|
"""
|
2
|
-
|
3
|
-
:class:`Metadata` objects for use within `DataEval`.
|
2
|
+
Utility functions that help organize raw metadata.
|
4
3
|
"""
|
5
4
|
|
6
5
|
from __future__ import annotations
|
7
6
|
|
8
|
-
__all__ = ["
|
7
|
+
__all__ = ["merge", "flatten"]
|
9
8
|
|
10
9
|
import warnings
|
11
|
-
from dataclasses import dataclass
|
12
10
|
from enum import Enum
|
13
|
-
from typing import Any, Iterable, Literal, Mapping,
|
11
|
+
from typing import Any, Iterable, Literal, Mapping, Sequence, overload
|
14
12
|
|
15
13
|
import numpy as np
|
16
|
-
from numpy.typing import
|
17
|
-
from scipy.stats import wasserstein_distance as wd
|
14
|
+
from numpy.typing import NDArray
|
18
15
|
|
19
|
-
|
20
|
-
from dataeval.output import Output, set_metadata
|
21
|
-
|
22
|
-
DISCRETE_MIN_WD = 0.054
|
23
|
-
CONTINUOUS_MIN_SAMPLE_SIZE = 20
|
24
|
-
DEFAULT_IMAGE_INDEX_KEY = "_image_index"
|
16
|
+
_TYPE_MAP = {int: 0, float: 1, str: 2}
|
25
17
|
|
26
18
|
|
27
19
|
class DropReason(Enum):
|
@@ -30,26 +22,15 @@ class DropReason(Enum):
|
|
30
22
|
NESTED_LIST = "nested_list"
|
31
23
|
|
32
24
|
|
33
|
-
T = TypeVar("T")
|
34
|
-
|
35
|
-
|
36
|
-
def _try_cast(v: Any, t: type[T]) -> T | None:
|
37
|
-
"""Casts a value to a type or returns None if unable"""
|
38
|
-
try:
|
39
|
-
return t(v) # type: ignore
|
40
|
-
except (TypeError, ValueError):
|
41
|
-
return None
|
42
|
-
|
43
|
-
|
44
25
|
@overload
|
45
|
-
def
|
26
|
+
def _simplify_type(data: list[str]) -> list[int] | list[float] | list[str]: ...
|
46
27
|
@overload
|
47
|
-
def
|
28
|
+
def _simplify_type(data: str) -> int | float | str: ...
|
48
29
|
|
49
30
|
|
50
|
-
def
|
31
|
+
def _simplify_type(data: list[str] | str) -> list[int] | list[float] | list[str] | int | float | str:
|
51
32
|
"""
|
52
|
-
|
33
|
+
Simplifies a value or a list of values to the simplest form possible,
|
53
34
|
in preferred order of `int`, `float`, or `string`.
|
54
35
|
|
55
36
|
Parameters
|
@@ -63,18 +44,20 @@ def _convert_type(data: list[str] | str) -> list[int] | list[float] | list[str]
|
|
63
44
|
The same values converted to the numerical type if possible
|
64
45
|
"""
|
65
46
|
if not isinstance(data, list):
|
66
|
-
|
47
|
+
try:
|
48
|
+
value = float(data)
|
49
|
+
except (TypeError, ValueError):
|
50
|
+
value = None
|
67
51
|
return str(data) if value is None else int(value) if value.is_integer() else value
|
68
52
|
|
69
53
|
converted = []
|
70
|
-
TYPE_MAP = {int: 0, float: 1, str: 2}
|
71
54
|
max_type = 0
|
72
55
|
for value in data:
|
73
|
-
value =
|
74
|
-
max_type = max(max_type,
|
56
|
+
value = _simplify_type(value)
|
57
|
+
max_type = max(max_type, _TYPE_MAP.get(type(value), 2))
|
75
58
|
converted.append(value)
|
76
59
|
for i in range(len(converted)):
|
77
|
-
converted[i] = list(
|
60
|
+
converted[i] = list(_TYPE_MAP)[max_type](converted[i])
|
78
61
|
return converted
|
79
62
|
|
80
63
|
|
@@ -92,7 +75,7 @@ def _get_key_indices(keys: Iterable[tuple[str, ...]]) -> dict[tuple[str, ...], i
|
|
92
75
|
dict[tuple[str, ...], int]
|
93
76
|
Mapping of tuple keys to starting index
|
94
77
|
"""
|
95
|
-
indices =
|
78
|
+
indices = dict.fromkeys(keys, -1)
|
96
79
|
ks = list(keys)
|
97
80
|
while len(ks) > 0:
|
98
81
|
seen: dict[tuple[str, ...], list[tuple[str, ...]]] = {}
|
@@ -144,6 +127,8 @@ def _flatten_dict_inner(
|
|
144
127
|
items: dict[tuple[str, ...], Any] = {}
|
145
128
|
for k, v in d.items():
|
146
129
|
new_keys: tuple[str, ...] = parent_keys + (k,)
|
130
|
+
if isinstance(v, np.ndarray):
|
131
|
+
v = v.tolist()
|
147
132
|
if isinstance(v, dict):
|
148
133
|
fd, size = _flatten_dict_inner(v, dropped, new_keys, size=size, nested=nested)
|
149
134
|
items.update(fd)
|
@@ -223,7 +208,7 @@ def flatten(
|
|
223
208
|
|
224
209
|
output = {}
|
225
210
|
for k, v in expanded.items():
|
226
|
-
cv =
|
211
|
+
cv = _simplify_type(v)
|
227
212
|
if isinstance(cv, list):
|
228
213
|
if len(cv) == size:
|
229
214
|
output[k] = cv
|
@@ -245,7 +230,8 @@ def flatten(
|
|
245
230
|
return output, size, _sorted_drop_reasons(dropped)
|
246
231
|
else:
|
247
232
|
if dropped:
|
248
|
-
|
233
|
+
dropped_items = "\n".join([f" {k}: {v}" for k, v in _sorted_drop_reasons(dropped).items()])
|
234
|
+
warnings.warn(f"Metadata entries were dropped:\n{dropped_items}")
|
249
235
|
return output, size
|
250
236
|
|
251
237
|
|
@@ -271,6 +257,8 @@ def merge(
|
|
271
257
|
ignore_lists: bool = False,
|
272
258
|
fully_qualified: bool = False,
|
273
259
|
return_numpy: bool = False,
|
260
|
+
targets_per_image: Sequence[int] | None = None,
|
261
|
+
image_index_key: str = "_image_index",
|
274
262
|
) -> tuple[dict[str, list[Any]] | dict[str, NDArray[Any]], dict[str, list[str]]]: ...
|
275
263
|
|
276
264
|
|
@@ -281,6 +269,8 @@ def merge(
|
|
281
269
|
ignore_lists: bool = False,
|
282
270
|
fully_qualified: bool = False,
|
283
271
|
return_numpy: bool = False,
|
272
|
+
targets_per_image: Sequence[int] | None = None,
|
273
|
+
image_index_key: str = "_image_index",
|
284
274
|
) -> dict[str, list[Any]] | dict[str, NDArray[Any]]: ...
|
285
275
|
|
286
276
|
|
@@ -290,6 +280,8 @@ def merge(
|
|
290
280
|
ignore_lists: bool = False,
|
291
281
|
fully_qualified: bool = False,
|
292
282
|
return_numpy: bool = False,
|
283
|
+
targets_per_image: Sequence[int] | None = None,
|
284
|
+
image_index_key: str = "_image_index",
|
293
285
|
):
|
294
286
|
"""
|
295
287
|
Merges a collection of metadata dictionaries into a single flattened
|
@@ -298,7 +290,7 @@ def merge(
|
|
298
290
|
Nested dictionaries are flattened, and lists are expanded. Nested lists are
|
299
291
|
dropped as the expanding into multiple hierarchical trees is not supported.
|
300
292
|
The function adds an internal "_image_index" key to the metadata dictionary
|
301
|
-
|
293
|
+
used by the `Metadata` class.
|
302
294
|
|
303
295
|
Parameters
|
304
296
|
----------
|
@@ -312,6 +304,10 @@ def merge(
|
|
312
304
|
Option to return dictionary keys full qualified instead of minimized
|
313
305
|
return_numpy : bool, default False
|
314
306
|
Option to return results as lists or NumPy arrays
|
307
|
+
targets_per_image : Sequence[int] or None, default None
|
308
|
+
Number of targets for each image metadata entry
|
309
|
+
image_index_key : str, default "_image_index"
|
310
|
+
User provided metadata key which maps the metadata entry to the source image.
|
315
311
|
|
316
312
|
Returns
|
317
313
|
-------
|
@@ -349,15 +345,24 @@ def merge(
|
|
349
345
|
else:
|
350
346
|
dicts = list(metadata)
|
351
347
|
|
348
|
+
if targets_per_image is not None and len(dicts) != len(targets_per_image):
|
349
|
+
raise ValueError("Number of targets per image must be equal to number of metadata entries.")
|
350
|
+
|
352
351
|
image_repeats = np.zeros(len(dicts), dtype=np.int_)
|
353
352
|
dropped: dict[str, set[DropReason]] = {}
|
354
353
|
for i, d in enumerate(dicts):
|
355
354
|
flattened, image_repeats[i], dropped_inner = flatten(
|
356
|
-
d,
|
357
|
-
return_dropped=True,
|
358
|
-
ignore_lists=ignore_lists,
|
359
|
-
fully_qualified=fully_qualified,
|
355
|
+
d, return_dropped=True, ignore_lists=ignore_lists, fully_qualified=fully_qualified
|
360
356
|
)
|
357
|
+
if targets_per_image is not None:
|
358
|
+
# check for mismatch in targets per image and force ignore_lists
|
359
|
+
if not ignore_lists and targets_per_image[i] != image_repeats[i]:
|
360
|
+
flattened, image_repeats[i], dropped_inner = flatten(
|
361
|
+
d, return_dropped=True, ignore_lists=True, fully_qualified=fully_qualified
|
362
|
+
)
|
363
|
+
if targets_per_image[i] != image_repeats[i]:
|
364
|
+
flattened = {k: [v] * targets_per_image[i] for k, v in flattened.items()}
|
365
|
+
image_repeats[i] = targets_per_image[i]
|
361
366
|
isect = isect.intersection(flattened.keys()) if isect else set(flattened.keys())
|
362
367
|
union.update(flattened.keys())
|
363
368
|
for k, v in dropped_inner.items():
|
@@ -384,345 +389,15 @@ def merge(
|
|
384
389
|
output["keys"] = np.array(keys) if return_numpy else keys
|
385
390
|
|
386
391
|
for k in (key for key in merged if key in isect):
|
387
|
-
cv =
|
392
|
+
cv = _simplify_type(merged[k])
|
388
393
|
output[k] = np.array(cv) if return_numpy else cv
|
389
|
-
|
394
|
+
if image_index_key not in output:
|
395
|
+
output[image_index_key] = image_indices if return_numpy else image_indices.tolist()
|
390
396
|
|
391
397
|
if return_dropped:
|
392
398
|
return output, _sorted_drop_reasons(dropped)
|
393
399
|
else:
|
394
400
|
if dropped:
|
395
|
-
|
401
|
+
dropped_items = "\n".join([f" {k}: {v}" for k, v in _sorted_drop_reasons(dropped).items()])
|
402
|
+
warnings.warn(f"Metadata entries were dropped:\n{dropped_items}")
|
396
403
|
return output
|
397
|
-
|
398
|
-
|
399
|
-
@dataclass(frozen=True)
|
400
|
-
class Metadata(Output):
|
401
|
-
"""
|
402
|
-
Dataclass containing binned metadata from the :func:`preprocess` function.
|
403
|
-
|
404
|
-
Attributes
|
405
|
-
----------
|
406
|
-
discrete_factor_names : list[str]
|
407
|
-
List containing factor names for the original data that was discrete and
|
408
|
-
the binned continuous data
|
409
|
-
discrete_data : NDArray[np.int]
|
410
|
-
Array containing values for the original data that was discrete and the
|
411
|
-
binned continuous data
|
412
|
-
continuous_factor_names : list[str]
|
413
|
-
List containing factor names for the original continuous data
|
414
|
-
continuous_data : NDArray[np.int or np.double] | None
|
415
|
-
Array containing values for the original continuous data or None if there
|
416
|
-
was no continuous data
|
417
|
-
class_labels : NDArray[np.int]
|
418
|
-
Numerical class labels for the images/objects
|
419
|
-
class_names : NDArray[Any]
|
420
|
-
Array of unique class names (for use with plotting)
|
421
|
-
total_num_factors : int
|
422
|
-
Sum of discrete_factor_names and continuous_factor_names plus 1 for class
|
423
|
-
"""
|
424
|
-
|
425
|
-
discrete_factor_names: list[str]
|
426
|
-
discrete_data: NDArray[np.int_]
|
427
|
-
continuous_factor_names: list[str]
|
428
|
-
continuous_data: NDArray[np.int_ | np.double] | None
|
429
|
-
class_labels: NDArray[np.int_]
|
430
|
-
class_names: NDArray[Any]
|
431
|
-
total_num_factors: int
|
432
|
-
|
433
|
-
|
434
|
-
@set_metadata
|
435
|
-
def preprocess(
|
436
|
-
metadata: dict[str, list[Any]] | dict[str, NDArray[Any]],
|
437
|
-
class_labels: ArrayLike | str,
|
438
|
-
continuous_factor_bins: Mapping[str, int | Iterable[float]] | None = None,
|
439
|
-
auto_bin_method: Literal["uniform_width", "uniform_count", "clusters"] = "uniform_width",
|
440
|
-
exclude: Iterable[str] | None = None,
|
441
|
-
image_index_key: str = "_image_index",
|
442
|
-
) -> Metadata:
|
443
|
-
"""
|
444
|
-
Restructures the metadata to be in the correct format for the bias functions.
|
445
|
-
|
446
|
-
This identifies whether the incoming metadata is discrete or continuous,
|
447
|
-
and whether the data is already binned or still needs binning.
|
448
|
-
It accepts a list of dictionaries containing the per image metadata and
|
449
|
-
automatically adjusts for multiple targets in an image.
|
450
|
-
|
451
|
-
Parameters
|
452
|
-
----------
|
453
|
-
metadata : dict[str, list[Any] | NDArray[Any]]
|
454
|
-
A flat dictionary which contains all of the metadata on a per image (classification)
|
455
|
-
or per object (object detection) basis. Length of lists/array should match the length
|
456
|
-
of the label list/array.
|
457
|
-
class_labels : ArrayLike or string
|
458
|
-
If arraylike, expects the labels for each image (image classification)
|
459
|
-
or each object (object detection). If the labels are included in the
|
460
|
-
metadata dictionary, pass in the key value.
|
461
|
-
continuous_factor_bins : Mapping[str, int or Iterable[float]] or None, default None
|
462
|
-
User provided dictionary specifying how to bin the continuous metadata
|
463
|
-
factors where the value is either an int to represent the number of bins,
|
464
|
-
or a list of floats representing the edges for each bin.
|
465
|
-
auto_bin_method : "uniform_width" or "uniform_count" or "clusters", default "uniform_width"
|
466
|
-
Method by which the function will automatically bin continuous metadata factors.
|
467
|
-
It is recommended that the user provide the bins through the `continuous_factor_bins`.
|
468
|
-
exclude : Iterable[str] or None, default None
|
469
|
-
User provided collection of metadata keys to exclude when processing metadata.
|
470
|
-
image_index_key : str, default "_image_index"
|
471
|
-
User provided metadata key which maps the metadata entry to the source image.
|
472
|
-
|
473
|
-
Returns
|
474
|
-
-------
|
475
|
-
Metadata
|
476
|
-
Output class containing the binned metadata
|
477
|
-
|
478
|
-
See Also
|
479
|
-
--------
|
480
|
-
merge
|
481
|
-
"""
|
482
|
-
# Check that metadata is a single, flattened dictionary with uniform array lengths
|
483
|
-
check_length = -1
|
484
|
-
for k, v in metadata.items():
|
485
|
-
if not isinstance(v, (list, tuple, np.ndarray)):
|
486
|
-
raise TypeError(
|
487
|
-
"Metadata dictionary needs to be a single dictionary whose values "
|
488
|
-
"are arraylike containing the metadata on a per image or per object basis."
|
489
|
-
)
|
490
|
-
else:
|
491
|
-
if check_length == -1:
|
492
|
-
check_length = len(v)
|
493
|
-
else:
|
494
|
-
if check_length != len(v):
|
495
|
-
raise ValueError(
|
496
|
-
"The lists/arrays in the metadata dict have varying lengths. "
|
497
|
-
"Preprocess needs them to be uniform in length."
|
498
|
-
)
|
499
|
-
|
500
|
-
# Grab continuous factors if supplied
|
501
|
-
continuous_factor_bins = dict(continuous_factor_bins) if continuous_factor_bins else None
|
502
|
-
|
503
|
-
# Drop any excluded metadata keys
|
504
|
-
for k in exclude or ():
|
505
|
-
metadata.pop(k, None)
|
506
|
-
if continuous_factor_bins:
|
507
|
-
continuous_factor_bins.pop(k, None)
|
508
|
-
|
509
|
-
# Get the class label array in numeric form and check its dimensions
|
510
|
-
class_array = as_numpy(metadata.pop(class_labels)) if isinstance(class_labels, str) else as_numpy(class_labels)
|
511
|
-
if class_array.ndim > 1:
|
512
|
-
raise ValueError(
|
513
|
-
f"Got class labels with {class_array.ndim}-dimensional "
|
514
|
-
f"shape {class_array.shape}, but expected a 1-dimensional array."
|
515
|
-
)
|
516
|
-
# Check if the label array is the same length as the metadata arrays
|
517
|
-
elif len(class_array) != check_length:
|
518
|
-
raise ValueError(
|
519
|
-
f"The length of the label array {len(class_array)} is not the same as "
|
520
|
-
f"the length of the metadata arrays {check_length}."
|
521
|
-
)
|
522
|
-
if not np.issubdtype(class_array.dtype, np.int_):
|
523
|
-
unique_classes, numerical_labels = np.unique(class_array, return_inverse=True)
|
524
|
-
else:
|
525
|
-
numerical_labels = class_array
|
526
|
-
unique_classes = np.unique(class_array)
|
527
|
-
|
528
|
-
# Determine if _image_index is given
|
529
|
-
image_indices = as_numpy(metadata[image_index_key]) if image_index_key in metadata else np.arange(check_length)
|
530
|
-
|
531
|
-
# Bin according to user supplied bins
|
532
|
-
continuous_metadata = {}
|
533
|
-
discrete_metadata = {}
|
534
|
-
if continuous_factor_bins is not None and continuous_factor_bins != {}:
|
535
|
-
invalid_keys = set(continuous_factor_bins.keys()) - set(metadata.keys())
|
536
|
-
if invalid_keys:
|
537
|
-
raise KeyError(
|
538
|
-
f"The keys - {invalid_keys} - are present in the `continuous_factor_bins` dictionary "
|
539
|
-
"but are not keys in the `metadata` dictionary. Delete these keys from `continuous_factor_bins` "
|
540
|
-
"or add corresponding entries to the `metadata` dictionary."
|
541
|
-
)
|
542
|
-
for factor, bins in continuous_factor_bins.items():
|
543
|
-
discrete_metadata[factor] = _digitize_data(metadata[factor], bins)
|
544
|
-
continuous_metadata[factor] = metadata[factor]
|
545
|
-
|
546
|
-
# Determine category of the rest of the keys
|
547
|
-
remaining_keys = set(metadata.keys()) - set(continuous_metadata.keys())
|
548
|
-
for key in remaining_keys:
|
549
|
-
data = to_numpy(metadata[key])
|
550
|
-
if np.issubdtype(data.dtype, np.number):
|
551
|
-
result = _is_continuous(data, image_indices)
|
552
|
-
if result:
|
553
|
-
continuous_metadata[key] = data
|
554
|
-
unique_samples, ordinal_data = np.unique(data, return_inverse=True)
|
555
|
-
if unique_samples.size <= np.max([20, data.size * 0.01]):
|
556
|
-
discrete_metadata[key] = ordinal_data
|
557
|
-
else:
|
558
|
-
warnings.warn(
|
559
|
-
f"A user defined binning was not provided for {key}. "
|
560
|
-
f"Using the {auto_bin_method} method to discretize the data. "
|
561
|
-
"It is recommended that the user rerun and supply the desired "
|
562
|
-
"bins using the continuous_factor_bins parameter.",
|
563
|
-
UserWarning,
|
564
|
-
)
|
565
|
-
discrete_metadata[key] = _bin_data(data, auto_bin_method)
|
566
|
-
else:
|
567
|
-
_, discrete_metadata[key] = np.unique(data, return_inverse=True)
|
568
|
-
|
569
|
-
# Split out the dictionaries into the keys and values
|
570
|
-
discrete_factor_names = list(discrete_metadata.keys())
|
571
|
-
discrete_data = np.stack(list(discrete_metadata.values()), axis=-1)
|
572
|
-
continuous_factor_names = list(continuous_metadata.keys())
|
573
|
-
continuous_data = np.stack(list(continuous_metadata.values()), axis=-1) if continuous_metadata else None
|
574
|
-
total_num_factors = len(discrete_factor_names + continuous_factor_names) + 1
|
575
|
-
|
576
|
-
return Metadata(
|
577
|
-
discrete_factor_names,
|
578
|
-
discrete_data,
|
579
|
-
continuous_factor_names,
|
580
|
-
continuous_data,
|
581
|
-
numerical_labels,
|
582
|
-
unique_classes,
|
583
|
-
total_num_factors,
|
584
|
-
)
|
585
|
-
|
586
|
-
|
587
|
-
def _digitize_data(data: list[Any] | NDArray[Any], bins: int | Iterable[float]) -> NDArray[np.intp]:
|
588
|
-
"""
|
589
|
-
Digitizes a list of values into a given number of bins.
|
590
|
-
|
591
|
-
Parameters
|
592
|
-
----------
|
593
|
-
data : list | NDArray
|
594
|
-
The values to be digitized.
|
595
|
-
bins : int | Iterable[float]
|
596
|
-
The number of bins or list of bin edges for the discrete values that data will be digitized into.
|
597
|
-
|
598
|
-
Returns
|
599
|
-
-------
|
600
|
-
NDArray[np.intp]
|
601
|
-
The digitized values
|
602
|
-
"""
|
603
|
-
|
604
|
-
if not np.all([np.issubdtype(type(n), np.number) for n in data]):
|
605
|
-
raise TypeError(
|
606
|
-
"Encountered a data value with non-numeric type when digitizing a factor. "
|
607
|
-
"Ensure all occurrences of continuous factors are numeric types."
|
608
|
-
)
|
609
|
-
if isinstance(bins, int):
|
610
|
-
_, bin_edges = np.histogram(data, bins=bins)
|
611
|
-
bin_edges[-1] = np.inf
|
612
|
-
bin_edges[0] = -np.inf
|
613
|
-
else:
|
614
|
-
bin_edges = list(bins)
|
615
|
-
return np.digitize(data, bin_edges)
|
616
|
-
|
617
|
-
|
618
|
-
def _bin_data(data: NDArray[Any], bin_method: str) -> NDArray[np.int_]:
|
619
|
-
"""
|
620
|
-
Bins continuous data through either equal width bins, equal amounts in each bin, or by clusters.
|
621
|
-
"""
|
622
|
-
if bin_method == "clusters":
|
623
|
-
# bin_edges = _binning_by_clusters(data)
|
624
|
-
warnings.warn(
|
625
|
-
"Binning by clusters is currently unavailable until changes to the clustering function go through.",
|
626
|
-
UserWarning,
|
627
|
-
)
|
628
|
-
bin_method = "uniform_width"
|
629
|
-
|
630
|
-
# if bin_method != "clusters": # restore this when clusters bin_method is available
|
631
|
-
counts, bin_edges = np.histogram(data, bins="auto")
|
632
|
-
n_bins = counts.size
|
633
|
-
if counts[counts > 0].min() < 10:
|
634
|
-
counter = 20
|
635
|
-
while counts[counts > 0].min() < 10 and n_bins >= 2 and counter > 0:
|
636
|
-
counter -= 1
|
637
|
-
n_bins -= 1
|
638
|
-
counts, bin_edges = np.histogram(data, bins=n_bins)
|
639
|
-
|
640
|
-
if bin_method == "uniform_count":
|
641
|
-
quantiles = np.linspace(0, 100, n_bins + 1)
|
642
|
-
bin_edges = np.asarray(np.percentile(data, quantiles))
|
643
|
-
|
644
|
-
bin_edges[0] = -np.inf # type: ignore # until the clusters speed up is merged
|
645
|
-
bin_edges[-1] = np.inf # type: ignore # and the _binning_by_clusters can be uncommented
|
646
|
-
return np.digitize(data, bin_edges) # type: ignore
|
647
|
-
|
648
|
-
|
649
|
-
def _is_continuous(data: NDArray[np.number], image_indices: NDArray[np.number]) -> bool:
|
650
|
-
"""
|
651
|
-
Determines whether the data is continuous or discrete using the Wasserstein distance.
|
652
|
-
|
653
|
-
Given a 1D sample, we consider the intervals between adjacent points. For a continuous distribution,
|
654
|
-
a point is equally likely to lie anywhere in the interval bounded by its two neighbors. Furthermore,
|
655
|
-
we can put all "between neighbor" locations on the same scale of 0 to 1 by subtracting the smaller
|
656
|
-
neighbor and dividing out the length of the interval. (Duplicates are either assigned to zero or
|
657
|
-
ignored, depending on context). These normalized locations will be much more uniformly distributed
|
658
|
-
for continuous data than for discrete, and this gives us a way to distinguish them. Call this the
|
659
|
-
Normalized Near Neighbor distribution (NNN), defined on the interval [0,1].
|
660
|
-
|
661
|
-
The Wasserstein distance is available in scipy.stats.wasserstein_distance. We can use it to measure
|
662
|
-
how close the NNN is to a uniform distribution over [0,1]. We found that as long as a sample has at
|
663
|
-
least 20 points, and furthermore at least half as many points as there are discrete values, we can
|
664
|
-
reliably distinguish discrete from continuous samples by testing that the Wasserstein distance
|
665
|
-
measured from a uniform distribution is greater or less than 0.054, respectively.
|
666
|
-
"""
|
667
|
-
# Check if the metadata is image specific
|
668
|
-
_, data_indices_unsorted = np.unique(data, return_index=True)
|
669
|
-
if data_indices_unsorted.size == image_indices.size:
|
670
|
-
data_indices = np.sort(data_indices_unsorted)
|
671
|
-
if (data_indices == image_indices).all():
|
672
|
-
data = data[data_indices]
|
673
|
-
|
674
|
-
# OLD METHOD
|
675
|
-
# uvals = np.unique(data)
|
676
|
-
# pct_unique = uvals.size / data.size
|
677
|
-
# return pct_unique < threshold
|
678
|
-
|
679
|
-
n_examples = len(data)
|
680
|
-
|
681
|
-
if n_examples < CONTINUOUS_MIN_SAMPLE_SIZE:
|
682
|
-
warnings.warn(
|
683
|
-
f"All samples look discrete with so few data points (< {CONTINUOUS_MIN_SAMPLE_SIZE})", UserWarning
|
684
|
-
)
|
685
|
-
return False
|
686
|
-
|
687
|
-
# Require at least 3 unique values before bothering with NNN
|
688
|
-
xu = np.unique(data, axis=None)
|
689
|
-
if xu.size < 3:
|
690
|
-
return False
|
691
|
-
|
692
|
-
Xs = np.sort(data)
|
693
|
-
|
694
|
-
X0, X1 = Xs[0:-2], Xs[2:] # left and right neighbors
|
695
|
-
|
696
|
-
dx = np.zeros(n_examples - 2) # no dx at end points
|
697
|
-
gtz = (X1 - X0) > 0 # check for dups; dx will be zero for them
|
698
|
-
dx[np.logical_not(gtz)] = 0.0
|
699
|
-
|
700
|
-
dx[gtz] = (Xs[1:-1] - X0)[gtz] / (X1 - X0)[gtz] # the core idea: dx is NNN samples.
|
701
|
-
|
702
|
-
shift = wd(dx, np.linspace(0, 1, dx.size)) # how far is dx from uniform, for this feature?
|
703
|
-
|
704
|
-
return shift < DISCRETE_MIN_WD # if NNN is close enough to uniform, consider the sample continuous.
|
705
|
-
|
706
|
-
|
707
|
-
def get_counts(data: NDArray[np.int_], min_num_bins: int | None = None) -> NDArray[np.int_]:
|
708
|
-
"""
|
709
|
-
Returns columnwise unique counts for discrete data.
|
710
|
-
|
711
|
-
Parameters
|
712
|
-
----------
|
713
|
-
data : NDArray
|
714
|
-
Array containing integer values for metadata factors
|
715
|
-
min_num_bins : int | None, default None
|
716
|
-
Minimum number of bins for bincount, helps force consistency across runs
|
717
|
-
|
718
|
-
Returns
|
719
|
-
-------
|
720
|
-
NDArray[np.int]
|
721
|
-
Bin counts per column of data.
|
722
|
-
"""
|
723
|
-
max_value = data.max() + 1 if min_num_bins is None else min_num_bins
|
724
|
-
cnt_array = np.zeros((max_value, data.shape[1]), dtype=np.int_)
|
725
|
-
for idx in range(data.shape[1]):
|
726
|
-
cnt_array[:, idx] = np.bincount(data[:, idx], minlength=max_value)
|
727
|
-
|
728
|
-
return cnt_array
|
@@ -10,7 +10,8 @@ from __future__ import annotations
|
|
10
10
|
|
11
11
|
__all__ = []
|
12
12
|
|
13
|
-
from
|
13
|
+
from dataclasses import dataclass
|
14
|
+
from typing import TypeVar
|
14
15
|
|
15
16
|
import numpy as np
|
16
17
|
import torch
|
@@ -18,7 +19,8 @@ import torch
|
|
18
19
|
TGMMData = TypeVar("TGMMData")
|
19
20
|
|
20
21
|
|
21
|
-
|
22
|
+
@dataclass
|
23
|
+
class GaussianMixtureModelParams:
|
22
24
|
"""
|
23
25
|
phi : torch.Tensor
|
24
26
|
Mixture component distribution weights.
|