dataeval 0.61.0__py3-none-any.whl → 0.64.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/_internal/detectors/clusterer.py +45 -16
  3. dataeval/_internal/detectors/drift/base.py +15 -12
  4. dataeval/_internal/detectors/drift/cvm.py +12 -8
  5. dataeval/_internal/detectors/drift/ks.py +7 -3
  6. dataeval/_internal/detectors/drift/mmd.py +15 -12
  7. dataeval/_internal/detectors/drift/uncertainty.py +6 -5
  8. dataeval/_internal/detectors/duplicates.py +35 -11
  9. dataeval/_internal/detectors/linter.py +85 -16
  10. dataeval/_internal/detectors/ood/ae.py +7 -5
  11. dataeval/_internal/detectors/ood/aegmm.py +6 -5
  12. dataeval/_internal/detectors/ood/base.py +15 -13
  13. dataeval/_internal/detectors/ood/llr.py +8 -5
  14. dataeval/_internal/detectors/ood/vae.py +6 -4
  15. dataeval/_internal/detectors/ood/vaegmm.py +6 -4
  16. dataeval/_internal/interop.py +43 -0
  17. dataeval/_internal/metrics/balance.py +180 -0
  18. dataeval/_internal/metrics/base.py +2 -84
  19. dataeval/_internal/metrics/ber.py +77 -53
  20. dataeval/_internal/metrics/coverage.py +80 -55
  21. dataeval/_internal/metrics/divergence.py +62 -54
  22. dataeval/_internal/metrics/diversity.py +206 -0
  23. dataeval/_internal/metrics/parity.py +292 -163
  24. dataeval/_internal/metrics/stats.py +48 -35
  25. dataeval/_internal/metrics/uap.py +31 -26
  26. dataeval/_internal/metrics/utils.py +237 -2
  27. dataeval/_internal/utils.py +64 -0
  28. dataeval/_internal/workflows/__init__.py +0 -0
  29. dataeval/metrics/__init__.py +25 -5
  30. dataeval/utils/__init__.py +9 -0
  31. {dataeval-0.61.0.dist-info → dataeval-0.64.0.dist-info}/METADATA +1 -2
  32. dataeval-0.64.0.dist-info/RECORD +60 -0
  33. dataeval/_internal/metrics/hash.py +0 -79
  34. dataeval-0.61.0.dist-info/RECORD +0 -55
  35. {dataeval-0.61.0.dist-info → dataeval-0.64.0.dist-info}/LICENSE.txt +0 -0
  36. {dataeval-0.61.0.dist-info → dataeval-0.64.0.dist-info}/WHEEL +0 -0
@@ -1,92 +1,10 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Callable, Dict, Generic, List, TypeVar
2
+ from typing import Generic, TypeVar
3
3
 
4
4
  TOutput = TypeVar("TOutput", bound=dict)
5
- TMethods = TypeVar("TMethods")
6
- TCallable = TypeVar("TCallable", bound=Callable)
7
-
8
-
9
- class MetricMixin(ABC, Generic[TOutput]):
10
- @abstractmethod
11
- def update(self, preds, targets): ...
12
-
13
- @abstractmethod
14
- def compute(self) -> TOutput: ...
15
-
16
- @abstractmethod
17
- def reset(self): ...
18
5
 
19
6
 
20
7
  class EvaluateMixin(ABC, Generic[TOutput]):
21
8
  @abstractmethod
22
- def evaluate(self) -> TOutput:
9
+ def evaluate(self, *args, **kwargs) -> TOutput:
23
10
  """Abstract method to calculate metric based off of constructor parameters"""
