dataeval 0.72.1__py3-none-any.whl → 0.73.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +4 -4
- dataeval/detectors/__init__.py +4 -3
- dataeval/detectors/drift/__init__.py +9 -10
- dataeval/{_internal/detectors → detectors}/drift/base.py +39 -91
- dataeval/{_internal/detectors → detectors}/drift/cvm.py +4 -3
- dataeval/{_internal/detectors → detectors}/drift/ks.py +4 -3
- dataeval/{_internal/detectors → detectors}/drift/mmd.py +23 -25
- dataeval/{_internal/detectors → detectors}/drift/torch.py +13 -11
- dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +7 -5
- dataeval/detectors/drift/updates.py +61 -0
- dataeval/detectors/linters/__init__.py +3 -3
- dataeval/{_internal/detectors → detectors/linters}/clusterer.py +41 -39
- dataeval/{_internal/detectors → detectors/linters}/duplicates.py +19 -9
- dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
- dataeval/{_internal/detectors → detectors/linters}/outliers.py +14 -21
- dataeval/detectors/ood/__init__.py +6 -6
- dataeval/{_internal/detectors → detectors}/ood/ae.py +20 -12
- dataeval/detectors/ood/aegmm.py +66 -0
- dataeval/{_internal/detectors → detectors}/ood/base.py +33 -21
- dataeval/{_internal/detectors → detectors}/ood/llr.py +43 -33
- dataeval/detectors/ood/metadata_ks_compare.py +99 -0
- dataeval/detectors/ood/metadata_least_likely.py +119 -0
- dataeval/detectors/ood/metadata_ood_mi.py +92 -0
- dataeval/{_internal/detectors → detectors}/ood/vae.py +23 -17
- dataeval/detectors/ood/vaegmm.py +75 -0
- dataeval/interop.py +56 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +4 -4
- dataeval/{_internal/metrics → metrics/bias}/balance.py +75 -13
- dataeval/{_internal/metrics → metrics/bias}/coverage.py +41 -7
- dataeval/{_internal/metrics → metrics/bias}/diversity.py +75 -18
- dataeval/metrics/bias/metadata.py +358 -0
- dataeval/{_internal/metrics → metrics/bias}/parity.py +54 -44
- dataeval/metrics/estimators/__init__.py +3 -3
- dataeval/{_internal/metrics → metrics/estimators}/ber.py +25 -22
- dataeval/{_internal/metrics → metrics/estimators}/divergence.py +11 -12
- dataeval/{_internal/metrics → metrics/estimators}/uap.py +5 -3
- dataeval/metrics/stats/__init__.py +7 -7
- dataeval/{_internal/metrics → metrics}/stats/base.py +59 -35
- dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +18 -14
- dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +18 -16
- dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +9 -7
- dataeval/metrics/stats/hashstats.py +156 -0
- dataeval/{_internal/metrics → metrics}/stats/labelstats.py +5 -3
- dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +9 -8
- dataeval/{_internal/metrics → metrics}/stats/visualstats.py +10 -9
- dataeval/{_internal/output.py → output.py} +26 -6
- dataeval/utils/__init__.py +8 -3
- dataeval/utils/image.py +71 -0
- dataeval/utils/lazy.py +26 -0
- dataeval/utils/metadata.py +258 -0
- dataeval/utils/shared.py +151 -0
- dataeval/{_internal → utils}/split_dataset.py +98 -33
- dataeval/utils/tensorflow/__init__.py +7 -6
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +8 -2
- dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +28 -18
- dataeval/{_internal/models/tensorflow/pixelcnn.py → utils/tensorflow/_internal/models.py} +387 -97
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +15 -6
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +84 -85
- dataeval/utils/tensorflow/loss/__init__.py +6 -2
- dataeval/utils/torch/__init__.py +7 -3
- dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
- dataeval/{_internal → utils/torch}/datasets.py +48 -42
- dataeval/utils/torch/models.py +138 -0
- dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +7 -136
- dataeval/{_internal → utils/torch}/utils.py +3 -1
- dataeval/workflows/__init__.py +1 -1
- dataeval/{_internal/workflows → workflows}/sufficiency.py +39 -34
- {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/METADATA +4 -3
- dataeval-0.73.0.dist-info/RECORD +73 -0
- dataeval/_internal/detectors/__init__.py +0 -0
- dataeval/_internal/detectors/drift/__init__.py +0 -0
- dataeval/_internal/detectors/ood/__init__.py +0 -0
- dataeval/_internal/detectors/ood/aegmm.py +0 -78
- dataeval/_internal/detectors/ood/vaegmm.py +0 -89
- dataeval/_internal/interop.py +0 -49
- dataeval/_internal/metrics/__init__.py +0 -0
- dataeval/_internal/metrics/stats/hashstats.py +0 -75
- dataeval/_internal/metrics/utils.py +0 -447
- dataeval/_internal/models/__init__.py +0 -0
- dataeval/_internal/models/pytorch/__init__.py +0 -0
- dataeval/_internal/models/pytorch/utils.py +0 -67
- dataeval/_internal/models/tensorflow/__init__.py +0 -0
- dataeval/_internal/models/tensorflow/autoencoder.py +0 -320
- dataeval/_internal/workflows/__init__.py +0 -0
- dataeval/detectors/drift/kernels/__init__.py +0 -10
- dataeval/detectors/drift/updates/__init__.py +0 -8
- dataeval/utils/tensorflow/models/__init__.py +0 -9
- dataeval/utils/tensorflow/recon/__init__.py +0 -3
- dataeval/utils/torch/datasets/__init__.py +0 -12
- dataeval/utils/torch/models/__init__.py +0 -11
- dataeval/utils/torch/trainer/__init__.py +0 -7
- dataeval-0.72.1.dist-info/RECORD +0 -81
- {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/WHEEL +0 -0
@@ -12,35 +12,325 @@ from __future__ import annotations
|
|
12
12
|
|
13
13
|
import functools
|
14
14
|
import warnings
|
15
|
+
from typing import TYPE_CHECKING, cast
|
15
16
|
|
16
17
|
import numpy as np
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
)
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
18
|
+
|
19
|
+
from dataeval.utils.lazy import lazyload
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
import tensorflow as tf
|
23
|
+
import tensorflow_probability.python.bijectors as bijectors
|
24
|
+
import tensorflow_probability.python.distributions as distributions
|
25
|
+
import tensorflow_probability.python.internal as tfp_internal
|
26
|
+
import tf_keras as keras
|
27
|
+
else:
|
28
|
+
tf = lazyload("tensorflow")
|
29
|
+
bijectors = lazyload("tensorflow_probability.python.bijectors")
|
30
|
+
distributions = lazyload("tensorflow_probability.python.distributions")
|
31
|
+
tfp_internal = lazyload("tensorflow_probability.python.internal")
|
32
|
+
keras = lazyload("tf_keras")
|
33
|
+
|
34
|
+
|
35
|
+
def relative_euclidean_distance(x: tf.Tensor, y: tf.Tensor, eps: float = 1e-12, axis: int = -1) -> tf.Tensor:
|
36
|
+
"""
|
37
|
+
Relative Euclidean distance.
|
38
|
+
|
39
|
+
Parameters
|
40
|
+
----------
|
41
|
+
x
|
42
|
+
Tensor used in distance computation.
|
43
|
+
y
|
44
|
+
Tensor used in distance computation.
|
45
|
+
eps
|
46
|
+
Epsilon added to denominator for numerical stability.
|
47
|
+
axis
|
48
|
+
Axis used to compute distance.
|
49
|
+
|
50
|
+
Returns
|
51
|
+
-------
|
52
|
+
Tensor with relative Euclidean distance across specified axis.
|
53
|
+
"""
|
54
|
+
denom = tf.concat(
|
55
|
+
[
|
56
|
+
tf.reshape(tf.norm(x, ord=2, axis=axis), (-1, 1)), # type: ignore
|
57
|
+
tf.reshape(tf.norm(y, ord=2, axis=axis), (-1, 1)), # type: ignore
|
58
|
+
],
|
59
|
+
axis=1,
|
60
|
+
)
|
61
|
+
dist = tf.norm(tf.math.subtract(x, y), ord=2, axis=axis) / (tf.reduce_min(denom, axis=axis) + eps) # type: ignore
|
62
|
+
return dist
|
63
|
+
|
64
|
+
|
65
|
+
def eucl_cosim_features(x: tf.Tensor, y: tf.Tensor, max_eucl: float = 1e2) -> tf.Tensor:
|
66
|
+
"""
|
67
|
+
Compute features extracted from the reconstructed instance using the
|
68
|
+
relative Euclidean distance and cosine similarity between 2 tensors.
|
69
|
+
|
70
|
+
Parameters
|
71
|
+
----------
|
72
|
+
x : tf.Tensor
|
73
|
+
Tensor used in feature computation.
|
74
|
+
y : tf.Tensor
|
75
|
+
Tensor used in feature computation.
|
76
|
+
max_eucl : float, default 1e2
|
77
|
+
Maximum value to clip relative Euclidean distance by.
|
78
|
+
|
79
|
+
Returns
|
80
|
+
-------
|
81
|
+
tf.Tensor
|
82
|
+
Tensor concatenating the relative Euclidean distance and cosine similarity features.
|
83
|
+
"""
|
84
|
+
if len(x.shape) > 2 or len(y.shape) > 2:
|
85
|
+
x = cast(tf.Tensor, keras.layers.Flatten()(x))
|
86
|
+
y = cast(tf.Tensor, keras.layers.Flatten()(y))
|
87
|
+
rec_cos = tf.reshape(keras.losses.cosine_similarity(y, x, -1), (-1, 1))
|
88
|
+
rec_euc = tf.reshape(relative_euclidean_distance(y, x, -1), (-1, 1))
|
89
|
+
# rec_euc could become very large so should be clipped
|
90
|
+
rec_euc = tf.clip_by_value(rec_euc, 0, max_eucl)
|
91
|
+
return cast(tf.Tensor, tf.concat([rec_cos, rec_euc], -1))
|
92
|
+
|
93
|
+
|
94
|
+
class Sampling(keras.layers.Layer):
|
95
|
+
"""Reparametrization trick - Uses (z_mean, z_log_var) to sample the latent vector z."""
|
96
|
+
|
97
|
+
def call(self, inputs: tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
|
98
|
+
"""
|
99
|
+
Sample z.
|
100
|
+
|
101
|
+
Parameters
|
102
|
+
----------
|
103
|
+
inputs
|
104
|
+
Tuple with mean and log :term:`variance<Variance>`.
|
105
|
+
|
106
|
+
Returns
|
107
|
+
-------
|
108
|
+
Sampled vector z.
|
109
|
+
"""
|
110
|
+
z_mean, z_log_var = inputs
|
111
|
+
batch, dim = tuple(tf.shape(z_mean).numpy().ravel()[:2]) # type: ignore
|
112
|
+
epsilon = cast(tf.Tensor, keras.backend.random_normal(shape=(batch, dim)))
|
113
|
+
return z_mean + tf.exp(tf.math.multiply(0.5, z_log_var)) * epsilon
|
114
|
+
|
115
|
+
|
116
|
+
class EncoderAE(keras.layers.Layer):
|
117
|
+
def __init__(self, encoder_net: keras.Sequential) -> None:
|
118
|
+
"""
|
119
|
+
Encoder of AE.
|
120
|
+
|
121
|
+
Parameters
|
122
|
+
----------
|
123
|
+
encoder_net
|
124
|
+
Layers for the encoder wrapped in a keras.keras.Sequential class.
|
125
|
+
name
|
126
|
+
Name of encoder.
|
127
|
+
"""
|
128
|
+
super().__init__(name="encoder_ae")
|
129
|
+
self.encoder_net: keras.Sequential = encoder_net
|
130
|
+
|
131
|
+
def call(self, x: tf.Tensor) -> tf.Tensor:
|
132
|
+
return cast(tf.Tensor, self.encoder_net(x))
|
133
|
+
|
134
|
+
|
135
|
+
class EncoderVAE(keras.layers.Layer):
|
136
|
+
def __init__(self, encoder_net: keras.Sequential, latent_dim: int) -> None:
|
137
|
+
"""
|
138
|
+
Encoder of VAE.
|
139
|
+
|
140
|
+
Parameters
|
141
|
+
----------
|
142
|
+
encoder_net
|
143
|
+
Layers for the encoder wrapped in a keras.keras.Sequential class.
|
144
|
+
latent_dim
|
145
|
+
Dimensionality of the :term:`latent space<Latent Space>`.
|
146
|
+
name
|
147
|
+
Name of encoder.
|
148
|
+
"""
|
149
|
+
super().__init__(name="encoder_vae")
|
150
|
+
self.encoder_net: keras.Sequential = encoder_net
|
151
|
+
self._fc_mean = keras.layers.Dense(latent_dim, activation=None)
|
152
|
+
self._fc_log_var = keras.layers.Dense(latent_dim, activation=None)
|
153
|
+
self._sampling = Sampling()
|
154
|
+
|
155
|
+
def call(self, x: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
156
|
+
x = cast(tf.Tensor, self.encoder_net(x))
|
157
|
+
if len(x.shape) > 2:
|
158
|
+
x = cast(tf.Tensor, keras.layers.Flatten()(x))
|
159
|
+
z_mean = cast(tf.Tensor, self._fc_mean(x))
|
160
|
+
z_log_var = cast(tf.Tensor, self._fc_log_var(x))
|
161
|
+
z = cast(tf.Tensor, self._sampling((z_mean, z_log_var)))
|
162
|
+
return z_mean, z_log_var, z
|
163
|
+
|
164
|
+
|
165
|
+
class Decoder(keras.layers.Layer):
|
166
|
+
def __init__(self, decoder_net: keras.Sequential) -> None:
|
167
|
+
"""
|
168
|
+
Decoder of AE and VAE.
|
169
|
+
|
170
|
+
Parameters
|
171
|
+
----------
|
172
|
+
decoder_net
|
173
|
+
Layers for the decoder wrapped in a keras.keras.Sequential class.
|
174
|
+
name
|
175
|
+
Name of decoder.
|
176
|
+
"""
|
177
|
+
super().__init__(name="decoder")
|
178
|
+
self.decoder_net: keras.Sequential = decoder_net
|
179
|
+
|
180
|
+
def call(self, inputs: tf.Tensor) -> tf.Tensor:
|
181
|
+
return cast(tf.Tensor, self.decoder_net(inputs))
|
182
|
+
|
183
|
+
|
184
|
+
class AE(keras.Model):
|
185
|
+
"""
|
186
|
+
Combine encoder and decoder in AE.
|
187
|
+
|
188
|
+
Parameters
|
189
|
+
----------
|
190
|
+
encoder_net : keras.Sequential
|
191
|
+
Layers for the encoder wrapped in a keras.keras.Sequential class.
|
192
|
+
decoder_net : keras.Sequential
|
193
|
+
Layers for the decoder wrapped in a keras.keras.Sequential class.
|
194
|
+
"""
|
195
|
+
|
196
|
+
def __init__(self, encoder_net: keras.Sequential, decoder_net: keras.Sequential) -> None:
|
197
|
+
super().__init__(name="ae")
|
198
|
+
self.encoder: keras.layers.Layer = EncoderAE(encoder_net)
|
199
|
+
self.decoder: keras.layers.Layer = Decoder(decoder_net)
|
200
|
+
|
201
|
+
def call(self, inputs: tf.Tensor, training: bool | None = None, mask: tf.Tensor | None = None) -> tf.Tensor:
|
202
|
+
z = cast(tf.Tensor, self.encoder(inputs))
|
203
|
+
x_recon = cast(tf.Tensor, self.decoder(z))
|
204
|
+
return x_recon
|
205
|
+
|
206
|
+
|
207
|
+
class VAE(keras.Model):
|
208
|
+
"""
|
209
|
+
Combine encoder and decoder in VAE.
|
210
|
+
|
211
|
+
Parameters
|
212
|
+
----------
|
213
|
+
encoder_net : keras.Sequential
|
214
|
+
Layers for the encoder wrapped in a keras.keras.Sequential class.
|
215
|
+
decoder_net : keras.Sequential
|
216
|
+
Layers for the decoder wrapped in a keras.keras.Sequential class.
|
217
|
+
latent_dim : int
|
218
|
+
Dimensionality of the :term:`latent space<Latent Space>`.
|
219
|
+
beta : float, default 1.0
|
220
|
+
Beta parameter for KL-divergence loss term.
|
221
|
+
"""
|
222
|
+
|
223
|
+
def __init__(
|
224
|
+
self, encoder_net: keras.Sequential, decoder_net: keras.Sequential, latent_dim: int, beta: float = 1.0
|
225
|
+
) -> None:
|
226
|
+
super().__init__(name="vae_model")
|
227
|
+
self.encoder: keras.layers.Layer = EncoderVAE(encoder_net, latent_dim)
|
228
|
+
self.decoder: keras.layers.Layer = Decoder(decoder_net)
|
229
|
+
self.beta: float = beta
|
230
|
+
self.latent_dim: int = latent_dim
|
231
|
+
|
232
|
+
def call(self, inputs: tf.Tensor, training: bool | None = None, mask: tf.Tensor | None = None) -> tf.Tensor:
|
233
|
+
z_mean, z_log_var, z = cast(tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.encoder(inputs))
|
234
|
+
x_recon = self.decoder(z)
|
235
|
+
# add KL divergence loss term
|
236
|
+
kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)
|
237
|
+
self.add_loss(self.beta * kl_loss)
|
238
|
+
return cast(tf.Tensor, x_recon)
|
239
|
+
|
240
|
+
|
241
|
+
class AEGMM(keras.Model):
|
242
|
+
"""
|
243
|
+
Deep Autoencoding Gaussian Mixture Model.
|
244
|
+
|
245
|
+
Parameters
|
246
|
+
----------
|
247
|
+
encoder_net : keras.Sequential
|
248
|
+
Layers for the encoder wrapped in a keras.keras.Sequential class.
|
249
|
+
decoder_net : keras.Sequential
|
250
|
+
Layers for the decoder wrapped in a keras.keras.Sequential class.
|
251
|
+
gmm_density_net : keras.Sequential
|
252
|
+
Layers for the GMM network wrapped in a keras.keras.Sequential class.
|
253
|
+
n_gmm : int
|
254
|
+
Number of components in GMM.
|
255
|
+
"""
|
256
|
+
|
257
|
+
def __init__(
|
258
|
+
self,
|
259
|
+
encoder_net: keras.Sequential,
|
260
|
+
decoder_net: keras.Sequential,
|
261
|
+
gmm_density_net: keras.Sequential,
|
262
|
+
n_gmm: int,
|
263
|
+
) -> None:
|
264
|
+
super().__init__("aegmm")
|
265
|
+
self.encoder = encoder_net
|
266
|
+
self.decoder = decoder_net
|
267
|
+
self.gmm_density = gmm_density_net
|
268
|
+
self.n_gmm = n_gmm
|
269
|
+
|
270
|
+
def call(
|
271
|
+
self, inputs: tf.Tensor, training: bool | None = None, mask: tf.Tensor | None = None
|
272
|
+
) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
273
|
+
enc = self.encoder(inputs)
|
274
|
+
x_recon = cast(tf.Tensor, self.decoder(enc))
|
275
|
+
recon_features = eucl_cosim_features(inputs, x_recon)
|
276
|
+
z = cast(tf.Tensor, tf.concat([enc, recon_features], -1))
|
277
|
+
gamma = cast(tf.Tensor, self.gmm_density(z))
|
278
|
+
return x_recon, z, gamma
|
279
|
+
|
280
|
+
|
281
|
+
class VAEGMM(keras.Model):
|
282
|
+
"""
|
283
|
+
Variational Autoencoding Gaussian Mixture Model.
|
284
|
+
|
285
|
+
Parameters
|
286
|
+
----------
|
287
|
+
encoder_net : keras.Sequential
|
288
|
+
Layers for the encoder wrapped in a keras.keras.Sequential class.
|
289
|
+
decoder_net : keras.Sequential
|
290
|
+
Layers for the decoder wrapped in a keras.keras.Sequential class.
|
291
|
+
gmm_density_net : keras.Sequential
|
292
|
+
Layers for the GMM network wrapped in a keras.keras.Sequential class.
|
293
|
+
n_gmm : int
|
294
|
+
Number of components in GMM.
|
295
|
+
latent_dim : int
|
296
|
+
Dimensionality of the :term:`latent space<Latent Space>`.
|
297
|
+
beta : float, default 1.0
|
298
|
+
Beta parameter for KL-divergence loss term.
|
299
|
+
"""
|
300
|
+
|
301
|
+
def __init__(
|
302
|
+
self,
|
303
|
+
encoder_net: keras.Sequential,
|
304
|
+
decoder_net: keras.Sequential,
|
305
|
+
gmm_density_net: keras.Sequential,
|
306
|
+
n_gmm: int,
|
307
|
+
latent_dim: int,
|
308
|
+
beta: float = 1.0,
|
309
|
+
) -> None:
|
310
|
+
super().__init__(name="vaegmm")
|
311
|
+
self.encoder: keras.Sequential = EncoderVAE(encoder_net, latent_dim)
|
312
|
+
self.decoder: keras.Sequential = decoder_net
|
313
|
+
self.gmm_density: keras.Sequential = gmm_density_net
|
314
|
+
self.n_gmm: int = n_gmm
|
315
|
+
self.latent_dim: int = latent_dim
|
316
|
+
self.beta = beta
|
317
|
+
|
318
|
+
def call(
|
319
|
+
self, inputs: tf.Tensor, training: bool | None = None, mask: tf.Tensor | None = None
|
320
|
+
) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
321
|
+
enc_mean, enc_log_var, enc = cast(tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.encoder(inputs))
|
322
|
+
x_recon = cast(tf.Tensor, self.decoder(enc))
|
323
|
+
recon_features = eucl_cosim_features(inputs, x_recon)
|
324
|
+
z = cast(tf.Tensor, tf.concat([enc, recon_features], -1))
|
325
|
+
gamma = cast(tf.Tensor, self.gmm_density(z))
|
326
|
+
# add KL divergence loss term
|
327
|
+
kl_loss = -0.5 * tf.reduce_mean(enc_log_var - tf.square(enc_mean) - tf.exp(enc_log_var) + 1)
|
328
|
+
self.add_loss(self.beta * kl_loss)
|
329
|
+
return x_recon, z, gamma
|
40
330
|
|
41
331
|
|
42
332
|
class WeightNorm(keras.layers.Wrapper):
|
43
|
-
def __init__(self, layer, data_init: bool = True, **kwargs):
|
333
|
+
def __init__(self, layer, data_init: bool = True, **kwargs) -> None:
|
44
334
|
"""Layer wrapper to decouple magnitude and direction of the layer's weights.
|
45
335
|
|
46
336
|
This wrapper reparameterizes a layer by decoupling the weight's
|
@@ -186,8 +476,8 @@ class WeightNorm(keras.layers.Wrapper):
|
|
186
476
|
return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list())
|
187
477
|
|
188
478
|
|
189
|
-
class Shift(
|
190
|
-
def __init__(self, shift, validate_args=False, name="shift"):
|
479
|
+
class Shift(bijectors.Bijector):
|
480
|
+
def __init__(self, shift, validate_args=False, name="shift") -> None:
|
191
481
|
"""Instantiates the `Shift` bijector which computes `Y = g(X; shift) = X + shift`
|
192
482
|
where `shift` is a numeric `Tensor`.
|
193
483
|
|
@@ -201,8 +491,8 @@ class Shift(bijector.Bijector):
|
|
201
491
|
Python `str` name given to ops managed by this object.
|
202
492
|
"""
|
203
493
|
with tf.name_scope(name) as name:
|
204
|
-
dtype = dtype_util.common_dtype([shift], dtype_hint=tf.float32)
|
205
|
-
self._shift = tensor_util.convert_nonref_to_tensor(shift, dtype=dtype, name="shift")
|
494
|
+
dtype = tfp_internal.dtype_util.common_dtype([shift], dtype_hint=tf.float32)
|
495
|
+
self._shift = tfp_internal.tensor_util.convert_nonref_to_tensor(shift, dtype=dtype, name="shift")
|
206
496
|
super().__init__(
|
207
497
|
forward_min_event_ndims=0,
|
208
498
|
is_constant_jacobian=True,
|
@@ -230,12 +520,12 @@ class Shift(bijector.Bijector):
|
|
230
520
|
# is_constant_jacobian = True for this bijector, hence the
|
231
521
|
# `log_det_jacobian` need only be specified for a single input, as this will
|
232
522
|
# be tiled to match `event_ndims`.
|
233
|
-
return tf.zeros([], dtype=dtype_util.base_dtype(x.dtype))
|
523
|
+
return tf.zeros([], dtype=tfp_internal.dtype_util.base_dtype(x.dtype))
|
234
524
|
|
235
525
|
|
236
|
-
class PixelCNN(distribution.Distribution):
|
526
|
+
class PixelCNN(distributions.distribution.Distribution):
|
237
527
|
"""
|
238
|
-
Construct Pixel CNN++ distribution.
|
528
|
+
Construct Pixel CNN++ distributions.distribution.
|
239
529
|
|
240
530
|
Parameters
|
241
531
|
----------
|
@@ -250,7 +540,7 @@ class PixelCNN(distribution.Distribution):
|
|
250
540
|
num_filters : int, default 160
|
251
541
|
The number of convolutional filters.
|
252
542
|
num_logistic_mix : int, default 10
|
253
|
-
Number of components in the logistic mixture distribution.
|
543
|
+
Number of components in the distributions.logistic mixture distributions.distribution.
|
254
544
|
receptive_field_dims tuple, default (3, 3)
|
255
545
|
Height and width in pixels of the receptive field of the convolutional layers above and to the left
|
256
546
|
of a given pixel. The width (second element of the tuple) should be odd. Figure 1 (middle) of [2]
|
@@ -276,13 +566,13 @@ class PixelCNN(distribution.Distribution):
|
|
276
566
|
|
277
567
|
def __init__(
|
278
568
|
self,
|
279
|
-
image_shape: tuple,
|
280
|
-
conditional_shape: tuple | None = None,
|
569
|
+
image_shape: tuple[int, int, int],
|
570
|
+
conditional_shape: tuple[int, ...] | None = None,
|
281
571
|
num_resnet: int = 5,
|
282
572
|
num_hierarchies: int = 3,
|
283
573
|
num_filters: int = 160,
|
284
574
|
num_logistic_mix: int = 10,
|
285
|
-
receptive_field_dims: tuple = (3, 3),
|
575
|
+
receptive_field_dims: tuple[int, int] = (3, 3),
|
286
576
|
dropout_p: float = 0.5,
|
287
577
|
resnet_activation: str = "concat_elu",
|
288
578
|
l2_weight: float = 0.0,
|
@@ -290,32 +580,32 @@ class PixelCNN(distribution.Distribution):
|
|
290
580
|
use_data_init: bool = True,
|
291
581
|
high: int = 255,
|
292
582
|
low: int = 0,
|
293
|
-
dtype=tf.float32,
|
583
|
+
dtype: tf.DType = tf.float32,
|
294
584
|
) -> None:
|
295
585
|
parameters = dict(locals())
|
296
586
|
with tf.name_scope("PixelCNN") as name:
|
297
587
|
super().__init__(
|
298
588
|
dtype=dtype,
|
299
|
-
reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
|
589
|
+
reparameterization_type=tfp_internal.reparameterization.NOT_REPARAMETERIZED,
|
300
590
|
validate_args=False,
|
301
591
|
allow_nan_stats=True,
|
302
592
|
parameters=parameters,
|
303
593
|
name=name,
|
304
594
|
)
|
305
595
|
|
306
|
-
if not tensorshape_util.is_fully_defined(image_shape):
|
596
|
+
if not tfp_internal.tensorshape_util.is_fully_defined(image_shape):
|
307
597
|
raise ValueError("`image_shape` must be fully defined.")
|
308
598
|
|
309
|
-
if conditional_shape is not None and not tensorshape_util.is_fully_defined(conditional_shape):
|
599
|
+
if conditional_shape is not None and not tfp_internal.tensorshape_util.is_fully_defined(conditional_shape):
|
310
600
|
raise ValueError("`conditional_shape` must be fully defined.")
|
311
601
|
|
312
|
-
if tensorshape_util.rank(image_shape) != 3:
|
602
|
+
if tfp_internal.tensorshape_util.rank(image_shape) != 3:
|
313
603
|
raise ValueError("`image_shape` must have length 3, representing [height, width, channels] dimensions.")
|
314
604
|
|
315
605
|
self._high = tf.cast(high, self.dtype)
|
316
606
|
self._low = tf.cast(low, self.dtype)
|
317
607
|
self._num_logistic_mix = num_logistic_mix
|
318
|
-
self.
|
608
|
+
self._network = PixelCNNNetwork(
|
319
609
|
dropout_p=dropout_p,
|
320
610
|
num_resnet=num_resnet,
|
321
611
|
num_hierarchies=num_hierarchies,
|
@@ -329,24 +619,24 @@ class PixelCNN(distribution.Distribution):
|
|
329
619
|
dtype=dtype,
|
330
620
|
)
|
331
621
|
|
332
|
-
image_input_shape = tensorshape_util.concatenate([None], image_shape)
|
622
|
+
image_input_shape = tfp_internal.tensorshape_util.concatenate([None], image_shape)
|
333
623
|
if conditional_shape is None:
|
334
624
|
input_shape = image_input_shape
|
335
625
|
else:
|
336
|
-
conditional_input_shape = tensorshape_util.concatenate([None], conditional_shape)
|
626
|
+
conditional_input_shape = tfp_internal.tensorshape_util.concatenate([None], conditional_shape)
|
337
627
|
input_shape = [image_input_shape, conditional_input_shape]
|
338
628
|
|
339
629
|
self.image_shape = image_shape
|
340
630
|
self.conditional_shape = conditional_shape
|
341
|
-
self.
|
631
|
+
self._network.build(input_shape)
|
342
632
|
|
343
633
|
def _make_mixture_dist(self, component_logits, locs, scales, return_per_feature: bool = False):
|
344
|
-
"""Builds a mixture of quantized logistic distributions.
|
634
|
+
"""Builds a mixture of quantized distributions.logistic distributions.
|
345
635
|
|
346
636
|
Parameters
|
347
637
|
----------
|
348
638
|
component_logits
|
349
|
-
4D `Tensor` of logits for the Categorical distribution
|
639
|
+
4D `Tensor` of logits for the Categorical distributions.distribution
|
350
640
|
over Quantized Logistic mixture components. Dimensions are `[batch_size,
|
351
641
|
height, width, num_logistic_mix]`.
|
352
642
|
locs
|
@@ -363,17 +653,17 @@ class PixelCNN(distribution.Distribution):
|
|
363
653
|
Returns
|
364
654
|
-------
|
365
655
|
dist
|
366
|
-
A quantized logistic mixture `tfp.distribution` over the input data.
|
656
|
+
A quantized distributions.logistic mixture `tfp.distributions.distribution` over the input data.
|
367
657
|
"""
|
368
|
-
mixture_distribution = categorical.Categorical(logits=component_logits)
|
658
|
+
mixture_distribution = distributions.categorical.Categorical(logits=component_logits)
|
369
659
|
|
370
|
-
# Convert distribution parameters for pixel values in
|
660
|
+
# Convert distributions.distribution parameters for pixel values in
|
371
661
|
# `[self._low, self._high]` for use with `QuantizedDistribution`
|
372
662
|
locs = self._low + 0.5 * (self._high - self._low) * (locs + 1.0)
|
373
663
|
scales *= 0.5 * (self._high - self._low)
|
374
|
-
logistic_dist = quantized_distribution.QuantizedDistribution(
|
375
|
-
distribution=transformed_distribution.TransformedDistribution(
|
376
|
-
distribution=logistic.Logistic(loc=locs, scale=scales),
|
664
|
+
logistic_dist = distributions.quantized_distribution.QuantizedDistribution(
|
665
|
+
distribution=distributions.transformed_distribution.TransformedDistribution(
|
666
|
+
distribution=distributions.logistic.Logistic(loc=locs, scale=scales),
|
377
667
|
bijector=Shift(shift=tf.cast(-0.5, self.dtype)),
|
378
668
|
),
|
379
669
|
low=self._low,
|
@@ -381,20 +671,20 @@ class PixelCNN(distribution.Distribution):
|
|
381
671
|
)
|
382
672
|
|
383
673
|
# mixture with logistics for the loc and scale on each pixel for each component
|
384
|
-
dist = mixture_same_family.MixtureSameFamily(
|
674
|
+
dist = distributions.mixture_same_family.MixtureSameFamily(
|
385
675
|
mixture_distribution=mixture_distribution,
|
386
|
-
components_distribution=independent.Independent(logistic_dist, reinterpreted_batch_ndims=1),
|
676
|
+
components_distribution=distributions.independent.Independent(logistic_dist, reinterpreted_batch_ndims=1),
|
387
677
|
)
|
388
678
|
if return_per_feature:
|
389
679
|
return dist
|
390
680
|
else:
|
391
|
-
return independent.Independent(dist, reinterpreted_batch_ndims=2)
|
681
|
+
return distributions.independent.Independent(dist, reinterpreted_batch_ndims=2)
|
392
682
|
|
393
683
|
def _log_prob(self, value, conditional_input=None, training=None, return_per_feature=False):
|
394
684
|
"""Log probability function with optional conditional input.
|
395
685
|
|
396
686
|
Calculates the log probability of a batch of data under the modeled
|
397
|
-
distribution (or conditional distribution, if conditional input is
|
687
|
+
distributions.distribution (or conditional distributions.distribution, if conditional input is
|
398
688
|
provided).
|
399
689
|
|
400
690
|
Parameters
|
@@ -404,7 +694,7 @@ class PixelCNN(distribution.Distribution):
|
|
404
694
|
dimension(s), which must broadcast to the leading batch dimensions of
|
405
695
|
`conditional_input`.
|
406
696
|
conditional_input
|
407
|
-
`Tensor` on which to condition the distribution (e.g.
|
697
|
+
`Tensor` on which to condition the distributions.distribution (e.g.
|
408
698
|
class labels), or `None`. May have leading batch dimension(s), which
|
409
699
|
must broadcast to the leading batch dimensions of `value`.
|
410
700
|
training
|
@@ -419,43 +709,43 @@ class PixelCNN(distribution.Distribution):
|
|
419
709
|
log_prob_values: `Tensor`.
|
420
710
|
"""
|
421
711
|
# Determine the batch shape of the input images
|
422
|
-
image_batch_shape = prefer_static.shape(value)[:-3]
|
712
|
+
image_batch_shape = tfp_internal.prefer_static.shape(value)[:-3]
|
423
713
|
|
424
714
|
# Broadcast `value` and `conditional_input` to the same batch_shape
|
425
715
|
if conditional_input is None:
|
426
716
|
image_batch_and_conditional_shape = image_batch_shape
|
427
717
|
else:
|
428
718
|
conditional_input = tf.convert_to_tensor(conditional_input)
|
429
|
-
conditional_input_shape = prefer_static.shape(conditional_input)
|
430
|
-
conditional_batch_rank = prefer_static.rank(
|
431
|
-
|
432
|
-
)
|
719
|
+
conditional_input_shape = tfp_internal.prefer_static.shape(conditional_input)
|
720
|
+
conditional_batch_rank = tfp_internal.prefer_static.rank(
|
721
|
+
conditional_input
|
722
|
+
) - tfp_internal.tensorshape_util.rank(self.conditional_shape)
|
433
723
|
conditional_batch_shape = conditional_input_shape[:conditional_batch_rank]
|
434
724
|
|
435
|
-
image_batch_and_conditional_shape = prefer_static.broadcast_shape(
|
725
|
+
image_batch_and_conditional_shape = tfp_internal.prefer_static.broadcast_shape(
|
436
726
|
image_batch_shape, conditional_batch_shape
|
437
727
|
)
|
438
728
|
conditional_input = tf.broadcast_to(
|
439
729
|
conditional_input,
|
440
|
-
prefer_static.concat([image_batch_and_conditional_shape, self.conditional_shape], axis=0),
|
730
|
+
tfp_internal.prefer_static.concat([image_batch_and_conditional_shape, self.conditional_shape], axis=0),
|
441
731
|
)
|
442
732
|
value = tf.broadcast_to(
|
443
733
|
value,
|
444
|
-
prefer_static.concat([image_batch_and_conditional_shape, self.event_shape], axis=0),
|
734
|
+
tfp_internal.prefer_static.concat([image_batch_and_conditional_shape, self.event_shape], axis=0),
|
445
735
|
)
|
446
736
|
|
447
737
|
# Flatten batch dimension for input to Keras model
|
448
738
|
conditional_input = tf.reshape(
|
449
739
|
conditional_input,
|
450
|
-
prefer_static.concat([(-1,), self.conditional_shape], axis=0),
|
740
|
+
tfp_internal.prefer_static.concat([(-1,), self.conditional_shape], axis=0),
|
451
741
|
)
|
452
742
|
|
453
|
-
value = tf.reshape(value, prefer_static.concat([(-1,), self.event_shape], axis=0))
|
743
|
+
value = tf.reshape(value, tfp_internal.prefer_static.concat([(-1,), self.event_shape], axis=0))
|
454
744
|
|
455
745
|
transformed_value = (2.0 * (value - self._low) / (self._high - self._low)) - 1.0
|
456
746
|
inputs = transformed_value if conditional_input is None else [transformed_value, conditional_input]
|
457
747
|
|
458
|
-
params = self.
|
748
|
+
params = self._network(inputs, training=training)
|
459
749
|
|
460
750
|
num_channels = self.event_shape[-1]
|
461
751
|
if num_channels == 1:
|
@@ -463,7 +753,7 @@ class PixelCNN(distribution.Distribution):
|
|
463
753
|
else:
|
464
754
|
# If there is more than one channel, we create a linear autoregressive
|
465
755
|
# dependency among the location parameters of the channels of a single
|
466
|
-
# pixel (the scale parameters within a pixel are independent). For a pixel
|
756
|
+
# pixel (the scale parameters within a pixel are distributions.independent). For a pixel
|
467
757
|
# with R/G/B channels, the `r`, `g`, and `b` saturation values are
|
468
758
|
# distributed as:
|
469
759
|
#
|
@@ -493,7 +783,7 @@ class PixelCNN(distribution.Distribution):
|
|
493
783
|
return tf.reshape(log_px, image_batch_and_conditional_shape)
|
494
784
|
|
495
785
|
def _sample_n(self, n, seed=None, conditional_input=None, training=False):
|
496
|
-
"""Samples from the distribution, with optional conditional input.
|
786
|
+
"""Samples from the distributions.distribution, with optional conditional input.
|
497
787
|
|
498
788
|
Parameters
|
499
789
|
----------
|
@@ -503,7 +793,7 @@ class PixelCNN(distribution.Distribution):
|
|
503
793
|
`int`, seed for RNG. Setting a random seed enforces reproducibility
|
504
794
|
of the samples between sessions (not within a single session).
|
505
795
|
conditional_input
|
506
|
-
`Tensor` on which to condition the distribution (e.g.
|
796
|
+
`Tensor` on which to condition the distributions.distribution (e.g.
|
507
797
|
class labels), or `None`.
|
508
798
|
training
|
509
799
|
`bool` or `None`. If `bool`, it controls the dropout layer,
|
@@ -517,9 +807,9 @@ class PixelCNN(distribution.Distribution):
|
|
517
807
|
"""
|
518
808
|
if conditional_input is not None:
|
519
809
|
conditional_input = tf.convert_to_tensor(conditional_input, dtype=self.dtype)
|
520
|
-
conditional_event_rank = tensorshape_util.rank(self.conditional_shape)
|
521
|
-
conditional_input_shape = prefer_static.shape(conditional_input)
|
522
|
-
conditional_sample_rank = prefer_static.rank(conditional_input) - conditional_event_rank
|
810
|
+
conditional_event_rank = tfp_internal.tensorshape_util.rank(self.conditional_shape)
|
811
|
+
conditional_input_shape = tfp_internal.prefer_static.shape(conditional_input)
|
812
|
+
conditional_sample_rank = tfp_internal.prefer_static.rank(conditional_input) - conditional_event_rank
|
523
813
|
|
524
814
|
# If `conditional_input` has no sample dimensions, prepend a sample
|
525
815
|
# dimension
|
@@ -532,14 +822,14 @@ class PixelCNN(distribution.Distribution):
|
|
532
822
|
conditional_event_shape = conditional_input_shape[conditional_sample_rank:]
|
533
823
|
with tf.control_dependencies([tf.assert_equal(self.conditional_shape, conditional_event_shape)]):
|
534
824
|
conditional_sample_shape = conditional_input_shape[:conditional_sample_rank]
|
535
|
-
repeat = n // prefer_static.reduce_prod(conditional_sample_shape)
|
825
|
+
repeat = n // tfp_internal.prefer_static.reduce_prod(conditional_sample_shape)
|
536
826
|
h = tf.reshape(
|
537
827
|
conditional_input,
|
538
|
-
prefer_static.concat([(-1,), self.conditional_shape], axis=0),
|
828
|
+
tfp_internal.prefer_static.concat([(-1,), self.conditional_shape], axis=0),
|
539
829
|
)
|
540
830
|
h = tf.tile(
|
541
831
|
h,
|
542
|
-
prefer_static.pad(
|
832
|
+
tfp_internal.prefer_static.pad(
|
543
833
|
[repeat],
|
544
834
|
paddings=[[0, conditional_event_rank]],
|
545
835
|
constant_values=1,
|
@@ -547,17 +837,17 @@ class PixelCNN(distribution.Distribution):
|
|
547
837
|
)
|
548
838
|
|
549
839
|
samples_0 = tf.random.uniform(
|
550
|
-
prefer_static.concat([(n,), self.event_shape], axis=0),
|
840
|
+
tfp_internal.prefer_static.concat([(n,), self.event_shape], axis=0),
|
551
841
|
minval=-1.0,
|
552
842
|
maxval=1.0,
|
553
843
|
dtype=self.dtype,
|
554
844
|
seed=seed,
|
555
845
|
)
|
556
846
|
inputs = samples_0 if conditional_input is None else [samples_0, h]
|
557
|
-
params_0 = self.
|
847
|
+
params_0 = self._network(inputs, training=training)
|
558
848
|
samples_0 = self._sample_channels(*params_0, seed=seed)
|
559
849
|
|
560
|
-
image_height, image_width, _ = tensorshape_util.as_list(self.event_shape)
|
850
|
+
image_height, image_width, _ = tfp_internal.tensorshape_util.as_list(self.event_shape)
|
561
851
|
|
562
852
|
def loop_body(index, samples):
|
563
853
|
"""Loop for iterative pixel sampling.
|
@@ -579,7 +869,7 @@ class PixelCNN(distribution.Distribution):
|
|
579
869
|
width, num_channels]`.
|
580
870
|
"""
|
581
871
|
inputs = samples if conditional_input is None else [samples, h]
|
582
|
-
params = self.
|
872
|
+
params = self._network(inputs, training=training)
|
583
873
|
samples_new = self._sample_channels(*params, seed=seed)
|
584
874
|
|
585
875
|
# Update the current pixel
|
@@ -609,7 +899,7 @@ class PixelCNN(distribution.Distribution):
|
|
609
899
|
Parameters
|
610
900
|
----------
|
611
901
|
component_logits
|
612
|
-
4D `Tensor` of logits for the Categorical distribution
|
902
|
+
4D `Tensor` of logits for the Categorical distributions.distribution
|
613
903
|
over Quantized Logistic mixture components. Dimensions are `[batch_size,
|
614
904
|
height, width, num_logistic_mix]`.
|
615
905
|
locs
|
@@ -637,7 +927,7 @@ class PixelCNN(distribution.Distribution):
|
|
637
927
|
num_channels = self.event_shape[-1]
|
638
928
|
|
639
929
|
# sample mixture components once for the entire pixel
|
640
|
-
component_dist = categorical.Categorical(logits=component_logits)
|
930
|
+
component_dist = distributions.categorical.Categorical(logits=component_logits)
|
641
931
|
mask = tf.one_hot(indices=component_dist.sample(seed=seed), depth=self._num_logistic_mix)
|
642
932
|
mask = tf.cast(mask[..., tf.newaxis], self.dtype)
|
643
933
|
|
@@ -660,7 +950,7 @@ class PixelCNN(distribution.Distribution):
|
|
660
950
|
loc += c * coef_tensors[coef_count]
|
661
951
|
coef_count += 1
|
662
952
|
|
663
|
-
logistic_samp = logistic.Logistic(loc=loc, scale=scale_tensors[i]).sample(seed=seed)
|
953
|
+
logistic_samp = distributions.logistic.Logistic(loc=loc, scale=scale_tensors[i]).sample(seed=seed)
|
664
954
|
logistic_samp = tf.clip_by_value(logistic_samp, -1.0, 1.0)
|
665
955
|
channel_samples.append(logistic_samp)
|
666
956
|
|
@@ -673,8 +963,8 @@ class PixelCNN(distribution.Distribution):
|
|
673
963
|
return tf.TensorShape(self.image_shape)
|
674
964
|
|
675
965
|
|
676
|
-
class
|
677
|
-
"""Keras `Layer` to parameterize a Pixel CNN++ distribution.
|
966
|
+
class PixelCNNNetwork(keras.layers.Layer):
|
967
|
+
"""Keras `Layer` to parameterize a Pixel CNN++ distributions.distribution.
|
678
968
|
This is a Keras implementation of the Pixel CNN++ network, as described in
|
679
969
|
Salimans et al. (2017)[1] and van den Oord et al. (2016)[2].
|
680
970
|
(https://github.com/openai/pixel-cnn).
|
@@ -699,14 +989,14 @@ class _PixelCNNNetwork(keras.layers.Layer):
|
|
699
989
|
num_hierarchies: int = 3,
|
700
990
|
num_filters: int = 160,
|
701
991
|
num_logistic_mix: int = 10,
|
702
|
-
receptive_field_dims: tuple = (3, 3),
|
992
|
+
receptive_field_dims: tuple[int, int] = (3, 3),
|
703
993
|
resnet_activation: str = "concat_elu",
|
704
994
|
l2_weight: float = 0.0,
|
705
995
|
use_weight_norm: bool = True,
|
706
996
|
use_data_init: bool = True,
|
707
|
-
dtype=tf.float32,
|
997
|
+
dtype: tf.DType = tf.float32,
|
708
998
|
) -> None:
|
709
|
-
"""Initialize the :term:`neural network<Neural Network>` for the Pixel CNN++ distribution.
|
999
|
+
"""Initialize the :term:`neural network<Neural Network>` for the Pixel CNN++ distributions.distribution.
|
710
1000
|
|
711
1001
|
Parameters
|
712
1002
|
----------
|
@@ -721,8 +1011,8 @@ class _PixelCNNNetwork(keras.layers.Layer):
|
|
721
1011
|
num_filters
|
722
1012
|
`int`, the number of convolutional filters.
|
723
1013
|
num_logistic_mix
|
724
|
-
`int`, number of components in the logistic mixture
|
725
|
-
distribution.
|
1014
|
+
`int`, number of components in the distributions.logistic mixture
|
1015
|
+
distributions.distribution.
|
726
1016
|
receptive_field_dims
|
727
1017
|
`tuple`, height and width in pixels of the receptive
|
728
1018
|
field of the convolutional layers above and to the left of a given
|
@@ -765,7 +1055,7 @@ class _PixelCNNNetwork(keras.layers.Layer):
|
|
765
1055
|
else:
|
766
1056
|
self._layer_wrapper = lambda layer: layer
|
767
1057
|
|
768
|
-
def build(self, input_shape):
|
1058
|
+
def build(self, input_shape: tuple[int, ...]) -> None:
|
769
1059
|
dtype = self.dtype
|
770
1060
|
if len(input_shape) == 2:
|
771
1061
|
batch_image_shape, batch_conditional_shape = input_shape
|
@@ -1008,21 +1298,21 @@ class _PixelCNNNetwork(keras.layers.Layer):
|
|
1008
1298
|
|
1009
1299
|
# Build final Dense/Reshape layers to output the correct number of
|
1010
1300
|
# parameters per pixel.
|
1011
|
-
num_channels = tensorshape_util.as_list(image_shape)[-1]
|
1301
|
+
num_channels = tfp_internal.tensorshape_util.as_list(image_shape)[-1]
|
1012
1302
|
num_coeffs = num_channels * (num_channels - 1) // 2 # alpha, beta, gamma in eq.3 of paper
|
1013
1303
|
num_out = num_channels * 2 + num_coeffs + 1 # mu, s + alpha, beta, gamma + 1 (mixture weight)
|
1014
1304
|
num_out_total = num_out * self._num_logistic_mix
|
1015
1305
|
params = Dense(num_out_total)(x_out)
|
1016
1306
|
params = tf.reshape(
|
1017
1307
|
params,
|
1018
|
-
prefer_static.concat( # [-1,H,W,nb mixtures, params per mixture]
|
1308
|
+
tfp_internal.prefer_static.concat( # [-1,H,W,nb mixtures, params per mixture]
|
1019
1309
|
[[-1], image_shape[:-1], [self._num_logistic_mix, num_out]], axis=0
|
1020
1310
|
),
|
1021
1311
|
)
|
1022
1312
|
|
1023
1313
|
# If there is one color channel, split the parameters into a list of three
|
1024
1314
|
# output `Tensor`s: (1) component logits for the Quantized Logistic mixture
|
1025
|
-
# distribution, (2) location parameters for each component, and (3) scale
|
1315
|
+
# distributions.distribution, (2) location parameters for each component, and (3) scale
|
1026
1316
|
# parameters for each component. If there is more than one color channel,
|
1027
1317
|
# return a fourth `Tensor` for the coefficients for the linear dependence
|
1028
1318
|
# among color channels (e.g. alpha, beta, gamma).
|
@@ -1040,7 +1330,7 @@ class _PixelCNNNetwork(keras.layers.Layer):
|
|
1040
1330
|
self._network = keras.Model(inputs=inputs, outputs=outputs)
|
1041
1331
|
super().build(input_shape)
|
1042
1332
|
|
1043
|
-
def call(self, inputs, training=None):
|
1333
|
+
def call(self, inputs: tf.Tensor, training: bool | None = None, mask: tf.Tensor | None = None) -> tf.Tensor:
|
1044
1334
|
"""Call the Pixel CNN network model.
|
1045
1335
|
|
1046
1336
|
Parameters
|
@@ -1060,7 +1350,7 @@ class _PixelCNNNetwork(keras.layers.Layer):
|
|
1060
1350
|
-------
|
1061
1351
|
outputs
|
1062
1352
|
a 3- or 4-element `list` of `Tensor`s in the following order: \
|
1063
|
-
component_logits: 4D `Tensor` of logits for the Categorical distribution \
|
1353
|
+
component_logits: 4D `Tensor` of logits for the Categorical distributions.distribution \
|
1064
1354
|
over Quantized Logistic mixture components. Dimensions are \
|
1065
1355
|
`[batch_size, height, width, num_logistic_mix]`.
|
1066
1356
|
locs
|