dataeval 0.74.0__py3-none-any.whl → 0.74.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. dataeval/__init__.py +3 -9
  2. dataeval/detectors/__init__.py +2 -10
  3. dataeval/detectors/drift/base.py +3 -3
  4. dataeval/detectors/drift/mmd.py +1 -1
  5. dataeval/detectors/linters/clusterer.py +3 -3
  6. dataeval/detectors/linters/duplicates.py +4 -4
  7. dataeval/detectors/linters/outliers.py +4 -4
  8. dataeval/detectors/ood/__init__.py +5 -12
  9. dataeval/detectors/ood/base.py +5 -5
  10. dataeval/detectors/ood/metadata_ks_compare.py +12 -13
  11. dataeval/interop.py +1 -1
  12. dataeval/metrics/bias/balance.py +3 -3
  13. dataeval/metrics/bias/coverage.py +3 -3
  14. dataeval/metrics/bias/diversity.py +3 -3
  15. dataeval/metrics/bias/metadata_preprocessing.py +3 -3
  16. dataeval/metrics/bias/parity.py +4 -4
  17. dataeval/metrics/estimators/ber.py +3 -3
  18. dataeval/metrics/estimators/divergence.py +3 -3
  19. dataeval/metrics/estimators/uap.py +3 -3
  20. dataeval/metrics/stats/base.py +2 -2
  21. dataeval/metrics/stats/boxratiostats.py +1 -1
  22. dataeval/metrics/stats/datasetstats.py +6 -6
  23. dataeval/metrics/stats/dimensionstats.py +1 -1
  24. dataeval/metrics/stats/hashstats.py +1 -1
  25. dataeval/metrics/stats/labelstats.py +3 -3
  26. dataeval/metrics/stats/pixelstats.py +1 -1
  27. dataeval/metrics/stats/visualstats.py +1 -1
  28. dataeval/output.py +77 -53
  29. dataeval/utils/__init__.py +1 -7
  30. dataeval/workflows/sufficiency.py +4 -4
  31. {dataeval-0.74.0.dist-info → dataeval-0.74.1.dist-info}/METADATA +3 -8
  32. dataeval-0.74.1.dist-info/RECORD +65 -0
  33. dataeval/detectors/ood/ae.py +0 -76
  34. dataeval/detectors/ood/aegmm.py +0 -67
  35. dataeval/detectors/ood/base_tf.py +0 -109
  36. dataeval/detectors/ood/llr.py +0 -302
  37. dataeval/detectors/ood/vae.py +0 -98
  38. dataeval/detectors/ood/vaegmm.py +0 -76
  39. dataeval/utils/lazy.py +0 -26
  40. dataeval/utils/tensorflow/__init__.py +0 -19
  41. dataeval/utils/tensorflow/_internal/gmm.py +0 -103
  42. dataeval/utils/tensorflow/_internal/loss.py +0 -121
  43. dataeval/utils/tensorflow/_internal/models.py +0 -1394
  44. dataeval/utils/tensorflow/_internal/trainer.py +0 -114
  45. dataeval/utils/tensorflow/_internal/utils.py +0 -256
  46. dataeval/utils/tensorflow/loss/__init__.py +0 -11
  47. dataeval-0.74.0.dist-info/RECORD +0 -79
  48. {dataeval-0.74.0.dist-info → dataeval-0.74.1.dist-info}/LICENSE.txt +0 -0
  49. {dataeval-0.74.0.dist-info → dataeval-0.74.1.dist-info}/WHEEL +0 -0