24
-
25
-
26
- class MethodsMixin(ABC, Generic[TMethods, TCallable]):
27
- """
28
- Use this mixin to define a mapping of functions to method names which
29
- can be queried by the user and called internally with the appropriate
30
- method name as the key.
31
-
32
- Explicitly defining the Callable generic helps with type safety and
33
- hinting for function signatures and recommended but optional.
34
-
35
- e.g.:
36
-
37
- def _mult(x: float, y: float) -> float:
38
- return x * y
39
-
40
- class MyMetric(MethodsMixin[Callable[float, float], float]):
41
-
42
- def _methods(cls) -> Dict[str, Callable[float, float], float]:
43
- return {
44
- "ADD": lambda x, y: x + y,
45
- "MULT": _mult,
46
- ...
47
- }
48
-
49
- Then during evaluate, you can call the method specified with the getter.
50
-
51
- e.g.:
52
-
53
- def evaluate(self):
54
- return self._method(x, y)
55
-
56
- The resulting class can be used like so.
57
-
58
- m = MyMetric(1.0, 2.0, "ADD")
59
- m.evaluate() # returns 3.0
60
- m.method # returns "ADD"
61
- MyMetric.methods() # returns "['ADD', 'MULT']
62
- m.method = "MULT"
63
- m.evaluate() # returns 2.0
64
- """
65
-
66
- @classmethod
67
- @abstractmethod
68
- def _methods(cls) -> Dict[str, TCallable]:
69
- """Abstract method returning available method functions for class"""
70
-
71
- @property
72
- def _method(self) -> TCallable:
73
- return self._methods()[self.method]
74
-
75
- @classmethod
76
- def methods(cls) -> List[str]:
77
- return list(cls._methods().keys())
78
-
79
- @property
80
- def method(self) -> str:
81
- return self._method_key
82
-
83
- @method.setter
84
- def method(self, value: TMethods):
85
- self._set_method(value)
86
-
87
- def _set_method(self, value: TMethods):
88
- """This setter is to fix pyright incorrect detection of
89
- incorrectly overriding the 'method' property"""
90
- if value not in self.methods():
91
- raise KeyError(f"Specified method not available for class ({self.methods()}).")
92
- self._method_key = value
@@ -7,19 +7,46 @@ Learning to Bound the Multi-class Bayes Error (Th. 3 and Th. 4)
7
7
  https://arxiv.org/abs/1811.06419
