dataeval 0.61.0__py3-none-any.whl → 0.64.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/_internal/detectors/clusterer.py +45 -16
- dataeval/_internal/detectors/drift/base.py +15 -12
- dataeval/_internal/detectors/drift/cvm.py +12 -8
- dataeval/_internal/detectors/drift/ks.py +7 -3
- dataeval/_internal/detectors/drift/mmd.py +15 -12
- dataeval/_internal/detectors/drift/uncertainty.py +6 -5
- dataeval/_internal/detectors/duplicates.py +35 -11
- dataeval/_internal/detectors/linter.py +85 -16
- dataeval/_internal/detectors/ood/ae.py +7 -5
- dataeval/_internal/detectors/ood/aegmm.py +6 -5
- dataeval/_internal/detectors/ood/base.py +15 -13
- dataeval/_internal/detectors/ood/llr.py +8 -5
- dataeval/_internal/detectors/ood/vae.py +6 -4
- dataeval/_internal/detectors/ood/vaegmm.py +6 -4
- dataeval/_internal/interop.py +43 -0
- dataeval/_internal/metrics/balance.py +180 -0
- dataeval/_internal/metrics/base.py +2 -84
- dataeval/_internal/metrics/ber.py +77 -53
- dataeval/_internal/metrics/coverage.py +80 -55
- dataeval/_internal/metrics/divergence.py +62 -54
- dataeval/_internal/metrics/diversity.py +206 -0
- dataeval/_internal/metrics/parity.py +292 -163
- dataeval/_internal/metrics/stats.py +48 -35
- dataeval/_internal/metrics/uap.py +31 -26
- dataeval/_internal/metrics/utils.py +237 -2
- dataeval/_internal/utils.py +64 -0
- dataeval/_internal/workflows/__init__.py +0 -0
- dataeval/metrics/__init__.py +25 -5
- dataeval/utils/__init__.py +9 -0
- {dataeval-0.61.0.dist-info → dataeval-0.64.0.dist-info}/METADATA +1 -2
- dataeval-0.64.0.dist-info/RECORD +60 -0
- dataeval/_internal/metrics/hash.py +0 -79
- dataeval-0.61.0.dist-info/RECORD +0 -55
- {dataeval-0.61.0.dist-info → dataeval-0.64.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.61.0.dist-info → dataeval-0.64.0.dist-info}/WHEEL +0 -0
@@ -1,92 +1,10 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from typing import
|
2
|
+
from typing import Generic, TypeVar
|
3
3
|
|
4
4
|
TOutput = TypeVar("TOutput", bound=dict)
|
5
|
-
TMethods = TypeVar("TMethods")
|
6
|
-
TCallable = TypeVar("TCallable", bound=Callable)
|
7
|
-
|
8
|
-
|
9
|
-
class MetricMixin(ABC, Generic[TOutput]):
|
10
|
-
@abstractmethod
|
11
|
-
def update(self, preds, targets): ...
|
12
|
-
|
13
|
-
@abstractmethod
|
14
|
-
def compute(self) -> TOutput: ...
|
15
|
-
|
16
|
-
@abstractmethod
|
17
|
-
def reset(self): ...
|
18
5
|
|
19
6
|
|
20
7
|
class EvaluateMixin(ABC, Generic[TOutput]):
|
21
8
|
@abstractmethod
|
22
|
-
def evaluate(self) -> TOutput:
|
9
|
+
def evaluate(self, *args, **kwargs) -> TOutput:
|
23
10
|
"""Abstract method to calculate metric based off of constructor parameters"""
|
24
|
-
|
25
|
-
|
26
|
-
class MethodsMixin(ABC, Generic[TMethods, TCallable]):
|
27
|
-
"""
|
28
|
-
Use this mixin to define a mapping of functions to method names which
|
29
|
-
can be queried by the user and called internally with the appropriate
|
30
|
-
method name as the key.
|
31
|
-
|
32
|
-
Explicitly defining the Callable generic helps with type safety and
|
33
|
-
hinting for function signatures and recommended but optional.
|
34
|
-
|
35
|
-
e.g.:
|
36
|
-
|
37
|
-
def _mult(x: float, y: float) -> float:
|
38
|
-
return x * y
|
39
|
-
|
40
|
-
class MyMetric(MethodsMixin[Callable[float, float], float]):
|
41
|
-
|
42
|
-
def _methods(cls) -> Dict[str, Callable[float, float], float]:
|
43
|
-
return {
|
44
|
-
"ADD": lambda x, y: x + y,
|
45
|
-
"MULT": _mult,
|
46
|
-
...
|
47
|
-
}
|
48
|
-
|
49
|
-
Then during evaluate, you can call the method specified with the getter.
|
50
|
-
|
51
|
-
e.g.:
|
52
|
-
|
53
|
-
def evaluate(self):
|
54
|
-
return self._method(x, y)
|
55
|
-
|
56
|
-
The resulting class can be used like so.
|
57
|
-
|
58
|
-
m = MyMetric(1.0, 2.0, "ADD")
|
59
|
-
m.evaluate() # returns 3.0
|
60
|
-
m.method # returns "ADD"
|
61
|
-
MyMetric.methods() # returns "['ADD', 'MULT']
|
62
|
-
m.method = "MULT"
|
63
|
-
m.evaluate() # returns 2.0
|
64
|
-
"""
|
65
|
-
|
66
|
-
@classmethod
|
67
|
-
@abstractmethod
|
68
|
-
def _methods(cls) -> Dict[str, TCallable]:
|
69
|
-
"""Abstract method returning available method functions for class"""
|
70
|
-
|
71
|
-
@property
|
72
|
-
def _method(self) -> TCallable:
|
73
|
-
return self._methods()[self.method]
|
74
|
-
|
75
|
-
@classmethod
|
76
|
-
def methods(cls) -> List[str]:
|
77
|
-
return list(cls._methods().keys())
|
78
|
-
|
79
|
-
@property
|
80
|
-
def method(self) -> str:
|
81
|
-
return self._method_key
|
82
|
-
|
83
|
-
@method.setter
|
84
|
-
def method(self, value: TMethods):
|
85
|
-
self._set_method(value)
|
86
|
-
|
87
|
-
def _set_method(self, value: TMethods):
|
88
|
-
"""This setter is to fix pyright incorrect detection of
|
89
|
-
incorrectly overriding the 'method' property"""
|
90
|
-
if value not in self.methods():
|
91
|
-
raise KeyError(f"Specified method not available for class ({self.methods()}).")
|
92
|
-
self._method_key = value
|
@@ -7,19 +7,46 @@ Learning to Bound the Multi-class Bayes Error (Th. 3 and Th. 4)
|
|
7
7
|
https://arxiv.org/abs/1811.06419
|
8
8
|
"""
|
9
9
|
|
10
|
-
from typing import
|
10
|
+
from typing import Literal, NamedTuple, Tuple
|
11
11
|
|
12
12
|
import numpy as np
|
13
|
-
from
|
13
|
+
from numpy.typing import ArrayLike, NDArray
|
14
14
|
from scipy.sparse import coo_matrix
|
15
15
|
from scipy.stats import mode
|
16
16
|
|
17
|
-
from dataeval._internal.
|
17
|
+
from dataeval._internal.interop import to_numpy
|
18
|
+
from dataeval._internal.metrics.utils import compute_neighbors, get_classes_counts, get_method, minimum_spanning_tree
|
18
19
|
|
19
|
-
|
20
|
+
|
21
|
+
class BEROutput(NamedTuple):
|
22
|
+
"""
|
23
|
+
Attributes
|
24
|
+
----------
|
25
|
+
ber : float
|
26
|
+
The upper bounds of the Bayes Error Rate
|
27
|
+
ber_lower : float
|
28
|
+
The lower bounds of the Bayes Error Rate
|
29
|
+
"""
|
30
|
+
|
31
|
+
ber: float
|
32
|
+
ber_lower: float
|
20
33
|
|
21
34
|
|
22
|
-
def
|
35
|
+
def ber_mst(X: NDArray, y: NDArray) -> Tuple[float, float]:
|
36
|
+
"""Calculates the Bayes Error Rate using a minimum spanning tree
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
X : NDArray, shape - (N, ... )
|
41
|
+
n_samples containing n_features
|
42
|
+
y : NDArray, shape - (N, 1)
|
43
|
+
Labels corresponding to each sample
|
44
|
+
|
45
|
+
Returns
|
46
|
+
-------
|
47
|
+
Tuple[float, float]
|
48
|
+
The upper and lower bounds of the bayes error rate
|
49
|
+
"""
|
23
50
|
M, N = get_classes_counts(y)
|
24
51
|
|
25
52
|
tree = coo_matrix(minimum_spanning_tree(X))
|
@@ -30,7 +57,21 @@ def _mst(X: np.ndarray, y: np.ndarray, _: int) -> Tuple[float, float]:
|
|
30
57
|
return upper, lower
|
31
58
|
|
32
59
|
|
33
|
-
def
|
60
|
+
def ber_knn(X: NDArray, y: NDArray, k: int) -> Tuple[float, float]:
|
61
|
+
"""Calculates the Bayes Error Rate using K-nearest neighbors
|
62
|
+
|
63
|
+
Parameters
|
64
|
+
----------
|
65
|
+
X : NDArray, shape - (N, ... )
|
66
|
+
n_samples containing n_features
|
67
|
+
y : NDArray, shape - (N, 1)
|
68
|
+
Labels corresponding to each sample
|
69
|
+
|
70
|
+
Returns
|
71
|
+
-------
|
72
|
+
Tuple[float, float]
|
73
|
+
The upper and lower bounds of the bayes error rate
|
74
|
+
"""
|
34
75
|
M, N = get_classes_counts(y)
|
35
76
|
|
36
77
|
# All features belong on second dimension
|
@@ -39,12 +80,12 @@ def _knn(X: np.ndarray, y: np.ndarray, k: int) -> Tuple[float, float]:
|
|
39
80
|
nn_indices = np.expand_dims(nn_indices, axis=1) if nn_indices.ndim == 1 else nn_indices
|
40
81
|
modal_class = mode(y[nn_indices], axis=1, keepdims=True).mode.squeeze()
|
41
82
|
upper = float(np.count_nonzero(modal_class - y) / N)
|
42
|
-
lower =
|
83
|
+
lower = knn_lowerbound(upper, M, k)
|
43
84
|
return upper, lower
|
44
85
|
|
45
86
|
|
46
|
-
def
|
47
|
-
"Several cases for computing the BER lower bound"
|
87
|
+
def knn_lowerbound(value: float, classes: int, k: int) -> float:
|
88
|
+
"""Several cases for computing the BER lower bound"""
|
48
89
|
if value <= 1e-10:
|
49
90
|
return 0.0
|
50
91
|
|
@@ -63,62 +104,45 @@ def _knn_lowerbound(value: float, classes: int, k: int) -> float:
|
|
63
104
|
return ((classes - 1) / classes) * (1 - np.sqrt(max(0, 1 - ((classes / (classes - 1)) * value))))
|
64
105
|
|
65
106
|
|
66
|
-
|
67
|
-
_FUNCTION = Callable[[np.ndarray, np.ndarray, int], Tuple[float, float]]
|
107
|
+
BER_FN_MAP = {"KNN": ber_knn, "MST": ber_mst}
|
68
108
|
|
69
109
|
|
70
|
-
|
110
|
+
def ber(images: ArrayLike, labels: ArrayLike, k: int = 1, method: Literal["KNN", "MST"] = "KNN") -> BEROutput:
|
71
111
|
"""
|
72
112
|
An estimator for Multi-class Bayes Error Rate using FR or KNN test statistic basis
|
73
113
|
|
74
114
|
Parameters
|
75
115
|
----------
|
76
|
-
|
116
|
+
images : ArrayLike (N, ... )
|
77
117
|
Array of images or image embeddings
|
78
|
-
labels :
|
118
|
+
labels : ArrayLike (N, 1)
|
79
119
|
Array of labels for each image or image embedding
|
80
|
-
method : Literal["MST", "KNN"], default "KNN"
|
81
|
-
Method to use when estimating the Bayes error rate
|
82
120
|
k : int, default 1
|
83
|
-
|
121
|
+
Number of nearest neighbors for KNN estimator -- ignored by MST estimator
|
122
|
+
method : Literal["KNN", "MST"], default "KNN"
|
123
|
+
Method to use when estimating the Bayes error rate
|
84
124
|
|
125
|
+
Returns
|
126
|
+
-------
|
127
|
+
BEROutput
|
128
|
+
The upper and lower bounds of the Bayes Error Rate
|
85
129
|
|
86
|
-
|
130
|
+
References
|
131
|
+
----------
|
132
|
+
[1] `Learning to Bound the Multi-class Bayes Error (Th. 3 and Th. 4) <https://arxiv.org/abs/1811.06419>`_
|
133
|
+
|
134
|
+
Examples
|
87
135
|
--------
|
88
|
-
|
136
|
+
>>> import sklearn.datasets as dsets
|
137
|
+
>>> from dataeval.metrics import ber
|
89
138
|
|
90
|
-
|
139
|
+
>>> images, labels = dsets.make_blobs(n_samples=50, centers=2, n_features=2, random_state=0)
|
91
140
|
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
cls,
|
101
|
-
) -> Dict[str, _FUNCTION]:
|
102
|
-
return {"MST": _mst, "KNN": _knn}
|
103
|
-
|
104
|
-
def evaluate(self) -> Dict[str, float]:
|
105
|
-
"""
|
106
|
-
Calculates the Bayes Error Rate estimate using the provided method
|
107
|
-
|
108
|
-
Returns
|
109
|
-
-------
|
110
|
-
Dict[str, float]
|
111
|
-
ber : float
|
112
|
-
The estimated lower bounds of the Bayes Error Rate
|
113
|
-
ber_lower : float
|
114
|
-
The estimated upper bounds of the Bayes Error Rate
|
115
|
-
|
116
|
-
Raises
|
117
|
-
------
|
118
|
-
ValueError
|
119
|
-
If unique classes M < 2
|
120
|
-
"""
|
121
|
-
data = np.asarray(self.data)
|
122
|
-
labels = np.asarray(self.labels)
|
123
|
-
upper, lower = self._method(data, labels, self.k)
|
124
|
-
return {"ber": upper, "ber_lower": lower}
|
141
|
+
>>> ber(images, labels)
|
142
|
+
BEROutput(ber=0.04, ber_lower=0.020416847668728033)
|
143
|
+
"""
|
144
|
+
ber_fn = get_method(BER_FN_MAP, method)
|
145
|
+
X = to_numpy(images)
|
146
|
+
y = to_numpy(labels)
|
147
|
+
upper, lower = ber_fn(X, y, k) if method == "KNN" else ber_fn(X, y)
|
148
|
+
return BEROutput(upper, lower)
|
@@ -1,80 +1,105 @@
|
|
1
1
|
import math
|
2
|
-
from typing import Literal,
|
2
|
+
from typing import Literal, NamedTuple
|
3
3
|
|
4
4
|
import numpy as np
|
5
|
+
from numpy.typing import ArrayLike, NDArray
|
5
6
|
from scipy.spatial.distance import pdist, squareform
|
6
7
|
|
8
|
+
from dataeval._internal.interop import to_numpy
|
7
9
|
|
8
|
-
|
10
|
+
|
11
|
+
class CoverageOutput(NamedTuple):
|
12
|
+
"""
|
13
|
+
Attributes
|
14
|
+
----------
|
15
|
+
indices : np.ndarray
|
16
|
+
Array of uncovered indices
|
17
|
+
radii : np.ndarray
|
18
|
+
Array of critical value radii
|
19
|
+
critical_value : float
|
20
|
+
Radius for coverage
|
9
21
|
"""
|
10
|
-
Class for evaluating coverage and identifying images/samples that are in undercovered regions.
|
11
22
|
|
12
|
-
|
23
|
+
indices: NDArray[np.intp]
|
24
|
+
radii: NDArray[np.float64]
|
25
|
+
critical_value: float
|
26
|
+
|
27
|
+
|
28
|
+
def coverage(
|
29
|
+
embeddings: ArrayLike,
|
30
|
+
radius_type: Literal["adaptive", "naive"] = "adaptive",
|
31
|
+
k: int = 20,
|
32
|
+
percent: np.float64 = np.float64(0.01),
|
33
|
+
) -> CoverageOutput:
|
34
|
+
"""
|
35
|
+
Class for evaluating coverage and identifying images/samples that are in undercovered regions.
|
13
36
|
|
14
37
|
Parameters
|
15
38
|
----------
|
16
|
-
embeddings :
|
17
|
-
|
39
|
+
embeddings : ArrayLike, shape - (N, P)
|
40
|
+
A dataset in an ArrayLike format.
|
41
|
+
Function expects the data to have 2 dimensions, N number of observations in a P-dimesionial space.
|
18
42
|
radius_type : Literal["adaptive", "naive"], default "adaptive"
|
19
43
|
The function used to determine radius.
|
20
44
|
k: int, default 20
|
21
45
|
Number of observations required in order to be covered.
|
46
|
+
[1] suggests that a minimum of 20-50 samples is necessary.
|
22
47
|
percent: np.float64, default np.float(0.01)
|
23
48
|
Percent of observations to be considered uncovered. Only applies to adaptive radius.
|
24
49
|
|
50
|
+
Returns
|
51
|
+
-------
|
52
|
+
CoverageOutput
|
53
|
+
Array of uncovered indices, critical value radii, and the radius for coverage
|
54
|
+
|
55
|
+
Raises
|
56
|
+
------
|
57
|
+
ValueError
|
58
|
+
If length of embeddings is less than or equal to k
|
59
|
+
ValueError
|
60
|
+
If radius_type is unknown
|
61
|
+
|
25
62
|
Note
|
26
63
|
----
|
27
64
|
Embeddings should be on the unit interval.
|
28
|
-
"""
|
29
65
|
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
66
|
+
Example
|
67
|
+
-------
|
68
|
+
>>> coverage(embeddings)
|
69
|
+
CoverageOutput(indices=array([], dtype=int64), radii=array([0.59307666, 0.56956307, 0.56328616, 0.70660265, 0.57778087,
|
70
|
+
0.53738624, 0.58968217, 1.27721334, 0.84378694, 0.67767021,
|
71
|
+
0.69680335, 1.35532621, 0.59764166, 0.8691945 , 0.83627602,
|
72
|
+
0.84187303, 0.62212358, 1.09039732, 0.67956797, 0.60134383,
|
73
|
+
0.83713908, 0.91784263, 1.12901193, 0.73907618, 0.63943983,
|
74
|
+
0.61188447, 0.47872713, 0.57207771, 0.92885883, 0.54750511,
|
75
|
+
0.83015726, 1.20721778, 0.50421928, 0.98312246, 0.59764166,
|
76
|
+
0.61009202, 0.73864073, 1.0381061 , 0.77598609, 0.72984036,
|
77
|
+
0.67573006, 0.48056064, 1.00050879, 0.89532971, 0.58395529,
|
78
|
+
0.95954793, 0.60134383, 1.10096454, 0.51955314, 0.73038702]), critical_value=0)
|
41
79
|
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
Returns
|
48
|
-
-------
|
49
|
-
np.ndarray
|
50
|
-
Array of uncovered indices
|
51
|
-
np.ndarray
|
52
|
-
Array of critical value radii
|
53
|
-
|
54
|
-
Raises
|
55
|
-
------
|
56
|
-
ValueError
|
57
|
-
If length of embeddings is less than or equal to k
|
58
|
-
ValueError
|
59
|
-
If radius_type is unknown
|
60
|
-
"""
|
80
|
+
Reference
|
81
|
+
---------
|
82
|
+
This implementation is based on https://dl.acm.org/doi/abs/10.1145/3448016.3457315.
|
83
|
+
[1] Seymour Sudman. 1976. Applied sampling. Academic Press New York (1976).
|
84
|
+
""" # noqa: E501
|
61
85
|
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
86
|
+
# Calculate distance matrix, look at the (k+1)th farthest neighbor for each image.
|
87
|
+
embeddings = to_numpy(embeddings)
|
88
|
+
n = len(embeddings)
|
89
|
+
if n <= k:
|
90
|
+
raise ValueError("Number of observations less than or equal to the specified number of neighbors.")
|
91
|
+
mat = squareform(pdist(embeddings)).astype(np.float64)
|
92
|
+
sorted_dists = np.sort(mat, axis=1)
|
93
|
+
crit = sorted_dists[:, k + 1]
|
69
94
|
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
95
|
+
d = np.shape(embeddings)[1]
|
96
|
+
if radius_type == "naive":
|
97
|
+
rho = (1 / math.sqrt(math.pi)) * ((2 * k * math.gamma(d / 2 + 1)) / (n)) ** (1 / d)
|
98
|
+
pvals = np.where(crit > rho)[0]
|
99
|
+
elif radius_type == "adaptive":
|
100
|
+
# Use data adaptive cutoff as rho
|
101
|
+
rho = int(n * percent)
|
102
|
+
pvals = np.argsort(crit)[::-1][:rho]
|
103
|
+
else:
|
104
|
+
raise ValueError("Invalid radius type.")
|
105
|
+
return CoverageOutput(pvals, crit, rho)
|
@@ -3,49 +3,69 @@ This module contains the implementation of HP Divergence
|
|
3
3
|
using the Fast Nearest Neighbor and Minimum Spanning Tree algorithms
|
4
4
|
"""
|
5
5
|
|
6
|
-
from typing import
|
6
|
+
from typing import Literal, NamedTuple
|
7
7
|
|
8
8
|
import numpy as np
|
9
|
+
from numpy.typing import ArrayLike
|
9
10
|
|
10
|
-
from dataeval._internal.
|
11
|
+
from dataeval._internal.interop import to_numpy
|
12
|
+
from dataeval._internal.metrics.utils import compute_neighbors, get_method, minimum_spanning_tree
|
11
13
|
|
12
|
-
from .utils import compute_neighbors, minimum_spanning_tree
|
13
14
|
|
15
|
+
class DivergenceOutput(NamedTuple):
|
16
|
+
"""
|
17
|
+
Attributes
|
18
|
+
----------
|
19
|
+
divergence : float
|
20
|
+
Divergence value calculated between 2 datasets ranging between 0.0 and 1.0
|
21
|
+
errors : int
|
22
|
+
The number of differing edges between the datasets
|
23
|
+
"""
|
24
|
+
|
25
|
+
divergence: float
|
26
|
+
errors: int
|
14
27
|
|
15
|
-
|
28
|
+
|
29
|
+
def divergence_mst(data: np.ndarray, labels: np.ndarray) -> int:
|
16
30
|
mst = minimum_spanning_tree(data).toarray()
|
17
31
|
edgelist = np.transpose(np.nonzero(mst))
|
18
32
|
errors = np.sum(labels[edgelist[:, 0]] != labels[edgelist[:, 1]])
|
19
33
|
return errors
|
20
34
|
|
21
35
|
|
22
|
-
def
|
36
|
+
def divergence_fnn(data: np.ndarray, labels: np.ndarray) -> int:
|
23
37
|
nn_indices = compute_neighbors(data, data)
|
24
38
|
errors = np.sum(np.abs(labels[nn_indices] - labels))
|
25
39
|
return errors
|
26
40
|
|
27
41
|
|
28
|
-
|
29
|
-
_FUNCTION = Callable[[np.ndarray, np.ndarray], int]
|
42
|
+
DIVERGENCE_FN_MAP = {"FNN": divergence_fnn, "MST": divergence_mst}
|
30
43
|
|
31
44
|
|
32
|
-
|
45
|
+
def divergence(data_a: ArrayLike, data_b: ArrayLike, method: Literal["FNN", "MST"] = "FNN") -> DivergenceOutput:
|
33
46
|
"""
|
34
|
-
Calculates the
|
47
|
+
Calculates the divergence and any errors between the datasets
|
35
48
|
|
36
49
|
Parameters
|
37
50
|
----------
|
38
|
-
data_a :
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
51
|
+
data_a : ArrayLike, shape - (N, P)
|
52
|
+
A dataset in an ArrayLike format to compare.
|
53
|
+
Function expects the data to have 2 dimensions, N number of observations in a P-dimesionial space.
|
54
|
+
data_b : ArrayLike, shape - (N, P)
|
55
|
+
A dataset in an ArrayLike format to compare.
|
56
|
+
Function expects the data to have 2 dimensions, N number of observations in a P-dimesionial space.
|
57
|
+
method : Literal["MST, "FNN"], default "FNN"
|
43
58
|
Method used to estimate dataset divergence
|
44
59
|
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
and
|
60
|
+
Returns
|
61
|
+
-------
|
62
|
+
DivergenceOutput
|
63
|
+
The divergence value (0.0..1.0) and the number of differing edges between the datasets
|
64
|
+
|
65
|
+
Notes
|
66
|
+
-----
|
67
|
+
The divergence value indicates how similar the 2 datasets are
|
68
|
+
with 0 indicating approximately identical data distributions.
|
49
69
|
|
50
70
|
Warning
|
51
71
|
-------
|
@@ -55,40 +75,28 @@ class Divergence(EvaluateMixin, MethodsMixin[_METHODS, _FUNCTION]):
|
|
55
75
|
Source of slowdown:
|
56
76
|
conversion to and from CSR format adds ~10% of the time diff between
|
57
77
|
1nn and scipy mst function the remaining 90%
|
58
|
-
"""
|
59
78
|
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
the number of differing edges
|
85
|
-
"""
|
86
|
-
N = self.data_a.shape[0]
|
87
|
-
M = self.data_b.shape[0]
|
88
|
-
|
89
|
-
stacked_data = np.vstack((self.data_a, self.data_b))
|
90
|
-
labels = np.vstack([np.zeros([N, 1]), np.ones([M, 1])])
|
91
|
-
|
92
|
-
errors = self._method(stacked_data, labels)
|
93
|
-
dp = max(0.0, 1 - ((M + N) / (2 * M * N)) * errors)
|
94
|
-
return {"divergence": dp, "error": errors}
|
79
|
+
References
|
80
|
+
----------
|
81
|
+
For more information about this divergence, its formal definition,
|
82
|
+
and its associated estimators see https://arxiv.org/abs/1412.6534.
|
83
|
+
|
84
|
+
Examples
|
85
|
+
--------
|
86
|
+
Evaluate the datasets:
|
87
|
+
|
88
|
+
>>> divergence(datasetA, datasetB)
|
89
|
+
DivergenceOutput(divergence=0.28, errors=36.0)
|
90
|
+
"""
|
91
|
+
div_fn = get_method(DIVERGENCE_FN_MAP, method)
|
92
|
+
a = to_numpy(data_a)
|
93
|
+
b = to_numpy(data_b)
|
94
|
+
N = a.shape[0]
|
95
|
+
M = b.shape[0]
|
96
|
+
|
97
|
+
stacked_data = np.vstack((a, b))
|
98
|
+
labels = np.vstack([np.zeros([N, 1]), np.ones([M, 1])])
|
99
|
+
|
100
|
+
errors = div_fn(stacked_data, labels)
|
101
|
+
dp = max(0.0, 1 - ((M + N) / (2 * M * N)) * errors)
|
102
|
+
return DivergenceOutput(dp, errors)
|