dataeval 0.63.0__tar.gz → 0.64.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. {dataeval-0.63.0 → dataeval-0.64.0}/PKG-INFO +1 -1
  2. {dataeval-0.63.0 → dataeval-0.64.0}/pyproject.toml +1 -1
  3. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/__init__.py +3 -3
  4. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/detectors/clusterer.py +2 -1
  5. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/detectors/drift/base.py +2 -1
  6. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/detectors/drift/cvm.py +2 -1
  7. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/detectors/drift/ks.py +2 -1
  8. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/detectors/drift/mmd.py +4 -3
  9. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/detectors/drift/uncertainty.py +1 -2
  10. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/detectors/duplicates.py +2 -1
  11. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/detectors/linter.py +1 -1
  12. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/detectors/ood/ae.py +2 -1
  13. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/detectors/ood/aegmm.py +2 -1
  14. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/detectors/ood/base.py +2 -1
  15. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/detectors/ood/llr.py +3 -2
  16. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/detectors/ood/vae.py +2 -1
  17. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/detectors/ood/vaegmm.py +2 -1
  18. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/interop.py +2 -11
  19. dataeval-0.64.0/src/dataeval/_internal/metrics/balance.py +180 -0
  20. dataeval-0.64.0/src/dataeval/_internal/metrics/base.py +10 -0
  21. dataeval-0.64.0/src/dataeval/_internal/metrics/ber.py +148 -0
  22. {dataeval-0.63.0/src/dataeval/_internal/functional → dataeval-0.64.0/src/dataeval/_internal/metrics}/coverage.py +44 -14
  23. dataeval-0.64.0/src/dataeval/_internal/metrics/divergence.py +102 -0
  24. dataeval-0.64.0/src/dataeval/_internal/metrics/diversity.py +206 -0
  25. dataeval-0.64.0/src/dataeval/_internal/metrics/parity.py +309 -0
  26. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/metrics/stats.py +7 -5
  27. dataeval-0.64.0/src/dataeval/_internal/metrics/uap.py +50 -0
  28. dataeval-0.64.0/src/dataeval/_internal/metrics/utils.py +393 -0
  29. dataeval-0.64.0/src/dataeval/_internal/utils.py +64 -0
  30. dataeval-0.64.0/src/dataeval/metrics/__init__.py +27 -0
  31. dataeval-0.64.0/src/dataeval/utils/__init__.py +9 -0
  32. dataeval-0.63.0/src/dataeval/_internal/functional/ber.py +0 -63
  33. dataeval-0.63.0/src/dataeval/_internal/functional/divergence.py +0 -16
  34. dataeval-0.63.0/src/dataeval/_internal/functional/hash.py +0 -79
  35. dataeval-0.63.0/src/dataeval/_internal/functional/metadata.py +0 -136
  36. dataeval-0.63.0/src/dataeval/_internal/functional/metadataparity.py +0 -190
  37. dataeval-0.63.0/src/dataeval/_internal/functional/uap.py +0 -6
  38. dataeval-0.63.0/src/dataeval/_internal/functional/utils.py +0 -158
  39. dataeval-0.63.0/src/dataeval/_internal/maite/utils.py +0 -30
  40. dataeval-0.63.0/src/dataeval/_internal/metrics/base.py +0 -92
  41. dataeval-0.63.0/src/dataeval/_internal/metrics/ber.py +0 -74
  42. dataeval-0.63.0/src/dataeval/_internal/metrics/coverage.py +0 -96
  43. dataeval-0.63.0/src/dataeval/_internal/metrics/divergence.py +0 -102
  44. dataeval-0.63.0/src/dataeval/_internal/metrics/metadata.py +0 -610
  45. dataeval-0.63.0/src/dataeval/_internal/metrics/metadataparity.py +0 -67
  46. dataeval-0.63.0/src/dataeval/_internal/metrics/parity.py +0 -164
  47. dataeval-0.63.0/src/dataeval/_internal/metrics/uap.py +0 -42
  48. dataeval-0.63.0/src/dataeval/_internal/models/tensorflow/__init__.py +0 -0
  49. dataeval-0.63.0/src/dataeval/_internal/workflows/__init__.py +0 -0
  50. dataeval-0.63.0/src/dataeval/metrics/__init__.py +0 -8
  51. {dataeval-0.63.0 → dataeval-0.64.0}/LICENSE.txt +0 -0
  52. {dataeval-0.63.0 → dataeval-0.64.0}/README.md +0 -0
  53. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/detectors/__init__.py +0 -0
  54. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/detectors/drift/__init__.py +0 -0
  55. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/detectors/drift/torch.py +0 -0
  56. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/detectors/ood/__init__.py +0 -0
  57. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/flags.py +0 -0
  58. {dataeval-0.63.0/src/dataeval/_internal/functional → dataeval-0.64.0/src/dataeval/_internal/metrics}/__init__.py +0 -0
  59. {dataeval-0.63.0/src/dataeval/_internal/maite → dataeval-0.64.0/src/dataeval/_internal/models}/__init__.py +0 -0
  60. {dataeval-0.63.0/src/dataeval/_internal/metrics → dataeval-0.64.0/src/dataeval/_internal/models/pytorch}/__init__.py +0 -0
  61. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/models/pytorch/autoencoder.py +0 -0
  62. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/models/pytorch/blocks.py +0 -0
  63. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/models/pytorch/utils.py +0 -0
  64. {dataeval-0.63.0/src/dataeval/_internal/models → dataeval-0.64.0/src/dataeval/_internal/models/tensorflow}/__init__.py +0 -0
  65. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/models/tensorflow/autoencoder.py +0 -0
  66. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/models/tensorflow/gmm.py +0 -0
  67. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/models/tensorflow/losses.py +0 -0
  68. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/models/tensorflow/pixelcnn.py +0 -0
  69. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/models/tensorflow/trainer.py +0 -0
  70. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/models/tensorflow/utils.py +0 -0
  71. {dataeval-0.63.0/src/dataeval/_internal/models/pytorch → dataeval-0.64.0/src/dataeval/_internal/workflows}/__init__.py +0 -0
  72. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/_internal/workflows/sufficiency.py +0 -0
  73. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/detectors/__init__.py +0 -0
  74. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/flags/__init__.py +0 -0
  75. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/models/__init__.py +0 -0
  76. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/models/tensorflow/__init__.py +0 -0
  77. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/models/torch/__init__.py +0 -0
  78. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/py.typed +0 -0
  79. {dataeval-0.63.0 → dataeval-0.64.0}/src/dataeval/workflows/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.63.0