8
8
  """
9
9
 
10
- from typing import Callable, Dict, Literal, Tuple
10
+ from typing import Literal, NamedTuple, Tuple
11
11
 
12
12
  import numpy as np
13
- from maite.protocols import ArrayLike
13
+ from numpy.typing import ArrayLike, NDArray
14
14
  from scipy.sparse import coo_matrix
15
15
  from scipy.stats import mode
16
16
 
17
- from dataeval._internal.metrics.base import EvaluateMixin, MethodsMixin
17
+ from dataeval._internal.interop import to_numpy
18
+ from dataeval._internal.metrics.utils import compute_neighbors, get_classes_counts, get_method, minimum_spanning_tree
18
19
 
19
- from .utils import compute_neighbors, get_classes_counts, minimum_spanning_tree
20
+
21
+ class BEROutput(NamedTuple):
22
+ """
23
+ Attributes
24
+ ----------
25
+ ber : float
26
+ The upper bounds of the Bayes Error Rate
27
+ ber_lower : float
28
+ The lower bounds of the Bayes Error Rate
29
+ """
30
+
31
+ ber: float
32
+ ber_lower: float
20
33
 
21
34
 
22
- def _mst(X: np.ndarray, y: np.ndarray, _: int) -> Tuple[float, float]:
35
+ def ber_mst(X: NDArray, y: NDArray) -> Tuple[float, float]:
36
+ """Calculates the Bayes Error Rate using a minimum spanning tree
37
+
38
+ Parameters
39
+ ----------
40
+ X : NDArray, shape - (N, ... )
41
+ n_samples containing n_features
42
+ y : NDArray, shape - (N, 1)
43
+ Labels corresponding to each sample
44
+
45
+ Returns
46
+ -------
47
+ Tuple[float, float]
48
+ The upper and lower bounds of the bayes error rate
49
+ """
23
50
  M, N = get_classes_counts(y)
24
51
 
25
52
  tree = coo_matrix(minimum_spanning_tree(X))
@@ -30,7 +57,21 @@ def _mst(X: np.ndarray, y: np.ndarray, _: int) -> Tuple[float, float]:
30
57
  return upper, lower
31
58
 
32
59
 
33
- def _knn(X: np.ndarray, y: np.ndarray, k: int) -> Tuple[float, float]:
60
+ def ber_knn(X: NDArray, y: NDArray, k: int) -> Tuple[float, float]:
61
+ """Calculates the Bayes Error Rate using K-nearest neighbors
62
+
63
+ Parameters
64
+ ----------
65
+ X : NDArray, shape - (N, ... )
66
+ n_samples containing n_features
67
+ y : NDArray, shape - (N, 1)
68
+ Labels corresponding to each sample
69
+
70
+ Returns
71
+ -------
72
+ Tuple[float, float]
73
+ The upper and lower bounds of the bayes error rate
74
+ """
34
75
  M, N = get_classes_counts(y)
35
76
 
36
77
  # All features belong on second dimension
@@ -39,12 +80,12 @@ def _knn(X: np.ndarray, y: np.ndarray, k: int) -> Tuple[float, float]:
39
80
  nn_indices = np.expand_dims(nn_indices, axis=1) if nn_indices.ndim == 1 else nn_indices
40
81
  modal_class = mode(y[nn_indices], axis=1, keepdims=True).mode.squeeze()
41
82
  upper = float(np.count_nonzero(modal_class - y) / N)
42
- lower = _knn_lowerbound(upper, M, k)
83
+ lower = knn_lowerbound(upper, M, k)
43
84
  return upper, lower
44
85
 
45
86
 
46
- def _knn_lowerbound(value: float, classes: int, k: int) -> float:
47
- "Several cases for computing the BER lower bound"
87
+ def knn_lowerbound(value: float, classes: int, k: int) -> float:
88
+ """Several cases for computing the BER lower bound"""
48
89
  if value <= 1e-10:
49
90
  return 0.0
50
91
 
@@ -63,62 +104,45 @@ def _knn_lowerbound(value: float, classes: int, k: int) -> float:
63
104
  return ((classes - 1) / classes) * (1 - np.sqrt(max(0, 1 - ((classes / (classes - 1)) * value))))
64
105
 
65
106
 
66
- _METHODS = Literal["MST", "KNN"]
67
- _FUNCTION = Callable[[np.ndarray, np.ndarray, int], Tuple[float, float]]
107
+ BER_FN_MAP = {"KNN": ber_knn, "MST": ber_mst}
68
108
 
69
109
 
70
- class BER(EvaluateMixin, MethodsMixin[_METHODS, _FUNCTION]):
110
+ def ber(images: ArrayLike, labels: ArrayLike, k: int = 1, method: Literal["KNN", "MST"] = "KNN") -> BEROutput:
71
111
  """
72
112
  An estimator for Multi-class Bayes Error Rate using FR or KNN test statistic basis
73
113
 
74
114
  Parameters
75
115
  ----------
76
- data : np.ndarray
116
+ images : ArrayLike (N, ... )
77
117
  Array of images or image embeddings
78
- labels : np.ndarray
118
+ labels : ArrayLike (N, 1)
79
119
  Array of labels for each image or image embedding
80
- method : Literal["MST", "KNN"], default "KNN"
81
- Method to use when estimating the Bayes error rate
82
120
  k : int, default 1
83
- number of nearest neighbors for KNN estimator -- ignored by MST estimator
121
+ Number of nearest neighbors for KNN estimator -- ignored by MST estimator
122
+ method : Literal["KNN", "MST"], default "KNN"
123
+ Method to use when estimating the Bayes error rate
84
124
 
125
+ Returns
126
+ -------
127
+ BEROutput
128
+ The upper and lower bounds of the Bayes Error Rate
85
129
 
86
- See Also
130
+ References
131
+ ----------
132
+ [1] `Learning to Bound the Multi-class Bayes Error (Th. 3 and Th. 4) <https://arxiv.org/abs/1811.06419>`_
133
+
134
+ Examples
87
135
  --------
