dataeval 0.74.0__py3-none-any.whl → 0.74.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +23 -10
- dataeval/detectors/__init__.py +2 -10
- dataeval/detectors/drift/base.py +3 -3
- dataeval/detectors/drift/mmd.py +1 -1
- dataeval/detectors/linters/clusterer.py +3 -3
- dataeval/detectors/linters/duplicates.py +4 -4
- dataeval/detectors/linters/outliers.py +4 -4
- dataeval/detectors/ood/__init__.py +5 -12
- dataeval/detectors/ood/base.py +5 -5
- dataeval/detectors/ood/metadata_ks_compare.py +12 -13
- dataeval/interop.py +15 -3
- dataeval/logging.py +16 -0
- dataeval/metrics/bias/balance.py +3 -3
- dataeval/metrics/bias/coverage.py +3 -3
- dataeval/metrics/bias/diversity.py +3 -3
- dataeval/metrics/bias/metadata_preprocessing.py +3 -3
- dataeval/metrics/bias/parity.py +4 -4
- dataeval/metrics/estimators/ber.py +3 -3
- dataeval/metrics/estimators/divergence.py +3 -3
- dataeval/metrics/estimators/uap.py +3 -3
- dataeval/metrics/stats/base.py +2 -2
- dataeval/metrics/stats/boxratiostats.py +1 -1
- dataeval/metrics/stats/datasetstats.py +6 -6
- dataeval/metrics/stats/dimensionstats.py +1 -1
- dataeval/metrics/stats/hashstats.py +1 -1
- dataeval/metrics/stats/labelstats.py +3 -3
- dataeval/metrics/stats/pixelstats.py +1 -1
- dataeval/metrics/stats/visualstats.py +1 -1
- dataeval/output.py +81 -57
- dataeval/utils/__init__.py +1 -7
- dataeval/utils/split_dataset.py +306 -279
- dataeval/workflows/sufficiency.py +4 -4
- {dataeval-0.74.0.dist-info → dataeval-0.74.2.dist-info}/METADATA +3 -8
- dataeval-0.74.2.dist-info/RECORD +66 -0
- dataeval/detectors/ood/ae.py +0 -76
- dataeval/detectors/ood/aegmm.py +0 -67
- dataeval/detectors/ood/base_tf.py +0 -109
- dataeval/detectors/ood/llr.py +0 -302
- dataeval/detectors/ood/vae.py +0 -98
- dataeval/detectors/ood/vaegmm.py +0 -76
- dataeval/utils/lazy.py +0 -26
- dataeval/utils/tensorflow/__init__.py +0 -19
- dataeval/utils/tensorflow/_internal/gmm.py +0 -103
- dataeval/utils/tensorflow/_internal/loss.py +0 -121
- dataeval/utils/tensorflow/_internal/models.py +0 -1394
- dataeval/utils/tensorflow/_internal/trainer.py +0 -114
- dataeval/utils/tensorflow/_internal/utils.py +0 -256
- dataeval/utils/tensorflow/loss/__init__.py +0 -11
- dataeval-0.74.0.dist-info/RECORD +0 -79
- {dataeval-0.74.0.dist-info → dataeval-0.74.2.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.74.0.dist-info → dataeval-0.74.2.dist-info}/WHEEL +0 -0
dataeval/detectors/ood/vae.py
DELETED
@@ -1,98 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Source code derived from Alibi-Detect 0.11.4
|
3
|
-
https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
|
4
|
-
|
5
|
-
Original code Copyright (c) 2023 Seldon Technologies Ltd
|
6
|
-
Licensed under Apache Software License (Apache 2.0)
|
7
|
-
"""
|
8
|
-
|
9
|
-
from __future__ import annotations
|
10
|
-
|
11
|
-
__all__ = ["OOD_VAE"]
|
12
|
-
|
13
|
-
from typing import TYPE_CHECKING, Callable
|
14
|
-
|
15
|
-
import numpy as np
|
16
|
-
from numpy.typing import ArrayLike
|
17
|
-
|
18
|
-
from dataeval.detectors.ood.base import OODScoreOutput
|
19
|
-
from dataeval.detectors.ood.base_tf import OODBase
|
20
|
-
from dataeval.interop import to_numpy
|
21
|
-
from dataeval.utils.lazy import lazyload
|
22
|
-
from dataeval.utils.tensorflow._internal.loss import Elbo
|
23
|
-
from dataeval.utils.tensorflow._internal.utils import predict_batch
|
24
|
-
|
25
|
-
if TYPE_CHECKING:
|
26
|
-
import tensorflow as tf
|
27
|
-
import tf_keras as keras
|
28
|
-
|
29
|
-
import dataeval.utils.tensorflow._internal.models as tf_models
|
30
|
-
else:
|
31
|
-
tf = lazyload("tensorflow")
|
32
|
-
keras = lazyload("tf_keras")
|
33
|
-
tf_models = lazyload("dataeval.utils.tensorflow._internal.models")
|
34
|
-
|
35
|
-
|
36
|
-
class OOD_VAE(OODBase):
|
37
|
-
"""
|
38
|
-
VAE based outlier detector.
|
39
|
-
|
40
|
-
Parameters
|
41
|
-
----------
|
42
|
-
model : VAE
|
43
|
-
A VAE model.
|
44
|
-
samples : int, default 10
|
45
|
-
Number of samples sampled to evaluate each instance.
|
46
|
-
|
47
|
-
Examples
|
48
|
-
--------
|
49
|
-
Instantiate an OOD detector metric with a generic dataset - batch of images with shape (3,25,25)
|
50
|
-
|
51
|
-
>>> metric = OOD_VAE(create_model("VAE", dataset[0].shape))
|
52
|
-
|
53
|
-
Adjusting fit parameters,
|
54
|
-
including setting the fit threshold at 85% for a training set with about 15% out-of-distribution
|
55
|
-
|
56
|
-
>>> metric.fit(dataset, threshold_perc=85, batch_size=128, verbose=False)
|
57
|
-
|
58
|
-
Detect :term:`out of distribution<Out-of-Distribution (OOD)>` samples at the 'feature' level
|
59
|
-
|
60
|
-
>>> result = metric.predict(dataset, ood_type="feature")
|
61
|
-
"""
|
62
|
-
|
63
|
-
def __init__(self, model: tf_models.VAE, samples: int = 10) -> None:
|
64
|
-
super().__init__(model)
|
65
|
-
self.samples = samples
|
66
|
-
|
67
|
-
def fit(
|
68
|
-
self,
|
69
|
-
x_ref: ArrayLike,
|
70
|
-
threshold_perc: float = 100.0,
|
71
|
-
loss_fn: Callable[..., tf.Tensor] | None = Elbo(0.05),
|
72
|
-
optimizer: keras.optimizers.Optimizer | None = None,
|
73
|
-
epochs: int = 20,
|
74
|
-
batch_size: int = 64,
|
75
|
-
verbose: bool = True,
|
76
|
-
) -> None:
|
77
|
-
super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
78
|
-
|
79
|
-
def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
80
|
-
self._validate(X := to_numpy(X))
|
81
|
-
|
82
|
-
# sample reconstructed instances
|
83
|
-
X_samples = np.repeat(X, self.samples, axis=0)
|
84
|
-
X_recon = predict_batch(X_samples, model=self.model, batch_size=batch_size)
|
85
|
-
|
86
|
-
# compute feature scores
|
87
|
-
fscore = np.power(X_samples - X_recon, 2)
|
88
|
-
fscore = fscore.reshape((-1, self.samples) + X_samples.shape[1:])
|
89
|
-
fscore = np.mean(fscore, axis=1)
|
90
|
-
|
91
|
-
# compute instance scores
|
92
|
-
fscore_flat = fscore.reshape(fscore.shape[0], -1).copy()
|
93
|
-
n_score_features = int(np.ceil(fscore_flat.shape[1]))
|
94
|
-
sorted_fscore = np.sort(fscore_flat, axis=1)
|
95
|
-
sorted_fscore_perc = sorted_fscore[:, -n_score_features:]
|
96
|
-
iscore = np.mean(sorted_fscore_perc, axis=1)
|
97
|
-
|
98
|
-
return OODScoreOutput(iscore, fscore)
|
dataeval/detectors/ood/vaegmm.py
DELETED
@@ -1,76 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Source code derived from Alibi-Detect 0.11.4
|
3
|
-
https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
|
4
|
-
|
5
|
-
Original code Copyright (c) 2023 Seldon Technologies Ltd
|
6
|
-
Licensed under Apache Software License (Apache 2.0)
|
7
|
-
"""
|
8
|
-
|
9
|
-
from __future__ import annotations
|
10
|
-
|
11
|
-
__all__ = ["OOD_VAEGMM"]
|
12
|
-
|
13
|
-
from typing import TYPE_CHECKING, Callable
|
14
|
-
|
15
|
-
import numpy as np
|
16
|
-
from numpy.typing import ArrayLike
|
17
|
-
|
18
|
-
from dataeval.detectors.ood.base import OODScoreOutput
|
19
|
-
from dataeval.detectors.ood.base_tf import OODBaseGMM
|
20
|
-
from dataeval.interop import to_numpy
|
21
|
-
from dataeval.utils.lazy import lazyload
|
22
|
-
from dataeval.utils.tensorflow._internal.gmm import gmm_energy
|
23
|
-
from dataeval.utils.tensorflow._internal.loss import Elbo, LossGMM
|
24
|
-
from dataeval.utils.tensorflow._internal.utils import predict_batch
|
25
|
-
|
26
|
-
if TYPE_CHECKING:
|
27
|
-
import tensorflow as tf
|
28
|
-
import tf_keras as keras
|
29
|
-
|
30
|
-
import dataeval.utils.tensorflow._internal.models as tf_models
|
31
|
-
else:
|
32
|
-
tf = lazyload("tensorflow")
|
33
|
-
keras = lazyload("tf_keras")
|
34
|
-
tf_models = lazyload("dataeval.utils.tensorflow._internal.models")
|
35
|
-
|
36
|
-
|
37
|
-
class OOD_VAEGMM(OODBaseGMM):
|
38
|
-
"""
|
39
|
-
VAE with Gaussian Mixture Model based outlier detector.
|
40
|
-
|
41
|
-
Parameters
|
42
|
-
----------
|
43
|
-
model : VAEGMM
|
44
|
-
A VAEGMM model.
|
45
|
-
samples
|
46
|
-
Number of samples sampled to evaluate each instance.
|
47
|
-
"""
|
48
|
-
|
49
|
-
def __init__(self, model: tf_models.VAEGMM, samples: int = 10) -> None:
|
50
|
-
super().__init__(model)
|
51
|
-
self.samples = samples
|
52
|
-
|
53
|
-
def fit(
|
54
|
-
self,
|
55
|
-
x_ref: ArrayLike,
|
56
|
-
threshold_perc: float = 100.0,
|
57
|
-
loss_fn: Callable[..., tf.Tensor] | None = LossGMM(elbo=Elbo(0.05)),
|
58
|
-
optimizer: keras.optimizers.Optimizer | None = None,
|
59
|
-
epochs: int = 20,
|
60
|
-
batch_size: int = 64,
|
61
|
-
verbose: bool = True,
|
62
|
-
) -> None:
|
63
|
-
super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
64
|
-
|
65
|
-
def _score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
66
|
-
self._validate(X := to_numpy(X))
|
67
|
-
|
68
|
-
# draw samples from latent space
|
69
|
-
X_samples = np.repeat(X, self.samples, axis=0)
|
70
|
-
_, z, _ = predict_batch(X_samples, self.model, batch_size=batch_size)
|
71
|
-
|
72
|
-
# compute average energy for samples
|
73
|
-
energy, _ = gmm_energy(z, self._gmm_params, return_mean=False)
|
74
|
-
energy_samples = energy.numpy().reshape((-1, self.samples)) # type: ignore
|
75
|
-
iscore = np.mean(energy_samples, axis=-1)
|
76
|
-
return OODScoreOutput(iscore)
|
dataeval/utils/lazy.py
DELETED
@@ -1,26 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from functools import cached_property
|
4
|
-
from importlib import import_module
|
5
|
-
from typing import Any
|
6
|
-
|
7
|
-
|
8
|
-
class LazyModule:
|
9
|
-
def __init__(self, name: str) -> None:
|
10
|
-
self._name = name
|
11
|
-
|
12
|
-
def __getattr__(self, key: str) -> Any:
|
13
|
-
return getattr(self._module, key)
|
14
|
-
|
15
|
-
@cached_property
|
16
|
-
def _module(self):
|
17
|
-
return import_module(self._name)
|
18
|
-
|
19
|
-
|
20
|
-
LAZY_MODULES: dict[str, LazyModule] = {}
|
21
|
-
|
22
|
-
|
23
|
-
def lazyload(name: str) -> LazyModule:
|
24
|
-
if name not in LAZY_MODULES:
|
25
|
-
LAZY_MODULES[name] = LazyModule(name)
|
26
|
-
return LAZY_MODULES[name]
|
@@ -1,19 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
TensorFlow models are used in :term:`out of distribution<Out-of-distribution (OOD)>` detectors in the
|
3
|
-
:mod:`dataeval.detectors.ood` module.
|
4
|
-
|
5
|
-
DataEval provides basic default models through the utility :func:`dataeval.utils.tensorflow.create_model`.
|
6
|
-
"""
|
7
|
-
|
8
|
-
from dataeval import _IS_TENSORFLOW_AVAILABLE
|
9
|
-
|
10
|
-
__all__ = []
|
11
|
-
|
12
|
-
|
13
|
-
if _IS_TENSORFLOW_AVAILABLE:
|
14
|
-
import dataeval.utils.tensorflow.loss as loss
|
15
|
-
from dataeval.utils.tensorflow._internal.utils import create_model
|
16
|
-
|
17
|
-
__all__ = ["create_model", "loss"]
|
18
|
-
|
19
|
-
del _IS_TENSORFLOW_AVAILABLE
|
@@ -1,103 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Source code derived from Alibi-Detect 0.11.4
|
3
|
-
https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
|
4
|
-
|
5
|
-
Original code Copyright (c) 2023 Seldon Technologies Ltd
|
6
|
-
Licensed under Apache Software License (Apache 2.0)
|
7
|
-
"""
|
8
|
-
|
9
|
-
from __future__ import annotations
|
10
|
-
|
11
|
-
from typing import TYPE_CHECKING
|
12
|
-
|
13
|
-
import numpy as np
|
14
|
-
|
15
|
-
from dataeval.utils.gmm import GaussianMixtureModelParams
|
16
|
-
from dataeval.utils.lazy import lazyload
|
17
|
-
|
18
|
-
if TYPE_CHECKING:
|
19
|
-
import tensorflow as tf
|
20
|
-
else:
|
21
|
-
tf = lazyload("tensorflow")
|
22
|
-
|
23
|
-
|
24
|
-
def gmm_params(z: tf.Tensor, gamma: tf.Tensor) -> GaussianMixtureModelParams[tf.Tensor]:
|
25
|
-
"""
|
26
|
-
Compute parameters of Gaussian Mixture Model.
|
27
|
-
|
28
|
-
Parameters
|
29
|
-
----------
|
30
|
-
z : tf.Tensor
|
31
|
-
Observations.
|
32
|
-
gamma : tf.Tensor
|
33
|
-
Mixture probabilities to derive mixture distribution weights from.
|
34
|
-
|
35
|
-
Returns
|
36
|
-
-------
|
37
|
-
GaussianMixtureModelParams(phi, mu, cov, L, log_det_cov)
|
38
|
-
The parameters used to calculate energy.
|
39
|
-
"""
|
40
|
-
# compute gmm parameters phi, mu and cov
|
41
|
-
N = gamma.shape[0] # nb of samples in batch
|
42
|
-
sum_gamma = tf.reduce_sum(gamma, 0) # K
|
43
|
-
phi = sum_gamma / N # K
|
44
|
-
mu = tf.reduce_sum(tf.expand_dims(gamma, -1) * tf.expand_dims(z, 1), 0) / tf.expand_dims(
|
45
|
-
sum_gamma, -1
|
46
|
-
) # K x D (D = latent_dim)
|
47
|
-
z_mu = tf.expand_dims(z, 1) - tf.expand_dims(mu, 0) # N x K x D
|
48
|
-
z_mu_outer = tf.expand_dims(z_mu, -1) * tf.expand_dims(z_mu, -2) # N x K x D x D
|
49
|
-
cov = tf.reduce_sum(tf.expand_dims(tf.expand_dims(gamma, -1), -1) * z_mu_outer, 0) / tf.expand_dims(
|
50
|
-
tf.expand_dims(sum_gamma, -1), -1
|
51
|
-
) # K x D x D
|
52
|
-
|
53
|
-
# cholesky decomposition of covariance and determinant derivation
|
54
|
-
D = tf.shape(cov)[1] # type: ignore
|
55
|
-
eps = 1e-6
|
56
|
-
L = tf.linalg.cholesky(cov + tf.eye(D) * eps) # K x D x D
|
57
|
-
log_det_cov = 2.0 * tf.reduce_sum(tf.math.log(tf.linalg.diag_part(L)), 1) # K
|
58
|
-
|
59
|
-
return GaussianMixtureModelParams(phi, mu, cov, L, log_det_cov)
|
60
|
-
|
61
|
-
|
62
|
-
def gmm_energy(
|
63
|
-
z: tf.Tensor,
|
64
|
-
params: GaussianMixtureModelParams[tf.Tensor],
|
65
|
-
return_mean: bool = True,
|
66
|
-
) -> tuple[tf.Tensor, tf.Tensor]:
|
67
|
-
"""
|
68
|
-
Compute sample energy from Gaussian Mixture Model.
|
69
|
-
|
70
|
-
Parameters
|
71
|
-
----------
|
72
|
-
params : GaussianMixtureModelParams
|
73
|
-
The gaussian mixture model parameters.
|
74
|
-
return_mean : bool, default True
|
75
|
-
Take mean across all sample energies in a batch.
|
76
|
-
|
77
|
-
Returns
|
78
|
-
-------
|
79
|
-
sample_energy
|
80
|
-
The sample energy of the GMM.
|
81
|
-
cov_diag
|
82
|
-
The inverse sum of the diagonal components of the covariance matrix.
|
83
|
-
"""
|
84
|
-
D = tf.shape(params.cov)[1] # type: ignore
|
85
|
-
z_mu = tf.expand_dims(z, 1) - tf.expand_dims(params.mu, 0) # N x K x D
|
86
|
-
z_mu_T = tf.transpose(z_mu, perm=[1, 2, 0]) # K x D x N
|
87
|
-
v = tf.linalg.triangular_solve(params.L, z_mu_T, lower=True) # K x D x D
|
88
|
-
|
89
|
-
# rewrite sample energy in logsumexp format for numerical stability
|
90
|
-
logits = tf.math.log(tf.expand_dims(params.phi, -1)) - 0.5 * (
|
91
|
-
tf.reduce_sum(tf.square(v), 1)
|
92
|
-
+ tf.cast(D, tf.float32) * tf.math.log(2.0 * np.pi) # type: ignore py38
|
93
|
-
+ tf.expand_dims(params.log_det_cov, -1)
|
94
|
-
) # K x N
|
95
|
-
sample_energy = -tf.reduce_logsumexp(logits, axis=0) # N
|
96
|
-
|
97
|
-
if return_mean:
|
98
|
-
sample_energy = tf.reduce_mean(sample_energy)
|
99
|
-
|
100
|
-
# inverse sum of variances
|
101
|
-
cov_diag = tf.reduce_sum(tf.divide(1, tf.linalg.diag_part(params.cov)))
|
102
|
-
|
103
|
-
return sample_energy, cov_diag
|
@@ -1,121 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Source code derived from Alibi-Detect 0.11.4
|
3
|
-
https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
|
4
|
-
|
5
|
-
Original code Copyright (c) 2023 Seldon Technologies Ltd
|
6
|
-
Licensed under Apache Software License (Apache 2.0)
|
7
|
-
"""
|
8
|
-
|
9
|
-
from __future__ import annotations
|
10
|
-
|
11
|
-
from typing import TYPE_CHECKING, Literal, cast
|
12
|
-
|
13
|
-
import numpy as np
|
14
|
-
from numpy.typing import NDArray
|
15
|
-
|
16
|
-
from dataeval.utils.lazy import lazyload
|
17
|
-
from dataeval.utils.tensorflow._internal.gmm import gmm_energy, gmm_params
|
18
|
-
|
19
|
-
if TYPE_CHECKING:
|
20
|
-
import tensorflow as tf
|
21
|
-
import tensorflow_probability.python.distributions.mvn_diag as mvn_diag
|
22
|
-
import tensorflow_probability.python.distributions.mvn_tril as mvn_tril
|
23
|
-
import tensorflow_probability.python.stats as tfp_stats
|
24
|
-
import tf_keras as keras
|
25
|
-
else:
|
26
|
-
tf = lazyload("tensorflow")
|
27
|
-
keras = lazyload("tf_keras")
|
28
|
-
mvn_diag = lazyload("tensorflow_probability.python.distributions.mvn_diag")
|
29
|
-
mvn_tril = lazyload("tensorflow_probability.python.distributions.mvn_tril")
|
30
|
-
tfp_stats = lazyload("tensorflow_probability.python.stats")
|
31
|
-
|
32
|
-
|
33
|
-
class Elbo:
|
34
|
-
"""
|
35
|
-
Compute ELBO loss.
|
36
|
-
|
37
|
-
The covariance matrix can be specified by passing the full covariance matrix, the matrix
|
38
|
-
diagonal, or a scale identity multiplier. Only one of these should be specified. If none are specified, the
|
39
|
-
identity matrix is used.
|
40
|
-
|
41
|
-
Parameters
|
42
|
-
----------
|
43
|
-
cov_type : Union[Literal["cov_full", "cov_diag"], float], default 1.0
|
44
|
-
Full covariance matrix, diagonal :term:`variance<Variance>` matrix, or scale identity multiplier.
|
45
|
-
x : ArrayLike, optional - default None
|
46
|
-
Dataset used to calculate the covariance matrix. Required for full and diagonal covariance matrix types.
|
47
|
-
"""
|
48
|
-
|
49
|
-
def __init__(
|
50
|
-
self,
|
51
|
-
cov_type: Literal["cov_full", "cov_diag"] | float = 1.0,
|
52
|
-
x: tf.Tensor | NDArray[np.float32] | None = None,
|
53
|
-
):
|
54
|
-
if isinstance(cov_type, float):
|
55
|
-
self._cov = ("sim", cov_type)
|
56
|
-
elif cov_type in ["cov_full", "cov_diag"]:
|
57
|
-
x_np: NDArray[np.float32] = x.numpy().astype(np.float32) if tf.is_tensor(x) else x # type: ignore
|
58
|
-
cov = tfp_stats.covariance(x_np.reshape(x_np.shape[0], -1)) # type: ignore py38
|
59
|
-
if cov_type == "cov_diag": # infer standard deviation from covariance matrix
|
60
|
-
cov = tf.math.sqrt(tf.linalg.diag_part(cov))
|
61
|
-
self._cov = (cov_type, cov)
|
62
|
-
else:
|
63
|
-
raise ValueError("Only cov_full, cov_diag or sim value should be specified.")
|
64
|
-
|
65
|
-
def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
|
66
|
-
y_pred_flat = cast(tf.Tensor, keras.layers.Flatten()(y_pred))
|
67
|
-
|
68
|
-
if self._cov[0] == "cov_full":
|
69
|
-
y_mn = mvn_tril.MultivariateNormalTriL(y_pred_flat, scale_tril=tf.linalg.cholesky(self._cov[1]))
|
70
|
-
else: # cov_diag and sim
|
71
|
-
cov_diag = self._cov[1] if self._cov[0] == "cov_diag" else self._cov[1] * tf.ones(y_pred_flat.shape[-1])
|
72
|
-
y_mn = mvn_diag.MultivariateNormalDiag(y_pred_flat, scale_diag=cov_diag)
|
73
|
-
|
74
|
-
loss = -tf.reduce_mean(y_mn.log_prob(keras.layers.Flatten()(y_true)))
|
75
|
-
return loss
|
76
|
-
|
77
|
-
|
78
|
-
class LossGMM:
|
79
|
-
"""
|
80
|
-
Loss function used for AE and VAE with GMM.
|
81
|
-
|
82
|
-
Parameters
|
83
|
-
----------
|
84
|
-
w_recon : float, default 1e-7
|
85
|
-
Weight on elbo loss term.
|
86
|
-
w_energy : float, default 0.1
|
87
|
-
Weight on sample energy loss term.
|
88
|
-
w_cov_diag : float, default 0.005
|
89
|
-
Weight on covariance regularizing loss term.
|
90
|
-
elbo : Elbo, optional - default None
|
91
|
-
ELBO loss function used to calculate w_recon.
|
92
|
-
"""
|
93
|
-
|
94
|
-
def __init__(
|
95
|
-
self,
|
96
|
-
w_recon: float = 1e-7,
|
97
|
-
w_energy: float = 0.1,
|
98
|
-
w_cov_diag: float = 0.005,
|
99
|
-
elbo: Elbo | None = None,
|
100
|
-
):
|
101
|
-
self.w_recon = w_recon
|
102
|
-
self.w_energy = w_energy
|
103
|
-
self.w_cov_diag = w_cov_diag
|
104
|
-
self.elbo = elbo
|
105
|
-
|
106
|
-
def __call__(
|
107
|
-
self,
|
108
|
-
x_true: tf.Tensor,
|
109
|
-
x_pred: tf.Tensor,
|
110
|
-
z: tf.Tensor,
|
111
|
-
gamma: tf.Tensor,
|
112
|
-
) -> tf.Tensor:
|
113
|
-
w_recon = (
|
114
|
-
tf.reduce_mean(tf.subtract(x_true, x_pred) ** 2)
|
115
|
-
if self.elbo is None
|
116
|
-
else tf.multiply(self.w_recon, self.elbo(x_true, x_pred))
|
117
|
-
)
|
118
|
-
sample_energy, cov_diag = gmm_energy(z, gmm_params(z, gamma))
|
119
|
-
w_energy = tf.multiply(self.w_energy, sample_energy)
|
120
|
-
w_cov_diag = tf.multiply(self.w_cov_diag, cov_diag)
|
121
|
-
return w_recon + w_energy + w_cov_diag
|