3
+ Version: 0.64.0
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "dataeval"
3
- version = "0.63.0" # dynamic
3
+ version = "0.64.0" # dynamic
4
4
  description = "DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks"
5
5
  license = "MIT"
6
6
  readme = "README.md"
@@ -2,14 +2,14 @@ from importlib.util import find_spec
2
2
 
3
3
  from . import detectors, flags, metrics
4
4
 
5
- __version__ = "0.63.0"
5
+ __version__ = "0.64.0"
6
6
 
7
7
  __all__ = ["detectors", "flags", "metrics"]
8
8
 
9
9
  if find_spec("torch") is not None: # pragma: no cover
10
- from . import models, workflows
10
+ from . import models, utils, workflows
11
11
 
12
- __all__ += ["models", "workflows"]
12
+ __all__ += ["models", "utils", "workflows"]
13
13
  elif find_spec("tensorflow") is not None: # pragma: no cover
14
14
  from . import models
15
15
 
@@ -1,10 +1,11 @@
1
1
  from typing import Dict, Iterable, List, NamedTuple, Tuple, Union, cast
2
2
 
3
3
  import numpy as np
4
+ from numpy.typing import ArrayLike
4
5
  from scipy.cluster.hierarchy import linkage
5
6
  from scipy.spatial.distance import pdist, squareform
6
7
 
7
- from dataeval._internal.interop import ArrayLike, to_numpy
8
+ from dataeval._internal.interop import to_numpy
8
9
 
9
10
 