@@ -1,1394 +0,0 @@
1
- # type: ignore
2
-
3
- """
4
- Source code derived from Alibi-Detect 0.11.4
5
- https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
6
-
7
- Original code Copyright (c) 2023 Seldon Technologies Ltd
8
- Licensed under Apache Software License (Apache 2.0)
9
- """
10
-
11
- from __future__ import annotations
12
-
13
- import functools
14
- import warnings
15
- from typing import TYPE_CHECKING, cast
16
-
17
- import numpy as np
18
-
19
- from dataeval.utils.lazy import lazyload
20
-
21
- if TYPE_CHECKING:
22
- import tensorflow as tf
23
- import tensorflow_probability.python.bijectors as bijectors
24
- import tensorflow_probability.python.distributions as distributions
25
- import tensorflow_probability.python.internal as tfp_internal
26
- import tf_keras as keras
27
- else:
28
- tf = lazyload("tensorflow")
29
- bijectors = lazyload("tensorflow_probability.python.bijectors")
30
- distributions = lazyload("tensorflow_probability.python.distributions")
31
- tfp_internal = lazyload("tensorflow_probability.python.internal")
32
- keras = lazyload("tf_keras")
33
-
34
-
35
- def relative_euclidean_distance(x: tf.Tensor, y: tf.Tensor, eps: float = 1e-12, axis: int = -1) -> tf.Tensor:
36
- """
37
- Relative Euclidean distance.
38
-
39
- Parameters
40
- ----------
41
- x
42
- Tensor used in distance computation.
43
- y
44
- Tensor used in distance computation.
45
- eps
46
- Epsilon added to denominator for numerical stability.
47
- axis
48
- Axis used to compute distance.
49
-
50
- Returns
51
- -------
52
- Tensor with relative Euclidean distance across specified axis.
53
- """
54
- denom = tf.concat(
55
- [
56
- tf.reshape(tf.norm(x, ord=2, axis=axis), (-1, 1)), # type: ignore
57
- tf.reshape(tf.norm(y, ord=2, axis=axis), (-1, 1)), # type: ignore
58
- ],
59
- axis=1,
60
- )
61
- dist = tf.norm(tf.math.subtract(x, y), ord=2, axis=axis) / (tf.reduce_min(denom, axis=axis) + eps) # type: ignore
62
- return dist
63
-
64
-
65
- def eucl_cosim_features(x: tf.Tensor, y: tf.Tensor, max_eucl: float = 1e2) -> tf.Tensor:
66
- """
67
- Compute features extracted from the reconstructed instance using the
68
- relative Euclidean distance and cosine similarity between 2 tensors.
69
-
70
- Parameters
71
- ----------
72
- x : tf.Tensor
73
- Tensor used in feature computation.
74
- y : tf.Tensor
75
- Tensor used in feature computation.
76
- max_eucl : float, default 1e2
77
- Maximum value to clip relative Euclidean distance by.
78
-
79
- Returns
80
- -------
81
- tf.Tensor
82
- Tensor concatenating the relative Euclidean distance and cosine similarity features.
83
- """
84
- if len(x.shape) > 2 or len(y.shape) > 2:
85
- x = cast(tf.Tensor, keras.layers.Flatten()(x))
86
- y = cast(tf.Tensor, keras.layers.Flatten()(y))
87
- rec_cos = tf.reshape(keras.losses.cosine_similarity(y, x, -1), (-1, 1))
88
- rec_euc = tf.reshape(relative_euclidean_distance(y, x, -1), (-1, 1))
89
- # rec_euc could become very large so should be clipped
90
- rec_euc = tf.clip_by_value(rec_euc, 0, max_eucl)
91
- return cast(tf.Tensor, tf.concat([rec_cos, rec_euc], -1))
92
-
93
-
94
- class Sampling(keras.layers.Layer):
95
- """Reparametrization trick - Uses (z_mean, z_log_var) to sample the latent vector z."""
96
-
97
- def call(self, inputs: tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
98
- """
99
- Sample z.
100
-
101
- Parameters
102
- ----------
103
- inputs
104
- Tuple with mean and log :term:`variance<Variance>`.
105
-
106
- Returns
107
- -------
108
- Sampled vector z.
109
- """
110
- z_mean, z_log_var = inputs
111
- batch, dim = tuple(tf.shape(z_mean).numpy().ravel()[:2]) # type: ignore
112
- epsilon = cast(tf.Tensor, keras.backend.random_normal(shape=(batch, dim)))
113
- return z_mean + tf.exp(tf.math.multiply(0.5, z_log_var)) * epsilon
114
-
115
-
116
- class EncoderAE(keras.layers.Layer):
117
- def __init__(self, encoder_net: keras.Sequential) -> None:
118
- """
119
- Encoder of AE.
120
-
121
- Parameters
122
- ----------
123
- encoder_net
124
- Layers for the encoder wrapped in a keras.keras.Sequential class.
125
- name
126
- Name of encoder.
127
- """
128
- super().__init__(name="encoder_ae")
129
- self.encoder_net: keras.Sequential = encoder_net
130
-
131
- def call(self, x: tf.Tensor) -> tf.Tensor:
132
- return cast(tf.Tensor, self.encoder_net(x))
133
-
134
-
135
- class EncoderVAE(keras.layers.Layer):
136
- def __init__(self, encoder_net: keras.Sequential, latent_dim: int) -> None:
137
- """
138
- Encoder of VAE.
139
-
140
- Parameters
141
- ----------
142
- encoder_net
143
- Layers for the encoder wrapped in a keras.keras.Sequential class.
144
- latent_dim
145
- Dimensionality of the :term:`latent space<Latent Space>`.
146
- name
147
- Name of encoder.
148
- """
149
- super().__init__(name="encoder_vae")
150
- self.encoder_net: keras.Sequential = encoder_net
151
- self._fc_mean = keras.layers.Dense(latent_dim, activation=None)
152
- self._fc_log_var = keras.layers.Dense(latent_dim, activation=None)
153
- self._sampling = Sampling()
154
-
155
- def call(self, x: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
156
- x = cast(tf.Tensor, self.encoder_net(x))
157
- if len(x.shape) > 2:
158
- x = cast(tf.Tensor, keras.layers.Flatten()(x))
159
- z_mean = cast(tf.Tensor, self._fc_mean(x))
160
- z_log_var = cast(tf.Tensor, self._fc_log_var(x))
161
- z = cast(tf.Tensor, self._sampling((z_mean, z_log_var)))
162
- return z_mean, z_log_var, z
163
-
164
-
165
- class Decoder(keras.layers.Layer):
166
- def __init__(self, decoder_net: keras.Sequential) -> None:
167
- """
168
- Decoder of AE and VAE.
169
-
170
- Parameters
171
- ----------
172
- decoder_net
173
- Layers for the decoder wrapped in a keras.keras.Sequential class.
174
- name
175
- Name of decoder.
176
- """
177
- super().__init__(name="decoder")
178
- self.decoder_net: keras.Sequential = decoder_net
179
-
180
- def call(self, inputs: tf.Tensor) -> tf.Tensor:
181
- return cast(tf.Tensor, self.decoder_net(inputs))
182
-
183
-
184
- class AE(keras.Model):
185
- """
186
- Combine encoder and decoder in AE.
187
-
188
- Parameters
189
- ----------
190
- encoder_net : keras.Sequential
191
- Layers for the encoder wrapped in a keras.keras.Sequential class.
192
- decoder_net : keras.Sequential
193
- Layers for the decoder wrapped in a keras.keras.Sequential class.
194
- """
195
-
196
- def __init__(self, encoder_net: keras.Sequential, decoder_net: keras.Sequential) -> None:
197
- super().__init__(name="ae")
198
- self.encoder: keras.layers.Layer = EncoderAE(encoder_net)
199
- self.decoder: keras.layers.Layer = Decoder(decoder_net)
200
-
201
- def call(self, inputs: tf.Tensor, training: bool | None = None, mask: tf.Tensor | None = None) -> tf.Tensor:
202
- z = cast(tf.Tensor, self.encoder(inputs))
203
- x_recon = cast(tf.Tensor, self.decoder(z))
204
- return x_recon
205
-
206
-
207
- class VAE(keras.Model):
208
- """
209
- Combine encoder and decoder in VAE.
210
-
211
- Parameters
212
- ----------
213
- encoder_net : keras.Sequential
214
- Layers for the encoder wrapped in a keras.keras.Sequential class.
215
- decoder_net : keras.Sequential
216
- Layers for the decoder wrapped in a keras.keras.Sequential class.
217
- latent_dim : int
218
- Dimensionality of the :term:`latent space<Latent Space>`.
219
- beta : float, default 1.0
220
- Beta parameter for KL-divergence loss term.
221
- """
222
-
223
- def __init__(
224
- self, encoder_net: keras.Sequential, decoder_net: keras.Sequential, latent_dim: int, beta: float = 1.0
225
- ) -> None:
226
- super().__init__(name="vae_model")
227
- self.encoder: keras.layers.Layer = EncoderVAE(encoder_net, latent_dim)
228
- self.decoder: keras.layers.Layer = Decoder(decoder_net)
229
- self.beta: float = beta
230
- self.latent_dim: int = latent_dim
231
-
232
- def call(self, inputs: tf.Tensor, training: bool | None = None, mask: tf.Tensor | None = None) -> tf.Tensor:
233
- z_mean, z_log_var, z = cast(tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.encoder(inputs))
234
- x_recon = self.decoder(z)
235
- # add KL divergence loss term
236
- kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)
237
- self.add_loss(self.beta * kl_loss)
238
- return cast(tf.Tensor, x_recon)
239
-
240
-
241
- class AEGMM(keras.Model):
242
- """
243
- Deep Autoencoding Gaussian Mixture Model.
244
-
245
- Parameters
246
- ----------
247
- encoder_net : keras.Sequential
248
- Layers for the encoder wrapped in a keras.keras.Sequential class.
249
- decoder_net : keras.Sequential
250
- Layers for the decoder wrapped in a keras.keras.Sequential class.
251
- gmm_density_net : keras.Sequential
252
- Layers for the GMM network wrapped in a keras.keras.Sequential class.
253
- n_gmm : int
254
- Number of components in GMM.
255
- """
256
-
257
- def __init__(
258
- self,
259
- encoder_net: keras.Sequential,
260
- decoder_net: keras.Sequential,
261
- gmm_density_net: keras.Sequential,
262
- n_gmm: int,
263
- ) -> None:
264
- super().__init__("aegmm")
265
- self.encoder = encoder_net
266
- self.decoder = decoder_net
267
- self.gmm_density = gmm_density_net
268
- self.n_gmm = n_gmm
269
-
270
- def call(
271
- self, inputs: tf.Tensor, training: bool | None = None, mask: tf.Tensor | None = None
272
- ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
273
- enc = self.encoder(inputs)
274
- x_recon = cast(tf.Tensor, self.decoder(enc))
275
- recon_features = eucl_cosim_features(inputs, x_recon)
276
- z = cast(tf.Tensor, tf.concat([enc, recon_features], -1))
277
- gamma = cast(tf.Tensor, self.gmm_density(z))
278
- return x_recon, z, gamma
279
-
280
-
281
- class VAEGMM(keras.Model):
282
- """
283
- Variational Autoencoding Gaussian Mixture Model.
284
-
285
- Parameters
286
- ----------
287
- encoder_net : keras.Sequential
288
- Layers for the encoder wrapped in a keras.keras.Sequential class.
289
- decoder_net : keras.Sequential
290
- Layers for the decoder wrapped in a keras.keras.Sequential class.
291
- gmm_density_net : keras.Sequential
292
- Layers for the GMM network wrapped in a keras.keras.Sequential class.
293
- n_gmm : int
294
- Number of components in GMM.
295
- latent_dim : int
296
- Dimensionality of the :term:`latent space<Latent Space>`.
297
- beta : float, default 1.0
298
- Beta parameter for KL-divergence loss term.
299
- """
300
-
301
- def __init__(
302
- self,
303
- encoder_net: keras.Sequential,
304
- decoder_net: keras.Sequential,
305
- gmm_density_net: keras.Sequential,
306
- n_gmm: int,
307
- latent_dim: int,
308
- beta: float = 1.0,
309
- ) -> None:
310
- super().__init__(name="vaegmm")
311
- self.encoder: keras.Sequential = EncoderVAE(encoder_net, latent_dim)
312
- self.decoder: keras.Sequential = decoder_net
313
- self.gmm_density: keras.Sequential = gmm_density_net
314
- self.n_gmm: int = n_gmm
315
- self.latent_dim: int = latent_dim
316
- self.beta = beta
317
-
318
- def call(
319
- self, inputs: tf.Tensor, training: bool | None = None, mask: tf.Tensor | None = None
320
- ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
321
- enc_mean, enc_log_var, enc = cast(tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.encoder(inputs))
322
- x_recon = cast(tf.Tensor, self.decoder(enc))
323
- recon_features = eucl_cosim_features(inputs, x_recon)
324
- z = cast(tf.Tensor, tf.concat([enc, recon_features], -1))
325
- gamma = cast(tf.Tensor, self.gmm_density(z))
326
- # add KL divergence loss term
327
- kl_loss = -0.5 * tf.reduce_mean(enc_log_var - tf.square(enc_mean) - tf.exp(enc_log_var) + 1)
328
- self.add_loss(self.beta * kl_loss)
329
- return x_recon, z, gamma
330
-
331
-
332
- class WeightNorm(keras.layers.Wrapper):
333
- def __init__(self, layer, data_init: bool = True, **kwargs) -> None:
334
- """Layer wrapper to decouple magnitude and direction of the layer's weights.
335
-
336
- This wrapper reparameterizes a layer by decoupling the weight's
337
- magnitude and direction. This speeds up convergence by improving the
338
- conditioning of the optimization problem. It has an optional data-dependent
339
- initialization scheme, in which initial values of weights are set as functions
340
- of the first minibatch of data. Both the weight normalization and data-
341
- dependent initialization are described in [Salimans and Kingma (2016)][1].
342
-
343
- Parameters
344
- ----------
345
- layer
346
- A `keras.layers.Layer` instance. Supported layer types are
347
- `Dense`, `Conv2D`, and `Conv2DTranspose`. Layers with multiple inputs
348
- are not supported.
349
- data_init
350
- If `True` use data dependent variable initialization.
351
- **kwargs
352
- Additional keyword args passed to `keras.layers.Wrapper`.
353
-
354
- Raises
355
- ------
356
- ValueError
357
- If `layer` is not a `keras.layers.Layer` instance.
358
- """
359
- if not isinstance(layer, keras.layers.Layer):
360
- raise ValueError(
361
- "Please initialize `WeightNorm` layer with a `keras.layers.Layer` " f"instance. You passed: {layer}"
362
- )
363
-
364
- layer_type = type(layer).__name__
365
- if layer_type not in ["Dense", "Conv2D", "Conv2DTranspose"]:
366
- warnings.warn(
367
- "`WeightNorm` is tested only for `Dense`, `Conv2D`, and "
368
- f"`Conv2DTranspose` layers. You passed a layer of type `{layer_type}`"
369
- )
370
-
371
- super().__init__(layer, **kwargs)
372
-
373
- self.data_init = data_init
374
- self._track_trackable(layer, name="layer")
375
- self.filter_axis = -2 if layer_type == "Conv2DTranspose" else -1
376
-
377
- def _compute_weights(self):
378
- """Generate weights with normalization."""
379
- # Determine the axis along which to expand `g` so that `g` broadcasts to
380
- # the shape of `v`.
381
- new_axis = -self.filter_axis - 3
382
-
383
- self.layer.kernel = tf.nn.l2_normalize(self.v, axis=self.kernel_norm_axes) * tf.expand_dims(self.g, new_axis)
384
-
385
- def _init_norm(self):
386
- """Set the norm of the weight vector."""
387
- kernel_norm = tf.sqrt(tf.reduce_sum(tf.square(self.v), axis=self.kernel_norm_axes))
388
- self.g.assign(kernel_norm)
389
-
390
- def _data_dep_init(self, inputs):
391
- """Data dependent initialization."""
392
- # Normalize kernel first so that calling the layer calculates
393
- # `tf.dot(v, x)/tf.norm(v)` as in (5) in ([Salimans and Kingma, 2016][1]).
394
- self._compute_weights()
395
-
396
- activation = self.layer.activation
397
- self.layer.activation = None
398
-
399
- use_bias = self.layer.bias is not None
400
- if use_bias:
401
- bias = self.layer.bias
402
- self.layer.bias = tf.zeros_like(bias)
403
-
404
- # Since the bias is initialized as zero, setting the activation to zero and
405
- # calling the initialized layer (with normalized kernel) yields the correct
406
- # computation ((5) in Salimans and Kingma (2016))
407
- x_init = self.layer(inputs)
408
- norm_axes_out = list(range(x_init.shape.rank - 1))
409
- m_init, v_init = tf.nn.moments(x_init, norm_axes_out)
410
- scale_init = 1.0 / tf.sqrt(v_init + 1e-10)
411
-
412
- self.g.assign(self.g * scale_init)
413
- if use_bias:
414
- self.layer.bias = bias
415
- self.layer.bias.assign(-m_init * scale_init)
416
- self.layer.activation = activation
417
-
418
- def build(self, input_shape=None):
419
- """Build `Layer`.
420
-
421
- Parameters
422
- ----------
423
- input_shape
424
- The shape of the input to `self.layer`.
425
-
426
- Raises
427
- ------
428
- ValueError
429
- If `Layer` does not contain a `kernel` of weights.
430
- """
431
- input_shape = tf.TensorShape(input_shape).as_list()
432
- input_shape[0] = None
433
- self.input_spec = keras.layers.InputSpec(shape=input_shape)
434
-
435
- if not self.layer.built:
436
- self.layer.build(input_shape)
437
-
438
- if not hasattr(self.layer, "kernel"):
439
- raise ValueError("`WeightNorm` must wrap a layer that contains a `kernel` for weights")
440
-
441
- self.kernel_norm_axes = list(range(self.layer.kernel.shape.ndims))
442
- self.kernel_norm_axes.pop(self.filter_axis)
443
-
444
- self.v = self.layer.kernel
445
-
446
- # to avoid a duplicate `kernel` variable after `build` is called
447
- self.layer.kernel = None
448
- self.g = self.add_weight(
449
- name="g",
450
- shape=(int(self.v.shape[self.filter_axis]),),
451
- initializer="ones",
452
- dtype=self.v.dtype,
453
- trainable=True,
454
- )
455
- self.initialized = self.add_weight(name="initialized", dtype=tf.bool, trainable=False)
456
- self.initialized.assign(False)
457
-
458
- super().build()
459
-
460
- @tf.function
461
- def call(self, inputs):
462
- """Call `Layer`."""
463
- if not self.initialized:
464
- if self.data_init:
465
- self._data_dep_init(inputs)
466
- else: # initialize `g` as the norm of the initialized kernel
467
- self._init_norm()
468
-
469
- self.initialized.assign(True)
470
-
471
- self._compute_weights()
472
- output = self.layer(inputs)
473
- return output
474
-
475
- def compute_output_shape(self, input_shape):
476
- return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list())
477
-
478
-
479
- class Shift(bijectors.Bijector):
480
- def __init__(self, shift, validate_args=False, name="shift") -> None:
481
- """Instantiates the `Shift` bijector which computes `Y = g(X; shift) = X + shift`
482
- where `shift` is a numeric `Tensor`.
483
-
484
- Parameters
485
- ----------
486
- shift
487
- Floating-point `Tensor`.
488
- validate_args
489
- Python `bool` indicating whether arguments should be checked for correctness.
490
- name
491
- Python `str` name given to ops managed by this object.
492
- """
493
- with tf.name_scope(name) as name:
494
- dtype = tfp_internal.dtype_util.common_dtype([shift], dtype_hint=tf.float32)
495
- self._shift = tfp_internal.tensor_util.convert_nonref_to_tensor(shift, dtype=dtype, name="shift")
496
- super().__init__(
497
- forward_min_event_ndims=0,
498
- is_constant_jacobian=True,
499
- dtype=dtype,
500
- validate_args=validate_args,
501
- name=name,
502
- )
503
-
504
- @property
505
- def shift(self):
506
- """The `shift` `Tensor` in `Y = X + shift`."""
507
- return self._shift
508
-
509
- @classmethod
510
- def _is_increasing(cls):
511
- return True
512
-
513
- def _forward(self, x):
514
- return x + self.shift
515
-
516
- def _inverse(self, y):
517
- return y - self.shift
518
-
519
- def _forward_log_det_jacobian(self, x):
520
- # is_constant_jacobian = True for this bijector, hence the
521
- # `log_det_jacobian` need only be specified for a single input, as this will
522
- # be tiled to match `event_ndims`.
523
- return tf.zeros([], dtype=tfp_internal.dtype_util.base_dtype(x.dtype))
524
-
525
-
526
- class PixelCNN(distributions.distribution.Distribution):
527
- """
528
- Construct Pixel CNN++ distributions.distribution.
529
-
530
- Parameters
531
- ----------
532
- image_shape : tuple
533
- 3D `TensorShape` or tuple for the `[height, width, channels]` dimensions of the image.
534
- conditional_shape : tuple, optional - default None
535
- `TensorShape` or tuple for the shape of the conditional input, or `None` if there is no conditional input.
536
- num_resnet : int, default 5
537
- The number of layers (shown in Figure 2 of [2]) within each highest-level block of Figure 2 of [1].
538
- num_hierarchies : int, default 3
539
- The number of highest-level blocks (separated by expansions/contractions of dimensions in Figure 2 of [1].)
540
- num_filters : int, default 160
541
- The number of convolutional filters.
542
- num_logistic_mix : int, default 10
543
- Number of components in the distributions.logistic mixture distributions.distribution.
544
- receptive_field_dims tuple, default (3, 3)
545
- Height and width in pixels of the receptive field of the convolutional layers above and to the left
546
- of a given pixel. The width (second element of the tuple) should be odd. Figure 1 (middle) of [2]
547
- shows a receptive field of (3, 5) (the row containing the current pixel is included in the height).
548
- The default of (3, 3) was used to produce the results in [1].
549
- dropout_p : float, default 0.0
550
- The dropout probability. Should be between 0 and 1.
551
- resnet_activation : str, default "concat_elu"
552
- The type of activation to use in the resnet blocks. May be 'concat_elu', 'elu', or 'relu'.
553
- l2_weight : float, default 0.0
554
- The L2 regularization weight.
555
- use_weight_norm : bool, default True
556
- If `True` then use weight normalization (works only in Eager mode).
557
- use_data_init : bool, default True
558
- If `True` then use data-dependent initialization (has no effect if `use_weight_norm` is `False`).
559
- high : int, default 255
560
- The maximum value of the input data (255 for an 8-bit image).
561
- low : int, default 0
562
- The minimum value of the input data.
563
- dtype : tensorflow dtype, default tf.float32
564
- Data type of the `Distribution`.
565
- """
566
-
567
- def __init__(
568
- self,
569
- image_shape: tuple[int, int, int],
570
- conditional_shape: tuple[int, ...] | None = None,
571
- num_resnet: int = 5,
572
- num_hierarchies: int = 3,
573
- num_filters: int = 160,
574
- num_logistic_mix: int = 10,
575
- receptive_field_dims: tuple[int, int] = (3, 3),
576
- dropout_p: float = 0.5,
577
- resnet_activation: str = "concat_elu",
578
- l2_weight: float = 0.0,
579
- use_weight_norm: bool = True,
580
- use_data_init: bool = True,
581
- high: int = 255,
582
- low: int = 0,
583
- dtype: tf.DType = tf.float32,
584
- ) -> None:
585
- parameters = dict(locals())
586
- with tf.name_scope("PixelCNN") as name:
587
- super().__init__(
588
- dtype=dtype,
589
- reparameterization_type=tfp_internal.reparameterization.NOT_REPARAMETERIZED,
590
- validate_args=False,
591
- allow_nan_stats=True,
592
- parameters=parameters,
593
- name=name,
594
- )
595
-
596
- if not tfp_internal.tensorshape_util.is_fully_defined(image_shape):
597
- raise ValueError("`image_shape` must be fully defined.")
598
-
599
- if conditional_shape is not None and not tfp_internal.tensorshape_util.is_fully_defined(conditional_shape):
600
- raise ValueError("`conditional_shape` must be fully defined.")
601
-
602
- if tfp_internal.tensorshape_util.rank(image_shape) != 3:
603
- raise ValueError("`image_shape` must have length 3, representing [height, width, channels] dimensions.")
604
-
605
- self._high = tf.cast(high, self.dtype)
606
- self._low = tf.cast(low, self.dtype)
607
- self._num_logistic_mix = num_logistic_mix
608
- self._network = PixelCNNNetwork(
609
- dropout_p=dropout_p,
610
- num_resnet=num_resnet,
611
- num_hierarchies=num_hierarchies,
612
- num_filters=num_filters,
613
- num_logistic_mix=num_logistic_mix,
614
- receptive_field_dims=receptive_field_dims,
615
- resnet_activation=resnet_activation,
616
- l2_weight=l2_weight,
617
- use_weight_norm=use_weight_norm,
618
- use_data_init=use_data_init,
619
- dtype=dtype,
620
- )
621
-
622
- image_input_shape = tfp_internal.tensorshape_util.concatenate([None], image_shape)
623
- if conditional_shape is None:
624
- input_shape = image_input_shape
625
- else:
626
- conditional_input_shape = tfp_internal.tensorshape_util.concatenate([None], conditional_shape)
627
- input_shape = [image_input_shape, conditional_input_shape]
628
-
629
- self.image_shape = image_shape
630
- self.conditional_shape = conditional_shape
631
- self._network.build(input_shape)
632
-
633
- def _make_mixture_dist(self, component_logits, locs, scales, return_per_feature: bool = False):
634
- """Builds a mixture of quantized distributions.logistic distributions.
635
-
636
- Parameters
637
- ----------
638
- component_logits
639
- 4D `Tensor` of logits for the Categorical distributions.distribution
640
- over Quantized Logistic mixture components. Dimensions are `[batch_size,
641
- height, width, num_logistic_mix]`.
642
- locs
643
- 4D `Tensor` of location parameters for the Quantized Logistic
644
- mixture components. Dimensions are `[batch_size, height, width,
645
- num_logistic_mix, num_channels]`.
646
- scales
647
- 4D `Tensor` of location parameters for the Quantized Logistic
648
- mixture components. Dimensions are `[batch_size, height, width,
649
- num_logistic_mix, num_channels]`.
650
- return_per_feature
651
- If True, return per pixel level log prob.
652
-
653
- Returns
654
- -------
655
- dist
656
- A quantized distributions.logistic mixture `tfp.distributions.distribution` over the input data.
657
- """
658
- mixture_distribution = distributions.categorical.Categorical(logits=component_logits)
659
-
660
- # Convert distributions.distribution parameters for pixel values in
661
- # `[self._low, self._high]` for use with `QuantizedDistribution`
662
- locs = self._low + 0.5 * (self._high - self._low) * (locs + 1.0)
663
- scales *= 0.5 * (self._high - self._low)
664
- logistic_dist = distributions.quantized_distribution.QuantizedDistribution(
665
- distribution=distributions.transformed_distribution.TransformedDistribution(
666
- distribution=distributions.logistic.Logistic(loc=locs, scale=scales),
667
- bijector=Shift(shift=tf.cast(-0.5, self.dtype)),
668
- ),
669
- low=self._low,
670
- high=self._high,
671
- )
672
-
673
- # mixture with logistics for the loc and scale on each pixel for each component
674
- dist = distributions.mixture_same_family.MixtureSameFamily(
675
- mixture_distribution=mixture_distribution,
676
- components_distribution=distributions.independent.Independent(logistic_dist, reinterpreted_batch_ndims=1),
677
- )
678
- if return_per_feature:
679
- return dist
680
- else:
681
- return distributions.independent.Independent(dist, reinterpreted_batch_ndims=2)
682
-
683
- def _log_prob(self, value, conditional_input=None, training=None, return_per_feature=False):
684
- """Log probability function with optional conditional input.
685
-
686
- Calculates the log probability of a batch of data under the modeled
687
- distributions.distribution (or conditional distributions.distribution, if conditional input is
688
- provided).
689
-
690
- Parameters
691
- ----------
692
- value
693
- `Tensor` or :term:`NumPy` array of image data. May have leading batch
694
- dimension(s), which must broadcast to the leading batch dimensions of
695
- `conditional_input`.
696
- conditional_input
697
- `Tensor` on which to condition the distributions.distribution (e.g.
698
- class labels), or `None`. May have leading batch dimension(s), which
699
- must broadcast to the leading batch dimensions of `value`.
700
- training
701
- `bool` or `None`. If `bool`, it controls the dropout layer,
702
- where `True` implies dropout is active. If `None`, it defaults to
703
- `keras.backend.learning_phase()`.
704
- return_per_feature
705
- `bool`. If True, return per pixel level log prob.
706
-
707
- Returns
708
- -------
709
- log_prob_values: `Tensor`.
710
- """
711
- # Determine the batch shape of the input images
712
- image_batch_shape = tfp_internal.prefer_static.shape(value)[:-3]
713
-
714
- # Broadcast `value` and `conditional_input` to the same batch_shape
715
- if conditional_input is None:
716
- image_batch_and_conditional_shape = image_batch_shape
717
- else:
718
- conditional_input = tf.convert_to_tensor(conditional_input)
719
- conditional_input_shape = tfp_internal.prefer_static.shape(conditional_input)
720
- conditional_batch_rank = tfp_internal.prefer_static.rank(
721
- conditional_input
722
- ) - tfp_internal.tensorshape_util.rank(self.conditional_shape)
723
- conditional_batch_shape = conditional_input_shape[:conditional_batch_rank]
724
-
725
- image_batch_and_conditional_shape = tfp_internal.prefer_static.broadcast_shape(
726
- image_batch_shape, conditional_batch_shape
727
- )
728
- conditional_input = tf.broadcast_to(
729
- conditional_input,
730
- tfp_internal.prefer_static.concat([image_batch_and_conditional_shape, self.conditional_shape], axis=0),
731
- )
732
- value = tf.broadcast_to(
733
- value,
734
- tfp_internal.prefer_static.concat([image_batch_and_conditional_shape, self.event_shape], axis=0),
735
- )
736
-
737
- # Flatten batch dimension for input to Keras model
738
- conditional_input = tf.reshape(
739
- conditional_input,
740
- tfp_internal.prefer_static.concat([(-1,), self.conditional_shape], axis=0),
741
- )
742
-
743
- value = tf.reshape(value, tfp_internal.prefer_static.concat([(-1,), self.event_shape], axis=0))
744
-
745
- transformed_value = (2.0 * (value - self._low) / (self._high - self._low)) - 1.0
746
- inputs = transformed_value if conditional_input is None else [transformed_value, conditional_input]
747
-
748
- params = self._network(inputs, training=training)
749
-
750
- num_channels = self.event_shape[-1]
751
- if num_channels == 1:
752
- component_logits, locs, scales = params
753
- else:
754
- # If there is more than one channel, we create a linear autoregressive
755
- # dependency among the location parameters of the channels of a single
756
- # pixel (the scale parameters within a pixel are distributions.independent). For a pixel
757
- # with R/G/B channels, the `r`, `g`, and `b` saturation values are
758
- # distributed as:
759
- #
760
- # r ~ Logistic(loc_r, scale_r)
761
- # g ~ Logistic(coef_rg * r + loc_g, scale_g)
762
- # b ~ Logistic(coef_rb * r + coef_gb * g + loc_b, scale_b)
763
- # on the coefficients instead of split/multiply/concat
764
- component_logits, locs, scales, coeffs = params
765
- num_coeffs = num_channels * (num_channels - 1) // 2
766
- loc_tensors = tf.split(locs, num_channels, axis=-1)
767
- coef_tensors = tf.split(coeffs, num_coeffs, axis=-1)
768
- channel_tensors = tf.split(value, num_channels, axis=-1)
769
-
770
- coef_count = 0
771
- for i in range(num_channels):
772
- channel_tensors[i] = channel_tensors[i][..., tf.newaxis, :]
773
- for j in range(i):
774
- loc_tensors[i] += channel_tensors[j] * coef_tensors[coef_count]
775
- coef_count += 1
776
- locs = tf.concat(loc_tensors, axis=-1)
777
-
778
- dist = self._make_mixture_dist(component_logits, locs, scales, return_per_feature=return_per_feature)
779
- log_px = dist.log_prob(value)
780
- if return_per_feature:
781
- return log_px
782
- else:
783
- return tf.reshape(log_px, image_batch_and_conditional_shape)
784
-
785
- def _sample_n(self, n, seed=None, conditional_input=None, training=False):
786
- """Samples from the distributions.distribution, with optional conditional input.
787
-
788
- Parameters
789
- ----------
790
- n
791
- `int`, number of samples desired.
792
- seed
793
- `int`, seed for RNG. Setting a random seed enforces reproducibility
794
- of the samples between sessions (not within a single session).
795
- conditional_input
796
- `Tensor` on which to condition the distributions.distribution (e.g.
797
- class labels), or `None`.
798
- training
799
- `bool` or `None`. If `bool`, it controls the dropout layer,
800
- where `True` implies dropout is active. If `None`, it defers to Keras'
801
- handling of train/eval status.
802
-
803
- Returns
804
- -------
805
- samples
806
- a `Tensor` of shape `[n, height, width, num_channels]`.
807
- """
808
- if conditional_input is not None:
809
- conditional_input = tf.convert_to_tensor(conditional_input, dtype=self.dtype)
810
- conditional_event_rank = tfp_internal.tensorshape_util.rank(self.conditional_shape)
811
- conditional_input_shape = tfp_internal.prefer_static.shape(conditional_input)
812
- conditional_sample_rank = tfp_internal.prefer_static.rank(conditional_input) - conditional_event_rank
813
-
814
- # If `conditional_input` has no sample dimensions, prepend a sample
815
- # dimension
816
- if conditional_sample_rank == 0:
817
- conditional_input = conditional_input[tf.newaxis, ...]
818
- conditional_sample_rank = 1
819
-
820
- # Assert that the conditional event shape in the `PixelCnnNetwork` is the
821
- # same as that implied by `conditional_input`.
822
- conditional_event_shape = conditional_input_shape[conditional_sample_rank:]
823
- with tf.control_dependencies([tf.assert_equal(self.conditional_shape, conditional_event_shape)]):
824
- conditional_sample_shape = conditional_input_shape[:conditional_sample_rank]
825
- repeat = n // tfp_internal.prefer_static.reduce_prod(conditional_sample_shape)
826
- h = tf.reshape(
827
- conditional_input,
828
- tfp_internal.prefer_static.concat([(-1,), self.conditional_shape], axis=0),
829
- )
830
- h = tf.tile(
831
- h,
832
- tfp_internal.prefer_static.pad(
833
- [repeat],
834
- paddings=[[0, conditional_event_rank]],
835
- constant_values=1,
836
- ),
837
- )
838
-
839
- samples_0 = tf.random.uniform(
840
- tfp_internal.prefer_static.concat([(n,), self.event_shape], axis=0),
841
- minval=-1.0,
842
- maxval=1.0,
843
- dtype=self.dtype,
844
- seed=seed,
845
- )
846
- inputs = samples_0 if conditional_input is None else [samples_0, h]
847
- params_0 = self._network(inputs, training=training)
848
- samples_0 = self._sample_channels(*params_0, seed=seed)
849
-
850
- image_height, image_width, _ = tfp_internal.tensorshape_util.as_list(self.event_shape)
851
-
852
- def loop_body(index, samples):
853
- """Loop for iterative pixel sampling.
854
-
855
- Parameters
856
- ----------
857
- index
858
- 0D `Tensor` of type `int32`. Index of the current pixel.
859
- samples
860
- 4D `Tensor`. Images with pixels sampled in raster order, up to
861
- pixel `[index]`, with dimensions `[batch_size, height, width,
862
- num_channels]`.
863
-
864
- Returns
865
- -------
866
- samples
867
- 4D `Tensor`. Images with pixels sampled in raster order, up to \
868
- and including pixel `[index]`, with dimensions `[batch_size, height, \
869
- width, num_channels]`.
870
- """
871
- inputs = samples if conditional_input is None else [samples, h]
872
- params = self._network(inputs, training=training)
873
- samples_new = self._sample_channels(*params, seed=seed)
874
-
875
- # Update the current pixel
876
- samples = tf.transpose(samples, [1, 2, 3, 0])
877
- samples_new = tf.transpose(samples_new, [1, 2, 3, 0])
878
- row, col = index // image_width, index % image_width
879
- updates = samples_new[row, col, ...][tf.newaxis, ...]
880
- samples = tf.tensor_scatter_nd_update(samples, [[row, col]], updates)
881
- samples = tf.transpose(samples, [3, 0, 1, 2])
882
-
883
- return index + 1, samples
884
-
885
- index0 = tf.zeros([], dtype=tf.int32)
886
-
887
- # Construct the while loop for sampling
888
- total_pixels = image_height * image_width
889
- loop_cond = lambda ind, _: tf.less(ind, total_pixels) # noqa: E731
890
- init_vars = (index0, samples_0)
891
- _, samples = tf.while_loop(loop_cond, loop_body, init_vars, parallel_iterations=1)
892
-
893
- transformed_samples = self._low + 0.5 * (self._high - self._low) * (samples + 1.0)
894
- return tf.round(transformed_samples)
895
-
896
- def _sample_channels(self, component_logits, locs, scales, coeffs=None, seed=None):
897
- """Sample a single pixel-iteration and apply channel conditioning.
898
-
899
- Parameters
900
- ----------
901
- component_logits
902
- 4D `Tensor` of logits for the Categorical distributions.distribution
903
- over Quantized Logistic mixture components. Dimensions are `[batch_size,
904
- height, width, num_logistic_mix]`.
905
- locs
906
- 4D `Tensor` of location parameters for the Quantized Logistic
907
- mixture components. Dimensions are `[batch_size, height, width,
908
- num_logistic_mix, num_channels]`.
909
- scales
910
- 4D `Tensor` of location parameters for the Quantized Logistic
911
- mixture components. Dimensions are `[batch_size, height, width,
912
- num_logistic_mix, num_channels]`.
913
- coeffs
914
- 4D `Tensor` of coefficients for the linear dependence among color
915
- channels, or `None` if there is only one channel. Dimensions are
916
- `[batch_size, height, width, num_logistic_mix, num_coeffs]`, where
917
- `num_coeffs = num_channels * (num_channels - 1) // 2`.
918
- seed
919
- `int`, random seed.
920
-
921
- Returns
922
- -------
923
- samples
924
- 4D `Tensor` of sampled image data with autoregression among \
925
- channels. Dimensions are `[batch_size, height, width, num_channels]`.
926
- """
927
- num_channels = self.event_shape[-1]
928
-
929
- # sample mixture components once for the entire pixel
930
- component_dist = distributions.categorical.Categorical(logits=component_logits)
931
- mask = tf.one_hot(indices=component_dist.sample(seed=seed), depth=self._num_logistic_mix)
932
- mask = tf.cast(mask[..., tf.newaxis], self.dtype)
933
-
934
- # apply mixture component mask and separate out RGB parameters
935
- masked_locs = tf.reduce_sum(locs * mask, axis=-2)
936
- loc_tensors = tf.split(masked_locs, num_channels, axis=-1)
937
- masked_scales = tf.reduce_sum(scales * mask, axis=-2)
938
- scale_tensors = tf.split(masked_scales, num_channels, axis=-1)
939
-
940
- if coeffs is not None:
941
- num_coeffs = num_channels * (num_channels - 1) // 2
942
- masked_coeffs = tf.reduce_sum(coeffs * mask, axis=-2)
943
- coef_tensors = tf.split(masked_coeffs, num_coeffs, axis=-1)
944
-
945
- channel_samples = []
946
- coef_count = 0
947
- for i in range(num_channels):
948
- loc = loc_tensors[i]
949
- for c in channel_samples:
950
- loc += c * coef_tensors[coef_count]
951
- coef_count += 1
952
-
953
- logistic_samp = distributions.logistic.Logistic(loc=loc, scale=scale_tensors[i]).sample(seed=seed)
954
- logistic_samp = tf.clip_by_value(logistic_samp, -1.0, 1.0)
955
- channel_samples.append(logistic_samp)
956
-
957
- return tf.concat(channel_samples, axis=-1)
958
-
959
- def _batch_shape(self):
960
- return tf.TensorShape([])
961
-
962
- def _event_shape(self):
963
- return tf.TensorShape(self.image_shape)
964
-
965
-
966
- class PixelCNNNetwork(keras.layers.Layer):
967
- """Keras `Layer` to parameterize a Pixel CNN++ distributions.distribution.
968
- This is a Keras implementation of the Pixel CNN++ network, as described in
969
- Salimans et al. (2017)[1] and van den Oord et al. (2016)[2].
970
- (https://github.com/openai/pixel-cnn).
971
- #### References
972
- [1]: Tim Salimans, Andrej Karpathy, Xi Chen, and Diederik P. Kingma.
973
- PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture
974
- Likelihood and Other Modifications. In _International Conference on
975
- Learning Representations_, 2017.
976
- https://pdfs.semanticscholar.org/9e90/6792f67cbdda7b7777b69284a81044857656.pdf
977
- Additional details at https://github.com/openai/pixel-cnn
978
- [2]: Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt,
979
- Alex Graves, and Koray Kavukcuoglu. Conditional Image Generation with
980
- PixelCNN Decoders. In _30th Conference on Neural Information Processing
981
- Systems_, 2016.
982
- https://papers.nips.cc/paper/6527-conditional-image-generation-with-pixelcnn-decoders.pdf.
983
- """
984
-
985
- def __init__(
986
- self,
987
- dropout_p: float = 0.5,
988
- num_resnet: int = 5,
989
- num_hierarchies: int = 3,
990
- num_filters: int = 160,
991
- num_logistic_mix: int = 10,
992
- receptive_field_dims: tuple[int, int] = (3, 3),
993
- resnet_activation: str = "concat_elu",
994
- l2_weight: float = 0.0,
995
- use_weight_norm: bool = True,
996
- use_data_init: bool = True,
997
- dtype: tf.DType = tf.float32,
998
- ) -> None:
999
- """Initialize the :term:`neural network<Neural Network>` for the Pixel CNN++ distributions.distribution.
1000
-
1001
- Parameters
1002
- ----------
1003
- dropout_p
1004
- `float`, the dropout probability. Should be between 0 and 1.
1005
- num_resnet
1006
- `int`, the number of layers (shown in Figure 2 of [2]) within
1007
- each highest-level block of Figure 2 of [1].
1008
- num_hierarchies
1009
- `int`, the number of hightest-level blocks (separated by
1010
- expansions/contractions of dimensions in Figure 2 of [1].)
1011
- num_filters
1012
- `int`, the number of convolutional filters.
1013
- num_logistic_mix
1014
- `int`, number of components in the distributions.logistic mixture
1015
- distributions.distribution.
1016
- receptive_field_dims
1017
- `tuple`, height and width in pixels of the receptive
1018
- field of the convolutional layers above and to the left of a given
1019
- pixel. The width (second element of the tuple) should be odd. Figure 1
1020
- (middle) of [2] shows a receptive field of (3, 5) (the row containing
1021
- the current pixel is included in the height). The default of (3, 3) was
1022
- used to produce the results in [1].
1023
- resnet_activation
1024
- `string`, the type of activation to use in the resnet
1025
- blocks. May be 'concat_elu', 'elu', or 'relu'.
1026
- l2_weight
1027
- `float`, the L2 regularization weight.
1028
- use_weight_norm
1029
- `bool`, if `True` then use weight normalization.
1030
- use_data_init
1031
- `bool`, if `True` then use data-dependent initialization
1032
- (has no effect if `use_weight_norm` is `False`).
1033
- dtype
1034
- Data type of the layer.
1035
- """
1036
- super().__init__(dtype=dtype)
1037
- self._dropout_p = dropout_p
1038
- self._num_resnet = num_resnet
1039
- self._num_hierarchies = num_hierarchies
1040
- self._num_filters = num_filters
1041
- self._num_logistic_mix = num_logistic_mix
1042
- self._receptive_field_dims = receptive_field_dims # first set desired receptive field, then infer kernel
1043
- self._resnet_activation = resnet_activation
1044
- self._l2_weight = l2_weight
1045
-
1046
- if use_weight_norm:
1047
-
1048
- def layer_wrapper(layer):
1049
- def wrapped_layer(*args, **kwargs):
1050
- return WeightNorm(layer(*args, **kwargs), data_init=use_data_init)
1051
-
1052
- return wrapped_layer
1053
-
1054
- self._layer_wrapper = layer_wrapper
1055
- else:
1056
- self._layer_wrapper = lambda layer: layer
1057
-
1058
- def build(self, input_shape: tuple[int, ...]) -> None:
1059
- dtype = self.dtype
1060
- if len(input_shape) == 2:
1061
- batch_image_shape, batch_conditional_shape = input_shape
1062
- conditional_input = keras.layers.Input(shape=batch_conditional_shape[1:], dtype=dtype)
1063
- else:
1064
- batch_image_shape = input_shape
1065
- conditional_input = None
1066
-
1067
- image_shape = batch_image_shape[1:]
1068
- image_input = keras.layers.Input(shape=image_shape, dtype=dtype)
1069
-
1070
- if self._resnet_activation == "concat_elu":
1071
- activation = keras.layers.Lambda(lambda x: tf.nn.elu(tf.concat([x, -x], axis=-1)), dtype=dtype)
1072
- else:
1073
- activation = keras.activations.get(self._resnet_activation)
1074
-
1075
- # Define layers with default inputs and layer wrapper applied
1076
- Conv2D = functools.partial( # pylint:disable=invalid-name
1077
- self._layer_wrapper(keras.layers.Convolution2D),
1078
- filters=self._num_filters,
1079
- padding="same",
1080
- kernel_regularizer=keras.regularizers.l2(self._l2_weight),
1081
- dtype=dtype,
1082
- )
1083
-
1084
- Dense = functools.partial( # pylint:disable=invalid-name
1085
- self._layer_wrapper(keras.layers.Dense),
1086
- kernel_regularizer=keras.regularizers.l2(self._l2_weight),
1087
- dtype=dtype,
1088
- )
1089
-
1090
- Conv2DTranspose = functools.partial( # pylint:disable=invalid-name
1091
- self._layer_wrapper(keras.layers.Conv2DTranspose),
1092
- filters=self._num_filters,
1093
- padding="same",
1094
- strides=(2, 2),
1095
- kernel_regularizer=keras.regularizers.l2(self._l2_weight),
1096
- dtype=dtype,
1097
- )
1098
-
1099
- rows, cols = self._receptive_field_dims
1100
-
1101
- # Define the dimensions of the valid (unmasked) areas of the layer kernels
1102
- # for stride 1 convolutions in the internal layers.
1103
- kernel_valid_dims = {
1104
- "vertical": (rows - 1, cols), # vertical stack
1105
- "horizontal": (2, cols // 2 + 1),
1106
- } # horizontal stack
1107
-
1108
- # Define the size of the kernel necessary to center the current pixel
1109
- # correctly for stride 1 convolutions in the internal layers.
1110
- kernel_sizes = {"vertical": (2 * rows - 3, cols), "horizontal": (3, cols)}
1111
-
1112
- # Make the kernel constraint functions for stride 1 convolutions in internal
1113
- # layers.
1114
- kernel_constraints = {
1115
- k: _make_kernel_constraint(kernel_sizes[k], (0, v[0]), (0, v[1])) for k, v in kernel_valid_dims.items()
1116
- }
1117
-
1118
- # Build the initial vertical stack/horizontal stack convolutional layers,
1119
- # as shown in Figure 1 of [2]. The receptive field of the initial vertical
1120
- # stack layer is a rectangular area centered above the current pixel.
1121
- vertical_stack_init = Conv2D(
1122
- kernel_size=(2 * rows - 1, cols),
1123
- kernel_constraint=_make_kernel_constraint((2 * rows - 1, cols), (0, rows - 1), (0, cols)),
1124
- )(image_input)
1125
-
1126
- # In Figure 1 [2], the receptive field of the horizontal stack is
1127
- # illustrated as the pixels in the same row and to the left of the current
1128
- # pixel. [1] increases the height of this receptive field from one pixel to
1129
- # two (`horizontal_stack_left`) and additionally includes a subset of the
1130
- # row of pixels centered above the current pixel (`horizontal_stack_up`).
1131
- horizontal_stack_up = Conv2D(
1132
- kernel_size=(3, cols),
1133
- kernel_constraint=_make_kernel_constraint((3, cols), (0, 1), (0, cols)),
1134
- )(image_input)
1135
-
1136
- horizontal_stack_left = Conv2D(
1137
- kernel_size=(3, cols),
1138
- kernel_constraint=_make_kernel_constraint((3, cols), (0, 2), (0, cols // 2)),
1139
- )(image_input)
1140
-
1141
- horizontal_stack_init = keras.layers.add([horizontal_stack_up, horizontal_stack_left], dtype=dtype)
1142
-
1143
- layer_stacks = {
1144
- "vertical": [vertical_stack_init],
1145
- "horizontal": [horizontal_stack_init],
1146
- }
1147
-
1148
- # Build the downward pass of the U-net (left-hand half of Figure 2 of [1]).
1149
- # Each `i` iteration builds one of the highest-level blocks (identified as
1150
- # 'Sequence of 6 layers' in the figure, consisting of `num_resnet=5` stride-
1151
- # 1 layers, and one stride-2 layer that contracts the height/width
1152
- # dimensions). The `_` iterations build the stride 1 layers. The layers of
1153
- # the downward pass are stored in lists, since we'll later need them to make
1154
- # skip-connections to layers in the upward pass of the U-net (the skip-
1155
- # connections are represented by curved lines in Figure 2 [1]).
1156
- for i in range(self._num_hierarchies):
1157
- for _ in range(self._num_resnet):
1158
- # Build a layer shown in Figure 2 of [2]. The 'vertical' iteration
1159
- # builds the layers in the left half of the figure, and the 'horizontal'
1160
- # iteration builds the layers in the right half.
1161
- for stack in ["vertical", "horizontal"]:
1162
- input_x = layer_stacks[stack][-1]
1163
- x = activation(input_x)
1164
- x = Conv2D(
1165
- kernel_size=kernel_sizes[stack],
1166
- kernel_constraint=kernel_constraints[stack],
1167
- )(x)
1168
-
1169
- # Add the vertical-stack layer to the horizontal-stack layer
1170
- if stack == "horizontal":
1171
- h = activation(layer_stacks["vertical"][-1])
1172
- h = Dense(self._num_filters)(h)
1173
- x = keras.layers.add([h, x], dtype=dtype)
1174
-
1175
- x = activation(x)
1176
- x = keras.layers.Dropout(self._dropout_p, dtype=dtype)(x)
1177
- x = Conv2D(
1178
- filters=2 * self._num_filters,
1179
- kernel_size=kernel_sizes[stack],
1180
- kernel_constraint=kernel_constraints[stack],
1181
- )(x)
1182
-
1183
- if conditional_input is not None:
1184
- h_projection = _build_and_apply_h_projection(conditional_input, self._num_filters, dtype=dtype)
1185
- x = keras.layers.add([x, h_projection], dtype=dtype)
1186
-
1187
- x = _apply_sigmoid_gating(x)
1188
-
1189
- # Add a residual connection from the layer's input.
1190
- out = keras.layers.add([input_x, x], dtype=dtype)
1191
- layer_stacks[stack].append(out)
1192
-
1193
- if i < self._num_hierarchies - 1:
1194
- # Build convolutional layers that contract the height/width dimensions
1195
- # on the downward pass between each set of layers (e.g. contracting from
1196
- # 32x32 to 16x16 in Figure 2 of [1]).
1197
- for stack in ["vertical", "horizontal"]:
1198
- # Define kernel dimensions/masking to maintain the autoregressive property.
1199
- x = layer_stacks[stack][-1]
1200
- h, w = kernel_valid_dims[stack]
1201
- kernel_height = 2 * h
1202
- kernel_width = w + 1 if stack == "vertical" else 2 * w
1203
- kernel_size = (kernel_height, kernel_width)
1204
- kernel_constraint = _make_kernel_constraint(kernel_size, (0, h), (0, w))
1205
- x = Conv2D(
1206
- strides=(2, 2),
1207
- kernel_size=kernel_size,
1208
- kernel_constraint=kernel_constraint,
1209
- )(x)
1210
- layer_stacks[stack].append(x)
1211
-
1212
- # Upward pass of the U-net (right-hand half of Figure 2 of [1]). We stored
1213
- # the layers of the downward pass in a list, in order to access them to make
1214
- # skip-connections to the upward pass. For the upward pass, we need to keep
1215
- # track of only the current layer, so we maintain a reference to the
1216
- # current layer of the horizontal/vertical stack in the `upward_pass` dict.
1217
- # The upward pass begins with the last layer of the downward pass.
1218
- upward_pass = {key: stack.pop() for key, stack in layer_stacks.items()}
1219
-
1220
- # As with the downward pass, each `i` iteration builds a highest level block
1221
- # in Figure 2 [1], and the `_` iterations build individual layers within the
1222
- # block.
1223
- for i in range(self._num_hierarchies):
1224
- num_resnet = self._num_resnet if i == 0 else self._num_resnet + 1
1225
-
1226
- for _ in range(num_resnet):
1227
- # Build a layer as shown in Figure 2 of [2], with a skip-connection
1228
- # from the symmetric layer in the downward pass.
1229
- for stack in ["vertical", "horizontal"]:
1230
- input_x = upward_pass[stack]
1231
- x_symmetric = layer_stacks[stack].pop()
1232
-
1233
- x = activation(input_x)
1234
- x = Conv2D(
1235
- kernel_size=kernel_sizes[stack],
1236
- kernel_constraint=kernel_constraints[stack],
1237
- )(x)
1238
-
1239
- # Include the vertical-stack layer of the upward pass in the layers
1240
- # to be added to the horizontal layer.
1241
- if stack == "horizontal":
1242
- x_symmetric = keras.layers.Concatenate(axis=-1, dtype=dtype)(
1243
- [upward_pass["vertical"], x_symmetric]
1244
- )
1245
-
1246
- # Add a skip-connection from the symmetric layer in the downward
1247
- # pass to the layer `x` in the upward pass.
1248
- h = activation(x_symmetric)
1249
- h = Dense(self._num_filters)(h)
1250
- x = keras.layers.add([h, x], dtype=dtype)
1251
-
1252
- x = activation(x)
1253
- x = keras.layers.Dropout(self._dropout_p, dtype=dtype)(x)
1254
- x = Conv2D(
1255
- filters=2 * self._num_filters,
1256
- kernel_size=kernel_sizes[stack],
1257
- kernel_constraint=kernel_constraints[stack],
1258
- )(x)
1259
-
1260
- if conditional_input is not None:
1261
- h_projection = _build_and_apply_h_projection(conditional_input, self._num_filters, dtype=dtype)
1262
- x = keras.layers.add([x, h_projection], dtype=dtype)
1263
-
1264
- x = _apply_sigmoid_gating(x)
1265
- upward_pass[stack] = keras.layers.add([input_x, x], dtype=dtype)
1266
-
1267
- # Define deconvolutional layers that expand height/width dimensions on the
1268
- # upward pass (e.g. expanding from 8x8 to 16x16 in Figure 2 of [1]), with
1269
- # the correct kernel dimensions/masking to maintain the autoregressive
1270
- # property.
1271
- if i < self._num_hierarchies - 1:
1272
- for stack in ["vertical", "horizontal"]:
1273
- h, w = kernel_valid_dims[stack]
1274
- kernel_height = 2 * h - 2
1275
- if stack == "vertical":
1276
- kernel_width = w + 1
1277
- kernel_constraint = _make_kernel_constraint(
1278
- (kernel_height, kernel_width),
1279
- (h - 2, kernel_height),
1280
- (0, w),
1281
- )
1282
- else:
1283
- kernel_width = 2 * w - 2
1284
- kernel_constraint = _make_kernel_constraint(
1285
- (kernel_height, kernel_width),
1286
- (h - 2, kernel_height),
1287
- (w - 2, kernel_width),
1288
- )
1289
-
1290
- x = upward_pass[stack]
1291
- x = Conv2DTranspose(
1292
- kernel_size=(kernel_height, kernel_width),
1293
- kernel_constraint=kernel_constraint,
1294
- )(x)
1295
- upward_pass[stack] = x
1296
-
1297
- x_out = keras.layers.ELU(dtype=dtype)(upward_pass["horizontal"])
1298
-
1299
- # Build final Dense/Reshape layers to output the correct number of
1300
- # parameters per pixel.
1301
- num_channels = tfp_internal.tensorshape_util.as_list(image_shape)[-1]
1302
- num_coeffs = num_channels * (num_channels - 1) // 2 # alpha, beta, gamma in eq.3 of paper
1303
- num_out = num_channels * 2 + num_coeffs + 1 # mu, s + alpha, beta, gamma + 1 (mixture weight)
1304
- num_out_total = num_out * self._num_logistic_mix
1305
- params = Dense(num_out_total)(x_out)
1306
- params = tf.reshape(
1307
- params,
1308
- tfp_internal.prefer_static.concat( # [-1,H,W,nb mixtures, params per mixture]
1309
- [[-1], image_shape[:-1], [self._num_logistic_mix, num_out]], axis=0
1310
- ),
1311
- )
1312
-
1313
- # If there is one color channel, split the parameters into a list of three
1314
- # output `Tensor`s: (1) component logits for the Quantized Logistic mixture
1315
- # distributions.distribution, (2) location parameters for each component, and (3) scale
1316
- # parameters for each component. If there is more than one color channel,
1317
- # return a fourth `Tensor` for the coefficients for the linear dependence
1318
- # among color channels (e.g. alpha, beta, gamma).
1319
- # [logits, mu, s, linear dependence]
1320
- splits = 3 if num_channels == 1 else [1, num_channels, num_channels, num_coeffs]
1321
- outputs = tf.split(params, splits, axis=-1)
1322
-
1323
- # Squeeze singleton dimension from component logits
1324
- outputs[0] = tf.squeeze(outputs[0], axis=-1)
1325
-
1326
- # Ensure scales are positive and do not collapse to near-zero
1327
- outputs[2] = tf.nn.softplus(outputs[2]) + tf.cast(tf.exp(-7.0), self.dtype)
1328
-
1329
- inputs = image_input if conditional_input is None else [image_input, conditional_input]
1330
- self._network = keras.Model(inputs=inputs, outputs=outputs)
1331
- super().build(input_shape)
1332
-
1333
- def call(self, inputs: tf.Tensor, training: bool | None = None, mask: tf.Tensor | None = None) -> tf.Tensor:
1334
- """Call the Pixel CNN network model.
1335
-
1336
- Parameters
1337
- ----------
1338
- inputs
1339
- 4D `Tensor` of image data with dimensions [batch size, height,
1340
- width, channels] or a 2-element `list`. If `list`, the first element is
1341
- the 4D image `Tensor` and the second element is a `Tensor` with
1342
- conditional input data (e.g. VAE encodings or class labels) with the
1343
- same leading batch dimension as the image `Tensor`.
1344
- training
1345
- `bool` or `None`. If `bool`, it controls the dropout layer,
1346
- where `True` implies dropout is active. If `None`, it it defaults to
1347
- `keras.backend.learning_phase()`
1348
-
1349
- Returns
1350
- -------
1351
- outputs
1352
- a 3- or 4-element `list` of `Tensor`s in the following order: \
1353
- component_logits: 4D `Tensor` of logits for the Categorical distributions.distribution \
1354
- over Quantized Logistic mixture components. Dimensions are \
1355
- `[batch_size, height, width, num_logistic_mix]`.
1356
- locs
1357
- 4D `Tensor` of location parameters for the Quantized Logistic \
1358
- mixture components. Dimensions are `[batch_size, height, width, \
1359
- num_logistic_mix, num_channels]`.
1360
- scales
1361
- 4D `Tensor` of location parameters for the Quantized Logistic \
1362
- mixture components. Dimensions are `[batch_size, height, width, \
1363
- num_logistic_mix, num_channels]`.
1364
- coeffs
1365
- 4D `Tensor` of coefficients for the linear dependence among \
1366
- color channels, included only if the image has more than one channel. \
1367
- Dimensions are `[batch_size, height, width, num_logistic_mix, \
1368
- num_coeffs]`, where `num_coeffs = num_channels * (num_channels - 1) // 2`.
1369
- """
1370
- return self._network(inputs, training=training)
1371
-
1372
-
1373
- def _make_kernel_constraint(kernel_size, valid_rows, valid_columns):
1374
- """Make the masking function for layer kernels."""
1375
- mask = np.zeros(kernel_size)
1376
- lower, upper = valid_rows
1377
- left, right = valid_columns
1378
- mask[lower:upper, left:right] = 1.0
1379
- mask = mask[:, :, np.newaxis, np.newaxis]
1380
- return lambda x: x * mask
1381
-
1382
-
1383
- def _build_and_apply_h_projection(h, num_filters, dtype):
1384
- """Project the conditional input."""
1385
- h = keras.layers.Flatten(dtype=dtype)(h)
1386
- h_projection = keras.layers.Dense(2 * num_filters, kernel_initializer="random_normal", dtype=dtype)(h)
1387
- return h_projection[..., tf.newaxis, tf.newaxis, :]
1388
-
1389
-
1390
- def _apply_sigmoid_gating(x):
1391
- """Apply the sigmoid gating in Figure 2 of [2]."""
1392
- activation_tensor, gate_tensor = tf.split(x, 2, axis=-1)
1393
- sigmoid_gate = tf.sigmoid(gate_tensor)
1394
- return keras.layers.multiply([sigmoid_gate, activation_tensor], dtype=x.dtype)