88
- `Learning to Bound the Multi-class Bayes Error (Th. 3 and Th. 4) <https://arxiv.org/abs/1811.06419>`_
136
+ >>> import sklearn.datasets as dsets
137
+ >>> from dataeval.metrics import ber
89
138
 
90
- """
139
+ >>> images, labels = dsets.make_blobs(n_samples=50, centers=2, n_features=2, random_state=0)
91
140
 
92
- def __init__(self, data: ArrayLike, labels: ArrayLike, method: _METHODS = "KNN", k: int = 1) -> None:
93
- self.data = data
94
- self.labels = labels
95
- self.k = k
96
- self._set_method(method)
97
-
98
- @classmethod
99
- def _methods(
100
- cls,
101
- ) -> Dict[str, _FUNCTION]:
102
- return {"MST": _mst, "KNN": _knn}
103
-
104
- def evaluate(self) -> Dict[str, float]:
105
- """
106
- Calculates the Bayes Error Rate estimate using the provided method
107
-
108
- Returns
109
- -------
110
- Dict[str, float]
111
- ber : float
112
- The estimated lower bounds of the Bayes Error Rate
113
- ber_lower : float
114
- The estimated upper bounds of the Bayes Error Rate
115
-
116
- Raises
117
- ------
118
- ValueError
119
- If unique classes M < 2
120
- """
121
- data = np.asarray(self.data)
122
- labels = np.asarray(self.labels)
123
- upper, lower = self._method(data, labels, self.k)
124
- return {"ber": upper, "ber_lower": lower}
141
+ >>> ber(images, labels)
142
+ BEROutput(ber=0.04, ber_lower=0.020416847668728033)
143
+ """
144
+ ber_fn = get_method(BER_FN_MAP, method)
145
+ X = to_numpy(images)
146
+ y = to_numpy(labels)
147
+ upper, lower = ber_fn(X, y, k) if method == "KNN" else ber_fn(X, y)
148
+ return BEROutput(upper, lower)
@@ -1,80 +1,105 @@
1
1
  import math
2
- from typing import Literal, Tuple
2
+ from typing import Literal, NamedTuple
3
3
 
4
4
  import numpy as np
5
+ from numpy.typing import ArrayLike, NDArray
5
6
  from scipy.spatial.distance import pdist, squareform
6
7
 
8
+ from dataeval._internal.interop import to_numpy
7
9
 
8
- class Coverage:
10
+
11
+ class CoverageOutput(NamedTuple):
12
+ """
13
+ Attributes
14
+ ----------
15
+ indices : np.ndarray
16
+ Array of uncovered indices
17
+ radii : np.ndarray
18
+ Array of critical value radii
19
+ critical_value : float
20
+ Radius for coverage
9
21
  """
10
- Class for evaluating coverage and identifying images/samples that are in undercovered regions.
11
22
 
12
- This implementation is based on https://dl.acm.org/doi/abs/10.1145/3448016.3457315.
23
+ indices: NDArray[np.intp]
24
+ radii: NDArray[np.float64]
25
+ critical_value: float
26
+
27
+
28
+ def coverage(
29
+ embeddings: ArrayLike,
30
+ radius_type: Literal["adaptive", "naive"] = "adaptive",
31
+ k: int = 20,
32
+ percent: np.float64 = np.float64(0.01),
33
+ ) -> CoverageOutput:
34
+ """
35
+ Class for evaluating coverage and identifying images/samples that are in undercovered regions.
13
36
 
14
37
  Parameters
15
38
  ----------
16
- embeddings : np.ndarray
17
- n x p array of image embeddings from the dataset.
39
+ embeddings : ArrayLike, shape - (N, P)
40
+ A dataset in an ArrayLike format.
41
+ Function expects the data to have 2 dimensions, N number of observations in a P-dimesionial space.
18
42
  radius_type : Literal["adaptive", "naive"], default "adaptive"
19
43
  The function used to determine radius.
20
44
  k: int, default 20
21
45
  Number of observations required in order to be covered.
46
+ [1] suggests that a minimum of 20-50 samples is necessary.
22
47
  percent: np.float64, default np.float(0.01)
23
48
  Percent of observations to be considered uncovered. Only applies to adaptive radius.
24
49
 
50
+ Returns
51
+ -------
52
+ CoverageOutput
53
+ Array of uncovered indices, critical value radii, and the radius for coverage
54
+
55
+ Raises
56
+ ------
57
+ ValueError
58
+ If length of embeddings is less than or equal to k
59
+ ValueError
60
+ If radius_type is unknown
61
+
25
62
  Note
26
63
  ----
27
64
  Embeddings should be on the unit interval.
28
- """
29
65
 