10
11
  def extend_linkage(link_arr: np.ndarray) -> np.ndarray:
@@ -11,8 +11,9 @@ from functools import wraps
11
11
  from typing import Callable, Dict, Literal, Optional, Tuple, Union
12
12
 
13
13
  import numpy as np
14
+ from numpy.typing import ArrayLike
14
15
 
15
- from dataeval._internal.interop import ArrayLike, to_numpy
16
+ from dataeval._internal.interop import to_numpy
16
17
 
17
18
 
18
19
  def update_x_ref(fn):
@@ -9,9 +9,10 @@ Licensed under Apache Software License (Apache 2.0)
9
9
  from typing import Callable, Literal, Optional, Tuple
10
10
 
11
11
  import numpy as np
12
+ from numpy.typing import ArrayLike
12
13
  from scipy.stats import cramervonmises_2samp
13
14
 
14
- from dataeval._internal.interop import ArrayLike, to_numpy
15
+ from dataeval._internal.interop import to_numpy
15
16
 
16
17
  from .base import BaseUnivariateDrift, UpdateStrategy, preprocess_x
17
18
 
@@ -9,9 +9,10 @@ Licensed under Apache Software License (Apache 2.0)
9
9
  from typing import Callable, Literal, Optional, Tuple
10
10
 
11
11
  import numpy as np
12
+ from numpy.typing import ArrayLike
12
13
  from scipy.stats import ks_2samp
13
14
 
14
- from dataeval._internal.interop import ArrayLike, to_numpy
15
+ from dataeval._internal.interop import to_numpy
15
16
 
16
17
  from .base import BaseUnivariateDrift, UpdateStrategy, preprocess_x
17
18
 
@@ -9,8 +9,9 @@ Licensed under Apache Software License (Apache 2.0)
9
9
  from typing import Callable, Dict, Optional, Tuple, Union
10
10
 
11
11
  import torch
12
+ from numpy.typing import ArrayLike
12
13
 
13
- from dataeval._internal.interop import ArrayLike, to_numpy
14
+ from dataeval._internal.interop import to_numpy
14
15
 
15
16
  from .base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
16
17
  from .torch import GaussianRBF, get_device, mmd2_from_kernel_matrix
@@ -74,7 +75,7 @@ class DriftMMD(BaseDrift):
74
75
  super().__init__(x_ref, p_val, x_ref_preprocessed, update_x_ref, preprocess_fn)
75
76
 
76
77
  self.infer_sigma = configure_kernel_from_x_ref
77
- if configure_kernel_from_x_ref and isinstance(sigma, ArrayLike):
78
+ if configure_kernel_from_x_ref and sigma is not None:
78
79
  self.infer_sigma = False
79
80
 
80
81
  self.n_permutations = n_permutations # nb of iterations through permutation test
@@ -83,7 +84,7 @@ class DriftMMD(BaseDrift):
83
84
  self.device = get_device(device)
84
85
 
85
86
  # initialize kernel
86
- sigma_tensor = torch.from_numpy(to_numpy(sigma)).to(self.device) if isinstance(sigma, ArrayLike) else None
87
+ sigma_tensor = torch.from_numpy(to_numpy(sigma)).to(self.device) if sigma is not None else None
87
88
  self.kernel = kernel(sigma_tensor).to(self.device) if kernel == GaussianRBF else kernel
88
89
 
89
90
  # compute kernel matrix for the reference data
@@ -10,11 +10,10 @@ from functools import partial
10
10
  from typing import Callable, Dict, Literal, Optional, Union
11
11
 
12
12
  import numpy as np
13
+ from numpy.typing import ArrayLike
13
14
  from scipy.special import softmax
14
15
  from scipy.stats import entropy
15
16
 
16
- from dataeval._internal.interop import ArrayLike
17
-
18
17
  from .base import UpdateStrategy
19
18
  from .ks import DriftKS
20
19
  from .torch import get_device, preprocess_drift
@@ -1,7 +1,8 @@
1
1
  from typing import Dict, Iterable, List, Literal
2
2
 
