dataeval 0.63.0__py3-none-any.whl → 0.64.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/_internal/detectors/clusterer.py +2 -1
- dataeval/_internal/detectors/drift/base.py +2 -1
- dataeval/_internal/detectors/drift/cvm.py +2 -1
- dataeval/_internal/detectors/drift/ks.py +2 -1
- dataeval/_internal/detectors/drift/mmd.py +4 -3
- dataeval/_internal/detectors/drift/uncertainty.py +1 -2
- dataeval/_internal/detectors/duplicates.py +2 -1
- dataeval/_internal/detectors/linter.py +1 -1
- dataeval/_internal/detectors/ood/ae.py +2 -1
- dataeval/_internal/detectors/ood/aegmm.py +2 -1
- dataeval/_internal/detectors/ood/base.py +2 -1
- dataeval/_internal/detectors/ood/llr.py +3 -2
- dataeval/_internal/detectors/ood/vae.py +2 -1
- dataeval/_internal/detectors/ood/vaegmm.py +2 -1
- dataeval/_internal/interop.py +2 -11
- dataeval/_internal/metrics/balance.py +180 -0
- dataeval/_internal/metrics/base.py +1 -83
- dataeval/_internal/metrics/ber.py +122 -48
- dataeval/_internal/metrics/coverage.py +83 -74
- dataeval/_internal/metrics/divergence.py +67 -67
- dataeval/_internal/metrics/diversity.py +206 -0
- dataeval/_internal/metrics/parity.py +300 -155
- dataeval/_internal/metrics/stats.py +7 -5
- dataeval/_internal/metrics/uap.py +37 -29
- dataeval/_internal/metrics/utils.py +393 -0
- dataeval/_internal/utils.py +64 -0
- dataeval/metrics/__init__.py +25 -6
- dataeval/utils/__init__.py +9 -0
- {dataeval-0.63.0.dist-info → dataeval-0.64.0.dist-info}/METADATA +1 -1
- dataeval-0.64.0.dist-info/RECORD +60 -0
- dataeval/_internal/functional/__init__.py +0 -0
- dataeval/_internal/functional/ber.py +0 -63
- dataeval/_internal/functional/coverage.py +0 -75
- dataeval/_internal/functional/divergence.py +0 -16
- dataeval/_internal/functional/hash.py +0 -79
- dataeval/_internal/functional/metadata.py +0 -136
- dataeval/_internal/functional/metadataparity.py +0 -190
- dataeval/_internal/functional/uap.py +0 -6
- dataeval/_internal/functional/utils.py +0 -158
- dataeval/_internal/maite/__init__.py +0 -0
- dataeval/_internal/maite/utils.py +0 -30
- dataeval/_internal/metrics/metadata.py +0 -610
- dataeval/_internal/metrics/metadataparity.py +0 -67
- dataeval-0.63.0.dist-info/RECORD +0 -68
- {dataeval-0.63.0.dist-info → dataeval-0.64.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.63.0.dist-info → dataeval-0.64.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py
CHANGED
@@ -2,14 +2,14 @@ from importlib.util import find_spec
|
|
2
2
|
|
3
3
|
from . import detectors, flags, metrics
|
4
4
|
|
5
|
-
__version__ = "0.
|
5
|
+
__version__ = "0.64.0"
|
6
6
|
|
7
7
|
__all__ = ["detectors", "flags", "metrics"]
|
8
8
|
|
9
9
|
if find_spec("torch") is not None: # pragma: no cover
|
10
|
-
from . import models, workflows
|
10
|
+
from . import models, utils, workflows
|
11
11
|
|
12
|
-
__all__ += ["models", "workflows"]
|
12
|
+
__all__ += ["models", "utils", "workflows"]
|
13
13
|
elif find_spec("tensorflow") is not None: # pragma: no cover
|
14
14
|
from . import models
|
15
15
|
|
@@ -1,10 +1,11 @@
|
|
1
1
|
from typing import Dict, Iterable, List, NamedTuple, Tuple, Union, cast
|
2
2
|
|
3
3
|
import numpy as np
|
4
|
+
from numpy.typing import ArrayLike
|
4
5
|
from scipy.cluster.hierarchy import linkage
|
5
6
|
from scipy.spatial.distance import pdist, squareform
|
6
7
|
|
7
|
-
from dataeval._internal.interop import
|
8
|
+
from dataeval._internal.interop import to_numpy
|
8
9
|
|
9
10
|
|
10
11
|
def extend_linkage(link_arr: np.ndarray) -> np.ndarray:
|
@@ -11,8 +11,9 @@ from functools import wraps
|
|
11
11
|
from typing import Callable, Dict, Literal, Optional, Tuple, Union
|
12
12
|
|
13
13
|
import numpy as np
|
14
|
+
from numpy.typing import ArrayLike
|
14
15
|
|
15
|
-
from dataeval._internal.interop import
|
16
|
+
from dataeval._internal.interop import to_numpy
|
16
17
|
|
17
18
|
|
18
19
|
def update_x_ref(fn):
|
@@ -9,9 +9,10 @@ Licensed under Apache Software License (Apache 2.0)
|
|
9
9
|
from typing import Callable, Literal, Optional, Tuple
|
10
10
|
|
11
11
|
import numpy as np
|
12
|
+
from numpy.typing import ArrayLike
|
12
13
|
from scipy.stats import cramervonmises_2samp
|
13
14
|
|
14
|
-
from dataeval._internal.interop import
|
15
|
+
from dataeval._internal.interop import to_numpy
|
15
16
|
|
16
17
|
from .base import BaseUnivariateDrift, UpdateStrategy, preprocess_x
|
17
18
|
|
@@ -9,9 +9,10 @@ Licensed under Apache Software License (Apache 2.0)
|
|
9
9
|
from typing import Callable, Literal, Optional, Tuple
|
10
10
|
|
11
11
|
import numpy as np
|
12
|
+
from numpy.typing import ArrayLike
|
12
13
|
from scipy.stats import ks_2samp
|
13
14
|
|
14
|
-
from dataeval._internal.interop import
|
15
|
+
from dataeval._internal.interop import to_numpy
|
15
16
|
|
16
17
|
from .base import BaseUnivariateDrift, UpdateStrategy, preprocess_x
|
17
18
|
|
@@ -9,8 +9,9 @@ Licensed under Apache Software License (Apache 2.0)
|
|
9
9
|
from typing import Callable, Dict, Optional, Tuple, Union
|
10
10
|
|
11
11
|
import torch
|
12
|
+
from numpy.typing import ArrayLike
|
12
13
|
|
13
|
-
from dataeval._internal.interop import
|
14
|
+
from dataeval._internal.interop import to_numpy
|
14
15
|
|
15
16
|
from .base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
|
16
17
|
from .torch import GaussianRBF, get_device, mmd2_from_kernel_matrix
|
@@ -74,7 +75,7 @@ class DriftMMD(BaseDrift):
|
|
74
75
|
super().__init__(x_ref, p_val, x_ref_preprocessed, update_x_ref, preprocess_fn)
|
75
76
|
|
76
77
|
self.infer_sigma = configure_kernel_from_x_ref
|
77
|
-
if configure_kernel_from_x_ref and
|
78
|
+
if configure_kernel_from_x_ref and sigma is not None:
|
78
79
|
self.infer_sigma = False
|
79
80
|
|
80
81
|
self.n_permutations = n_permutations # nb of iterations through permutation test
|
@@ -83,7 +84,7 @@ class DriftMMD(BaseDrift):
|
|
83
84
|
self.device = get_device(device)
|
84
85
|
|
85
86
|
# initialize kernel
|
86
|
-
sigma_tensor = torch.from_numpy(to_numpy(sigma)).to(self.device) if
|
87
|
+
sigma_tensor = torch.from_numpy(to_numpy(sigma)).to(self.device) if sigma is not None else None
|
87
88
|
self.kernel = kernel(sigma_tensor).to(self.device) if kernel == GaussianRBF else kernel
|
88
89
|
|
89
90
|
# compute kernel matrix for the reference data
|
@@ -10,11 +10,10 @@ from functools import partial
|
|
10
10
|
from typing import Callable, Dict, Literal, Optional, Union
|
11
11
|
|
12
12
|
import numpy as np
|
13
|
+
from numpy.typing import ArrayLike
|
13
14
|
from scipy.special import softmax
|
14
15
|
from scipy.stats import entropy
|
15
16
|
|
16
|
-
from dataeval._internal.interop import ArrayLike
|
17
|
-
|
18
17
|
from .base import UpdateStrategy
|
19
18
|
from .ks import DriftKS
|
20
19
|
from .torch import get_device, preprocess_drift
|
@@ -1,9 +1,9 @@
|
|
1
1
|
from typing import Iterable, Literal, Optional, Sequence, Union
|
2
2
|
|
3
3
|
import numpy as np
|
4
|
+
from numpy.typing import ArrayLike
|
4
5
|
|
5
6
|
from dataeval._internal.flags import ImageProperty, ImageVisuals, LinterFlags
|
6
|
-
from dataeval._internal.interop import ArrayLike
|
7
7
|
from dataeval._internal.metrics.stats import ImageStats
|
8
8
|
|
9
9
|
|
@@ -10,9 +10,10 @@ from typing import Callable
|
|
10
10
|
|
11
11
|
import keras
|
12
12
|
import numpy as np
|
13
|
+
from numpy.typing import ArrayLike
|
13
14
|
|
14
15
|
from dataeval._internal.detectors.ood.base import OODBase, OODScore
|
15
|
-
from dataeval._internal.interop import
|
16
|
+
from dataeval._internal.interop import to_numpy
|
16
17
|
from dataeval._internal.models.tensorflow.autoencoder import AE
|
17
18
|
from dataeval._internal.models.tensorflow.utils import predict_batch
|
18
19
|
|
@@ -9,9 +9,10 @@ Licensed under Apache Software License (Apache 2.0)
|
|
9
9
|
from typing import Callable
|
10
10
|
|
11
11
|
import keras
|
12
|
+
from numpy.typing import ArrayLike
|
12
13
|
|
13
14
|
from dataeval._internal.detectors.ood.base import OODGMMBase, OODScore
|
14
|
-
from dataeval._internal.interop import
|
15
|
+
from dataeval._internal.interop import to_numpy
|
15
16
|
from dataeval._internal.models.tensorflow.autoencoder import AEGMM
|
16
17
|
from dataeval._internal.models.tensorflow.gmm import gmm_energy
|
17
18
|
from dataeval._internal.models.tensorflow.losses import LossGMM
|
@@ -12,8 +12,9 @@ from typing import Callable, Dict, List, Literal, NamedTuple, Optional, Tuple, c
|
|
12
12
|
import keras
|
13
13
|
import numpy as np
|
14
14
|
import tensorflow as tf
|
15
|
+
from numpy.typing import ArrayLike
|
15
16
|
|
16
|
-
from dataeval._internal.interop import
|
17
|
+
from dataeval._internal.interop import to_numpy
|
17
18
|
from dataeval._internal.models.tensorflow.gmm import GaussianMixtureModelParams, gmm_params
|
18
19
|
from dataeval._internal.models.tensorflow.trainer import trainer
|
19
20
|
|
@@ -14,9 +14,10 @@ import numpy as np
|
|
14
14
|
import tensorflow as tf
|
15
15
|
from keras.layers import Input
|
16
16
|
from keras.models import Model
|
17
|
+
from numpy.typing import ArrayLike
|
17
18
|
|
18
19
|
from dataeval._internal.detectors.ood.base import OODBase, OODScore
|
19
|
-
from dataeval._internal.interop import
|
20
|
+
from dataeval._internal.interop import to_numpy
|
20
21
|
from dataeval._internal.models.tensorflow.pixelcnn import PixelCNN
|
21
22
|
from dataeval._internal.models.tensorflow.trainer import trainer
|
22
23
|
from dataeval._internal.models.tensorflow.utils import predict_batch
|
@@ -180,7 +181,7 @@ class OOD_LLR(OODBase):
|
|
180
181
|
|
181
182
|
# create background data
|
182
183
|
mutate_fn = partial(mutate_fn, **mutate_fn_kwargs)
|
183
|
-
X_back = predict_batch(x_ref, mutate_fn, batch_size=mutate_batch_size, dtype=x_ref.dtype)
|
184
|
+
X_back = predict_batch(x_ref, mutate_fn, batch_size=mutate_batch_size, dtype=x_ref.dtype) # type: ignore
|
184
185
|
|
185
186
|
# prepare sequential data
|
186
187
|
if self.sequential and not self.has_log_prob:
|
@@ -10,9 +10,10 @@ from typing import Callable
|
|
10
10
|
|
11
11
|
import keras
|
12
12
|
import numpy as np
|
13
|
+
from numpy.typing import ArrayLike
|
13
14
|
|
14
15
|
from dataeval._internal.detectors.ood.base import OODBase, OODScore
|
15
|
-
from dataeval._internal.interop import
|
16
|
+
from dataeval._internal.interop import to_numpy
|
16
17
|
from dataeval._internal.models.tensorflow.autoencoder import VAE
|
17
18
|
from dataeval._internal.models.tensorflow.losses import Elbo
|
18
19
|
from dataeval._internal.models.tensorflow.utils import predict_batch
|
@@ -10,9 +10,10 @@ from typing import Callable
|
|
10
10
|
|
11
11
|
import keras
|
12
12
|
import numpy as np
|
13
|
+
from numpy.typing import ArrayLike
|
13
14
|
|
14
15
|
from dataeval._internal.detectors.ood.base import OODGMMBase, OODScore
|
15
|
-
from dataeval._internal.interop import
|
16
|
+
from dataeval._internal.interop import to_numpy
|
16
17
|
from dataeval._internal.models.tensorflow.autoencoder import VAEGMM
|
17
18
|
from dataeval._internal.models.tensorflow.gmm import gmm_energy
|
18
19
|
from dataeval._internal.models.tensorflow.losses import Elbo, LossGMM
|
dataeval/_internal/interop.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
from importlib import import_module
|
2
|
-
from typing import
|
2
|
+
from typing import Iterable, Optional
|
3
3
|
|
4
4
|
import numpy as np
|
5
|
+
from numpy.typing import ArrayLike
|
5
6
|
|
6
7
|
module_cache = {}
|
7
8
|
|
@@ -19,16 +20,6 @@ def try_import(module_name):
|
|
19
20
|
return module
|
20
21
|
|
21
22
|
|
22
|
-
try:
|
23
|
-
from maite.protocols import ArrayLike # type: ignore
|
24
|
-
except ImportError: # pragma: no cover - covered by test_mindeps.py
|
25
|
-
from typing import Protocol
|
26
|
-
|
27
|
-
@runtime_checkable
|
28
|
-
class ArrayLike(Protocol):
|
29
|
-
def __array__(self) -> Any: ...
|
30
|
-
|
31
|
-
|
32
23
|
def to_numpy(array: Optional[ArrayLike]) -> np.ndarray:
|
33
24
|
if array is None:
|
34
25
|
return np.ndarray([])
|
@@ -0,0 +1,180 @@
|
|
1
|
+
import warnings
|
2
|
+
from typing import Dict, List, NamedTuple, Sequence
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
from numpy.typing import NDArray
|
6
|
+
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
|
7
|
+
|
8
|
+
from dataeval._internal.metrics.utils import entropy, preprocess_metadata
|
9
|
+
|
10
|
+
|
11
|
+
class BalanceOutput(NamedTuple):
|
12
|
+
"""
|
13
|
+
Attributes
|
14
|
+
----------
|
15
|
+
mutual_information : NDArray[np.float64]
|
16
|
+
Estimate of mutual information between metadata factors and class label
|
17
|
+
"""
|
18
|
+
|
19
|
+
mutual_information: NDArray[np.float64]
|
20
|
+
|
21
|
+
|
22
|
+
def validate_num_neighbors(num_neighbors: int) -> int:
|
23
|
+
if not isinstance(num_neighbors, (int, float)):
|
24
|
+
raise TypeError(
|
25
|
+
f"Variable {num_neighbors} is not real-valued numeric type."
|
26
|
+
"num_neighbors should be an int, greater than 0 and less than"
|
27
|
+
"the number of samples in the dataset"
|
28
|
+
)
|
29
|
+
if num_neighbors < 1:
|
30
|
+
raise ValueError(
|
31
|
+
f"Invalid value for {num_neighbors}."
|
32
|
+
"Choose a value greater than 0 and less than number of samples"
|
33
|
+
"in the dataset."
|
34
|
+
)
|
35
|
+
if isinstance(num_neighbors, float):
|
36
|
+
num_neighbors = int(num_neighbors)
|
37
|
+
warnings.warn(f"Variable {num_neighbors} is currently type float and will be truncated to type int.")
|
38
|
+
|
39
|
+
return num_neighbors
|
40
|
+
|
41
|
+
|
42
|
+
def balance(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: int = 5) -> BalanceOutput:
|
43
|
+
"""
|
44
|
+
Mutual information (MI) between factors (class label, metadata, label/image properties)
|
45
|
+
|
46
|
+
Parameters
|
47
|
+
----------
|
48
|
+
class_labels: Sequence[int]
|
49
|
+
List of class labels for each image
|
50
|
+
metadata: List[Dict]
|
51
|
+
List of metadata factors for each image
|
52
|
+
num_neighbors: int, default 5
|
53
|
+
Number of nearest neighbors to use for computing MI between discrete
|
54
|
+
and continuous variables.
|
55
|
+
|
56
|
+
Returns
|
57
|
+
-------
|
58
|
+
BalanceOutput
|
59
|
+
(num_factors+1) x (num_factors+1) estimate of mutual information
|
60
|
+
between num_factors metadata factors and class label. Symmetry is enforced.
|
61
|
+
|
62
|
+
Notes
|
63
|
+
-----
|
64
|
+
We use `mutual_info_classif` from sklearn since class label is categorical.
|
65
|
+
`mutual_info_classif` outputs are consistent up to O(1e-4) and depend on a random
|
66
|
+
seed. MI is computed differently for categorical and continuous variables, and
|
67
|
+
we attempt to infer whether a variable is categorical by the fraction of unique
|
68
|
+
values in the dataset.
|
69
|
+
|
70
|
+
See Also
|
71
|
+
--------
|
72
|
+
sklearn.feature_selection.mutual_info_classif
|
73
|
+
sklearn.feature_selection.mutual_info_regression
|
74
|
+
sklearn.metrics.mutual_info_score
|
75
|
+
"""
|
76
|
+
num_neighbors = validate_num_neighbors(num_neighbors)
|
77
|
+
data, names, is_categorical = preprocess_metadata(class_labels, metadata)
|
78
|
+
num_factors = len(names)
|
79
|
+
mi = np.empty((num_factors, num_factors))
|
80
|
+
mi[:] = np.nan
|
81
|
+
|
82
|
+
for idx in range(num_factors):
|
83
|
+
tgt = data[:, idx]
|
84
|
+
|
85
|
+
if is_categorical[idx]:
|
86
|
+
# categorical target
|
87
|
+
mi[idx, :] = mutual_info_classif(
|
88
|
+
data,
|
89
|
+
tgt,
|
90
|
+
discrete_features=is_categorical, # type: ignore
|
91
|
+
n_neighbors=num_neighbors,
|
92
|
+
)
|
93
|
+
else:
|
94
|
+
# continuous variables
|
95
|
+
mi[idx, :] = mutual_info_regression(
|
96
|
+
data,
|
97
|
+
tgt,
|
98
|
+
discrete_features=is_categorical, # type: ignore
|
99
|
+
n_neighbors=num_neighbors,
|
100
|
+
)
|
101
|
+
|
102
|
+
ent_all = entropy(data, names, is_categorical, normalized=False)
|
103
|
+
norm_factor = 0.5 * np.add.outer(ent_all, ent_all) + 1e-6
|
104
|
+
# in principle MI should be symmetric, but it is not in practice.
|
105
|
+
nmi = 0.5 * (mi + mi.T) / norm_factor
|
106
|
+
|
107
|
+
return BalanceOutput(nmi)
|
108
|
+
|
109
|
+
|
110
|
+
def balance_classwise(class_labels: Sequence[int], metadata: List[Dict], num_neighbors: int = 5) -> BalanceOutput:
|
111
|
+
"""
|
112
|
+
Compute mutual information (analogous to correlation) between metadata factors
|
113
|
+
(class label, metadata, label/image properties) with individual class labels.
|
114
|
+
|
115
|
+
Parameters
|
116
|
+
----------
|
117
|
+
class_labels: Sequence[int]
|
118
|
+
List of class labels for each image
|
119
|
+
metadata: List[Dict]
|
120
|
+
List of metadata factors for each image
|
121
|
+
num_neighbors: int, default 5
|
122
|
+
Number of nearest neighbors to use for computing MI between discrete
|
123
|
+
and continuous variables.
|
124
|
+
|
125
|
+
Notes
|
126
|
+
-----
|
127
|
+
We use `mutual_info_classif` from sklearn since class label is categorical.
|
128
|
+
`mutual_info_classif` outputs are consistent up to O(1e-4) and depend on a random
|
129
|
+
seed. MI is computed differently for categorical and continuous variables, so we
|
130
|
+
have to specify with is_categorical.
|
131
|
+
|
132
|
+
Returns
|
133
|
+
-------
|
134
|
+
BalanceOutput
|
135
|
+
(num_classes x num_factors) estimate of mutual information between
|
136
|
+
num_factors metadata factors and individual class labels.
|
137
|
+
|
138
|
+
See Also
|
139
|
+
--------
|
140
|
+
sklearn.feature_selection.mutual_info_classif
|
141
|
+
sklearn.feature_selection.mutual_info_regression
|
142
|
+
sklearn.metrics.mutual_info_score
|
143
|
+
compute_mutual_information
|
144
|
+
"""
|
145
|
+
num_neighbors = validate_num_neighbors(num_neighbors)
|
146
|
+
data, names, is_categorical = preprocess_metadata(class_labels, metadata)
|
147
|
+
num_factors = len(names)
|
148
|
+
# unique class labels
|
149
|
+
class_idx = names.index("class_label")
|
150
|
+
class_data = data[:, class_idx]
|
151
|
+
u_cls = np.unique(class_data)
|
152
|
+
num_classes = len(u_cls)
|
153
|
+
|
154
|
+
data_no_class = np.concatenate((data[:, :class_idx], data[:, (class_idx + 1) :]), axis=1)
|
155
|
+
|
156
|
+
# assume class is a factor
|
157
|
+
mi = np.empty((num_classes, num_factors - 1))
|
158
|
+
mi[:] = np.nan
|
159
|
+
|
160
|
+
# categorical variables, excluding class label
|
161
|
+
cat_mask = np.concatenate((is_categorical[:class_idx], is_categorical[(class_idx + 1) :]), axis=0).astype(int)
|
162
|
+
|
163
|
+
# classification MI for discrete/categorical features
|
164
|
+
for idx, cls in enumerate(u_cls):
|
165
|
+
tgt = class_data == cls
|
166
|
+
# units: nat
|
167
|
+
mi[idx, :] = mutual_info_classif(
|
168
|
+
data_no_class,
|
169
|
+
tgt,
|
170
|
+
discrete_features=cat_mask, # type: ignore
|
171
|
+
n_neighbors=num_neighbors,
|
172
|
+
)
|
173
|
+
|
174
|
+
# let this recompute for all features including class label
|
175
|
+
ent_all = entropy(data, names, is_categorical)
|
176
|
+
ent_tgt = ent_all[class_idx]
|
177
|
+
ent_all = np.concatenate((ent_all[:class_idx], ent_all[(class_idx + 1) :]), axis=0)
|
178
|
+
norm_factor = 0.5 * np.add.outer(ent_tgt, ent_all) + 1e-6
|
179
|
+
nmi = mi / norm_factor
|
180
|
+
return BalanceOutput(nmi)
|
@@ -1,92 +1,10 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from typing import
|
2
|
+
from typing import Generic, TypeVar
|
3
3
|
|
4
4
|
TOutput = TypeVar("TOutput", bound=dict)
|
5
|
-
TMethods = TypeVar("TMethods")
|
6
|
-
TCallable = TypeVar("TCallable", bound=Callable)
|
7
|
-
|
8
|
-
|
9
|
-
class MetricMixin(ABC, Generic[TOutput]):
|
10
|
-
@abstractmethod
|
11
|
-
def update(self, *args, **kwargs): ...
|
12
|
-
|
13
|
-
@abstractmethod
|
14
|
-
def compute(self) -> TOutput: ...
|
15
|
-
|
16
|
-
@abstractmethod
|
17
|
-
def reset(self): ...
|
18
5
|
|
19
6
|
|
20
7
|
class EvaluateMixin(ABC, Generic[TOutput]):
|
21
8
|
@abstractmethod
|
22
9
|
def evaluate(self, *args, **kwargs) -> TOutput:
|
23
10
|
"""Abstract method to calculate metric based off of constructor parameters"""
|
24
|
-
|
25
|
-
|
26
|
-
class MethodsMixin(ABC, Generic[TMethods, TCallable]):
|
27
|
-
"""
|
28
|
-
Use this mixin to define a mapping of functions to method names which
|
29
|
-
can be queried by the user and called internally with the appropriate
|
30
|
-
method name as the key.
|
31
|
-
|
32
|
-
Explicitly defining the Callable generic helps with type safety and
|
33
|
-
hinting for function signatures and recommended but optional.
|
34
|
-
|
35
|
-
e.g.:
|
36
|
-
|
37
|
-
def _mult(x: float, y: float) -> float:
|
38
|
-
return x * y
|
39
|
-
|
40
|
-
class MyMetric(MethodsMixin[Callable[float, float], float]):
|
41
|
-
|
42
|
-
def _methods(cls) -> Dict[str, Callable[float, float], float]:
|
43
|
-
return {
|
44
|
-
"ADD": lambda x, y: x + y,
|
45
|
-
"MULT": _mult,
|
46
|
-
...
|
47
|
-
}
|
48
|
-
|
49
|
-
Then during evaluate, you can call the method specified with the getter.
|
50
|
-
|
51
|
-
e.g.:
|
52
|
-
|
53
|
-
def evaluate(self):
|
54
|
-
return self._method(x, y)
|
55
|
-
|
56
|
-
The resulting class can be used like so.
|
57
|
-
|
58
|
-
m = MyMetric(1.0, 2.0, "ADD")
|
59
|
-
m.evaluate() # returns 3.0
|
60
|
-
m.method # returns "ADD"
|
61
|
-
MyMetric.methods() # returns "['ADD', 'MULT']
|
62
|
-
m.method = "MULT"
|
63
|
-
m.evaluate() # returns 2.0
|
64
|
-
"""
|
65
|
-
|
66
|
-
@classmethod
|
67
|
-
@abstractmethod
|
68
|
-
def _methods(cls) -> Dict[str, TCallable]:
|
69
|
-
"""Abstract method returning available method functions for class"""
|
70
|
-
|
71
|
-
@property
|
72
|
-
def _method(self) -> TCallable:
|
73
|
-
return self._methods()[self.method]
|
74
|
-
|
75
|
-
@classmethod
|
76
|
-
def methods(cls) -> List[str]:
|
77
|
-
return list(cls._methods().keys())
|
78
|
-
|
79
|
-
@property
|
80
|
-
def method(self) -> str:
|
81
|
-
return self._method_key
|
82
|
-
|
83
|
-
@method.setter
|
84
|
-
def method(self, value: TMethods):
|
85
|
-
self._set_method(value)
|
86
|
-
|
87
|
-
def _set_method(self, value: TMethods):
|
88
|
-
"""This setter is to fix pyright incorrect detection of
|
89
|
-
incorrectly overriding the 'method' property"""
|
90
|
-
if value not in self.methods():
|
91
|
-
raise KeyError(f"Specified method not available for class ({self.methods()}).")
|
92
|
-
self._method_key = value
|