30
- def __init__(
31
- self,
32
- embeddings: np.ndarray,
33
- radius_type: Literal["adaptive", "naive"] = "adaptive",
34
- k: int = 20,
35
- percent: np.float64 = np.float64(0.01),
36
- ):
37
- self.embeddings = embeddings
38
- self.radius_type = radius_type
39
- self.k = k
40
- self.percent = percent
66
+ Example
67
+ -------
68
+ >>> coverage(embeddings)
69
+ CoverageOutput(indices=array([], dtype=int64), radii=array([0.59307666, 0.56956307, 0.56328616, 0.70660265, 0.57778087,
70
+ 0.53738624, 0.58968217, 1.27721334, 0.84378694, 0.67767021,
71
+ 0.69680335, 1.35532621, 0.59764166, 0.8691945 , 0.83627602,
72
+ 0.84187303, 0.62212358, 1.09039732, 0.67956797, 0.60134383,
73
+ 0.83713908, 0.91784263, 1.12901193, 0.73907618, 0.63943983,
74
+ 0.61188447, 0.47872713, 0.57207771, 0.92885883, 0.54750511,
75
+ 0.83015726, 1.20721778, 0.50421928, 0.98312246, 0.59764166,
76
+ 0.61009202, 0.73864073, 1.0381061 , 0.77598609, 0.72984036,
77
+ 0.67573006, 0.48056064, 1.00050879, 0.89532971, 0.58395529,
78
+ 0.95954793, 0.60134383, 1.10096454, 0.51955314, 0.73038702]), critical_value=0)
41
79
 
42
- def evaluate(self) -> Tuple[np.ndarray, np.ndarray]:
43
- """
44
- Perform a one-way chi-squared test between observation frequencies and expected frequencies that
45
- tests the null hypothesis that the observed data has the expected frequencies.
46
-
47
- Returns
48
- -------
49
- np.ndarray
50
- Array of uncovered indices
51
- np.ndarray
52
- Array of critical value radii
53
-
54
- Raises
55
- ------
56
- ValueError
57
- If length of embeddings is less than or equal to k
58
- ValueError
59
- If radius_type is unknown
60
- """
80
+ Reference
81
+ ---------
82
+ This implementation is based on https://dl.acm.org/doi/abs/10.1145/3448016.3457315.
83
+ [1] Seymour Sudman. 1976. Applied sampling. Academic Press New York (1976).
84
+ """ # noqa: E501
61
85
 
62
- # Calculate distance matrix, look at the (k+1)th farthest neighbor for each image.
63
- n = len(self.embeddings)
64
- if n <= self.k:
65
- raise ValueError("Number of observations less than or equal to the specified number of neighbors.")
66
- mat = squareform(pdist(self.embeddings))
67
- sorted_dists = np.sort(mat, axis=1)
68
- crit = sorted_dists[:, self.k + 1]
86
+ # Calculate distance matrix, look at the (k+1)th farthest neighbor for each image.
87
+ embeddings = to_numpy(embeddings)
88
+ n = len(embeddings)
89
+ if n <= k:
90
+ raise ValueError("Number of observations less than or equal to the specified number of neighbors.")
91
+ mat = squareform(pdist(embeddings)).astype(np.float64)
92
+ sorted_dists = np.sort(mat, axis=1)
93
+ crit = sorted_dists[:, k + 1]
69
94
 
70
- d = np.shape(self.embeddings)[1]
71
- if self.radius_type == "naive":
72
- self.rho = (1 / math.sqrt(math.pi)) * ((2 * self.k * math.gamma(d / 2 + 1)) / (n)) ** (1 / d)
73
- pvals = np.where(crit > self.rho)[0]
74
- elif self.radius_type == "adaptive":
75
- # Use data adaptive cutoff
76
- cutoff = int(n * self.percent)
77
- pvals = np.argsort(crit)[::-1][:cutoff]
78
- else:
79
- raise ValueError("Invalid radius type.")
80
- return pvals, crit
95
+ d = np.shape(embeddings)[1]
96
+ if radius_type == "naive":
97
+ rho = (1 / math.sqrt(math.pi)) * ((2 * k * math.gamma(d / 2 + 1)) / (n)) ** (1 / d)
98
+ pvals = np.where(crit > rho)[0]
99
+ elif radius_type == "adaptive":
100
+ # Use data adaptive cutoff as rho
101
+ rho = int(n * percent)
102
+ pvals = np.argsort(crit)[::-1][:rho]
103
+ else:
104
+ raise ValueError("Invalid radius type.")
105
+ return CoverageOutput(pvals, crit, rho)
@@ -3,49 +3,69 @@ This module contains the implementation of HP Divergence
3
3
  using the Fast Nearest Neighbor and Minimum Spanning Tree algorithms
4
4
  """