3
+ from numpy.typing import ArrayLike
4
+
3
5
  from dataeval._internal.flags import ImageHash
4
- from dataeval._internal.interop import ArrayLike
5
6
  from dataeval._internal.metrics.stats import ImageStats
6
7
 
7
8
 
@@ -1,9 +1,9 @@
1
1
  from typing import Iterable, Literal, Optional, Sequence, Union
2
2
 
3
3
  import numpy as np
4
+ from numpy.typing import ArrayLike
4
5
 
5
6
  from dataeval._internal.flags import ImageProperty, ImageVisuals, LinterFlags
6
- from dataeval._internal.interop import ArrayLike
7
7
  from dataeval._internal.metrics.stats import ImageStats
8
8
 
9
9
 
@@ -10,9 +10,10 @@ from typing import Callable
10
10
 
11
11
  import keras
12
12
  import numpy as np
13
+ from numpy.typing import ArrayLike
13
14
 
14
15
  from dataeval._internal.detectors.ood.base import OODBase, OODScore
15
- from dataeval._internal.interop import ArrayLike, to_numpy
16
+ from dataeval._internal.interop import to_numpy
16
17
  from dataeval._internal.models.tensorflow.autoencoder import AE
17
18
  from dataeval._internal.models.tensorflow.utils import predict_batch
18
19
 
@@ -9,9 +9,10 @@ Licensed under Apache Software License (Apache 2.0)
9
9
  from typing import Callable
10
10
 
11
11
  import keras
12
+ from numpy.typing import ArrayLike
12
13
 
13
14
  from dataeval._internal.detectors.ood.base import OODGMMBase, OODScore
14
- from dataeval._internal.interop import ArrayLike, to_numpy
15
+ from dataeval._internal.interop import to_numpy
15
16
  from dataeval._internal.models.tensorflow.autoencoder import AEGMM
16
17
  from dataeval._internal.models.tensorflow.gmm import gmm_energy
17
18
  from dataeval._internal.models.tensorflow.losses import LossGMM
@@ -12,8 +12,9 @@ from typing import Callable, Dict, List, Literal, NamedTuple, Optional, Tuple, c
12
12
  import keras
13
13
  import numpy as np
14
14
  import tensorflow as tf
15
+ from numpy.typing import ArrayLike
15
16
 
16
- from dataeval._internal.interop import ArrayLike, to_numpy
17
+ from dataeval._internal.interop import to_numpy
17
18
  from dataeval._internal.models.tensorflow.gmm import GaussianMixtureModelParams, gmm_params
18
19
  from dataeval._internal.models.tensorflow.trainer import trainer
19
20
 
@@ -14,9 +14,10 @@ import numpy as np
14
14
  import tensorflow as tf
15
15
  from keras.layers import Input
16
16
  from keras.models import Model
17
+ from numpy.typing import ArrayLike
17
18
 
18
19
  from dataeval._internal.detectors.ood.base import OODBase, OODScore
19
- from dataeval._internal.interop import ArrayLike, to_numpy
20
+ from dataeval._internal.interop import to_numpy
20
21
  from dataeval._internal.models.tensorflow.pixelcnn import PixelCNN
21
22
  from dataeval._internal.models.tensorflow.trainer import trainer
22
23
  from dataeval._internal.models.tensorflow.utils import predict_batch
@@ -180,7 +181,7 @@ class OOD_LLR(OODBase):
180
181
 
181
182
  # create background data
182
183
  mutate_fn = partial(mutate_fn, **mutate_fn_kwargs)
183
- X_back = predict_batch(x_ref, mutate_fn, batch_size=mutate_batch_size, dtype=x_ref.dtype)
184
+ X_back = predict_batch(x_ref, mutate_fn, batch_size=mutate_batch_size, dtype=x_ref.dtype) # type: ignore
184
185
 
185
186
  # prepare sequential data
186
187
  if self.sequential and not self.has_log_prob:
@@ -10,9 +10,10 @@ from typing import Callable
10
10
 
11
11
  import keras
12
12
  import numpy as np
