dataeval 0.61.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. dataeval/__init__.py +18 -0
  2. dataeval/_internal/detectors/__init__.py +0 -0
  3. dataeval/_internal/detectors/clusterer.py +469 -0
  4. dataeval/_internal/detectors/drift/__init__.py +0 -0
  5. dataeval/_internal/detectors/drift/base.py +265 -0
  6. dataeval/_internal/detectors/drift/cvm.py +97 -0
  7. dataeval/_internal/detectors/drift/ks.py +100 -0
  8. dataeval/_internal/detectors/drift/mmd.py +166 -0
  9. dataeval/_internal/detectors/drift/torch.py +310 -0
  10. dataeval/_internal/detectors/drift/uncertainty.py +149 -0
  11. dataeval/_internal/detectors/duplicates.py +49 -0
  12. dataeval/_internal/detectors/linter.py +78 -0
  13. dataeval/_internal/detectors/ood/__init__.py +0 -0
  14. dataeval/_internal/detectors/ood/ae.py +77 -0
  15. dataeval/_internal/detectors/ood/aegmm.py +69 -0
  16. dataeval/_internal/detectors/ood/base.py +199 -0
  17. dataeval/_internal/detectors/ood/llr.py +284 -0
  18. dataeval/_internal/detectors/ood/vae.py +86 -0
  19. dataeval/_internal/detectors/ood/vaegmm.py +79 -0
  20. dataeval/_internal/flags.py +47 -0
  21. dataeval/_internal/metrics/__init__.py +0 -0
  22. dataeval/_internal/metrics/base.py +92 -0
  23. dataeval/_internal/metrics/ber.py +124 -0
  24. dataeval/_internal/metrics/coverage.py +80 -0
  25. dataeval/_internal/metrics/divergence.py +94 -0
  26. dataeval/_internal/metrics/hash.py +79 -0
  27. dataeval/_internal/metrics/parity.py +180 -0
  28. dataeval/_internal/metrics/stats.py +332 -0
  29. dataeval/_internal/metrics/uap.py +45 -0
  30. dataeval/_internal/metrics/utils.py +158 -0
  31. dataeval/_internal/models/__init__.py +0 -0
  32. dataeval/_internal/models/pytorch/__init__.py +0 -0
  33. dataeval/_internal/models/pytorch/autoencoder.py +202 -0
  34. dataeval/_internal/models/pytorch/blocks.py +46 -0
  35. dataeval/_internal/models/pytorch/utils.py +67 -0
  36. dataeval/_internal/models/tensorflow/__init__.py +0 -0
  37. dataeval/_internal/models/tensorflow/autoencoder.py +317 -0
  38. dataeval/_internal/models/tensorflow/gmm.py +115 -0
  39. dataeval/_internal/models/tensorflow/losses.py +107 -0
  40. dataeval/_internal/models/tensorflow/pixelcnn.py +1106 -0
  41. dataeval/_internal/models/tensorflow/trainer.py +102 -0
  42. dataeval/_internal/models/tensorflow/utils.py +254 -0
  43. dataeval/_internal/workflows/sufficiency.py +555 -0
  44. dataeval/detectors/__init__.py +29 -0
  45. dataeval/flags/__init__.py +3 -0
  46. dataeval/metrics/__init__.py +7 -0
  47. dataeval/models/__init__.py +15 -0
  48. dataeval/models/tensorflow/__init__.py +6 -0
  49. dataeval/models/torch/__init__.py +8 -0
  50. dataeval/py.typed +0 -0
  51. dataeval/workflows/__init__.py +8 -0
  52. dataeval-0.61.0.dist-info/LICENSE.txt +21 -0
  53. dataeval-0.61.0.dist-info/METADATA +114 -0
  54. dataeval-0.61.0.dist-info/RECORD +55 -0
  55. dataeval-0.61.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,317 @@