5
5
 
6
- from typing import Any, Callable, Dict, Literal
6
+ from typing import Literal, NamedTuple
7
7
 
8
8
  import numpy as np
9
+ from numpy.typing import ArrayLike
9
10
 
10
- from dataeval._internal.metrics.base import EvaluateMixin, MethodsMixin
11
+ from dataeval._internal.interop import to_numpy
12
+ from dataeval._internal.metrics.utils import compute_neighbors, get_method, minimum_spanning_tree
11
13
 
12
- from .utils import compute_neighbors, minimum_spanning_tree
13
14
 
15
+ class DivergenceOutput(NamedTuple):
16
+ """
17
+ Attributes
18
+ ----------
19
+ divergence : float
20
+ Divergence value calculated between 2 datasets ranging between 0.0 and 1.0
21
+ errors : int
22
+ The number of differing edges between the datasets
23
+ """
24
+
25
+ divergence: float
26
+ errors: int
14
27
 
15
- def _mst(data: np.ndarray, labels: np.ndarray) -> int:
28
+
29
+ def divergence_mst(data: np.ndarray, labels: np.ndarray) -> int:
16
30
  mst = minimum_spanning_tree(data).toarray()
17
31
  edgelist = np.transpose(np.nonzero(mst))
18
32
  errors = np.sum(labels[edgelist[:, 0]] != labels[edgelist[:, 1]])
19
33
  return errors
20
34
 
21
35
 
22
- def _fnn(data: np.ndarray, labels: np.ndarray) -> int:
36
+ def divergence_fnn(data: np.ndarray, labels: np.ndarray) -> int:
23
37
  nn_indices = compute_neighbors(data, data)
24
38
  errors = np.sum(np.abs(labels[nn_indices] - labels))
25
39
  return errors
26
40
 
27
41
 
28
- _METHODS = Literal["MST", "FNN"]
29
- _FUNCTION = Callable[[np.ndarray, np.ndarray], int]
42
+ DIVERGENCE_FN_MAP = {"FNN": divergence_fnn, "MST": divergence_mst}
30
43
 
31
44
 