13
+ from numpy.typing import ArrayLike
13
14
 
14
15
  from dataeval._internal.detectors.ood.base import OODBase, OODScore
15
- from dataeval._internal.interop import ArrayLike, to_numpy
16
+ from dataeval._internal.interop import to_numpy
16
17
  from dataeval._internal.models.tensorflow.autoencoder import VAE
17
18
  from dataeval._internal.models.tensorflow.losses import Elbo
18
19
  from dataeval._internal.models.tensorflow.utils import predict_batch
@@ -10,9 +10,10 @@ from typing import Callable
10
10
 
11
11
  import keras
12
12
  import numpy as np
13
+ from numpy.typing import ArrayLike
13
14
 
14
15
  from dataeval._internal.detectors.ood.base import OODGMMBase, OODScore
15
- from dataeval._internal.interop import ArrayLike, to_numpy
16
+ from dataeval._internal.interop import to_numpy
16
17
  from dataeval._internal.models.tensorflow.autoencoder import VAEGMM
17
18
  from dataeval._internal.models.tensorflow.gmm import gmm_energy
18
19
  from dataeval._internal.models.tensorflow.losses import Elbo, LossGMM
@@ -1,7 +1,8 @@
1
1
  from importlib import import_module
2
- from typing import Any, Iterable, Optional, runtime_checkable
2
+ from typing import Iterable, Optional
3
3
 
4
4
  import numpy as np
5
+ from numpy.typing import ArrayLike
5
6
 
6
7
  module_cache = {}
7
8
 
@@ -19,16 +20,6 @@ def try_import(module_name):
19
20
  return module
20
21
 
21
22
 
22
- try:
23
- from maite.protocols import ArrayLike # type: ignore
24
- except ImportError: # pragma: no cover - covered by test_mindeps.py
25
- from typing import Protocol
26
-
27
- @runtime_checkable
28
- class ArrayLike(Protocol):
29
- def __array__(self) -> Any: ...
30
-
31
-
32
23
  def to_numpy(array: Optional[ArrayLike]) -> np.ndarray:
33
24
  if array is None:
34
25
  return np.ndarray([])