1
+ """
2
+ Source code derived from Alibi-Detect 0.11.4
3
+ https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
4
+
5
+ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
+ Licensed under Apache Software License (Apache 2.0)
7
+ """
8
+
9
+ # pyright: reportIncompatibleMethodOverride=false
10
+
11
+ from typing import Callable, Tuple, cast
12
+
13
+ import keras
14
+ import tensorflow as tf
15
+ from keras.layers import (
16
+ Dense,
17
+ Flatten,
18
+ Layer,
19
+ )
20
+
21
+
22
+ def relative_euclidean_distance(x: tf.Tensor, y: tf.Tensor, eps: float = 1e-12, axis: int = -1) -> tf.Tensor:
23
+ """
24
+ Relative Euclidean distance.
25
+
26
+ Parameters
27
+ ----------
28
+ x
29
+ Tensor used in distance computation.
30
+ y
31
+ Tensor used in distance computation.
32
+ eps
33
+ Epsilon added to denominator for numerical stability.
34
+ axis
35
+ Axis used to compute distance.
36
+
37
+ Returns
38
+ -------
39
+ Tensor with relative Euclidean distance across specified axis.
40
+ """
41
+ denom = tf.concat(
42
+ [
43
+ tf.reshape(tf.norm(x, ord=2, axis=axis), (-1, 1)), # type: ignore
44
+ tf.reshape(tf.norm(y, ord=2, axis=axis), (-1, 1)), # type: ignore
45
+ ],
46
+ axis=1,
47
+ )
48
+ dist = tf.norm(tf.math.subtract(x, y), ord=2, axis=axis) / (tf.reduce_min(denom, axis=axis) + eps) # type: ignore
49
+ return dist
50
+
51
+
52
+ def eucl_cosim_features(x: tf.Tensor, y: tf.Tensor, max_eucl: float = 1e2) -> tf.Tensor:
53
+ """
54
+ Compute features extracted from the reconstructed instance using the
55
+ relative Euclidean distance and cosine similarity between 2 tensors.
56
+
57
+ Parameters
58
+ ----------
59
+ x
60
+ Tensor used in feature computation.
61
+ y
62
+ Tensor used in feature computation.
63
+ max_eucl
64
+ Maximum value to clip relative Euclidean distance by.
65
+
66
+ Returns
67
+ -------
68
+ Tensor concatenating the relative Euclidean distance and cosine similarity features.
69
+ """
70
+ if len(x.shape) > 2 or len(y.shape) > 2:
71
+ x = cast(tf.Tensor, Flatten()(x))
72
+ y = cast(tf.Tensor, Flatten()(y))
73
+ rec_cos = tf.reshape(keras.losses.cosine_similarity(y, x, -1), (-1, 1))
74
+ rec_euc = tf.reshape(relative_euclidean_distance(y, x, -1), (-1, 1))
75
+ # rec_euc could become very large so should be clipped
76
+ rec_euc = tf.clip_by_value(rec_euc, 0, max_eucl)
77
+ return cast(tf.Tensor, tf.concat([rec_cos, rec_euc], -1))
78
+
79
+
80
+ class Sampling(Layer):
81
+ """Reparametrization trick. Uses (z_mean, z_log_var) to sample the latent vector z."""
82
+
83
+ def call(self, inputs: Tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
84
+ """
85
+ Sample z.
86
+
87
+ Parameters
88
+ ----------
89
+ inputs
90
+ Tuple with mean and log variance.
91
+
92
+ Returns
93
+ -------
94
+ Sampled vector z.
95
+ """
96
+ z_mean, z_log_var = inputs
97
+ batch, dim = tuple(tf.shape(z_mean).numpy().ravel()[:2]) # type: ignore
98
+ epsilon = cast(tf.Tensor, keras.backend.random_normal(shape=(batch, dim)))
99
+ return z_mean + tf.exp(tf.math.multiply(0.5, z_log_var)) * epsilon
100
+
101
+
102
+ class EncoderAE(Layer):
103
+ def __init__(self, encoder_net: keras.Model) -> None:
104
+ """
105
+ Encoder of AE.
106
+
107
+ Parameters
108
+ ----------
109
+ encoder_net
110
+ Layers for the encoder wrapped in a keras.Sequential class.
111
+ name
112
+ Name of encoder.
113
+ """
114
+ super().__init__(name="encoder_ae")
115
+ self.encoder_net = encoder_net
116
+
117
+ def call(self, x: tf.Tensor) -> tf.Tensor:
118
+ return cast(tf.Tensor, self.encoder_net(x))
119
+
120
+
121
+ class EncoderVAE(Layer):
122
+ def __init__(self, encoder_net: keras.Model, latent_dim: int) -> None:
123
+ """
124
+ Encoder of VAE.
125
+
126
+ Parameters
127
+ ----------
128
+ encoder_net
129
+ Layers for the encoder wrapped in a keras.Sequential class.
130
+ latent_dim
131
+ Dimensionality of the latent space.
132
+ name
133
+ Name of encoder.
134
+ """
135
+ super().__init__(name="encoder_vae")
136
+ self.encoder_net = encoder_net
137
+ self.fc_mean = Dense(latent_dim, activation=None)
138
+ self.fc_log_var = Dense(latent_dim, activation=None)
139
+ self.sampling = Sampling()
140
+
141
+ def call(self, x: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
142
+ x = cast(tf.Tensor, self.encoder_net(x))
143
+ if len(x.shape) > 2:
144
+ x = cast(tf.Tensor, Flatten()(x))
145
+ z_mean = cast(tf.Tensor, self.fc_mean(x))
146
+ z_log_var = cast(tf.Tensor, self.fc_log_var(x))
147
+ z = cast(tf.Tensor, self.sampling((z_mean, z_log_var)))
148
+ return z_mean, z_log_var, z
149
+
150
+
151
+ class Decoder(Layer):
152
+ def __init__(self, decoder_net: keras.Model) -> None:
153
+ """
154
+ Decoder of AE and VAE.
155
+
156
+ Parameters
157
+ ----------
158
+ decoder_net
159
+ Layers for the decoder wrapped in a keras.Sequential class.
160
+ name
161
+ Name of decoder.
162
+ """
163
+ super().__init__(name="decoder")
164
+ self.decoder_net = decoder_net
165
+
166
+ def call(self, x: tf.Tensor) -> tf.Tensor:
167
+ return cast(tf.Tensor, self.decoder_net(x))
168
+
169
+
170
+ class AE(keras.Model):
171
+ """
172
+ Combine encoder and decoder in AE.
173
+
174
+ Parameters
175
+ ----------
176
+ encoder_net
177
+ Layers for the encoder wrapped in a keras.Sequential class.
178
+ decoder_net
179
+ Layers for the decoder wrapped in a keras.Sequential class.
180
+ """
181
+
182
+ def __init__(self, encoder_net: keras.Model, decoder_net: keras.Model) -> None:
183
+ super().__init__(name="ae")
184
+ self.encoder = EncoderAE(encoder_net)
185
+ self.decoder = Decoder(decoder_net)
186
+
187
+ def call(self, x: tf.Tensor) -> tf.Tensor:
188
+ z = cast(tf.Tensor, self.encoder(x))
189
+ x_recon = cast(tf.Tensor, self.decoder(z))
190
+ return x_recon
191
+
192
+
193
+ class VAE(keras.Model):
194
+ """
195
+ Combine encoder and decoder in VAE.
196
+
197
+ Parameters
198
+ ----------
199
+ encoder_net
200
+ Layers for the encoder wrapped in a keras.Sequential class.
201
+ decoder_net
202
+ Layers for the decoder wrapped in a keras.Sequential class.
203
+ latent_dim
204
+ Dimensionality of the latent space.
205
+ beta
206
+ Beta parameter for KL-divergence loss term.
207
+ """
208
+
209
+ def __init__(self, encoder_net: keras.Model, decoder_net: keras.Model, latent_dim: int, beta: float = 1.0) -> None:
210
+ super().__init__(name="vae_model")
211
+ self.encoder = EncoderVAE(encoder_net, latent_dim)
212
+ self.decoder = Decoder(decoder_net)
213
+ self.beta = beta
214
+ self.latent_dim = latent_dim
215
+
216
+ def call(self, x: tf.Tensor) -> tf.Tensor:
217
+ z_mean, z_log_var, z = cast(Tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.encoder(x))
218
+ x_recon = self.decoder(z)
219
+ # add KL divergence loss term
220
+ kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)
221
+ self.add_loss(self.beta * kl_loss)
222
+ return cast(tf.Tensor, x_recon)
223
+
224
+
225
+ class AEGMM(keras.Model):
226
+ """
227
+ Deep Autoencoding Gaussian Mixture Model.
228
+
229
+ Parameters
230
+ ----------
231
+ encoder_net
232
+ Layers for the encoder wrapped in a keras.Sequential class.
233
+ decoder_net
234
+ Layers for the decoder wrapped in a keras.Sequential class.
235
+ gmm_density_net
236
+ Layers for the GMM network wrapped in a keras.Sequential class.
237
+ n_gmm
238
+ Number of components in GMM.
239
+ recon_features
240
+ Function to extract features from the reconstructed instance by the decoder.
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ encoder_net: keras.Model,
246
+ decoder_net: keras.Model,
247
+ gmm_density_net: keras.Model,
248
+ n_gmm: int,
249
+ recon_features: Callable = eucl_cosim_features,
250
+ ) -> None:
251
+ super().__init__("aegmm")
252
+ self.encoder = encoder_net
253
+ self.decoder = decoder_net
254
+ self.gmm_density = gmm_density_net
255
+ self.n_gmm = n_gmm
256
+ self.recon_features = recon_features
257
+
258
+ def call(self, x: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
259
+ enc = self.encoder(x)
260
+ x_recon = cast(tf.Tensor, self.decoder(enc))
261
+ recon_features = self.recon_features(x, x_recon)
262
+ z = cast(tf.Tensor, tf.concat([enc, recon_features], -1))
263
+ gamma = cast(tf.Tensor, self.gmm_density(z))
264
+ return x_recon, z, gamma
265
+
266
+
267
+ class VAEGMM(keras.Model):
268
+ """
269
+ Variational Autoencoding Gaussian Mixture Model.
270
+
271
+ Parameters
272
+ ----------
273
+ encoder_net
274
+ Layers for the encoder wrapped in a keras.Sequential class.
275
+ decoder_net
276
+ Layers for the decoder wrapped in a keras.Sequential class.
277
+ gmm_density_net
278
+ Layers for the GMM network wrapped in a keras.Sequential class.
279
+ n_gmm
280
+ Number of components in GMM.
281
+ latent_dim
282
+ Dimensionality of the latent space.
283
+ recon_features
284
+ Function to extract features from the reconstructed instance by the decoder.
285
+ beta
286
+ Beta parameter for KL-divergence loss term.
287
+ """
288
+
289
+ def __init__(
290
+ self,
291
+ encoder_net: keras.Model,
292
+ decoder_net: keras.Model,
293
+ gmm_density_net: keras.Model,
294
+ n_gmm: int,
295
+ latent_dim: int,
296
+ recon_features: Callable = eucl_cosim_features,
297
+ beta: float = 1.0,
298
+ ) -> None:
299
+ super().__init__(name="vaegmm")
300
+ self.encoder = EncoderVAE(encoder_net, latent_dim)
301
+ self.decoder = decoder_net
302
+ self.gmm_density = gmm_density_net
303
+ self.n_gmm = n_gmm
304
+ self.latent_dim = latent_dim
305
+ self.recon_features = recon_features
306
+ self.beta = beta
307
+
308
+ def call(self, x: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
309
+ enc_mean, enc_log_var, enc = cast(Tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.encoder(x))
310
+ x_recon = cast(tf.Tensor, self.decoder(enc))
311
+ recon_features = self.recon_features(x, x_recon)
312
+ z = cast(tf.Tensor, tf.concat([enc, recon_features], -1))
313
+ gamma = cast(tf.Tensor, self.gmm_density(z))
314
+ # add KL divergence loss term
315
+ kl_loss = -0.5 * tf.reduce_mean(enc_log_var - tf.square(enc_mean) - tf.exp(enc_log_var) + 1)
316
+ self.add_loss(self.beta * kl_loss)
317
+ return x_recon, z, gamma
@@ -0,0 +1,115 @@
1
+ """
2
+ Source code derived from Alibi-Detect 0.11.4
3
+ https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
4
+
5
+ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
+ Licensed under Apache Software License (Apache 2.0)
7
+ """
8
+
9
+ from typing import NamedTuple, Tuple
10
+
11
+ import numpy as np
12
+ import tensorflow as tf
13
+
14
+
15
+ class GaussianMixtureModelParams(NamedTuple):
16
+ """
17
+ phi : tf.Tensor
18
+ Mixture component distribution weights.
19
+ mu : tf.Tensor
20
+ Mixture means.
21
+ cov : tf.Tensor
22
+ Mixture covariance.
23
+ L : tf.Tensor
24
+ Cholesky decomposition of `cov`.
25
+ log_det_cov : tf.Tensor
26
+ Log of the determinant of `cov`.
27
+ """
28
+
29
+ phi: tf.Tensor
30
+ mu: tf.Tensor
31
+ cov: tf.Tensor
32
+ L: tf.Tensor
33
+ log_det_cov: tf.Tensor
34
+
35
+
36
+ def gmm_params(z: tf.Tensor, gamma: tf.Tensor) -> GaussianMixtureModelParams:
37
+ """
38
+ Compute parameters of Gaussian Mixture Model.
39
+
40
+ Parameters
41
+ ----------
42
+ z : tf.Tensor
43
+ Observations.
44
+ gamma : tf.Tensor
45
+ Mixture probabilities to derive mixture distribution weights from.
46
+
47
+ Returns
48
+ -------
49
+ GaussianMixtureModelParams(phi, mu, cov, L, log_det_cov)
50
+ The parameters used to calculate energy.
51
+ """
52
+ # compute gmm parameters phi, mu and cov
53
+ N = gamma.shape[0] # nb of samples in batch
54
+ sum_gamma = tf.reduce_sum(gamma, 0) # K
55
+ phi = sum_gamma / N # K
56
+ mu = tf.reduce_sum(tf.expand_dims(gamma, -1) * tf.expand_dims(z, 1), 0) / tf.expand_dims(
57
+ sum_gamma, -1
58
+ ) # K x D (D = latent_dim)
59
+ z_mu = tf.expand_dims(z, 1) - tf.expand_dims(mu, 0) # N x K x D
60
+ z_mu_outer = tf.expand_dims(z_mu, -1) * tf.expand_dims(z_mu, -2) # N x K x D x D
61
+ cov = tf.reduce_sum(tf.expand_dims(tf.expand_dims(gamma, -1), -1) * z_mu_outer, 0) / tf.expand_dims(
62
+ tf.expand_dims(sum_gamma, -1), -1
63
+ ) # K x D x D
64
+
65
+ # cholesky decomposition of covariance and determinant derivation
66
+ D = tf.shape(cov)[1] # type: ignore
67
+ eps = 1e-6
68
+ L = tf.linalg.cholesky(cov + tf.eye(D) * eps) # K x D x D
69
+ log_det_cov = 2.0 * tf.reduce_sum(tf.math.log(tf.linalg.diag_part(L)), 1) # K
70
+
71
+ return GaussianMixtureModelParams(phi, mu, cov, L, log_det_cov)
72
+
73
+
74
+ def gmm_energy(
75
+ z: tf.Tensor,
76
+ params: GaussianMixtureModelParams,
77
+ return_mean: bool = True,
78
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
79
+ """
80
+ Compute sample energy from Gaussian Mixture Model.
81
+
82
+ Parameters
83
+ ----------
84
+ params : GaussianMixtureModelParams
85
+ The gaussian mixture model parameters.
86
+ return_mean : bool, default True
87
+ Take mean across all sample energies in a batch.
88
+
89
+ Returns
90
+ -------
91
+ sample_energy
92
+ The sample energy of the GMM.
93
+ cov_diag
94
+ The inverse sum of the diagonal components of the covariance matrix.
95
+ """
96
+ D = tf.shape(params.cov)[1] # type: ignore
97
+ z_mu = tf.expand_dims(z, 1) - tf.expand_dims(params.mu, 0) # N x K x D
98
+ z_mu_T = tf.transpose(z_mu, perm=[1, 2, 0]) # K x D x N
99
+ v = tf.linalg.triangular_solve(params.L, z_mu_T, lower=True) # K x D x D
100
+
101
+ # rewrite sample energy in logsumexp format for numerical stability
102
+ logits = tf.math.log(tf.expand_dims(params.phi, -1)) - 0.5 * (
103
+ tf.reduce_sum(tf.square(v), 1)
104
+ + tf.cast(D, tf.float32) * tf.math.log(2.0 * np.pi) # type: ignore py38
105
+ + tf.expand_dims(params.log_det_cov, -1)
106
+ ) # K x N
107
+ sample_energy = -tf.reduce_logsumexp(logits, axis=0) # N
108
+
109
+ if return_mean:
110
+ sample_energy = tf.reduce_mean(sample_energy)
111
+
112
+ # inverse sum of variances
113
+ cov_diag = tf.reduce_sum(tf.divide(1, tf.linalg.diag_part(params.cov)))
114
+
115
+ return sample_energy, cov_diag
@@ -0,0 +1,107 @@
1
+ """
2
+ Source code derived from Alibi-Detect 0.11.4
3
+ https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
4
+
5
+ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
+ Licensed under Apache Software License (Apache 2.0)
7
+ """
8
+
9
+ from typing import Literal, Optional, Union, cast
10
+
11
+ import numpy as np
12
+ import tensorflow as tf
13
+ from keras.layers import Flatten
14
+ from tensorflow_probability.python.distributions.mvn_diag import MultivariateNormalDiag
15
+ from tensorflow_probability.python.distributions.mvn_tril import MultivariateNormalTriL
16
+ from tensorflow_probability.python.stats import covariance
17
+
18
+ from dataeval._internal.models.tensorflow.gmm import gmm_energy, gmm_params
19
+
20
+
21
+ class Elbo:
22
+ """
23
+ Compute ELBO loss. The covariance matrix can be specified by passing the full covariance matrix, the matrix
24
+ diagonal, or a scale identity multiplier. Only one of these should be specified. If none are specified, the
25
+ identity matrix is used.
26
+
27
+ Parameters
28
+ ----------
29
+ cov_type
30
+ Full covariance matrix, diagonal variance matrix, or scale identity multiplier.
31
+ x
32
+ Dataset used to calculate the covariance matrix. Required for full and diagonal covariance matrix types.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ cov_type: Union[Literal["cov_full", "cov_diag"], float] = 1.0,
38
+ x: Optional[Union[tf.Tensor, np.ndarray]] = None,
39
+ ):
40
+ if isinstance(cov_type, float):
41
+ self.cov = ("sim", cov_type)
42
+ elif cov_type in ["cov_full", "cov_diag"]:
43
+ x_np: np.ndarray = x.numpy() if tf.is_tensor(x) else x # type: ignore
44
+ cov = covariance(x_np.reshape(x_np.shape[0], -1)) # type: ignore py38
45
+ if cov_type == "cov_diag": # infer standard deviation from covariance matrix
46
+ cov = tf.math.sqrt(tf.linalg.diag_part(cov))
47
+ self.cov = (cov_type, cov)
48
+ else:
49
+ raise ValueError("Only cov_full, cov_diag or sim value should be specified.")
50
+
51
+ def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
52
+ y_pred_flat = cast(tf.Tensor, Flatten()(y_pred))
53
+
54
+ if self.cov[0] == "cov_full":
55
+ y_mn = MultivariateNormalTriL(y_pred_flat, scale_tril=tf.linalg.cholesky(self.cov[1]))
56
+ else: # cov_diag and sim
57
+ cov_diag = self.cov[1] if self.cov[0] == "cov_diag" else self.cov[1] * tf.ones(y_pred_flat.shape[-1])
58
+ y_mn = MultivariateNormalDiag(y_pred_flat, scale_diag=cov_diag)
59
+
60
+ loss = -tf.reduce_mean(y_mn.log_prob(Flatten()(y_true)))
61
+ return loss
62
+
63
+
64
+ class LossGMM:
65
+ """
66
+ Loss function used for AE and VAE with GMM.
67
+
68
+ Parameters
69
+ ----------
70
+ w_recon
71
+ Weight on elbo loss term.
72
+ w_energy
73
+ Weight on sample energy loss term.
74
+ w_cov_diag
75
+ Weight on covariance regularizing loss term.
76
+ elbo
77
+ ELBO loss function used to calculate w_recon.
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ w_recon: float = 1e-7,
83
+ w_energy: float = 0.1,
84
+ w_cov_diag: float = 0.005,
85
+ elbo: Optional[Elbo] = None,
86
+ ):
87
+ self.w_recon = w_recon
88
+ self.w_energy = w_energy
89
+ self.w_cov_diag = w_cov_diag
90
+ self.elbo = elbo
91
+
92
+ def __call__(
93
+ self,
94
+ x_true: tf.Tensor,
95
+ x_pred: tf.Tensor,
96
+ z: tf.Tensor,
97
+ gamma: tf.Tensor,
98
+ ) -> tf.Tensor:
99
+ w_recon = (
100
+ tf.reduce_mean(tf.subtract(x_true, x_pred) ** 2)
101
+ if self.elbo is None
102
+ else tf.multiply(self.w_recon, self.elbo(x_true, x_pred))
103
+ )
104
+ sample_energy, cov_diag = gmm_energy(z, gmm_params(z, gamma))
105
+ w_energy = tf.multiply(self.w_energy, sample_energy)
106
+ w_cov_diag = tf.multiply(self.w_cov_diag, cov_diag)
107
+ return w_recon + w_energy + w_cov_diag