32
- class Divergence(EvaluateMixin, MethodsMixin[_METHODS, _FUNCTION]):
45
+ def divergence(data_a: ArrayLike, data_b: ArrayLike, method: Literal["FNN", "MST"] = "FNN") -> DivergenceOutput:
33
46
  """
34
- Calculates the estimated divergence between two datasets
47
+ Calculates the divergence and any errors between the datasets
35
48
 
36
49
  Parameters
37
50
  ----------
38
- data_a : np.ndarray
39
- Array of images or image embeddings to compare
40
- data_b : np.ndarray
41
- Array of images or image embeddings to compare
42
- method : Literal["MST, "FNN"], default "MST"
51
+ data_a : ArrayLike, shape - (N, P)
52
+ A dataset in an ArrayLike format to compare.
53
+ Function expects the data to have 2 dimensions, N number of observations in a P-dimesionial space.
54
+ data_b : ArrayLike, shape - (N, P)
55
+ A dataset in an ArrayLike format to compare.
56
+ Function expects the data to have 2 dimensions, N number of observations in a P-dimesionial space.
57
+ method : Literal["MST, "FNN"], default "FNN"
43
58
  Method used to estimate dataset divergence
44
59
 
45
- See Also
46
- --------
47
- For more information about this divergence, its formal definition,
48
- and its associated estimators see https://arxiv.org/abs/1412.6534.
60
+ Returns
61
+ -------
62
+ DivergenceOutput
63
+ The divergence value (0.0..1.0) and the number of differing edges between the datasets
64
+
65
+ Notes
66
+ -----
67
+ The divergence value indicates how similar the 2 datasets are
68
+ with 0 indicating approximately identical data distributions.
49
69
 
50
70
  Warning
51
71
  -------
@@ -55,40 +75,28 @@ class Divergence(EvaluateMixin, MethodsMixin[_METHODS, _FUNCTION]):
55
75
  Source of slowdown:
56
76
  conversion to and from CSR format adds ~10% of the time diff between
57
77
  1nn and scipy mst function the remaining 90%
58
- """
59
78
 
60
- def __init__(
61
- self,
62
- data_a: np.ndarray,
63
- data_b: np.ndarray,
64
- method: _METHODS = "MST",
65
- ) -> None:
66
- self.data_a = data_a
67
- self.data_b = data_b
68
- self._set_method(method)
69
-
70
- @classmethod
71
- def _methods(cls) -> Dict[str, _FUNCTION]:
72
- return {"FNN": _fnn, "MST": _mst}
73
-
74
- def evaluate(self) -> Dict[str, Any]:
75
- """
76
- Calculates the divergence and any errors between the datasets
77
-
78
- Returns
79
- -------
80
- Dict[str, Any]
81
- dp : float
82
- divergence value between 0.0 and 1.0
83
- errors : int
84
- the number of differing edges
85
- """
86
- N = self.data_a.shape[0]
87
- M = self.data_b.shape[0]
88
-
89
- stacked_data = np.vstack((self.data_a, self.data_b))
90
- labels = np.vstack([np.zeros([N, 1]), np.ones([M, 1])])
91
-
92
- errors = self._method(stacked_data, labels)
93
- dp = max(0.0, 1 - ((M + N) / (2 * M * N)) * errors)
94
- return {"divergence": dp, "error": errors}
79
+ References
80
+ ----------
81
+ For more information about this divergence, its formal definition,
82
+ and its associated estimators see https://arxiv.org/abs/1412.6534.
83
+
84
+ Examples
85
+ --------
86
+ Evaluate the datasets:
87
+
88
+ >>> divergence(datasetA, datasetB)
89
+ DivergenceOutput(divergence=0.28, errors=36.0)
90
+ """
91
+ div_fn = get_method(DIVERGENCE_FN_MAP, method)
92
+ a = to_numpy(data_a)
93
+ b = to_numpy(data_b)
94
+ N = a.shape[0]
95
+ M = b.shape[0]
96
+
97
+ stacked_data = np.vstack((a, b))
98
+ labels = np.vstack([np.zeros([N, 1]), np.ones([M, 1])])
99
+
100
+ errors = div_fn(stacked_data, labels)
101
+ dp = max(0.0, 1 - ((M + N) / (2 * M * N)) * errors)
102
+ return DivergenceOutput(dp, errors)