@@ -0,0 +1,180 @@
1
+ import warnings
2
+ from typing import Dict, List, NamedTuple, Sequence
3
+
4
+ import numpy as np
5
+ from numpy.typing import NDArray
6
+ from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
7
+
8
+ from dataeval._internal.metrics.utils import entropy, preprocess_metadata
9
+
10
+
11
+ class BalanceOutput(NamedTuple):
12
+ """
13
+ Attributes
14
+ ----------
15
+ mutual_information : NDArray[np.float64]
16
+ Estimate of mutual information between metadata factors and class label
17
+ """
18
+
19
+ mutual_information: NDArray[np.float64]
20
+
21
+
22
+ def validate_num_neighbors(num_neighbors: int) -> int:
23
+ if not isinstance(num_neighbors, (int, float)):
24
+ raise TypeError(
25
+ f"Variable {num_neighbors} is not real-valued numeric type."
26
+ "num_neighbors should be an int, greater than 0 and less than"
27
+ "the number of samples in the dataset"
28
+ )
29
+ if num_neighbors < 1:
30
+ raise ValueError(
31
+ f"Invalid value for {num_neighbors}."
32
+ "Choose a value greater than 0 and less than number of samples"
33
+ "in the dataset."
34
+ )
35
+ if isinstance(num_neighbors, float):
36
+ num_neighbors = int(num_neighbors)
37
+ warnings.warn(f"Variable {num_neighbors} is currently type float and will be truncated to type int.")
38
+
39
+ return num_neighbors
40
+
41
+
42
+ def balance(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: int = 5) -> BalanceOutput:
43
+ """
44
+ Mutual information (MI) between factors (class label, metadata, label/image properties)
45
+
46
+ Parameters
47
+ ----------
48
+ class_labels: Sequence[int]
49
+ List of class labels for each image
50
+ metadata: List[Dict]
51
+ List of metadata factors for each image
52
+ num_neighbors: int, default 5
53
+ Number of nearest neighbors to use for computing MI between discrete
54
+ and continuous variables.
55
+
56
+ Returns
57
+ -------
58
+ BalanceOutput
59
+ (num_factors+1) x (num_factors+1) estimate of mutual information
60
+ between num_factors metadata factors and class label. Symmetry is enforced.
61
+
62
+ Notes
63
+ -----
64
+ We use `mutual_info_classif` from sklearn since class label is categorical.
65
+ `mutual_info_classif` outputs are consistent up to O(1e-4) and depend on a random
66
+ seed. MI is computed differently for categorical and continuous variables, and
67
+ we attempt to infer whether a variable is categorical by the fraction of unique
68
+ values in the dataset.
69
+
70
+ See Also
71
+ --------
72
+ sklearn.feature_selection.mutual_info_classif
73
+ sklearn.feature_selection.mutual_info_regression
74
+ sklearn.metrics.mutual_info_score
75
+ """
76
+ num_neighbors = validate_num_neighbors(num_neighbors)
77
+ data, names, is_categorical = preprocess_metadata(class_labels, metadata)
78
+ num_factors = len(names)
79
+ mi = np.empty((num_factors, num_factors))
80
+ mi[:] = np.nan
81
+
82
+ for idx in range(num_factors):
83
+ tgt = data[:, idx]
84
+
85
+ if is_categorical[idx]:
86
+ # categorical target
87
+ mi[idx, :] = mutual_info_classif(
88
+ data,
89
+ tgt,
90
+ discrete_features=is_categorical, # type: ignore
91
+ n_neighbors=num_neighbors,
92
+ )
93
+ else:
94
+ # continuous variables
95
+ mi[idx, :] = mutual_info_regression(
96
+ data,
97
+ tgt,
98
+ discrete_features=is_categorical, # type: ignore
99
+ n_neighbors=num_neighbors,
100
+ )
101
+
102
+ ent_all = entropy(data, names, is_categorical, normalized=False)
103
+ norm_factor = 0.5 * np.add.outer(ent_all, ent_all) + 1e-6
104
+ # in principle MI should be symmetric, but it is not in practice.
105
+ nmi = 0.5 * (mi + mi.T) / norm_factor
106
+
107
+ return BalanceOutput(nmi)
108
+
109
+
110
+ def balance_classwise(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: int = 5) -> BalanceOutput:
111
+ """
112
+ Compute mutual information (analogous to correlation) between metadata factors
113
+ (class label, metadata, label/image properties) with individual class labels.
114
+
115
+ Parameters
116
+ ----------
117
+ class_labels: Sequence[int]
118
+ List of class labels for each image
119
+ metadata: List[Dict]
120
+ List of metadata factors for each image
121
+ num_neighbors: int, default 5
122
+ Number of nearest neighbors to use for computing MI between discrete
123
+ and continuous variables.
124
+
125
+ Notes
126
+ -----
127
+ We use `mutual_info_classif` from sklearn since class label is categorical.
128
+ `mutual_info_classif` outputs are consistent up to O(1e-4) and depend on a random
129
+ seed. MI is computed differently for categorical and continuous variables, so we
130
+ have to specify with is_categorical.
131
+
132
+ Returns
133
+ -------
134
+ BalanceOutput
135
+ (num_classes x num_factors) estimate of mutual information between
136
+ num_factors metadata factors and individual class labels.
137
+
138
+ See Also
139
+ --------
140
+ sklearn.feature_selection.mutual_info_classif
141
+ sklearn.feature_selection.mutual_info_regression
142
+ sklearn.metrics.mutual_info_score
143
+ compute_mutual_information
144
+ """
145
+ num_neighbors = validate_num_neighbors(num_neighbors)
146
+ data, names, is_categorical = preprocess_metadata(class_labels, metadata)
147
+ num_factors = len(names)
148
+ # unique class labels
149
+ class_idx = names.index("class_label")
150
+ class_data = data[:, class_idx]
151
+ u_cls = np.unique(class_data)
152
+ num_classes = len(u_cls)
153
+
154
+ data_no_class = np.concatenate((data[:, :class_idx], data[:, (class_idx + 1) :]), axis=1)
155
+
156
+ # assume class is a factor
157
+ mi = np.empty((num_classes, num_factors - 1))
158
+ mi[:] = np.nan
159
+
160
+ # categorical variables, excluding class label
161
+ cat_mask = np.concatenate((is_categorical[:class_idx], is_categorical[(class_idx + 1) :]), axis=0).astype(int)
162
+
163
+ # classification MI for discrete/categorical features
164
+ for idx, cls in enumerate(u_cls):
165
+ tgt = class_data == cls
166
+ # units: nat
167
+ mi[idx, :] = mutual_info_classif(
168
+ data_no_class,
169
+ tgt,
170
+ discrete_features=cat_mask, # type: ignore
171
+ n_neighbors=num_neighbors,
172
+ )
173
+
174
+ # let this recompute for all features including class label
175
+ ent_all = entropy(data, names, is_categorical)
176
+ ent_tgt = ent_all[class_idx]
177
+ ent_all = np.concatenate((ent_all[:class_idx], ent_all[(class_idx + 1) :]), axis=0)
178
+ norm_factor = 0.5 * np.add.outer(ent_tgt, ent_all) + 1e-6
179
+ nmi = mi / norm_factor
180
+ return BalanceOutput(nmi)
@@ -0,0 +1,10 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Generic, TypeVar
3
+
4
+ TOutput = TypeVar("TOutput", bound=dict)
5
+
6
+
7
+ class EvaluateMixin(ABC, Generic[TOutput]):
8
+ @abstractmethod
9
+ def evaluate(self, *args, **kwargs) -> TOutput:
10
+ """Abstract method to calculate metric based off of constructor parameters"""
@@ -0,0 +1,148 @@
1
+ """
2
+ This module contains the implementation of the
3
+ FR Test Statistic based estimate and the
4
+ KNN based estimate for the Bayes Error Rate
5
+
6
+ Learning to Bound the Multi-class Bayes Error (Th. 3 and Th. 4)
7
+ https://arxiv.org/abs/1811.06419
8
+ """
9
+
10
+ from typing import Literal, NamedTuple, Tuple
11
+
12
+ import numpy as np
13
+ from numpy.typing import ArrayLike, NDArray
14
+ from scipy.sparse import coo_matrix
15
+ from scipy.stats import mode
16
+
17
+ from dataeval._internal.interop import to_numpy
18
+ from dataeval._internal.metrics.utils import compute_neighbors, get_classes_counts, get_method, minimum_spanning_tree
19
+
20
+
21
+ class BEROutput(NamedTuple):
22
+ """
23
+ Attributes
24
+ ----------
25
+ ber : float
26
+ The upper bounds of the Bayes Error Rate
27
+ ber_lower : float
28
+ The lower bounds of the Bayes Error Rate
29
+ """
30
+
31
+ ber: float
32
+ ber_lower: float
33
+
34
+
35
+ def ber_mst(X: NDArray, y: NDArray) -> Tuple[float, float]:
36
+ """Calculates the Bayes Error Rate using a minimum spanning tree
37
+
38
+ Parameters
39
+ ----------
40
+ X : NDArray, shape - (N, ... )
41
+ n_samples containing n_features
42
+ y : NDArray, shape - (N, 1)
43
+ Labels corresponding to each sample
44
+
45
+ Returns
46
+ -------
47
+ Tuple[float, float]
48
+ The upper and lower bounds of the bayes error rate
49
+ """
50
+ M, N = get_classes_counts(y)
51
+
52
+ tree = coo_matrix(minimum_spanning_tree(X))
53
+ matches = np.sum([y[tree.row[i]] != y[tree.col[i]] for i in range(N - 1)])
54
+ deltas = matches / (2 * N)
55
+ upper = 2 * deltas
56
+ lower = ((M - 1) / (M)) * (1 - max(1 - 2 * ((M) / (M - 1)) * deltas, 0) ** 0.5)
57
+ return upper, lower
58
+
59
+
60
+ def ber_knn(X: NDArray, y: NDArray, k: int) -> Tuple[float, float]:
61
+ """Calculates the Bayes Error Rate using K-nearest neighbors
62
+
63
+ Parameters
64
+ ----------
65
+ X : NDArray, shape - (N, ... )
66
+ n_samples containing n_features
67
+ y : NDArray, shape - (N, 1)
68
+ Labels corresponding to each sample
69
+
70
+ Returns
71
+ -------
72
+ Tuple[float, float]
73
+ The upper and lower bounds of the bayes error rate
74
+ """
75
+ M, N = get_classes_counts(y)
76
+
77
+ # All features belong on second dimension
78
+ X = X.reshape((X.shape[0], -1))
79
+ nn_indices = compute_neighbors(X, X, k=k)
80
+ nn_indices = np.expand_dims(nn_indices, axis=1) if nn_indices.ndim == 1 else nn_indices
81
+ modal_class = mode(y[nn_indices], axis=1, keepdims=True).mode.squeeze()
82
+ upper = float(np.count_nonzero(modal_class - y) / N)
83
+ lower = knn_lowerbound(upper, M, k)
84
+ return upper, lower
85
+
86
+
87
+ def knn_lowerbound(value: float, classes: int, k: int) -> float:
88
+ """Several cases for computing the BER lower bound"""
89
+ if value <= 1e-10:
90
+ return 0.0
91
+
92
+ if classes == 2 and k != 1:
93
+ if k > 5:
94
+ # Property 2 (Devroye, 1981) cited in Snoopy paper, not in snoopy repo
95
+ alpha = 0.3399
96
+ beta = 0.9749
97
+ a_k = alpha * np.sqrt(k) / (k - 3.25) * (1 + beta / (np.sqrt(k - 3)))
98
+ return value / (1 + a_k)
99
+ if k > 2:
100
+ return value / (1 + (1 / np.sqrt(k)))
101
+ # k == 2:
102
+ return value / 2
103
+
104
+ return ((classes - 1) / classes) * (1 - np.sqrt(max(0, 1 - ((classes / (classes - 1)) * value))))
105
+
106
+
107
+ BER_FN_MAP = {"KNN": ber_knn, "MST": ber_mst}
108
+
109
+
110
+ def ber(images: ArrayLike, labels: ArrayLike, k: int = 1, method: Literal["KNN", "MST"] = "KNN") -> BEROutput:
111
+ """
112
+ An estimator for Multi-class Bayes Error Rate using FR or KNN test statistic basis
113
+
114
+ Parameters
115
+ ----------
116
+ images : ArrayLike (N, ... )
117
+ Array of images or image embeddings
118
+ labels : ArrayLike (N, 1)
119
+ Array of labels for each image or image embedding
120
+ k : int, default 1
121
+ Number of nearest neighbors for KNN estimator -- ignored by MST estimator
122
+ method : Literal["KNN", "MST"], default "KNN"
123
+ Method to use when estimating the Bayes error rate
124
+
125
+ Returns
126
+ -------
127
+ BEROutput
128
+ The upper and lower bounds of the Bayes Error Rate
129
+
130
+ References
131
+ ----------
132
+ [1] `Learning to Bound the Multi-class Bayes Error (Th. 3 and Th. 4) <https://arxiv.org/abs/1811.06419>`_
133
+
134
+ Examples
135
+ --------
136
+ >>> import sklearn.datasets as dsets
137
+ >>> from dataeval.metrics import ber
138
+
139
+ >>> images, labels = dsets.make_blobs(n_samples=50, centers=2, n_features=2, random_state=0)
140
+
141
+ >>> ber(images, labels)
142
+ BEROutput(ber=0.04, ber_lower=0.020416847668728033)
143
+ """
144
+ ber_fn = get_method(BER_FN_MAP, method)
145
+ X = to_numpy(images)
146
+ y = to_numpy(labels)
147
+ upper, lower = ber_fn(X, y, k) if method == "KNN" else ber_fn(X, y)
148
+ return BEROutput(upper, lower)