dataeval 0.74.0__py3-none-any.whl → 0.74.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -9
- dataeval/detectors/__init__.py +2 -10
- dataeval/detectors/drift/base.py +3 -3
- dataeval/detectors/drift/mmd.py +1 -1
- dataeval/detectors/linters/clusterer.py +3 -3
- dataeval/detectors/linters/duplicates.py +4 -4
- dataeval/detectors/linters/outliers.py +4 -4
- dataeval/detectors/ood/__init__.py +5 -12
- dataeval/detectors/ood/base.py +5 -5
- dataeval/detectors/ood/metadata_ks_compare.py +12 -13
- dataeval/interop.py +1 -1
- dataeval/metrics/bias/balance.py +3 -3
- dataeval/metrics/bias/coverage.py +3 -3
- dataeval/metrics/bias/diversity.py +3 -3
- dataeval/metrics/bias/metadata_preprocessing.py +3 -3
- dataeval/metrics/bias/parity.py +4 -4
- dataeval/metrics/estimators/ber.py +3 -3
- dataeval/metrics/estimators/divergence.py +3 -3
- dataeval/metrics/estimators/uap.py +3 -3
- dataeval/metrics/stats/base.py +2 -2
- dataeval/metrics/stats/boxratiostats.py +1 -1
- dataeval/metrics/stats/datasetstats.py +6 -6
- dataeval/metrics/stats/dimensionstats.py +1 -1
- dataeval/metrics/stats/hashstats.py +1 -1
- dataeval/metrics/stats/labelstats.py +3 -3
- dataeval/metrics/stats/pixelstats.py +1 -1
- dataeval/metrics/stats/visualstats.py +1 -1
- dataeval/output.py +77 -53
- dataeval/utils/__init__.py +1 -7
- dataeval/workflows/sufficiency.py +4 -4
- {dataeval-0.74.0.dist-info → dataeval-0.74.1.dist-info}/METADATA +3 -8
- dataeval-0.74.1.dist-info/RECORD +65 -0
- dataeval/detectors/ood/ae.py +0 -76
- dataeval/detectors/ood/aegmm.py +0 -67
- dataeval/detectors/ood/base_tf.py +0 -109
- dataeval/detectors/ood/llr.py +0 -302
- dataeval/detectors/ood/vae.py +0 -98
- dataeval/detectors/ood/vaegmm.py +0 -76
- dataeval/utils/lazy.py +0 -26
- dataeval/utils/tensorflow/__init__.py +0 -19
- dataeval/utils/tensorflow/_internal/gmm.py +0 -103
- dataeval/utils/tensorflow/_internal/loss.py +0 -121
- dataeval/utils/tensorflow/_internal/models.py +0 -1394
- dataeval/utils/tensorflow/_internal/trainer.py +0 -114
- dataeval/utils/tensorflow/_internal/utils.py +0 -256
- dataeval/utils/tensorflow/loss/__init__.py +0 -11
- dataeval-0.74.0.dist-info/RECORD +0 -79
- {dataeval-0.74.0.dist-info → dataeval-0.74.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.74.0.dist-info → dataeval-0.74.1.dist-info}/WHEEL +0 -0
@@ -1,1394 +0,0 @@
|
|
1
|
-
# type: ignore
|
2
|
-
|
3
|
-
"""
|
4
|
-
Source code derived from Alibi-Detect 0.11.4
|
5
|
-
https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
|
6
|
-
|
7
|
-
Original code Copyright (c) 2023 Seldon Technologies Ltd
|
8
|
-
Licensed under Apache Software License (Apache 2.0)
|
9
|
-
"""
|
10
|
-
|
11
|
-
from __future__ import annotations
|
12
|
-
|
13
|
-
import functools
|
14
|
-
import warnings
|
15
|
-
from typing import TYPE_CHECKING, cast
|
16
|
-
|
17
|
-
import numpy as np
|
18
|
-
|
19
|
-
from dataeval.utils.lazy import lazyload
|
20
|
-
|
21
|
-
if TYPE_CHECKING:
|
22
|
-
import tensorflow as tf
|
23
|
-
import tensorflow_probability.python.bijectors as bijectors
|
24
|
-
import tensorflow_probability.python.distributions as distributions
|
25
|
-
import tensorflow_probability.python.internal as tfp_internal
|
26
|
-
import tf_keras as keras
|
27
|
-
else:
|
28
|
-
tf = lazyload("tensorflow")
|
29
|
-
bijectors = lazyload("tensorflow_probability.python.bijectors")
|
30
|
-
distributions = lazyload("tensorflow_probability.python.distributions")
|
31
|
-
tfp_internal = lazyload("tensorflow_probability.python.internal")
|
32
|
-
keras = lazyload("tf_keras")
|
33
|
-
|
34
|
-
|
35
|
-
def relative_euclidean_distance(x: tf.Tensor, y: tf.Tensor, eps: float = 1e-12, axis: int = -1) -> tf.Tensor:
|
36
|
-
"""
|
37
|
-
Relative Euclidean distance.
|
38
|
-
|
39
|
-
Parameters
|
40
|
-
----------
|
41
|
-
x
|
42
|
-
Tensor used in distance computation.
|
43
|
-
y
|
44
|
-
Tensor used in distance computation.
|
45
|
-
eps
|
46
|
-
Epsilon added to denominator for numerical stability.
|
47
|
-
axis
|
48
|
-
Axis used to compute distance.
|
49
|
-
|
50
|
-
Returns
|
51
|
-
-------
|
52
|
-
Tensor with relative Euclidean distance across specified axis.
|
53
|
-
"""
|
54
|
-
denom = tf.concat(
|
55
|
-
[
|
56
|
-
tf.reshape(tf.norm(x, ord=2, axis=axis), (-1, 1)), # type: ignore
|
57
|
-
tf.reshape(tf.norm(y, ord=2, axis=axis), (-1, 1)), # type: ignore
|
58
|
-
],
|
59
|
-
axis=1,
|
60
|
-
)
|
61
|
-
dist = tf.norm(tf.math.subtract(x, y), ord=2, axis=axis) / (tf.reduce_min(denom, axis=axis) + eps) # type: ignore
|
62
|
-
return dist
|
63
|
-
|
64
|
-
|
65
|
-
def eucl_cosim_features(x: tf.Tensor, y: tf.Tensor, max_eucl: float = 1e2) -> tf.Tensor:
|
66
|
-
"""
|
67
|
-
Compute features extracted from the reconstructed instance using the
|
68
|
-
relative Euclidean distance and cosine similarity between 2 tensors.
|
69
|
-
|
70
|
-
Parameters
|
71
|
-
----------
|
72
|
-
x : tf.Tensor
|
73
|
-
Tensor used in feature computation.
|
74
|
-
y : tf.Tensor
|
75
|
-
Tensor used in feature computation.
|
76
|
-
max_eucl : float, default 1e2
|
77
|
-
Maximum value to clip relative Euclidean distance by.
|
78
|
-
|
79
|
-
Returns
|
80
|
-
-------
|
81
|
-
tf.Tensor
|
82
|
-
Tensor concatenating the relative Euclidean distance and cosine similarity features.
|
83
|
-
"""
|
84
|
-
if len(x.shape) > 2 or len(y.shape) > 2:
|
85
|
-
x = cast(tf.Tensor, keras.layers.Flatten()(x))
|
86
|
-
y = cast(tf.Tensor, keras.layers.Flatten()(y))
|
87
|
-
rec_cos = tf.reshape(keras.losses.cosine_similarity(y, x, -1), (-1, 1))
|
88
|
-
rec_euc = tf.reshape(relative_euclidean_distance(y, x, -1), (-1, 1))
|
89
|
-
# rec_euc could become very large so should be clipped
|
90
|
-
rec_euc = tf.clip_by_value(rec_euc, 0, max_eucl)
|
91
|
-
return cast(tf.Tensor, tf.concat([rec_cos, rec_euc], -1))
|
92
|
-
|
93
|
-
|
94
|
-
class Sampling(keras.layers.Layer):
|
95
|
-
"""Reparametrization trick - Uses (z_mean, z_log_var) to sample the latent vector z."""
|
96
|
-
|
97
|
-
def call(self, inputs: tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
|
98
|
-
"""
|
99
|
-
Sample z.
|
100
|
-
|
101
|
-
Parameters
|
102
|
-
----------
|
103
|
-
inputs
|
104
|
-
Tuple with mean and log :term:`variance<Variance>`.
|
105
|
-
|
106
|
-
Returns
|
107
|
-
-------
|
108
|
-
Sampled vector z.
|
109
|
-
"""
|
110
|
-
z_mean, z_log_var = inputs
|
111
|
-
batch, dim = tuple(tf.shape(z_mean).numpy().ravel()[:2]) # type: ignore
|
112
|
-
epsilon = cast(tf.Tensor, keras.backend.random_normal(shape=(batch, dim)))
|
113
|
-
return z_mean + tf.exp(tf.math.multiply(0.5, z_log_var)) * epsilon
|
114
|
-
|
115
|
-
|
116
|
-
class EncoderAE(keras.layers.Layer):
|
117
|
-
def __init__(self, encoder_net: keras.Sequential) -> None:
|
118
|
-
"""
|
119
|
-
Encoder of AE.
|
120
|
-
|
121
|
-
Parameters
|
122
|
-
----------
|
123
|
-
encoder_net
|
124
|
-
Layers for the encoder wrapped in a keras.keras.Sequential class.
|
125
|
-
name
|
126
|
-
Name of encoder.
|
127
|
-
"""
|
128
|
-
super().__init__(name="encoder_ae")
|
129
|
-
self.encoder_net: keras.Sequential = encoder_net
|
130
|
-
|
131
|
-
def call(self, x: tf.Tensor) -> tf.Tensor:
|
132
|
-
return cast(tf.Tensor, self.encoder_net(x))
|
133
|
-
|
134
|
-
|
135
|
-
class EncoderVAE(keras.layers.Layer):
|
136
|
-
def __init__(self, encoder_net: keras.Sequential, latent_dim: int) -> None:
|
137
|
-
"""
|
138
|
-
Encoder of VAE.
|
139
|
-
|
140
|
-
Parameters
|
141
|
-
----------
|
142
|
-
encoder_net
|
143
|
-
Layers for the encoder wrapped in a keras.keras.Sequential class.
|
144
|
-
latent_dim
|
145
|
-
Dimensionality of the :term:`latent space<Latent Space>`.
|
146
|
-
name
|
147
|
-
Name of encoder.
|
148
|
-
"""
|
149
|
-
super().__init__(name="encoder_vae")
|
150
|
-
self.encoder_net: keras.Sequential = encoder_net
|
151
|
-
self._fc_mean = keras.layers.Dense(latent_dim, activation=None)
|
152
|
-
self._fc_log_var = keras.layers.Dense(latent_dim, activation=None)
|
153
|
-
self._sampling = Sampling()
|
154
|
-
|
155
|
-
def call(self, x: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
156
|
-
x = cast(tf.Tensor, self.encoder_net(x))
|
157
|
-
if len(x.shape) > 2:
|
158
|
-
x = cast(tf.Tensor, keras.layers.Flatten()(x))
|
159
|
-
z_mean = cast(tf.Tensor, self._fc_mean(x))
|
160
|
-
z_log_var = cast(tf.Tensor, self._fc_log_var(x))
|
161
|
-
z = cast(tf.Tensor, self._sampling((z_mean, z_log_var)))
|
162
|
-
return z_mean, z_log_var, z
|
163
|
-
|
164
|
-
|
165
|
-
class Decoder(keras.layers.Layer):
|
166
|
-
def __init__(self, decoder_net: keras.Sequential) -> None:
|
167
|
-
"""
|
168
|
-
Decoder of AE and VAE.
|
169
|
-
|
170
|
-
Parameters
|
171
|
-
----------
|
172
|
-
decoder_net
|
173
|
-
Layers for the decoder wrapped in a keras.keras.Sequential class.
|
174
|
-
name
|
175
|
-
Name of decoder.
|
176
|
-
"""
|
177
|
-
super().__init__(name="decoder")
|
178
|
-
self.decoder_net: keras.Sequential = decoder_net
|
179
|
-
|
180
|
-
def call(self, inputs: tf.Tensor) -> tf.Tensor:
|
181
|
-
return cast(tf.Tensor, self.decoder_net(inputs))
|
182
|
-
|
183
|
-
|
184
|
-
class AE(keras.Model):
|
185
|
-
"""
|
186
|
-
Combine encoder and decoder in AE.
|
187
|
-
|
188
|
-
Parameters
|
189
|
-
----------
|
190
|
-
encoder_net : keras.Sequential
|
191
|
-
Layers for the encoder wrapped in a keras.keras.Sequential class.
|
192
|
-
decoder_net : keras.Sequential
|
193
|
-
Layers for the decoder wrapped in a keras.keras.Sequential class.
|
194
|
-
"""
|
195
|
-
|
196
|
-
def __init__(self, encoder_net: keras.Sequential, decoder_net: keras.Sequential) -> None:
|
197
|
-
super().__init__(name="ae")
|
198
|
-
self.encoder: keras.layers.Layer = EncoderAE(encoder_net)
|
199
|
-
self.decoder: keras.layers.Layer = Decoder(decoder_net)
|
200
|
-
|
201
|
-
def call(self, inputs: tf.Tensor, training: bool | None = None, mask: tf.Tensor | None = None) -> tf.Tensor:
|
202
|
-
z = cast(tf.Tensor, self.encoder(inputs))
|
203
|
-
x_recon = cast(tf.Tensor, self.decoder(z))
|
204
|
-
return x_recon
|
205
|
-
|
206
|
-
|
207
|
-
class VAE(keras.Model):
|
208
|
-
"""
|
209
|
-
Combine encoder and decoder in VAE.
|
210
|
-
|
211
|
-
Parameters
|
212
|
-
----------
|
213
|
-
encoder_net : keras.Sequential
|
214
|
-
Layers for the encoder wrapped in a keras.keras.Sequential class.
|
215
|
-
decoder_net : keras.Sequential
|
216
|
-
Layers for the decoder wrapped in a keras.keras.Sequential class.
|
217
|
-
latent_dim : int
|
218
|
-
Dimensionality of the :term:`latent space<Latent Space>`.
|
219
|
-
beta : float, default 1.0
|
220
|
-
Beta parameter for KL-divergence loss term.
|
221
|
-
"""
|
222
|
-
|
223
|
-
def __init__(
|
224
|
-
self, encoder_net: keras.Sequential, decoder_net: keras.Sequential, latent_dim: int, beta: float = 1.0
|
225
|
-
) -> None:
|
226
|
-
super().__init__(name="vae_model")
|
227
|
-
self.encoder: keras.layers.Layer = EncoderVAE(encoder_net, latent_dim)
|
228
|
-
self.decoder: keras.layers.Layer = Decoder(decoder_net)
|
229
|
-
self.beta: float = beta
|
230
|
-
self.latent_dim: int = latent_dim
|
231
|
-
|
232
|
-
def call(self, inputs: tf.Tensor, training: bool | None = None, mask: tf.Tensor | None = None) -> tf.Tensor:
|
233
|
-
z_mean, z_log_var, z = cast(tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.encoder(inputs))
|
234
|
-
x_recon = self.decoder(z)
|
235
|
-
# add KL divergence loss term
|
236
|
-
kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)
|
237
|
-
self.add_loss(self.beta * kl_loss)
|
238
|
-
return cast(tf.Tensor, x_recon)
|
239
|
-
|
240
|
-
|
241
|
-
class AEGMM(keras.Model):
|
242
|
-
"""
|
243
|
-
Deep Autoencoding Gaussian Mixture Model.
|
244
|
-
|
245
|
-
Parameters
|
246
|
-
----------
|
247
|
-
encoder_net : keras.Sequential
|
248
|
-
Layers for the encoder wrapped in a keras.keras.Sequential class.
|
249
|
-
decoder_net : keras.Sequential
|
250
|
-
Layers for the decoder wrapped in a keras.keras.Sequential class.
|
251
|
-
gmm_density_net : keras.Sequential
|
252
|
-
Layers for the GMM network wrapped in a keras.keras.Sequential class.
|
253
|
-
n_gmm : int
|
254
|
-
Number of components in GMM.
|
255
|
-
"""
|
256
|
-
|
257
|
-
def __init__(
|
258
|
-
self,
|
259
|
-
encoder_net: keras.Sequential,
|
260
|
-
decoder_net: keras.Sequential,
|
261
|
-
gmm_density_net: keras.Sequential,
|
262
|
-
n_gmm: int,
|
263
|
-
) -> None:
|
264
|
-
super().__init__("aegmm")
|
265
|
-
self.encoder = encoder_net
|
266
|
-
self.decoder = decoder_net
|
267
|
-
self.gmm_density = gmm_density_net
|
268
|
-
self.n_gmm = n_gmm
|
269
|
-
|
270
|
-
def call(
|
271
|
-
self, inputs: tf.Tensor, training: bool | None = None, mask: tf.Tensor | None = None
|
272
|
-
) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
273
|
-
enc = self.encoder(inputs)
|
274
|
-
x_recon = cast(tf.Tensor, self.decoder(enc))
|
275
|
-
recon_features = eucl_cosim_features(inputs, x_recon)
|
276
|
-
z = cast(tf.Tensor, tf.concat([enc, recon_features], -1))
|
277
|
-
gamma = cast(tf.Tensor, self.gmm_density(z))
|
278
|
-
return x_recon, z, gamma
|
279
|
-
|
280
|
-
|
281
|
-
class VAEGMM(keras.Model):
|
282
|
-
"""
|
283
|
-
Variational Autoencoding Gaussian Mixture Model.
|
284
|
-
|
285
|
-
Parameters
|
286
|
-
----------
|
287
|
-
encoder_net : keras.Sequential
|
288
|
-
Layers for the encoder wrapped in a keras.keras.Sequential class.
|
289
|
-
decoder_net : keras.Sequential
|
290
|
-
Layers for the decoder wrapped in a keras.keras.Sequential class.
|
291
|
-
gmm_density_net : keras.Sequential
|
292
|
-
Layers for the GMM network wrapped in a keras.keras.Sequential class.
|
293
|
-
n_gmm : int
|
294
|
-
Number of components in GMM.
|
295
|
-
latent_dim : int
|
296
|
-
Dimensionality of the :term:`latent space<Latent Space>`.
|
297
|
-
beta : float, default 1.0
|
298
|
-
Beta parameter for KL-divergence loss term.
|
299
|
-
"""
|
300
|
-
|
301
|
-
def __init__(
|
302
|
-
self,
|
303
|
-
encoder_net: keras.Sequential,
|
304
|
-
decoder_net: keras.Sequential,
|
305
|
-
gmm_density_net: keras.Sequential,
|
306
|
-
n_gmm: int,
|
307
|
-
latent_dim: int,
|
308
|
-
beta: float = 1.0,
|
309
|
-
) -> None:
|
310
|
-
super().__init__(name="vaegmm")
|
311
|
-
self.encoder: keras.Sequential = EncoderVAE(encoder_net, latent_dim)
|
312
|
-
self.decoder: keras.Sequential = decoder_net
|
313
|
-
self.gmm_density: keras.Sequential = gmm_density_net
|
314
|
-
self.n_gmm: int = n_gmm
|
315
|
-
self.latent_dim: int = latent_dim
|
316
|
-
self.beta = beta
|
317
|
-
|
318
|
-
def call(
|
319
|
-
self, inputs: tf.Tensor, training: bool | None = None, mask: tf.Tensor | None = None
|
320
|
-
) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
|
321
|
-
enc_mean, enc_log_var, enc = cast(tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.encoder(inputs))
|
322
|
-
x_recon = cast(tf.Tensor, self.decoder(enc))
|
323
|
-
recon_features = eucl_cosim_features(inputs, x_recon)
|
324
|
-
z = cast(tf.Tensor, tf.concat([enc, recon_features], -1))
|
325
|
-
gamma = cast(tf.Tensor, self.gmm_density(z))
|
326
|
-
# add KL divergence loss term
|
327
|
-
kl_loss = -0.5 * tf.reduce_mean(enc_log_var - tf.square(enc_mean) - tf.exp(enc_log_var) + 1)
|
328
|
-
self.add_loss(self.beta * kl_loss)
|
329
|
-
return x_recon, z, gamma
|
330
|
-
|
331
|
-
|
332
|
-
class WeightNorm(keras.layers.Wrapper):
|
333
|
-
def __init__(self, layer, data_init: bool = True, **kwargs) -> None:
|
334
|
-
"""Layer wrapper to decouple magnitude and direction of the layer's weights.
|
335
|
-
|
336
|
-
This wrapper reparameterizes a layer by decoupling the weight's
|
337
|
-
magnitude and direction. This speeds up convergence by improving the
|
338
|
-
conditioning of the optimization problem. It has an optional data-dependent
|
339
|
-
initialization scheme, in which initial values of weights are set as functions
|
340
|
-
of the first minibatch of data. Both the weight normalization and data-
|
341
|
-
dependent initialization are described in [Salimans and Kingma (2016)][1].
|
342
|
-
|
343
|
-
Parameters
|
344
|
-
----------
|
345
|
-
layer
|
346
|
-
A `keras.layers.Layer` instance. Supported layer types are
|
347
|
-
`Dense`, `Conv2D`, and `Conv2DTranspose`. Layers with multiple inputs
|
348
|
-
are not supported.
|
349
|
-
data_init
|
350
|
-
If `True` use data dependent variable initialization.
|
351
|
-
**kwargs
|
352
|
-
Additional keyword args passed to `keras.layers.Wrapper`.
|
353
|
-
|
354
|
-
Raises
|
355
|
-
------
|
356
|
-
ValueError
|
357
|
-
If `layer` is not a `keras.layers.Layer` instance.
|
358
|
-
"""
|
359
|
-
if not isinstance(layer, keras.layers.Layer):
|
360
|
-
raise ValueError(
|
361
|
-
"Please initialize `WeightNorm` layer with a `keras.layers.Layer` " f"instance. You passed: {layer}"
|
362
|
-
)
|
363
|
-
|
364
|
-
layer_type = type(layer).__name__
|
365
|
-
if layer_type not in ["Dense", "Conv2D", "Conv2DTranspose"]:
|
366
|
-
warnings.warn(
|
367
|
-
"`WeightNorm` is tested only for `Dense`, `Conv2D`, and "
|
368
|
-
f"`Conv2DTranspose` layers. You passed a layer of type `{layer_type}`"
|
369
|
-
)
|
370
|
-
|
371
|
-
super().__init__(layer, **kwargs)
|
372
|
-
|
373
|
-
self.data_init = data_init
|
374
|
-
self._track_trackable(layer, name="layer")
|
375
|
-
self.filter_axis = -2 if layer_type == "Conv2DTranspose" else -1
|
376
|
-
|
377
|
-
def _compute_weights(self):
|
378
|
-
"""Generate weights with normalization."""
|
379
|
-
# Determine the axis along which to expand `g` so that `g` broadcasts to
|
380
|
-
# the shape of `v`.
|
381
|
-
new_axis = -self.filter_axis - 3
|
382
|
-
|
383
|
-
self.layer.kernel = tf.nn.l2_normalize(self.v, axis=self.kernel_norm_axes) * tf.expand_dims(self.g, new_axis)
|
384
|
-
|
385
|
-
def _init_norm(self):
|
386
|
-
"""Set the norm of the weight vector."""
|
387
|
-
kernel_norm = tf.sqrt(tf.reduce_sum(tf.square(self.v), axis=self.kernel_norm_axes))
|
388
|
-
self.g.assign(kernel_norm)
|
389
|
-
|
390
|
-
def _data_dep_init(self, inputs):
|
391
|
-
"""Data dependent initialization."""
|
392
|
-
# Normalize kernel first so that calling the layer calculates
|
393
|
-
# `tf.dot(v, x)/tf.norm(v)` as in (5) in ([Salimans and Kingma, 2016][1]).
|
394
|
-
self._compute_weights()
|
395
|
-
|
396
|
-
activation = self.layer.activation
|
397
|
-
self.layer.activation = None
|
398
|
-
|
399
|
-
use_bias = self.layer.bias is not None
|
400
|
-
if use_bias:
|
401
|
-
bias = self.layer.bias
|
402
|
-
self.layer.bias = tf.zeros_like(bias)
|
403
|
-
|
404
|
-
# Since the bias is initialized as zero, setting the activation to zero and
|
405
|
-
# calling the initialized layer (with normalized kernel) yields the correct
|
406
|
-
# computation ((5) in Salimans and Kingma (2016))
|
407
|
-
x_init = self.layer(inputs)
|
408
|
-
norm_axes_out = list(range(x_init.shape.rank - 1))
|
409
|
-
m_init, v_init = tf.nn.moments(x_init, norm_axes_out)
|
410
|
-
scale_init = 1.0 / tf.sqrt(v_init + 1e-10)
|
411
|
-
|
412
|
-
self.g.assign(self.g * scale_init)
|
413
|
-
if use_bias:
|
414
|
-
self.layer.bias = bias
|
415
|
-
self.layer.bias.assign(-m_init * scale_init)
|
416
|
-
self.layer.activation = activation
|
417
|
-
|
418
|
-
def build(self, input_shape=None):
|
419
|
-
"""Build `Layer`.
|
420
|
-
|
421
|
-
Parameters
|
422
|
-
----------
|
423
|
-
input_shape
|
424
|
-
The shape of the input to `self.layer`.
|
425
|
-
|
426
|
-
Raises
|
427
|
-
------
|
428
|
-
ValueError
|
429
|
-
If `Layer` does not contain a `kernel` of weights.
|
430
|
-
"""
|
431
|
-
input_shape = tf.TensorShape(input_shape).as_list()
|
432
|
-
input_shape[0] = None
|
433
|
-
self.input_spec = keras.layers.InputSpec(shape=input_shape)
|
434
|
-
|
435
|
-
if not self.layer.built:
|
436
|
-
self.layer.build(input_shape)
|
437
|
-
|
438
|
-
if not hasattr(self.layer, "kernel"):
|
439
|
-
raise ValueError("`WeightNorm` must wrap a layer that contains a `kernel` for weights")
|
440
|
-
|
441
|
-
self.kernel_norm_axes = list(range(self.layer.kernel.shape.ndims))
|
442
|
-
self.kernel_norm_axes.pop(self.filter_axis)
|
443
|
-
|
444
|
-
self.v = self.layer.kernel
|
445
|
-
|
446
|
-
# to avoid a duplicate `kernel` variable after `build` is called
|
447
|
-
self.layer.kernel = None
|
448
|
-
self.g = self.add_weight(
|
449
|
-
name="g",
|
450
|
-
shape=(int(self.v.shape[self.filter_axis]),),
|
451
|
-
initializer="ones",
|
452
|
-
dtype=self.v.dtype,
|
453
|
-
trainable=True,
|
454
|
-
)
|
455
|
-
self.initialized = self.add_weight(name="initialized", dtype=tf.bool, trainable=False)
|
456
|
-
self.initialized.assign(False)
|
457
|
-
|
458
|
-
super().build()
|
459
|
-
|
460
|
-
@tf.function
|
461
|
-
def call(self, inputs):
|
462
|
-
"""Call `Layer`."""
|
463
|
-
if not self.initialized:
|
464
|
-
if self.data_init:
|
465
|
-
self._data_dep_init(inputs)
|
466
|
-
else: # initialize `g` as the norm of the initialized kernel
|
467
|
-
self._init_norm()
|
468
|
-
|
469
|
-
self.initialized.assign(True)
|
470
|
-
|
471
|
-
self._compute_weights()
|
472
|
-
output = self.layer(inputs)
|
473
|
-
return output
|
474
|
-
|
475
|
-
def compute_output_shape(self, input_shape):
|
476
|
-
return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list())
|
477
|
-
|
478
|
-
|
479
|
-
class Shift(bijectors.Bijector):
|
480
|
-
def __init__(self, shift, validate_args=False, name="shift") -> None:
|
481
|
-
"""Instantiates the `Shift` bijector which computes `Y = g(X; shift) = X + shift`
|
482
|
-
where `shift` is a numeric `Tensor`.
|
483
|
-
|
484
|
-
Parameters
|
485
|
-
----------
|
486
|
-
shift
|
487
|
-
Floating-point `Tensor`.
|
488
|
-
validate_args
|
489
|
-
Python `bool` indicating whether arguments should be checked for correctness.
|
490
|
-
name
|
491
|
-
Python `str` name given to ops managed by this object.
|
492
|
-
"""
|
493
|
-
with tf.name_scope(name) as name:
|
494
|
-
dtype = tfp_internal.dtype_util.common_dtype([shift], dtype_hint=tf.float32)
|
495
|
-
self._shift = tfp_internal.tensor_util.convert_nonref_to_tensor(shift, dtype=dtype, name="shift")
|
496
|
-
super().__init__(
|
497
|
-
forward_min_event_ndims=0,
|
498
|
-
is_constant_jacobian=True,
|
499
|
-
dtype=dtype,
|
500
|
-
validate_args=validate_args,
|
501
|
-
name=name,
|
502
|
-
)
|
503
|
-
|
504
|
-
@property
|
505
|
-
def shift(self):
|
506
|
-
"""The `shift` `Tensor` in `Y = X + shift`."""
|
507
|
-
return self._shift
|
508
|
-
|
509
|
-
@classmethod
|
510
|
-
def _is_increasing(cls):
|
511
|
-
return True
|
512
|
-
|
513
|
-
def _forward(self, x):
|
514
|
-
return x + self.shift
|
515
|
-
|
516
|
-
def _inverse(self, y):
|
517
|
-
return y - self.shift
|
518
|
-
|
519
|
-
def _forward_log_det_jacobian(self, x):
|
520
|
-
# is_constant_jacobian = True for this bijector, hence the
|
521
|
-
# `log_det_jacobian` need only be specified for a single input, as this will
|
522
|
-
# be tiled to match `event_ndims`.
|
523
|
-
return tf.zeros([], dtype=tfp_internal.dtype_util.base_dtype(x.dtype))
|
524
|
-
|
525
|
-
|
526
|
-
class PixelCNN(distributions.distribution.Distribution):
|
527
|
-
"""
|
528
|
-
Construct Pixel CNN++ distributions.distribution.
|
529
|
-
|
530
|
-
Parameters
|
531
|
-
----------
|
532
|
-
image_shape : tuple
|
533
|
-
3D `TensorShape` or tuple for the `[height, width, channels]` dimensions of the image.
|
534
|
-
conditional_shape : tuple, optional - default None
|
535
|
-
`TensorShape` or tuple for the shape of the conditional input, or `None` if there is no conditional input.
|
536
|
-
num_resnet : int, default 5
|
537
|
-
The number of layers (shown in Figure 2 of [2]) within each highest-level block of Figure 2 of [1].
|
538
|
-
num_hierarchies : int, default 3
|
539
|
-
The number of highest-level blocks (separated by expansions/contractions of dimensions in Figure 2 of [1].)
|
540
|
-
num_filters : int, default 160
|
541
|
-
The number of convolutional filters.
|
542
|
-
num_logistic_mix : int, default 10
|
543
|
-
Number of components in the distributions.logistic mixture distributions.distribution.
|
544
|
-
receptive_field_dims tuple, default (3, 3)
|
545
|
-
Height and width in pixels of the receptive field of the convolutional layers above and to the left
|
546
|
-
of a given pixel. The width (second element of the tuple) should be odd. Figure 1 (middle) of [2]
|
547
|
-
shows a receptive field of (3, 5) (the row containing the current pixel is included in the height).
|
548
|
-
The default of (3, 3) was used to produce the results in [1].
|
549
|
-
dropout_p : float, default 0.0
|
550
|
-
The dropout probability. Should be between 0 and 1.
|
551
|
-
resnet_activation : str, default "concat_elu"
|
552
|
-
The type of activation to use in the resnet blocks. May be 'concat_elu', 'elu', or 'relu'.
|
553
|
-
l2_weight : float, default 0.0
|
554
|
-
The L2 regularization weight.
|
555
|
-
use_weight_norm : bool, default True
|
556
|
-
If `True` then use weight normalization (works only in Eager mode).
|
557
|
-
use_data_init : bool, default True
|
558
|
-
If `True` then use data-dependent initialization (has no effect if `use_weight_norm` is `False`).
|
559
|
-
high : int, default 255
|
560
|
-
The maximum value of the input data (255 for an 8-bit image).
|
561
|
-
low : int, default 0
|
562
|
-
The minimum value of the input data.
|
563
|
-
dtype : tensorflow dtype, default tf.float32
|
564
|
-
Data type of the `Distribution`.
|
565
|
-
"""
|
566
|
-
|
567
|
-
def __init__(
|
568
|
-
self,
|
569
|
-
image_shape: tuple[int, int, int],
|
570
|
-
conditional_shape: tuple[int, ...] | None = None,
|
571
|
-
num_resnet: int = 5,
|
572
|
-
num_hierarchies: int = 3,
|
573
|
-
num_filters: int = 160,
|
574
|
-
num_logistic_mix: int = 10,
|
575
|
-
receptive_field_dims: tuple[int, int] = (3, 3),
|
576
|
-
dropout_p: float = 0.5,
|
577
|
-
resnet_activation: str = "concat_elu",
|
578
|
-
l2_weight: float = 0.0,
|
579
|
-
use_weight_norm: bool = True,
|
580
|
-
use_data_init: bool = True,
|
581
|
-
high: int = 255,
|
582
|
-
low: int = 0,
|
583
|
-
dtype: tf.DType = tf.float32,
|
584
|
-
) -> None:
|
585
|
-
parameters = dict(locals())
|
586
|
-
with tf.name_scope("PixelCNN") as name:
|
587
|
-
super().__init__(
|
588
|
-
dtype=dtype,
|
589
|
-
reparameterization_type=tfp_internal.reparameterization.NOT_REPARAMETERIZED,
|
590
|
-
validate_args=False,
|
591
|
-
allow_nan_stats=True,
|
592
|
-
parameters=parameters,
|
593
|
-
name=name,
|
594
|
-
)
|
595
|
-
|
596
|
-
if not tfp_internal.tensorshape_util.is_fully_defined(image_shape):
|
597
|
-
raise ValueError("`image_shape` must be fully defined.")
|
598
|
-
|
599
|
-
if conditional_shape is not None and not tfp_internal.tensorshape_util.is_fully_defined(conditional_shape):
|
600
|
-
raise ValueError("`conditional_shape` must be fully defined.")
|
601
|
-
|
602
|
-
if tfp_internal.tensorshape_util.rank(image_shape) != 3:
|
603
|
-
raise ValueError("`image_shape` must have length 3, representing [height, width, channels] dimensions.")
|
604
|
-
|
605
|
-
self._high = tf.cast(high, self.dtype)
|
606
|
-
self._low = tf.cast(low, self.dtype)
|
607
|
-
self._num_logistic_mix = num_logistic_mix
|
608
|
-
self._network = PixelCNNNetwork(
|
609
|
-
dropout_p=dropout_p,
|
610
|
-
num_resnet=num_resnet,
|
611
|
-
num_hierarchies=num_hierarchies,
|
612
|
-
num_filters=num_filters,
|
613
|
-
num_logistic_mix=num_logistic_mix,
|
614
|
-
receptive_field_dims=receptive_field_dims,
|
615
|
-
resnet_activation=resnet_activation,
|
616
|
-
l2_weight=l2_weight,
|
617
|
-
use_weight_norm=use_weight_norm,
|
618
|
-
use_data_init=use_data_init,
|
619
|
-
dtype=dtype,
|
620
|
-
)
|
621
|
-
|
622
|
-
image_input_shape = tfp_internal.tensorshape_util.concatenate([None], image_shape)
|
623
|
-
if conditional_shape is None:
|
624
|
-
input_shape = image_input_shape
|
625
|
-
else:
|
626
|
-
conditional_input_shape = tfp_internal.tensorshape_util.concatenate([None], conditional_shape)
|
627
|
-
input_shape = [image_input_shape, conditional_input_shape]
|
628
|
-
|
629
|
-
self.image_shape = image_shape
|
630
|
-
self.conditional_shape = conditional_shape
|
631
|
-
self._network.build(input_shape)
|
632
|
-
|
633
|
-
def _make_mixture_dist(self, component_logits, locs, scales, return_per_feature: bool = False):
|
634
|
-
"""Builds a mixture of quantized distributions.logistic distributions.
|
635
|
-
|
636
|
-
Parameters
|
637
|
-
----------
|
638
|
-
component_logits
|
639
|
-
4D `Tensor` of logits for the Categorical distributions.distribution
|
640
|
-
over Quantized Logistic mixture components. Dimensions are `[batch_size,
|
641
|
-
height, width, num_logistic_mix]`.
|
642
|
-
locs
|
643
|
-
4D `Tensor` of location parameters for the Quantized Logistic
|
644
|
-
mixture components. Dimensions are `[batch_size, height, width,
|
645
|
-
num_logistic_mix, num_channels]`.
|
646
|
-
scales
|
647
|
-
4D `Tensor` of location parameters for the Quantized Logistic
|
648
|
-
mixture components. Dimensions are `[batch_size, height, width,
|
649
|
-
num_logistic_mix, num_channels]`.
|
650
|
-
return_per_feature
|
651
|
-
If True, return per pixel level log prob.
|
652
|
-
|
653
|
-
Returns
|
654
|
-
-------
|
655
|
-
dist
|
656
|
-
A quantized distributions.logistic mixture `tfp.distributions.distribution` over the input data.
|
657
|
-
"""
|
658
|
-
mixture_distribution = distributions.categorical.Categorical(logits=component_logits)
|
659
|
-
|
660
|
-
# Convert distributions.distribution parameters for pixel values in
|
661
|
-
# `[self._low, self._high]` for use with `QuantizedDistribution`
|
662
|
-
locs = self._low + 0.5 * (self._high - self._low) * (locs + 1.0)
|
663
|
-
scales *= 0.5 * (self._high - self._low)
|
664
|
-
logistic_dist = distributions.quantized_distribution.QuantizedDistribution(
|
665
|
-
distribution=distributions.transformed_distribution.TransformedDistribution(
|
666
|
-
distribution=distributions.logistic.Logistic(loc=locs, scale=scales),
|
667
|
-
bijector=Shift(shift=tf.cast(-0.5, self.dtype)),
|
668
|
-
),
|
669
|
-
low=self._low,
|
670
|
-
high=self._high,
|
671
|
-
)
|
672
|
-
|
673
|
-
# mixture with logistics for the loc and scale on each pixel for each component
|
674
|
-
dist = distributions.mixture_same_family.MixtureSameFamily(
|
675
|
-
mixture_distribution=mixture_distribution,
|
676
|
-
components_distribution=distributions.independent.Independent(logistic_dist, reinterpreted_batch_ndims=1),
|
677
|
-
)
|
678
|
-
if return_per_feature:
|
679
|
-
return dist
|
680
|
-
else:
|
681
|
-
return distributions.independent.Independent(dist, reinterpreted_batch_ndims=2)
|
682
|
-
|
683
|
-
def _log_prob(self, value, conditional_input=None, training=None, return_per_feature=False):
|
684
|
-
"""Log probability function with optional conditional input.
|
685
|
-
|
686
|
-
Calculates the log probability of a batch of data under the modeled
|
687
|
-
distributions.distribution (or conditional distributions.distribution, if conditional input is
|
688
|
-
provided).
|
689
|
-
|
690
|
-
Parameters
|
691
|
-
----------
|
692
|
-
value
|
693
|
-
`Tensor` or :term:`NumPy` array of image data. May have leading batch
|
694
|
-
dimension(s), which must broadcast to the leading batch dimensions of
|
695
|
-
`conditional_input`.
|
696
|
-
conditional_input
|
697
|
-
`Tensor` on which to condition the distributions.distribution (e.g.
|
698
|
-
class labels), or `None`. May have leading batch dimension(s), which
|
699
|
-
must broadcast to the leading batch dimensions of `value`.
|
700
|
-
training
|
701
|
-
`bool` or `None`. If `bool`, it controls the dropout layer,
|
702
|
-
where `True` implies dropout is active. If `None`, it defaults to
|
703
|
-
`keras.backend.learning_phase()`.
|
704
|
-
return_per_feature
|
705
|
-
`bool`. If True, return per pixel level log prob.
|
706
|
-
|
707
|
-
Returns
|
708
|
-
-------
|
709
|
-
log_prob_values: `Tensor`.
|
710
|
-
"""
|
711
|
-
# Determine the batch shape of the input images
|
712
|
-
image_batch_shape = tfp_internal.prefer_static.shape(value)[:-3]
|
713
|
-
|
714
|
-
# Broadcast `value` and `conditional_input` to the same batch_shape
|
715
|
-
if conditional_input is None:
|
716
|
-
image_batch_and_conditional_shape = image_batch_shape
|
717
|
-
else:
|
718
|
-
conditional_input = tf.convert_to_tensor(conditional_input)
|
719
|
-
conditional_input_shape = tfp_internal.prefer_static.shape(conditional_input)
|
720
|
-
conditional_batch_rank = tfp_internal.prefer_static.rank(
|
721
|
-
conditional_input
|
722
|
-
) - tfp_internal.tensorshape_util.rank(self.conditional_shape)
|
723
|
-
conditional_batch_shape = conditional_input_shape[:conditional_batch_rank]
|
724
|
-
|
725
|
-
image_batch_and_conditional_shape = tfp_internal.prefer_static.broadcast_shape(
|
726
|
-
image_batch_shape, conditional_batch_shape
|
727
|
-
)
|
728
|
-
conditional_input = tf.broadcast_to(
|
729
|
-
conditional_input,
|
730
|
-
tfp_internal.prefer_static.concat([image_batch_and_conditional_shape, self.conditional_shape], axis=0),
|
731
|
-
)
|
732
|
-
value = tf.broadcast_to(
|
733
|
-
value,
|
734
|
-
tfp_internal.prefer_static.concat([image_batch_and_conditional_shape, self.event_shape], axis=0),
|
735
|
-
)
|
736
|
-
|
737
|
-
# Flatten batch dimension for input to Keras model
|
738
|
-
conditional_input = tf.reshape(
|
739
|
-
conditional_input,
|
740
|
-
tfp_internal.prefer_static.concat([(-1,), self.conditional_shape], axis=0),
|
741
|
-
)
|
742
|
-
|
743
|
-
value = tf.reshape(value, tfp_internal.prefer_static.concat([(-1,), self.event_shape], axis=0))
|
744
|
-
|
745
|
-
transformed_value = (2.0 * (value - self._low) / (self._high - self._low)) - 1.0
|
746
|
-
inputs = transformed_value if conditional_input is None else [transformed_value, conditional_input]
|
747
|
-
|
748
|
-
params = self._network(inputs, training=training)
|
749
|
-
|
750
|
-
num_channels = self.event_shape[-1]
|
751
|
-
if num_channels == 1:
|
752
|
-
component_logits, locs, scales = params
|
753
|
-
else:
|
754
|
-
# If there is more than one channel, we create a linear autoregressive
|
755
|
-
# dependency among the location parameters of the channels of a single
|
756
|
-
# pixel (the scale parameters within a pixel are distributions.independent). For a pixel
|
757
|
-
# with R/G/B channels, the `r`, `g`, and `b` saturation values are
|
758
|
-
# distributed as:
|
759
|
-
#
|
760
|
-
# r ~ Logistic(loc_r, scale_r)
|
761
|
-
# g ~ Logistic(coef_rg * r + loc_g, scale_g)
|
762
|
-
# b ~ Logistic(coef_rb * r + coef_gb * g + loc_b, scale_b)
|
763
|
-
# on the coefficients instead of split/multiply/concat
|
764
|
-
component_logits, locs, scales, coeffs = params
|
765
|
-
num_coeffs = num_channels * (num_channels - 1) // 2
|
766
|
-
loc_tensors = tf.split(locs, num_channels, axis=-1)
|
767
|
-
coef_tensors = tf.split(coeffs, num_coeffs, axis=-1)
|
768
|
-
channel_tensors = tf.split(value, num_channels, axis=-1)
|
769
|
-
|
770
|
-
coef_count = 0
|
771
|
-
for i in range(num_channels):
|
772
|
-
channel_tensors[i] = channel_tensors[i][..., tf.newaxis, :]
|
773
|
-
for j in range(i):
|
774
|
-
loc_tensors[i] += channel_tensors[j] * coef_tensors[coef_count]
|
775
|
-
coef_count += 1
|
776
|
-
locs = tf.concat(loc_tensors, axis=-1)
|
777
|
-
|
778
|
-
dist = self._make_mixture_dist(component_logits, locs, scales, return_per_feature=return_per_feature)
|
779
|
-
log_px = dist.log_prob(value)
|
780
|
-
if return_per_feature:
|
781
|
-
return log_px
|
782
|
-
else:
|
783
|
-
return tf.reshape(log_px, image_batch_and_conditional_shape)
|
784
|
-
|
785
|
-
def _sample_n(self, n, seed=None, conditional_input=None, training=False):
|
786
|
-
"""Samples from the distributions.distribution, with optional conditional input.
|
787
|
-
|
788
|
-
Parameters
|
789
|
-
----------
|
790
|
-
n
|
791
|
-
`int`, number of samples desired.
|
792
|
-
seed
|
793
|
-
`int`, seed for RNG. Setting a random seed enforces reproducibility
|
794
|
-
of the samples between sessions (not within a single session).
|
795
|
-
conditional_input
|
796
|
-
`Tensor` on which to condition the distributions.distribution (e.g.
|
797
|
-
class labels), or `None`.
|
798
|
-
training
|
799
|
-
`bool` or `None`. If `bool`, it controls the dropout layer,
|
800
|
-
where `True` implies dropout is active. If `None`, it defers to Keras'
|
801
|
-
handling of train/eval status.
|
802
|
-
|
803
|
-
Returns
|
804
|
-
-------
|
805
|
-
samples
|
806
|
-
a `Tensor` of shape `[n, height, width, num_channels]`.
|
807
|
-
"""
|
808
|
-
if conditional_input is not None:
|
809
|
-
conditional_input = tf.convert_to_tensor(conditional_input, dtype=self.dtype)
|
810
|
-
conditional_event_rank = tfp_internal.tensorshape_util.rank(self.conditional_shape)
|
811
|
-
conditional_input_shape = tfp_internal.prefer_static.shape(conditional_input)
|
812
|
-
conditional_sample_rank = tfp_internal.prefer_static.rank(conditional_input) - conditional_event_rank
|
813
|
-
|
814
|
-
# If `conditional_input` has no sample dimensions, prepend a sample
|
815
|
-
# dimension
|
816
|
-
if conditional_sample_rank == 0:
|
817
|
-
conditional_input = conditional_input[tf.newaxis, ...]
|
818
|
-
conditional_sample_rank = 1
|
819
|
-
|
820
|
-
# Assert that the conditional event shape in the `PixelCnnNetwork` is the
|
821
|
-
# same as that implied by `conditional_input`.
|
822
|
-
conditional_event_shape = conditional_input_shape[conditional_sample_rank:]
|
823
|
-
with tf.control_dependencies([tf.assert_equal(self.conditional_shape, conditional_event_shape)]):
|
824
|
-
conditional_sample_shape = conditional_input_shape[:conditional_sample_rank]
|
825
|
-
repeat = n // tfp_internal.prefer_static.reduce_prod(conditional_sample_shape)
|
826
|
-
h = tf.reshape(
|
827
|
-
conditional_input,
|
828
|
-
tfp_internal.prefer_static.concat([(-1,), self.conditional_shape], axis=0),
|
829
|
-
)
|
830
|
-
h = tf.tile(
|
831
|
-
h,
|
832
|
-
tfp_internal.prefer_static.pad(
|
833
|
-
[repeat],
|
834
|
-
paddings=[[0, conditional_event_rank]],
|
835
|
-
constant_values=1,
|
836
|
-
),
|
837
|
-
)
|
838
|
-
|
839
|
-
samples_0 = tf.random.uniform(
|
840
|
-
tfp_internal.prefer_static.concat([(n,), self.event_shape], axis=0),
|
841
|
-
minval=-1.0,
|
842
|
-
maxval=1.0,
|
843
|
-
dtype=self.dtype,
|
844
|
-
seed=seed,
|
845
|
-
)
|
846
|
-
inputs = samples_0 if conditional_input is None else [samples_0, h]
|
847
|
-
params_0 = self._network(inputs, training=training)
|
848
|
-
samples_0 = self._sample_channels(*params_0, seed=seed)
|
849
|
-
|
850
|
-
image_height, image_width, _ = tfp_internal.tensorshape_util.as_list(self.event_shape)
|
851
|
-
|
852
|
-
def loop_body(index, samples):
|
853
|
-
"""Loop for iterative pixel sampling.
|
854
|
-
|
855
|
-
Parameters
|
856
|
-
----------
|
857
|
-
index
|
858
|
-
0D `Tensor` of type `int32`. Index of the current pixel.
|
859
|
-
samples
|
860
|
-
4D `Tensor`. Images with pixels sampled in raster order, up to
|
861
|
-
pixel `[index]`, with dimensions `[batch_size, height, width,
|
862
|
-
num_channels]`.
|
863
|
-
|
864
|
-
Returns
|
865
|
-
-------
|
866
|
-
samples
|
867
|
-
4D `Tensor`. Images with pixels sampled in raster order, up to \
|
868
|
-
and including pixel `[index]`, with dimensions `[batch_size, height, \
|
869
|
-
width, num_channels]`.
|
870
|
-
"""
|
871
|
-
inputs = samples if conditional_input is None else [samples, h]
|
872
|
-
params = self._network(inputs, training=training)
|
873
|
-
samples_new = self._sample_channels(*params, seed=seed)
|
874
|
-
|
875
|
-
# Update the current pixel
|
876
|
-
samples = tf.transpose(samples, [1, 2, 3, 0])
|
877
|
-
samples_new = tf.transpose(samples_new, [1, 2, 3, 0])
|
878
|
-
row, col = index // image_width, index % image_width
|
879
|
-
updates = samples_new[row, col, ...][tf.newaxis, ...]
|
880
|
-
samples = tf.tensor_scatter_nd_update(samples, [[row, col]], updates)
|
881
|
-
samples = tf.transpose(samples, [3, 0, 1, 2])
|
882
|
-
|
883
|
-
return index + 1, samples
|
884
|
-
|
885
|
-
index0 = tf.zeros([], dtype=tf.int32)
|
886
|
-
|
887
|
-
# Construct the while loop for sampling
|
888
|
-
total_pixels = image_height * image_width
|
889
|
-
loop_cond = lambda ind, _: tf.less(ind, total_pixels) # noqa: E731
|
890
|
-
init_vars = (index0, samples_0)
|
891
|
-
_, samples = tf.while_loop(loop_cond, loop_body, init_vars, parallel_iterations=1)
|
892
|
-
|
893
|
-
transformed_samples = self._low + 0.5 * (self._high - self._low) * (samples + 1.0)
|
894
|
-
return tf.round(transformed_samples)
|
895
|
-
|
896
|
-
def _sample_channels(self, component_logits, locs, scales, coeffs=None, seed=None):
|
897
|
-
"""Sample a single pixel-iteration and apply channel conditioning.
|
898
|
-
|
899
|
-
Parameters
|
900
|
-
----------
|
901
|
-
component_logits
|
902
|
-
4D `Tensor` of logits for the Categorical distributions.distribution
|
903
|
-
over Quantized Logistic mixture components. Dimensions are `[batch_size,
|
904
|
-
height, width, num_logistic_mix]`.
|
905
|
-
locs
|
906
|
-
4D `Tensor` of location parameters for the Quantized Logistic
|
907
|
-
mixture components. Dimensions are `[batch_size, height, width,
|
908
|
-
num_logistic_mix, num_channels]`.
|
909
|
-
scales
|
910
|
-
4D `Tensor` of location parameters for the Quantized Logistic
|
911
|
-
mixture components. Dimensions are `[batch_size, height, width,
|
912
|
-
num_logistic_mix, num_channels]`.
|
913
|
-
coeffs
|
914
|
-
4D `Tensor` of coefficients for the linear dependence among color
|
915
|
-
channels, or `None` if there is only one channel. Dimensions are
|
916
|
-
`[batch_size, height, width, num_logistic_mix, num_coeffs]`, where
|
917
|
-
`num_coeffs = num_channels * (num_channels - 1) // 2`.
|
918
|
-
seed
|
919
|
-
`int`, random seed.
|
920
|
-
|
921
|
-
Returns
|
922
|
-
-------
|
923
|
-
samples
|
924
|
-
4D `Tensor` of sampled image data with autoregression among \
|
925
|
-
channels. Dimensions are `[batch_size, height, width, num_channels]`.
|
926
|
-
"""
|
927
|
-
num_channels = self.event_shape[-1]
|
928
|
-
|
929
|
-
# sample mixture components once for the entire pixel
|
930
|
-
component_dist = distributions.categorical.Categorical(logits=component_logits)
|
931
|
-
mask = tf.one_hot(indices=component_dist.sample(seed=seed), depth=self._num_logistic_mix)
|
932
|
-
mask = tf.cast(mask[..., tf.newaxis], self.dtype)
|
933
|
-
|
934
|
-
# apply mixture component mask and separate out RGB parameters
|
935
|
-
masked_locs = tf.reduce_sum(locs * mask, axis=-2)
|
936
|
-
loc_tensors = tf.split(masked_locs, num_channels, axis=-1)
|
937
|
-
masked_scales = tf.reduce_sum(scales * mask, axis=-2)
|
938
|
-
scale_tensors = tf.split(masked_scales, num_channels, axis=-1)
|
939
|
-
|
940
|
-
if coeffs is not None:
|
941
|
-
num_coeffs = num_channels * (num_channels - 1) // 2
|
942
|
-
masked_coeffs = tf.reduce_sum(coeffs * mask, axis=-2)
|
943
|
-
coef_tensors = tf.split(masked_coeffs, num_coeffs, axis=-1)
|
944
|
-
|
945
|
-
channel_samples = []
|
946
|
-
coef_count = 0
|
947
|
-
for i in range(num_channels):
|
948
|
-
loc = loc_tensors[i]
|
949
|
-
for c in channel_samples:
|
950
|
-
loc += c * coef_tensors[coef_count]
|
951
|
-
coef_count += 1
|
952
|
-
|
953
|
-
logistic_samp = distributions.logistic.Logistic(loc=loc, scale=scale_tensors[i]).sample(seed=seed)
|
954
|
-
logistic_samp = tf.clip_by_value(logistic_samp, -1.0, 1.0)
|
955
|
-
channel_samples.append(logistic_samp)
|
956
|
-
|
957
|
-
return tf.concat(channel_samples, axis=-1)
|
958
|
-
|
959
|
-
def _batch_shape(self):
|
960
|
-
return tf.TensorShape([])
|
961
|
-
|
962
|
-
def _event_shape(self):
|
963
|
-
return tf.TensorShape(self.image_shape)
|
964
|
-
|
965
|
-
|
966
|
-
class PixelCNNNetwork(keras.layers.Layer):
|
967
|
-
"""Keras `Layer` to parameterize a Pixel CNN++ distributions.distribution.
|
968
|
-
This is a Keras implementation of the Pixel CNN++ network, as described in
|
969
|
-
Salimans et al. (2017)[1] and van den Oord et al. (2016)[2].
|
970
|
-
(https://github.com/openai/pixel-cnn).
|
971
|
-
#### References
|
972
|
-
[1]: Tim Salimans, Andrej Karpathy, Xi Chen, and Diederik P. Kingma.
|
973
|
-
PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture
|
974
|
-
Likelihood and Other Modifications. In _International Conference on
|
975
|
-
Learning Representations_, 2017.
|
976
|
-
https://pdfs.semanticscholar.org/9e90/6792f67cbdda7b7777b69284a81044857656.pdf
|
977
|
-
Additional details at https://github.com/openai/pixel-cnn
|
978
|
-
[2]: Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt,
|
979
|
-
Alex Graves, and Koray Kavukcuoglu. Conditional Image Generation with
|
980
|
-
PixelCNN Decoders. In _30th Conference on Neural Information Processing
|
981
|
-
Systems_, 2016.
|
982
|
-
https://papers.nips.cc/paper/6527-conditional-image-generation-with-pixelcnn-decoders.pdf.
|
983
|
-
"""
|
984
|
-
|
985
|
-
def __init__(
|
986
|
-
self,
|
987
|
-
dropout_p: float = 0.5,
|
988
|
-
num_resnet: int = 5,
|
989
|
-
num_hierarchies: int = 3,
|
990
|
-
num_filters: int = 160,
|
991
|
-
num_logistic_mix: int = 10,
|
992
|
-
receptive_field_dims: tuple[int, int] = (3, 3),
|
993
|
-
resnet_activation: str = "concat_elu",
|
994
|
-
l2_weight: float = 0.0,
|
995
|
-
use_weight_norm: bool = True,
|
996
|
-
use_data_init: bool = True,
|
997
|
-
dtype: tf.DType = tf.float32,
|
998
|
-
) -> None:
|
999
|
-
"""Initialize the :term:`neural network<Neural Network>` for the Pixel CNN++ distributions.distribution.
|
1000
|
-
|
1001
|
-
Parameters
|
1002
|
-
----------
|
1003
|
-
dropout_p
|
1004
|
-
`float`, the dropout probability. Should be between 0 and 1.
|
1005
|
-
num_resnet
|
1006
|
-
`int`, the number of layers (shown in Figure 2 of [2]) within
|
1007
|
-
each highest-level block of Figure 2 of [1].
|
1008
|
-
num_hierarchies
|
1009
|
-
`int`, the number of hightest-level blocks (separated by
|
1010
|
-
expansions/contractions of dimensions in Figure 2 of [1].)
|
1011
|
-
num_filters
|
1012
|
-
`int`, the number of convolutional filters.
|
1013
|
-
num_logistic_mix
|
1014
|
-
`int`, number of components in the distributions.logistic mixture
|
1015
|
-
distributions.distribution.
|
1016
|
-
receptive_field_dims
|
1017
|
-
`tuple`, height and width in pixels of the receptive
|
1018
|
-
field of the convolutional layers above and to the left of a given
|
1019
|
-
pixel. The width (second element of the tuple) should be odd. Figure 1
|
1020
|
-
(middle) of [2] shows a receptive field of (3, 5) (the row containing
|
1021
|
-
the current pixel is included in the height). The default of (3, 3) was
|
1022
|
-
used to produce the results in [1].
|
1023
|
-
resnet_activation
|
1024
|
-
`string`, the type of activation to use in the resnet
|
1025
|
-
blocks. May be 'concat_elu', 'elu', or 'relu'.
|
1026
|
-
l2_weight
|
1027
|
-
`float`, the L2 regularization weight.
|
1028
|
-
use_weight_norm
|
1029
|
-
`bool`, if `True` then use weight normalization.
|
1030
|
-
use_data_init
|
1031
|
-
`bool`, if `True` then use data-dependent initialization
|
1032
|
-
(has no effect if `use_weight_norm` is `False`).
|
1033
|
-
dtype
|
1034
|
-
Data type of the layer.
|
1035
|
-
"""
|
1036
|
-
super().__init__(dtype=dtype)
|
1037
|
-
self._dropout_p = dropout_p
|
1038
|
-
self._num_resnet = num_resnet
|
1039
|
-
self._num_hierarchies = num_hierarchies
|
1040
|
-
self._num_filters = num_filters
|
1041
|
-
self._num_logistic_mix = num_logistic_mix
|
1042
|
-
self._receptive_field_dims = receptive_field_dims # first set desired receptive field, then infer kernel
|
1043
|
-
self._resnet_activation = resnet_activation
|
1044
|
-
self._l2_weight = l2_weight
|
1045
|
-
|
1046
|
-
if use_weight_norm:
|
1047
|
-
|
1048
|
-
def layer_wrapper(layer):
|
1049
|
-
def wrapped_layer(*args, **kwargs):
|
1050
|
-
return WeightNorm(layer(*args, **kwargs), data_init=use_data_init)
|
1051
|
-
|
1052
|
-
return wrapped_layer
|
1053
|
-
|
1054
|
-
self._layer_wrapper = layer_wrapper
|
1055
|
-
else:
|
1056
|
-
self._layer_wrapper = lambda layer: layer
|
1057
|
-
|
1058
|
-
def build(self, input_shape: tuple[int, ...]) -> None:
|
1059
|
-
dtype = self.dtype
|
1060
|
-
if len(input_shape) == 2:
|
1061
|
-
batch_image_shape, batch_conditional_shape = input_shape
|
1062
|
-
conditional_input = keras.layers.Input(shape=batch_conditional_shape[1:], dtype=dtype)
|
1063
|
-
else:
|
1064
|
-
batch_image_shape = input_shape
|
1065
|
-
conditional_input = None
|
1066
|
-
|
1067
|
-
image_shape = batch_image_shape[1:]
|
1068
|
-
image_input = keras.layers.Input(shape=image_shape, dtype=dtype)
|
1069
|
-
|
1070
|
-
if self._resnet_activation == "concat_elu":
|
1071
|
-
activation = keras.layers.Lambda(lambda x: tf.nn.elu(tf.concat([x, -x], axis=-1)), dtype=dtype)
|
1072
|
-
else:
|
1073
|
-
activation = keras.activations.get(self._resnet_activation)
|
1074
|
-
|
1075
|
-
# Define layers with default inputs and layer wrapper applied
|
1076
|
-
Conv2D = functools.partial( # pylint:disable=invalid-name
|
1077
|
-
self._layer_wrapper(keras.layers.Convolution2D),
|
1078
|
-
filters=self._num_filters,
|
1079
|
-
padding="same",
|
1080
|
-
kernel_regularizer=keras.regularizers.l2(self._l2_weight),
|
1081
|
-
dtype=dtype,
|
1082
|
-
)
|
1083
|
-
|
1084
|
-
Dense = functools.partial( # pylint:disable=invalid-name
|
1085
|
-
self._layer_wrapper(keras.layers.Dense),
|
1086
|
-
kernel_regularizer=keras.regularizers.l2(self._l2_weight),
|
1087
|
-
dtype=dtype,
|
1088
|
-
)
|
1089
|
-
|
1090
|
-
Conv2DTranspose = functools.partial( # pylint:disable=invalid-name
|
1091
|
-
self._layer_wrapper(keras.layers.Conv2DTranspose),
|
1092
|
-
filters=self._num_filters,
|
1093
|
-
padding="same",
|
1094
|
-
strides=(2, 2),
|
1095
|
-
kernel_regularizer=keras.regularizers.l2(self._l2_weight),
|
1096
|
-
dtype=dtype,
|
1097
|
-
)
|
1098
|
-
|
1099
|
-
rows, cols = self._receptive_field_dims
|
1100
|
-
|
1101
|
-
# Define the dimensions of the valid (unmasked) areas of the layer kernels
|
1102
|
-
# for stride 1 convolutions in the internal layers.
|
1103
|
-
kernel_valid_dims = {
|
1104
|
-
"vertical": (rows - 1, cols), # vertical stack
|
1105
|
-
"horizontal": (2, cols // 2 + 1),
|
1106
|
-
} # horizontal stack
|
1107
|
-
|
1108
|
-
# Define the size of the kernel necessary to center the current pixel
|
1109
|
-
# correctly for stride 1 convolutions in the internal layers.
|
1110
|
-
kernel_sizes = {"vertical": (2 * rows - 3, cols), "horizontal": (3, cols)}
|
1111
|
-
|
1112
|
-
# Make the kernel constraint functions for stride 1 convolutions in internal
|
1113
|
-
# layers.
|
1114
|
-
kernel_constraints = {
|
1115
|
-
k: _make_kernel_constraint(kernel_sizes[k], (0, v[0]), (0, v[1])) for k, v in kernel_valid_dims.items()
|
1116
|
-
}
|
1117
|
-
|
1118
|
-
# Build the initial vertical stack/horizontal stack convolutional layers,
|
1119
|
-
# as shown in Figure 1 of [2]. The receptive field of the initial vertical
|
1120
|
-
# stack layer is a rectangular area centered above the current pixel.
|
1121
|
-
vertical_stack_init = Conv2D(
|
1122
|
-
kernel_size=(2 * rows - 1, cols),
|
1123
|
-
kernel_constraint=_make_kernel_constraint((2 * rows - 1, cols), (0, rows - 1), (0, cols)),
|
1124
|
-
)(image_input)
|
1125
|
-
|
1126
|
-
# In Figure 1 [2], the receptive field of the horizontal stack is
|
1127
|
-
# illustrated as the pixels in the same row and to the left of the current
|
1128
|
-
# pixel. [1] increases the height of this receptive field from one pixel to
|
1129
|
-
# two (`horizontal_stack_left`) and additionally includes a subset of the
|
1130
|
-
# row of pixels centered above the current pixel (`horizontal_stack_up`).
|
1131
|
-
horizontal_stack_up = Conv2D(
|
1132
|
-
kernel_size=(3, cols),
|
1133
|
-
kernel_constraint=_make_kernel_constraint((3, cols), (0, 1), (0, cols)),
|
1134
|
-
)(image_input)
|
1135
|
-
|
1136
|
-
horizontal_stack_left = Conv2D(
|
1137
|
-
kernel_size=(3, cols),
|
1138
|
-
kernel_constraint=_make_kernel_constraint((3, cols), (0, 2), (0, cols // 2)),
|
1139
|
-
)(image_input)
|
1140
|
-
|
1141
|
-
horizontal_stack_init = keras.layers.add([horizontal_stack_up, horizontal_stack_left], dtype=dtype)
|
1142
|
-
|
1143
|
-
layer_stacks = {
|
1144
|
-
"vertical": [vertical_stack_init],
|
1145
|
-
"horizontal": [horizontal_stack_init],
|
1146
|
-
}
|
1147
|
-
|
1148
|
-
# Build the downward pass of the U-net (left-hand half of Figure 2 of [1]).
|
1149
|
-
# Each `i` iteration builds one of the highest-level blocks (identified as
|
1150
|
-
# 'Sequence of 6 layers' in the figure, consisting of `num_resnet=5` stride-
|
1151
|
-
# 1 layers, and one stride-2 layer that contracts the height/width
|
1152
|
-
# dimensions). The `_` iterations build the stride 1 layers. The layers of
|
1153
|
-
# the downward pass are stored in lists, since we'll later need them to make
|
1154
|
-
# skip-connections to layers in the upward pass of the U-net (the skip-
|
1155
|
-
# connections are represented by curved lines in Figure 2 [1]).
|
1156
|
-
for i in range(self._num_hierarchies):
|
1157
|
-
for _ in range(self._num_resnet):
|
1158
|
-
# Build a layer shown in Figure 2 of [2]. The 'vertical' iteration
|
1159
|
-
# builds the layers in the left half of the figure, and the 'horizontal'
|
1160
|
-
# iteration builds the layers in the right half.
|
1161
|
-
for stack in ["vertical", "horizontal"]:
|
1162
|
-
input_x = layer_stacks[stack][-1]
|
1163
|
-
x = activation(input_x)
|
1164
|
-
x = Conv2D(
|
1165
|
-
kernel_size=kernel_sizes[stack],
|
1166
|
-
kernel_constraint=kernel_constraints[stack],
|
1167
|
-
)(x)
|
1168
|
-
|
1169
|
-
# Add the vertical-stack layer to the horizontal-stack layer
|
1170
|
-
if stack == "horizontal":
|
1171
|
-
h = activation(layer_stacks["vertical"][-1])
|
1172
|
-
h = Dense(self._num_filters)(h)
|
1173
|
-
x = keras.layers.add([h, x], dtype=dtype)
|
1174
|
-
|
1175
|
-
x = activation(x)
|
1176
|
-
x = keras.layers.Dropout(self._dropout_p, dtype=dtype)(x)
|
1177
|
-
x = Conv2D(
|
1178
|
-
filters=2 * self._num_filters,
|
1179
|
-
kernel_size=kernel_sizes[stack],
|
1180
|
-
kernel_constraint=kernel_constraints[stack],
|
1181
|
-
)(x)
|
1182
|
-
|
1183
|
-
if conditional_input is not None:
|
1184
|
-
h_projection = _build_and_apply_h_projection(conditional_input, self._num_filters, dtype=dtype)
|
1185
|
-
x = keras.layers.add([x, h_projection], dtype=dtype)
|
1186
|
-
|
1187
|
-
x = _apply_sigmoid_gating(x)
|
1188
|
-
|
1189
|
-
# Add a residual connection from the layer's input.
|
1190
|
-
out = keras.layers.add([input_x, x], dtype=dtype)
|
1191
|
-
layer_stacks[stack].append(out)
|
1192
|
-
|
1193
|
-
if i < self._num_hierarchies - 1:
|
1194
|
-
# Build convolutional layers that contract the height/width dimensions
|
1195
|
-
# on the downward pass between each set of layers (e.g. contracting from
|
1196
|
-
# 32x32 to 16x16 in Figure 2 of [1]).
|
1197
|
-
for stack in ["vertical", "horizontal"]:
|
1198
|
-
# Define kernel dimensions/masking to maintain the autoregressive property.
|
1199
|
-
x = layer_stacks[stack][-1]
|
1200
|
-
h, w = kernel_valid_dims[stack]
|
1201
|
-
kernel_height = 2 * h
|
1202
|
-
kernel_width = w + 1 if stack == "vertical" else 2 * w
|
1203
|
-
kernel_size = (kernel_height, kernel_width)
|
1204
|
-
kernel_constraint = _make_kernel_constraint(kernel_size, (0, h), (0, w))
|
1205
|
-
x = Conv2D(
|
1206
|
-
strides=(2, 2),
|
1207
|
-
kernel_size=kernel_size,
|
1208
|
-
kernel_constraint=kernel_constraint,
|
1209
|
-
)(x)
|
1210
|
-
layer_stacks[stack].append(x)
|
1211
|
-
|
1212
|
-
# Upward pass of the U-net (right-hand half of Figure 2 of [1]). We stored
|
1213
|
-
# the layers of the downward pass in a list, in order to access them to make
|
1214
|
-
# skip-connections to the upward pass. For the upward pass, we need to keep
|
1215
|
-
# track of only the current layer, so we maintain a reference to the
|
1216
|
-
# current layer of the horizontal/vertical stack in the `upward_pass` dict.
|
1217
|
-
# The upward pass begins with the last layer of the downward pass.
|
1218
|
-
upward_pass = {key: stack.pop() for key, stack in layer_stacks.items()}
|
1219
|
-
|
1220
|
-
# As with the downward pass, each `i` iteration builds a highest level block
|
1221
|
-
# in Figure 2 [1], and the `_` iterations build individual layers within the
|
1222
|
-
# block.
|
1223
|
-
for i in range(self._num_hierarchies):
|
1224
|
-
num_resnet = self._num_resnet if i == 0 else self._num_resnet + 1
|
1225
|
-
|
1226
|
-
for _ in range(num_resnet):
|
1227
|
-
# Build a layer as shown in Figure 2 of [2], with a skip-connection
|
1228
|
-
# from the symmetric layer in the downward pass.
|
1229
|
-
for stack in ["vertical", "horizontal"]:
|
1230
|
-
input_x = upward_pass[stack]
|
1231
|
-
x_symmetric = layer_stacks[stack].pop()
|
1232
|
-
|
1233
|
-
x = activation(input_x)
|
1234
|
-
x = Conv2D(
|
1235
|
-
kernel_size=kernel_sizes[stack],
|
1236
|
-
kernel_constraint=kernel_constraints[stack],
|
1237
|
-
)(x)
|
1238
|
-
|
1239
|
-
# Include the vertical-stack layer of the upward pass in the layers
|
1240
|
-
# to be added to the horizontal layer.
|
1241
|
-
if stack == "horizontal":
|
1242
|
-
x_symmetric = keras.layers.Concatenate(axis=-1, dtype=dtype)(
|
1243
|
-
[upward_pass["vertical"], x_symmetric]
|
1244
|
-
)
|
1245
|
-
|
1246
|
-
# Add a skip-connection from the symmetric layer in the downward
|
1247
|
-
# pass to the layer `x` in the upward pass.
|
1248
|
-
h = activation(x_symmetric)
|
1249
|
-
h = Dense(self._num_filters)(h)
|
1250
|
-
x = keras.layers.add([h, x], dtype=dtype)
|
1251
|
-
|
1252
|
-
x = activation(x)
|
1253
|
-
x = keras.layers.Dropout(self._dropout_p, dtype=dtype)(x)
|
1254
|
-
x = Conv2D(
|
1255
|
-
filters=2 * self._num_filters,
|
1256
|
-
kernel_size=kernel_sizes[stack],
|
1257
|
-
kernel_constraint=kernel_constraints[stack],
|
1258
|
-
)(x)
|
1259
|
-
|
1260
|
-
if conditional_input is not None:
|
1261
|
-
h_projection = _build_and_apply_h_projection(conditional_input, self._num_filters, dtype=dtype)
|
1262
|
-
x = keras.layers.add([x, h_projection], dtype=dtype)
|
1263
|
-
|
1264
|
-
x = _apply_sigmoid_gating(x)
|
1265
|
-
upward_pass[stack] = keras.layers.add([input_x, x], dtype=dtype)
|
1266
|
-
|
1267
|
-
# Define deconvolutional layers that expand height/width dimensions on the
|
1268
|
-
# upward pass (e.g. expanding from 8x8 to 16x16 in Figure 2 of [1]), with
|
1269
|
-
# the correct kernel dimensions/masking to maintain the autoregressive
|
1270
|
-
# property.
|
1271
|
-
if i < self._num_hierarchies - 1:
|
1272
|
-
for stack in ["vertical", "horizontal"]:
|
1273
|
-
h, w = kernel_valid_dims[stack]
|
1274
|
-
kernel_height = 2 * h - 2
|
1275
|
-
if stack == "vertical":
|
1276
|
-
kernel_width = w + 1
|
1277
|
-
kernel_constraint = _make_kernel_constraint(
|
1278
|
-
(kernel_height, kernel_width),
|
1279
|
-
(h - 2, kernel_height),
|
1280
|
-
(0, w),
|
1281
|
-
)
|
1282
|
-
else:
|
1283
|
-
kernel_width = 2 * w - 2
|
1284
|
-
kernel_constraint = _make_kernel_constraint(
|
1285
|
-
(kernel_height, kernel_width),
|
1286
|
-
(h - 2, kernel_height),
|
1287
|
-
(w - 2, kernel_width),
|
1288
|
-
)
|
1289
|
-
|
1290
|
-
x = upward_pass[stack]
|
1291
|
-
x = Conv2DTranspose(
|
1292
|
-
kernel_size=(kernel_height, kernel_width),
|
1293
|
-
kernel_constraint=kernel_constraint,
|
1294
|
-
)(x)
|
1295
|
-
upward_pass[stack] = x
|
1296
|
-
|
1297
|
-
x_out = keras.layers.ELU(dtype=dtype)(upward_pass["horizontal"])
|
1298
|
-
|
1299
|
-
# Build final Dense/Reshape layers to output the correct number of
|
1300
|
-
# parameters per pixel.
|
1301
|
-
num_channels = tfp_internal.tensorshape_util.as_list(image_shape)[-1]
|
1302
|
-
num_coeffs = num_channels * (num_channels - 1) // 2 # alpha, beta, gamma in eq.3 of paper
|
1303
|
-
num_out = num_channels * 2 + num_coeffs + 1 # mu, s + alpha, beta, gamma + 1 (mixture weight)
|
1304
|
-
num_out_total = num_out * self._num_logistic_mix
|
1305
|
-
params = Dense(num_out_total)(x_out)
|
1306
|
-
params = tf.reshape(
|
1307
|
-
params,
|
1308
|
-
tfp_internal.prefer_static.concat( # [-1,H,W,nb mixtures, params per mixture]
|
1309
|
-
[[-1], image_shape[:-1], [self._num_logistic_mix, num_out]], axis=0
|
1310
|
-
),
|
1311
|
-
)
|
1312
|
-
|
1313
|
-
# If there is one color channel, split the parameters into a list of three
|
1314
|
-
# output `Tensor`s: (1) component logits for the Quantized Logistic mixture
|
1315
|
-
# distributions.distribution, (2) location parameters for each component, and (3) scale
|
1316
|
-
# parameters for each component. If there is more than one color channel,
|
1317
|
-
# return a fourth `Tensor` for the coefficients for the linear dependence
|
1318
|
-
# among color channels (e.g. alpha, beta, gamma).
|
1319
|
-
# [logits, mu, s, linear dependence]
|
1320
|
-
splits = 3 if num_channels == 1 else [1, num_channels, num_channels, num_coeffs]
|
1321
|
-
outputs = tf.split(params, splits, axis=-1)
|
1322
|
-
|
1323
|
-
# Squeeze singleton dimension from component logits
|
1324
|
-
outputs[0] = tf.squeeze(outputs[0], axis=-1)
|
1325
|
-
|
1326
|
-
# Ensure scales are positive and do not collapse to near-zero
|
1327
|
-
outputs[2] = tf.nn.softplus(outputs[2]) + tf.cast(tf.exp(-7.0), self.dtype)
|
1328
|
-
|
1329
|
-
inputs = image_input if conditional_input is None else [image_input, conditional_input]
|
1330
|
-
self._network = keras.Model(inputs=inputs, outputs=outputs)
|
1331
|
-
super().build(input_shape)
|
1332
|
-
|
1333
|
-
def call(self, inputs: tf.Tensor, training: bool | None = None, mask: tf.Tensor | None = None) -> tf.Tensor:
|
1334
|
-
"""Call the Pixel CNN network model.
|
1335
|
-
|
1336
|
-
Parameters
|
1337
|
-
----------
|
1338
|
-
inputs
|
1339
|
-
4D `Tensor` of image data with dimensions [batch size, height,
|
1340
|
-
width, channels] or a 2-element `list`. If `list`, the first element is
|
1341
|
-
the 4D image `Tensor` and the second element is a `Tensor` with
|
1342
|
-
conditional input data (e.g. VAE encodings or class labels) with the
|
1343
|
-
same leading batch dimension as the image `Tensor`.
|
1344
|
-
training
|
1345
|
-
`bool` or `None`. If `bool`, it controls the dropout layer,
|
1346
|
-
where `True` implies dropout is active. If `None`, it it defaults to
|
1347
|
-
`keras.backend.learning_phase()`
|
1348
|
-
|
1349
|
-
Returns
|
1350
|
-
-------
|
1351
|
-
outputs
|
1352
|
-
a 3- or 4-element `list` of `Tensor`s in the following order: \
|
1353
|
-
component_logits: 4D `Tensor` of logits for the Categorical distributions.distribution \
|
1354
|
-
over Quantized Logistic mixture components. Dimensions are \
|
1355
|
-
`[batch_size, height, width, num_logistic_mix]`.
|
1356
|
-
locs
|
1357
|
-
4D `Tensor` of location parameters for the Quantized Logistic \
|
1358
|
-
mixture components. Dimensions are `[batch_size, height, width, \
|
1359
|
-
num_logistic_mix, num_channels]`.
|
1360
|
-
scales
|
1361
|
-
4D `Tensor` of location parameters for the Quantized Logistic \
|
1362
|
-
mixture components. Dimensions are `[batch_size, height, width, \
|
1363
|
-
num_logistic_mix, num_channels]`.
|
1364
|
-
coeffs
|
1365
|
-
4D `Tensor` of coefficients for the linear dependence among \
|
1366
|
-
color channels, included only if the image has more than one channel. \
|
1367
|
-
Dimensions are `[batch_size, height, width, num_logistic_mix, \
|
1368
|
-
num_coeffs]`, where `num_coeffs = num_channels * (num_channels - 1) // 2`.
|
1369
|
-
"""
|
1370
|
-
return self._network(inputs, training=training)
|
1371
|
-
|
1372
|
-
|
1373
|
-
def _make_kernel_constraint(kernel_size, valid_rows, valid_columns):
|
1374
|
-
"""Make the masking function for layer kernels."""
|
1375
|
-
mask = np.zeros(kernel_size)
|
1376
|
-
lower, upper = valid_rows
|
1377
|
-
left, right = valid_columns
|
1378
|
-
mask[lower:upper, left:right] = 1.0
|
1379
|
-
mask = mask[:, :, np.newaxis, np.newaxis]
|
1380
|
-
return lambda x: x * mask
|
1381
|
-
|
1382
|
-
|
1383
|
-
def _build_and_apply_h_projection(h, num_filters, dtype):
|
1384
|
-
"""Project the conditional input."""
|
1385
|
-
h = keras.layers.Flatten(dtype=dtype)(h)
|
1386
|
-
h_projection = keras.layers.Dense(2 * num_filters, kernel_initializer="random_normal", dtype=dtype)(h)
|
1387
|
-
return h_projection[..., tf.newaxis, tf.newaxis, :]
|
1388
|
-
|
1389
|
-
|
1390
|
-
def _apply_sigmoid_gating(x):
|
1391
|
-
"""Apply the sigmoid gating in Figure 2 of [2]."""
|
1392
|
-
activation_tensor, gate_tensor = tf.split(x, 2, axis=-1)
|
1393
|
-
sigmoid_gate = tf.sigmoid(gate_tensor)
|
1394
|
-
return keras.layers.multiply([sigmoid_gate, activation_tensor], dtype=x.dtype)
|