dataeval 0.61.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. dataeval/__init__.py +18 -0
  2. dataeval/_internal/detectors/__init__.py +0 -0
  3. dataeval/_internal/detectors/clusterer.py +469 -0
  4. dataeval/_internal/detectors/drift/__init__.py +0 -0
  5. dataeval/_internal/detectors/drift/base.py +265 -0
  6. dataeval/_internal/detectors/drift/cvm.py +97 -0
  7. dataeval/_internal/detectors/drift/ks.py +100 -0
  8. dataeval/_internal/detectors/drift/mmd.py +166 -0
  9. dataeval/_internal/detectors/drift/torch.py +310 -0
  10. dataeval/_internal/detectors/drift/uncertainty.py +149 -0
  11. dataeval/_internal/detectors/duplicates.py +49 -0
  12. dataeval/_internal/detectors/linter.py +78 -0
  13. dataeval/_internal/detectors/ood/__init__.py +0 -0
  14. dataeval/_internal/detectors/ood/ae.py +77 -0
  15. dataeval/_internal/detectors/ood/aegmm.py +69 -0
  16. dataeval/_internal/detectors/ood/base.py +199 -0
  17. dataeval/_internal/detectors/ood/llr.py +284 -0
  18. dataeval/_internal/detectors/ood/vae.py +86 -0
  19. dataeval/_internal/detectors/ood/vaegmm.py +79 -0
  20. dataeval/_internal/flags.py +47 -0
  21. dataeval/_internal/metrics/__init__.py +0 -0
  22. dataeval/_internal/metrics/base.py +92 -0
  23. dataeval/_internal/metrics/ber.py +124 -0
  24. dataeval/_internal/metrics/coverage.py +80 -0
  25. dataeval/_internal/metrics/divergence.py +94 -0
  26. dataeval/_internal/metrics/hash.py +79 -0
  27. dataeval/_internal/metrics/parity.py +180 -0
  28. dataeval/_internal/metrics/stats.py +332 -0
  29. dataeval/_internal/metrics/uap.py +45 -0
  30. dataeval/_internal/metrics/utils.py +158 -0
  31. dataeval/_internal/models/__init__.py +0 -0
  32. dataeval/_internal/models/pytorch/__init__.py +0 -0
  33. dataeval/_internal/models/pytorch/autoencoder.py +202 -0
  34. dataeval/_internal/models/pytorch/blocks.py +46 -0
  35. dataeval/_internal/models/pytorch/utils.py +67 -0
  36. dataeval/_internal/models/tensorflow/__init__.py +0 -0
  37. dataeval/_internal/models/tensorflow/autoencoder.py +317 -0
  38. dataeval/_internal/models/tensorflow/gmm.py +115 -0
  39. dataeval/_internal/models/tensorflow/losses.py +107 -0
  40. dataeval/_internal/models/tensorflow/pixelcnn.py +1106 -0
  41. dataeval/_internal/models/tensorflow/trainer.py +102 -0
  42. dataeval/_internal/models/tensorflow/utils.py +254 -0
  43. dataeval/_internal/workflows/sufficiency.py +555 -0
  44. dataeval/detectors/__init__.py +29 -0
  45. dataeval/flags/__init__.py +3 -0
  46. dataeval/metrics/__init__.py +7 -0
  47. dataeval/models/__init__.py +15 -0
  48. dataeval/models/tensorflow/__init__.py +6 -0
  49. dataeval/models/torch/__init__.py +8 -0
  50. dataeval/py.typed +0 -0
  51. dataeval/workflows/__init__.py +8 -0
  52. dataeval-0.61.0.dist-info/LICENSE.txt +21 -0
  53. dataeval-0.61.0.dist-info/METADATA +114 -0
  54. dataeval-0.61.0.dist-info/RECORD +55 -0
  55. dataeval-0.61.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,1106 @@
1
+ # type: ignore
2
+
3
+ """
4
+ Source code derived from Alibi-Detect 0.11.4
5
+ https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
6
+
7
+ Original code Copyright (c) 2023 Seldon Technologies Ltd
8
+ Licensed under Apache Software License (Apache 2.0)
9
+ """
10
+
11
+ import functools
12
+ import warnings
13
+ from typing import Optional
14
+
15
+ import keras
16
+ import numpy as np
17
+ import tensorflow as tf
18
+ from tensorflow_probability.python.bijectors import bijector
19
+ from tensorflow_probability.python.distributions import (
20
+ categorical,
21
+ distribution,
22
+ independent,
23
+ logistic,
24
+ mixture_same_family,
25
+ quantized_distribution,
26
+ transformed_distribution,
27
+ )
28
+ from tensorflow_probability.python.internal import (
29
+ dtype_util,
30
+ prefer_static,
31
+ reparameterization,
32
+ tensor_util,
33
+ tensorshape_util,
34
+ )
35
+
36
+ __all__ = [
37
+ "Shift",
38
+ ]
39
+
40
+
41
+ class WeightNorm(keras.layers.Wrapper):
42
+ def __init__(self, layer, data_init: bool = True, **kwargs):
43
+ """Layer wrapper to decouple magnitude and direction of the layer's weights.
44
+
45
+ This wrapper reparameterizes a layer by decoupling the weight's
46
+ magnitude and direction. This speeds up convergence by improving the
47
+ conditioning of the optimization problem. It has an optional data-dependent
48
+ initialization scheme, in which initial values of weights are set as functions
49
+ of the first minibatch of data. Both the weight normalization and data-
50
+ dependent initialization are described in [Salimans and Kingma (2016)][1].
51
+
52
+ Parameters
53
+ ----------
54
+ layer
55
+ A `keras.layers.Layer` instance. Supported layer types are
56
+ `Dense`, `Conv2D`, and `Conv2DTranspose`. Layers with multiple inputs
57
+ are not supported.
58
+ data_init
59
+ If `True` use data dependent variable initialization.
60
+ **kwargs
61
+ Additional keyword args passed to `keras.layers.Wrapper`.
62
+
63
+ Raises
64
+ ------
65
+ ValueError
66
+ If `layer` is not a `keras.layers.Layer` instance.
67
+ """
68
+ if not isinstance(layer, keras.layers.Layer):
69
+ raise ValueError(
70
+ "Please initialize `WeightNorm` layer with a `keras.layers.Layer` " f"instance. You passed: {layer}"
71
+ )
72
+
73
+ layer_type = type(layer).__name__
74
+ if layer_type not in ["Dense", "Conv2D", "Conv2DTranspose"]:
75
+ warnings.warn(
76
+ "`WeightNorm` is tested only for `Dense`, `Conv2D`, and "
77
+ f"`Conv2DTranspose` layers. You passed a layer of type `{layer_type}`"
78
+ )
79
+
80
+ super().__init__(layer, **kwargs)
81
+
82
+ self.data_init = data_init
83
+ self._track_trackable(layer, name="layer")
84
+ self.filter_axis = -2 if layer_type == "Conv2DTranspose" else -1
85
+
86
+ def _compute_weights(self):
87
+ """Generate weights with normalization."""
88
+ # Determine the axis along which to expand `g` so that `g` broadcasts to
89
+ # the shape of `v`.
90
+ new_axis = -self.filter_axis - 3
91
+
92
+ self.layer.kernel = tf.nn.l2_normalize(self.v, axis=self.kernel_norm_axes) * tf.expand_dims(self.g, new_axis)
93
+
94
+ def _init_norm(self):
95
+ """Set the norm of the weight vector."""
96
+ kernel_norm = tf.sqrt(tf.reduce_sum(tf.square(self.v), axis=self.kernel_norm_axes))
97
+ self.g.assign(kernel_norm)
98
+
99
+ def _data_dep_init(self, inputs):
100
+ """Data dependent initialization."""
101
+ # Normalize kernel first so that calling the layer calculates
102
+ # `tf.dot(v, x)/tf.norm(v)` as in (5) in ([Salimans and Kingma, 2016][1]).
103
+ self._compute_weights()
104
+
105
+ activation = self.layer.activation
106
+ self.layer.activation = None
107
+
108
+ use_bias = self.layer.bias is not None
109
+ if use_bias:
110
+ bias = self.layer.bias
111
+ self.layer.bias = tf.zeros_like(bias)
112
+
113
+ # Since the bias is initialized as zero, setting the activation to zero and
114
+ # calling the initialized layer (with normalized kernel) yields the correct
115
+ # computation ((5) in Salimans and Kingma (2016))
116
+ x_init = self.layer(inputs)
117
+ norm_axes_out = list(range(x_init.shape.rank - 1))
118
+ m_init, v_init = tf.nn.moments(x_init, norm_axes_out)
119
+ scale_init = 1.0 / tf.sqrt(v_init + 1e-10)
120
+
121
+ self.g.assign(self.g * scale_init)
122
+ if use_bias:
123
+ self.layer.bias = bias
124
+ self.layer.bias.assign(-m_init * scale_init)
125
+ self.layer.activation = activation
126
+
127
+ def build(self, input_shape=None):
128
+ """Build `Layer`.
129
+
130
+ Parameters
131
+ ----------
132
+ input_shape
133
+ The shape of the input to `self.layer`.
134
+
135
+ Raises
136
+ ------
137
+ ValueError
138
+ If `Layer` does not contain a `kernel` of weights.
139
+ """
140
+ input_shape = tf.TensorShape(input_shape).as_list()
141
+ input_shape[0] = None
142
+ self.input_spec = keras.layers.InputSpec(shape=input_shape)
143
+
144
+ if not self.layer.built:
145
+ self.layer.build(input_shape)
146
+
147
+ if not hasattr(self.layer, "kernel"):
148
+ raise ValueError("`WeightNorm` must wrap a layer that contains a `kernel` for weights")
149
+
150
+ self.kernel_norm_axes = list(range(self.layer.kernel.shape.ndims))
151
+ self.kernel_norm_axes.pop(self.filter_axis)
152
+
153
+ self.v = self.layer.kernel
154
+
155
+ # to avoid a duplicate `kernel` variable after `build` is called
156
+ self.layer.kernel = None
157
+ self.g = self.add_weight(
158
+ name="g",
159
+ shape=(int(self.v.shape[self.filter_axis]),),
160
+ initializer="ones",
161
+ dtype=self.v.dtype,
162
+ trainable=True,
163
+ )
164
+ self.initialized = self.add_weight(name="initialized", dtype=tf.bool, trainable=False)
165
+ self.initialized.assign(False)
166
+
167
+ super().build()
168
+
169
+ @tf.function
170
+ def call(self, inputs):
171
+ """Call `Layer`."""
172
+ if not self.initialized:
173
+ if self.data_init:
174
+ self._data_dep_init(inputs)
175
+ else: # initialize `g` as the norm of the initialized kernel
176
+ self._init_norm()
177
+
178
+ self.initialized.assign(True)
179
+
180
+ self._compute_weights()
181
+ output = self.layer(inputs)
182
+ return output
183
+
184
+ def compute_output_shape(self, input_shape):
185
+ return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list())
186
+
187
+
188
+ class Shift(bijector.Bijector):
189
+ def __init__(self, shift, validate_args=False, name="shift"):
190
+ """Instantiates the `Shift` bijector which computes `Y = g(X; shift) = X + shift`
191
+ where `shift` is a numeric `Tensor`.
192
+
193
+ Parameters
194
+ ----------
195
+ shift
196
+ Floating-point `Tensor`.
197
+ validate_args
198
+ Python `bool` indicating whether arguments should be checked for correctness.
199
+ name
200
+ Python `str` name given to ops managed by this object.
201
+ """
202
+ with tf.name_scope(name) as name:
203
+ dtype = dtype_util.common_dtype([shift], dtype_hint=tf.float32)
204
+ self._shift = tensor_util.convert_nonref_to_tensor(shift, dtype=dtype, name="shift")
205
+ super().__init__(
206
+ forward_min_event_ndims=0,
207
+ is_constant_jacobian=True,
208
+ dtype=dtype,
209
+ validate_args=validate_args,
210
+ name=name,
211
+ )
212
+
213
+ @property
214
+ def shift(self):
215
+ """The `shift` `Tensor` in `Y = X + shift`."""
216
+ return self._shift
217
+
218
+ @classmethod
219
+ def _is_increasing(cls):
220
+ return True
221
+
222
+ def _forward(self, x):
223
+ return x + self.shift
224
+
225
+ def _inverse(self, y):
226
+ return y - self.shift
227
+
228
+ def _forward_log_det_jacobian(self, x):
229
+ # is_constant_jacobian = True for this bijector, hence the
230
+ # `log_det_jacobian` need only be specified for a single input, as this will
231
+ # be tiled to match `event_ndims`.
232
+ return tf.zeros([], dtype=dtype_util.base_dtype(x.dtype))
233
+
234
+
235
+ class PixelCNN(distribution.Distribution):
236
+ """
237
+ Construct Pixel CNN++ distribution.
238
+
239
+ Parameters
240
+ ----------
241
+ image_shape
242
+ 3D `TensorShape` or tuple for the `[height, width, channels]` dimensions of the image.
243
+ conditional_shape
244
+ `TensorShape` or tuple for the shape of the conditional input, or `None` if there is no conditional input.
245
+ num_resnet
246
+ The number of layers (shown in Figure 2 of [2]) within each highest-level block of Figure 2 of [1].
247
+ num_hierarchies
248
+ The number of highest-level blocks (separated by expansions/contractions of dimensions in Figure 2 of [1].)
249
+ num_filters
250
+ The number of convolutional filters.
251
+ num_logistic_mix
252
+ Number of components in the logistic mixture distribution.
253
+ receptive_field_dims
254
+ Height and width in pixels of the receptive field of the convolutional layers above and to the left
255
+ of a given pixel. The width (second element of the tuple) should be odd. Figure 1 (middle) of [2]
256
+ shows a receptive field of (3, 5) (the row containing the current pixel is included in the height).
257
+ The default of (3, 3) was used to produce the results in [1].
258
+ dropout_p
259
+ The dropout probability. Should be between 0 and 1.
260
+ resnet_activation
261
+ The type of activation to use in the resnet blocks. May be 'concat_elu', 'elu', or 'relu'.
262
+ l2_weight
263
+ The L2 regularization weight.
264
+ use_weight_norm
265
+ If `True` then use weight normalization (works only in Eager mode).
266
+ use_data_init
267
+ If `True` then use data-dependent initialization (has no effect if `use_weight_norm` is `False`).
268
+ high
269
+ The maximum value of the input data (255 for an 8-bit image).
270
+ low
271
+ The minimum value of the input data.
272
+ dtype
273
+ Data type of the `Distribution`.
274
+ name
275
+ The name of the `Distribution`.
276
+ """
277
+
278
+ def __init__(
279
+ self,
280
+ image_shape: tuple,
281
+ conditional_shape: Optional[tuple] = None,
282
+ num_resnet: int = 5,
283
+ num_hierarchies: int = 3,
284
+ num_filters: int = 160,
285
+ num_logistic_mix: int = 10,
286
+ receptive_field_dims: tuple = (3, 3),
287
+ dropout_p: float = 0.5,
288
+ resnet_activation: str = "concat_elu",
289
+ l2_weight: float = 0.0,
290
+ use_weight_norm: bool = True,
291
+ use_data_init: bool = True,
292
+ high: int = 255,
293
+ low: int = 0,
294
+ dtype=tf.float32,
295
+ name: str = "PixelCNN",
296
+ ) -> None:
297
+ parameters = dict(locals())
298
+ with tf.name_scope(name) as name:
299
+ super().__init__(
300
+ dtype=dtype,
301
+ reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
302
+ validate_args=False,
303
+ allow_nan_stats=True,
304
+ parameters=parameters,
305
+ name=name,
306
+ )
307
+
308
+ if not tensorshape_util.is_fully_defined(image_shape):
309
+ raise ValueError("`image_shape` must be fully defined.")
310
+
311
+ if conditional_shape is not None and not tensorshape_util.is_fully_defined(conditional_shape):
312
+ raise ValueError("`conditional_shape` must be fully defined.")
313
+
314
+ if tensorshape_util.rank(image_shape) != 3:
315
+ raise ValueError("`image_shape` must have length 3, representing [height, width, channels] dimensions.")
316
+
317
+ self._high = tf.cast(high, self.dtype)
318
+ self._low = tf.cast(low, self.dtype)
319
+ self._num_logistic_mix = num_logistic_mix
320
+ self.network = _PixelCNNNetwork(
321
+ dropout_p=dropout_p,
322
+ num_resnet=num_resnet,
323
+ num_hierarchies=num_hierarchies,
324
+ num_filters=num_filters,
325
+ num_logistic_mix=num_logistic_mix,
326
+ receptive_field_dims=receptive_field_dims,
327
+ resnet_activation=resnet_activation,
328
+ l2_weight=l2_weight,
329
+ use_weight_norm=use_weight_norm,
330
+ use_data_init=use_data_init,
331
+ dtype=dtype,
332
+ )
333
+
334
+ image_input_shape = tensorshape_util.concatenate([None], image_shape)
335
+ if conditional_shape is None:
336
+ input_shape = image_input_shape
337
+ else:
338
+ conditional_input_shape = tensorshape_util.concatenate([None], conditional_shape)
339
+ input_shape = [image_input_shape, conditional_input_shape]
340
+
341
+ self.image_shape = image_shape
342
+ self.conditional_shape = conditional_shape
343
+ self.network.build(input_shape)
344
+
345
+ def _make_mixture_dist(self, component_logits, locs, scales, return_per_feature: bool = False):
346
+ """Builds a mixture of quantized logistic distributions.
347
+
348
+ Parameters
349
+ ----------
350
+ component_logits
351
+ 4D `Tensor` of logits for the Categorical distribution
352
+ over Quantized Logistic mixture components. Dimensions are `[batch_size,
353
+ height, width, num_logistic_mix]`.
354
+ locs
355
+ 4D `Tensor` of location parameters for the Quantized Logistic
356
+ mixture components. Dimensions are `[batch_size, height, width,
357
+ num_logistic_mix, num_channels]`.
358
+ scales
359
+ 4D `Tensor` of location parameters for the Quantized Logistic
360
+ mixture components. Dimensions are `[batch_size, height, width,
361
+ num_logistic_mix, num_channels]`.
362
+ return_per_feature
363
+ If True, return per pixel level log prob.
364
+
365
+ Returns
366
+ -------
367
+ dist
368
+ A quantized logistic mixture `tfp.distribution` over the input data.
369
+ """
370
+ mixture_distribution = categorical.Categorical(logits=component_logits)
371
+
372
+ # Convert distribution parameters for pixel values in
373
+ # `[self._low, self._high]` for use with `QuantizedDistribution`
374
+ locs = self._low + 0.5 * (self._high - self._low) * (locs + 1.0)
375
+ scales *= 0.5 * (self._high - self._low)
376
+ logistic_dist = quantized_distribution.QuantizedDistribution(
377
+ distribution=transformed_distribution.TransformedDistribution(
378
+ distribution=logistic.Logistic(loc=locs, scale=scales),
379
+ bijector=Shift(shift=tf.cast(-0.5, self.dtype)),
380
+ ),
381
+ low=self._low,
382
+ high=self._high,
383
+ )
384
+
385
+ # mixture with logistics for the loc and scale on each pixel for each component
386
+ dist = mixture_same_family.MixtureSameFamily(
387
+ mixture_distribution=mixture_distribution,
388
+ components_distribution=independent.Independent(logistic_dist, reinterpreted_batch_ndims=1),
389
+ )
390
+ if return_per_feature:
391
+ return dist
392
+ else:
393
+ return independent.Independent(dist, reinterpreted_batch_ndims=2)
394
+
395
+ def _log_prob(self, value, conditional_input=None, training=None, return_per_feature=False):
396
+ """Log probability function with optional conditional input.
397
+
398
+ Calculates the log probability of a batch of data under the modeled
399
+ distribution (or conditional distribution, if conditional input is
400
+ provided).
401
+
402
+ Parameters
403
+ ----------
404
+ value
405
+ `Tensor` or Numpy array of image data. May have leading batch
406
+ dimension(s), which must broadcast to the leading batch dimensions of
407
+ `conditional_input`.
408
+ conditional_input
409
+ `Tensor` on which to condition the distribution (e.g.
410
+ class labels), or `None`. May have leading batch dimension(s), which
411
+ must broadcast to the leading batch dimensions of `value`.
412
+ training
413
+ `bool` or `None`. If `bool`, it controls the dropout layer,
414
+ where `True` implies dropout is active. If `None`, it defaults to
415
+ `keras.backend.learning_phase()`.
416
+ return_per_feature
417
+ `bool`. If True, return per pixel level log prob.
418
+
419
+ Returns
420
+ -------
421
+ log_prob_values: `Tensor`.
422
+ """
423
+ # Determine the batch shape of the input images
424
+ image_batch_shape = prefer_static.shape(value)[:-3]
425
+
426
+ # Broadcast `value` and `conditional_input` to the same batch_shape
427
+ if conditional_input is None:
428
+ image_batch_and_conditional_shape = image_batch_shape
429
+ else:
430
+ conditional_input = tf.convert_to_tensor(conditional_input)
431
+ conditional_input_shape = prefer_static.shape(conditional_input)
432
+ conditional_batch_rank = prefer_static.rank(conditional_input) - tensorshape_util.rank(
433
+ self.conditional_shape
434
+ )
435
+ conditional_batch_shape = conditional_input_shape[:conditional_batch_rank]
436
+
437
+ image_batch_and_conditional_shape = prefer_static.broadcast_shape(
438
+ image_batch_shape, conditional_batch_shape
439
+ )
440
+ conditional_input = tf.broadcast_to(
441
+ conditional_input,
442
+ prefer_static.concat([image_batch_and_conditional_shape, self.conditional_shape], axis=0),
443
+ )
444
+ value = tf.broadcast_to(
445
+ value,
446
+ prefer_static.concat([image_batch_and_conditional_shape, self.event_shape], axis=0),
447
+ )
448
+
449
+ # Flatten batch dimension for input to Keras model
450
+ conditional_input = tf.reshape(
451
+ conditional_input,
452
+ prefer_static.concat([(-1,), self.conditional_shape], axis=0),
453
+ )
454
+
455
+ value = tf.reshape(value, prefer_static.concat([(-1,), self.event_shape], axis=0))
456
+
457
+ transformed_value = (2.0 * (value - self._low) / (self._high - self._low)) - 1.0
458
+ inputs = transformed_value if conditional_input is None else [transformed_value, conditional_input]
459
+
460
+ params = self.network(inputs, training=training)
461
+
462
+ num_channels = self.event_shape[-1]
463
+ if num_channels == 1:
464
+ component_logits, locs, scales = params
465
+ else:
466
+ # If there is more than one channel, we create a linear autoregressive
467
+ # dependency among the location parameters of the channels of a single
468
+ # pixel (the scale parameters within a pixel are independent). For a pixel
469
+ # with R/G/B channels, the `r`, `g`, and `b` saturation values are
470
+ # distributed as:
471
+ #
472
+ # r ~ Logistic(loc_r, scale_r)
473
+ # g ~ Logistic(coef_rg * r + loc_g, scale_g)
474
+ # b ~ Logistic(coef_rb * r + coef_gb * g + loc_b, scale_b)
475
+ # on the coefficients instead of split/multiply/concat
476
+ component_logits, locs, scales, coeffs = params
477
+ num_coeffs = num_channels * (num_channels - 1) // 2
478
+ loc_tensors = tf.split(locs, num_channels, axis=-1)
479
+ coef_tensors = tf.split(coeffs, num_coeffs, axis=-1)
480
+ channel_tensors = tf.split(value, num_channels, axis=-1)
481
+
482
+ coef_count = 0
483
+ for i in range(num_channels):
484
+ channel_tensors[i] = channel_tensors[i][..., tf.newaxis, :]
485
+ for j in range(i):
486
+ loc_tensors[i] += channel_tensors[j] * coef_tensors[coef_count]
487
+ coef_count += 1
488
+ locs = tf.concat(loc_tensors, axis=-1)
489
+
490
+ dist = self._make_mixture_dist(component_logits, locs, scales, return_per_feature=return_per_feature)
491
+ log_px = dist.log_prob(value)
492
+ if return_per_feature:
493
+ return log_px
494
+ else:
495
+ return tf.reshape(log_px, image_batch_and_conditional_shape)
496
+
497
+ def _sample_n(self, n, seed=None, conditional_input=None, training=False):
498
+ """Samples from the distribution, with optional conditional input.
499
+
500
+ Parameters
501
+ ----------
502
+ n
503
+ `int`, number of samples desired.
504
+ seed
505
+ `int`, seed for RNG. Setting a random seed enforces reproducibility
506
+ of the samples between sessions (not within a single session).
507
+ conditional_input
508
+ `Tensor` on which to condition the distribution (e.g.
509
+ class labels), or `None`.
510
+ training
511
+ `bool` or `None`. If `bool`, it controls the dropout layer,
512
+ where `True` implies dropout is active. If `None`, it defers to Keras'
513
+ handling of train/eval status.
514
+
515
+ Returns
516
+ -------
517
+ samples
518
+ a `Tensor` of shape `[n, height, width, num_channels]`.
519
+ """
520
+ if conditional_input is not None:
521
+ conditional_input = tf.convert_to_tensor(conditional_input, dtype=self.dtype)
522
+ conditional_event_rank = tensorshape_util.rank(self.conditional_shape)
523
+ conditional_input_shape = prefer_static.shape(conditional_input)
524
+ conditional_sample_rank = prefer_static.rank(conditional_input) - conditional_event_rank
525
+
526
+ # If `conditional_input` has no sample dimensions, prepend a sample
527
+ # dimension
528
+ if conditional_sample_rank == 0:
529
+ conditional_input = conditional_input[tf.newaxis, ...]
530
+ conditional_sample_rank = 1
531
+
532
+ # Assert that the conditional event shape in the `PixelCnnNetwork` is the
533
+ # same as that implied by `conditional_input`.
534
+ conditional_event_shape = conditional_input_shape[conditional_sample_rank:]
535
+ with tf.control_dependencies([tf.assert_equal(self.conditional_shape, conditional_event_shape)]):
536
+ conditional_sample_shape = conditional_input_shape[:conditional_sample_rank]
537
+ repeat = n // prefer_static.reduce_prod(conditional_sample_shape)
538
+ h = tf.reshape(
539
+ conditional_input,
540
+ prefer_static.concat([(-1,), self.conditional_shape], axis=0),
541
+ )
542
+ h = tf.tile(
543
+ h,
544
+ prefer_static.pad(
545
+ [repeat],
546
+ paddings=[[0, conditional_event_rank]],
547
+ constant_values=1,
548
+ ),
549
+ )
550
+
551
+ samples_0 = tf.random.uniform(
552
+ prefer_static.concat([(n,), self.event_shape], axis=0),
553
+ minval=-1.0,
554
+ maxval=1.0,
555
+ dtype=self.dtype,
556
+ seed=seed,
557
+ )
558
+ inputs = samples_0 if conditional_input is None else [samples_0, h]
559
+ params_0 = self.network(inputs, training=training)
560
+ samples_0 = self._sample_channels(*params_0, seed=seed)
561
+
562
+ image_height, image_width, _ = tensorshape_util.as_list(self.event_shape)
563
+
564
+ def loop_body(index, samples):
565
+ """Loop for iterative pixel sampling.
566
+
567
+ Parameters
568
+ ----------
569
+ index
570
+ 0D `Tensor` of type `int32`. Index of the current pixel.
571
+ samples
572
+ 4D `Tensor`. Images with pixels sampled in raster order, up to
573
+ pixel `[index]`, with dimensions `[batch_size, height, width,
574
+ num_channels]`.
575
+
576
+ Returns
577
+ -------
578
+ samples
579
+ 4D `Tensor`. Images with pixels sampled in raster order, up to \
580
+ and including pixel `[index]`, with dimensions `[batch_size, height, \
581
+ width, num_channels]`.
582
+ """
583
+ inputs = samples if conditional_input is None else [samples, h]
584
+ params = self.network(inputs, training=training)
585
+ samples_new = self._sample_channels(*params, seed=seed)
586
+
587
+ # Update the current pixel
588
+ samples = tf.transpose(samples, [1, 2, 3, 0])
589
+ samples_new = tf.transpose(samples_new, [1, 2, 3, 0])
590
+ row, col = index // image_width, index % image_width
591
+ updates = samples_new[row, col, ...][tf.newaxis, ...]
592
+ samples = tf.tensor_scatter_nd_update(samples, [[row, col]], updates)
593
+ samples = tf.transpose(samples, [3, 0, 1, 2])
594
+
595
+ return index + 1, samples
596
+
597
+ index0 = tf.zeros([], dtype=tf.int32)
598
+
599
+ # Construct the while loop for sampling
600
+ total_pixels = image_height * image_width
601
+ loop_cond = lambda ind, _: tf.less(ind, total_pixels) # noqa: E731
602
+ init_vars = (index0, samples_0)
603
+ _, samples = tf.while_loop(loop_cond, loop_body, init_vars, parallel_iterations=1)
604
+
605
+ transformed_samples = self._low + 0.5 * (self._high - self._low) * (samples + 1.0)
606
+ return tf.round(transformed_samples)
607
+
608
+ def _sample_channels(self, component_logits, locs, scales, coeffs=None, seed=None):
609
+ """Sample a single pixel-iteration and apply channel conditioning.
610
+
611
+ Parameters
612
+ ----------
613
+ component_logits
614
+ 4D `Tensor` of logits for the Categorical distribution
615
+ over Quantized Logistic mixture components. Dimensions are `[batch_size,
616
+ height, width, num_logistic_mix]`.
617
+ locs
618
+ 4D `Tensor` of location parameters for the Quantized Logistic
619
+ mixture components. Dimensions are `[batch_size, height, width,
620
+ num_logistic_mix, num_channels]`.
621
+ scales
622
+ 4D `Tensor` of location parameters for the Quantized Logistic
623
+ mixture components. Dimensions are `[batch_size, height, width,
624
+ num_logistic_mix, num_channels]`.
625
+ coeffs
626
+ 4D `Tensor` of coefficients for the linear dependence among color
627
+ channels, or `None` if there is only one channel. Dimensions are
628
+ `[batch_size, height, width, num_logistic_mix, num_coeffs]`, where
629
+ `num_coeffs = num_channels * (num_channels - 1) // 2`.
630
+ seed
631
+ `int`, random seed.
632
+
633
+ Returns
634
+ -------
635
+ samples
636
+ 4D `Tensor` of sampled image data with autoregression among \
637
+ channels. Dimensions are `[batch_size, height, width, num_channels]`.
638
+ """
639
+ num_channels = self.event_shape[-1]
640
+
641
+ # sample mixture components once for the entire pixel
642
+ component_dist = categorical.Categorical(logits=component_logits)
643
+ mask = tf.one_hot(indices=component_dist.sample(seed=seed), depth=self._num_logistic_mix)
644
+ mask = tf.cast(mask[..., tf.newaxis], self.dtype)
645
+
646
+ # apply mixture component mask and separate out RGB parameters
647
+ masked_locs = tf.reduce_sum(locs * mask, axis=-2)
648
+ loc_tensors = tf.split(masked_locs, num_channels, axis=-1)
649
+ masked_scales = tf.reduce_sum(scales * mask, axis=-2)
650
+ scale_tensors = tf.split(masked_scales, num_channels, axis=-1)
651
+
652
+ if coeffs is not None:
653
+ num_coeffs = num_channels * (num_channels - 1) // 2
654
+ masked_coeffs = tf.reduce_sum(coeffs * mask, axis=-2)
655
+ coef_tensors = tf.split(masked_coeffs, num_coeffs, axis=-1)
656
+
657
+ channel_samples = []
658
+ coef_count = 0
659
+ for i in range(num_channels):
660
+ loc = loc_tensors[i]
661
+ for c in channel_samples:
662
+ loc += c * coef_tensors[coef_count]
663
+ coef_count += 1
664
+
665
+ logistic_samp = logistic.Logistic(loc=loc, scale=scale_tensors[i]).sample(seed=seed)
666
+ logistic_samp = tf.clip_by_value(logistic_samp, -1.0, 1.0)
667
+ channel_samples.append(logistic_samp)
668
+
669
+ return tf.concat(channel_samples, axis=-1)
670
+
671
+ def _batch_shape(self):
672
+ return tf.TensorShape([])
673
+
674
+ def _event_shape(self):
675
+ return tf.TensorShape(self.image_shape)
676
+
677
+
678
+ class _PixelCNNNetwork(keras.layers.Layer):
679
+ """Keras `Layer` to parameterize a Pixel CNN++ distribution.
680
+ This is a Keras implementation of the Pixel CNN++ network, as described in
681
+ Salimans et al. (2017)[1] and van den Oord et al. (2016)[2].
682
+ (https://github.com/openai/pixel-cnn).
683
+ #### References
684
+ [1]: Tim Salimans, Andrej Karpathy, Xi Chen, and Diederik P. Kingma.
685
+ PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture
686
+ Likelihood and Other Modifications. In _International Conference on
687
+ Learning Representations_, 2017.
688
+ https://pdfs.semanticscholar.org/9e90/6792f67cbdda7b7777b69284a81044857656.pdf
689
+ Additional details at https://github.com/openai/pixel-cnn
690
+ [2]: Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt,
691
+ Alex Graves, and Koray Kavukcuoglu. Conditional Image Generation with
692
+ PixelCNN Decoders. In _30th Conference on Neural Information Processing
693
+ Systems_, 2016.
694
+ https://papers.nips.cc/paper/6527-conditional-image-generation-with-pixelcnn-decoders.pdf.
695
+ """
696
+
697
+ def __init__(
698
+ self,
699
+ dropout_p: float = 0.5,
700
+ num_resnet: int = 5,
701
+ num_hierarchies: int = 3,
702
+ num_filters: int = 160,
703
+ num_logistic_mix: int = 10,
704
+ receptive_field_dims: tuple = (3, 3),
705
+ resnet_activation: str = "concat_elu",
706
+ l2_weight: float = 0.0,
707
+ use_weight_norm: bool = True,
708
+ use_data_init: bool = True,
709
+ dtype=tf.float32,
710
+ ) -> None:
711
+ """Initialize the neural network for the Pixel CNN++ distribution.
712
+
713
+ Parameters
714
+ ----------
715
+ dropout_p
716
+ `float`, the dropout probability. Should be between 0 and 1.
717
+ num_resnet
718
+ `int`, the number of layers (shown in Figure 2 of [2]) within
719
+ each highest-level block of Figure 2 of [1].
720
+ num_hierarchies
721
+ `int`, the number of hightest-level blocks (separated by
722
+ expansions/contractions of dimensions in Figure 2 of [1].)
723
+ num_filters
724
+ `int`, the number of convolutional filters.
725
+ num_logistic_mix
726
+ `int`, number of components in the logistic mixture
727
+ distribution.
728
+ receptive_field_dims
729
+ `tuple`, height and width in pixels of the receptive
730
+ field of the convolutional layers above and to the left of a given
731
+ pixel. The width (second element of the tuple) should be odd. Figure 1
732
+ (middle) of [2] shows a receptive field of (3, 5) (the row containing
733
+ the current pixel is included in the height). The default of (3, 3) was
734
+ used to produce the results in [1].
735
+ resnet_activation
736
+ `string`, the type of activation to use in the resnet
737
+ blocks. May be 'concat_elu', 'elu', or 'relu'.
738
+ l2_weight
739
+ `float`, the L2 regularization weight.
740
+ use_weight_norm
741
+ `bool`, if `True` then use weight normalization.
742
+ use_data_init
743
+ `bool`, if `True` then use data-dependent initialization
744
+ (has no effect if `use_weight_norm` is `False`).
745
+ dtype
746
+ Data type of the layer.
747
+ """
748
+ super().__init__(dtype=dtype)
749
+ self._dropout_p = dropout_p
750
+ self._num_resnet = num_resnet
751
+ self._num_hierarchies = num_hierarchies
752
+ self._num_filters = num_filters
753
+ self._num_logistic_mix = num_logistic_mix
754
+ self._receptive_field_dims = receptive_field_dims # first set desired receptive field, then infer kernel
755
+ self._resnet_activation = resnet_activation
756
+ self._l2_weight = l2_weight
757
+
758
+ if use_weight_norm:
759
+
760
+ def layer_wrapper(layer):
761
+ def wrapped_layer(*args, **kwargs):
762
+ return WeightNorm(layer(*args, **kwargs), data_init=use_data_init)
763
+
764
+ return wrapped_layer
765
+
766
+ self._layer_wrapper = layer_wrapper
767
+ else:
768
+ self._layer_wrapper = lambda layer: layer
769
+
770
+ def build(self, input_shape):
771
+ dtype = self.dtype
772
+ if len(input_shape) == 2:
773
+ batch_image_shape, batch_conditional_shape = input_shape
774
+ conditional_input = keras.layers.Input(shape=batch_conditional_shape[1:], dtype=dtype)
775
+ else:
776
+ batch_image_shape = input_shape
777
+ conditional_input = None
778
+
779
+ image_shape = batch_image_shape[1:]
780
+ image_input = keras.layers.Input(shape=image_shape, dtype=dtype)
781
+
782
+ if self._resnet_activation == "concat_elu":
783
+ activation = keras.layers.Lambda(lambda x: tf.nn.elu(tf.concat([x, -x], axis=-1)), dtype=dtype)
784
+ else:
785
+ activation = keras.activations.get(self._resnet_activation)
786
+
787
+ # Define layers with default inputs and layer wrapper applied
788
+ Conv2D = functools.partial( # pylint:disable=invalid-name
789
+ self._layer_wrapper(keras.layers.Convolution2D),
790
+ filters=self._num_filters,
791
+ padding="same",
792
+ kernel_regularizer=keras.regularizers.l2(self._l2_weight),
793
+ dtype=dtype,
794
+ )
795
+
796
+ Dense = functools.partial( # pylint:disable=invalid-name
797
+ self._layer_wrapper(keras.layers.Dense),
798
+ kernel_regularizer=keras.regularizers.l2(self._l2_weight),
799
+ dtype=dtype,
800
+ )
801
+
802
+ Conv2DTranspose = functools.partial( # pylint:disable=invalid-name
803
+ self._layer_wrapper(keras.layers.Conv2DTranspose),
804
+ filters=self._num_filters,
805
+ padding="same",
806
+ strides=(2, 2),
807
+ kernel_regularizer=keras.regularizers.l2(self._l2_weight),
808
+ dtype=dtype,
809
+ )
810
+
811
+ rows, cols = self._receptive_field_dims
812
+
813
+ # Define the dimensions of the valid (unmasked) areas of the layer kernels
814
+ # for stride 1 convolutions in the internal layers.
815
+ kernel_valid_dims = {
816
+ "vertical": (rows - 1, cols), # vertical stack
817
+ "horizontal": (2, cols // 2 + 1),
818
+ } # horizontal stack
819
+
820
+ # Define the size of the kernel necessary to center the current pixel
821
+ # correctly for stride 1 convolutions in the internal layers.
822
+ kernel_sizes = {"vertical": (2 * rows - 3, cols), "horizontal": (3, cols)}
823
+
824
+ # Make the kernel constraint functions for stride 1 convolutions in internal
825
+ # layers.
826
+ kernel_constraints = {
827
+ k: _make_kernel_constraint(kernel_sizes[k], (0, v[0]), (0, v[1])) for k, v in kernel_valid_dims.items()
828
+ }
829
+
830
+ # Build the initial vertical stack/horizontal stack convolutional layers,
831
+ # as shown in Figure 1 of [2]. The receptive field of the initial vertical
832
+ # stack layer is a rectangular area centered above the current pixel.
833
+ vertical_stack_init = Conv2D(
834
+ kernel_size=(2 * rows - 1, cols),
835
+ kernel_constraint=_make_kernel_constraint((2 * rows - 1, cols), (0, rows - 1), (0, cols)),
836
+ )(image_input)
837
+
838
+ # In Figure 1 [2], the receptive field of the horizontal stack is
839
+ # illustrated as the pixels in the same row and to the left of the current
840
+ # pixel. [1] increases the height of this receptive field from one pixel to
841
+ # two (`horizontal_stack_left`) and additionally includes a subset of the
842
+ # row of pixels centered above the current pixel (`horizontal_stack_up`).
843
+ horizontal_stack_up = Conv2D(
844
+ kernel_size=(3, cols),
845
+ kernel_constraint=_make_kernel_constraint((3, cols), (0, 1), (0, cols)),
846
+ )(image_input)
847
+
848
+ horizontal_stack_left = Conv2D(
849
+ kernel_size=(3, cols),
850
+ kernel_constraint=_make_kernel_constraint((3, cols), (0, 2), (0, cols // 2)),
851
+ )(image_input)
852
+
853
+ horizontal_stack_init = keras.layers.add([horizontal_stack_up, horizontal_stack_left], dtype=dtype)
854
+
855
+ layer_stacks = {
856
+ "vertical": [vertical_stack_init],
857
+ "horizontal": [horizontal_stack_init],
858
+ }
859
+
860
+ # Build the downward pass of the U-net (left-hand half of Figure 2 of [1]).
861
+ # Each `i` iteration builds one of the highest-level blocks (identified as
862
+ # 'Sequence of 6 layers' in the figure, consisting of `num_resnet=5` stride-
863
+ # 1 layers, and one stride-2 layer that contracts the height/width
864
+ # dimensions). The `_` iterations build the stride 1 layers. The layers of
865
+ # the downward pass are stored in lists, since we'll later need them to make
866
+ # skip-connections to layers in the upward pass of the U-net (the skip-
867
+ # connections are represented by curved lines in Figure 2 [1]).
868
+ for i in range(self._num_hierarchies):
869
+ for _ in range(self._num_resnet):
870
+ # Build a layer shown in Figure 2 of [2]. The 'vertical' iteration
871
+ # builds the layers in the left half of the figure, and the 'horizontal'
872
+ # iteration builds the layers in the right half.
873
+ for stack in ["vertical", "horizontal"]:
874
+ input_x = layer_stacks[stack][-1]
875
+ x = activation(input_x)
876
+ x = Conv2D(
877
+ kernel_size=kernel_sizes[stack],
878
+ kernel_constraint=kernel_constraints[stack],
879
+ )(x)
880
+
881
+ # Add the vertical-stack layer to the horizontal-stack layer
882
+ if stack == "horizontal":
883
+ h = activation(layer_stacks["vertical"][-1])
884
+ h = Dense(self._num_filters)(h)
885
+ x = keras.layers.add([h, x], dtype=dtype)
886
+
887
+ x = activation(x)
888
+ x = keras.layers.Dropout(self._dropout_p, dtype=dtype)(x)
889
+ x = Conv2D(
890
+ filters=2 * self._num_filters,
891
+ kernel_size=kernel_sizes[stack],
892
+ kernel_constraint=kernel_constraints[stack],
893
+ )(x)
894
+
895
+ if conditional_input is not None:
896
+ h_projection = _build_and_apply_h_projection(conditional_input, self._num_filters, dtype=dtype)
897
+ x = keras.layers.add([x, h_projection], dtype=dtype)
898
+
899
+ x = _apply_sigmoid_gating(x)
900
+
901
+ # Add a residual connection from the layer's input.
902
+ out = keras.layers.add([input_x, x], dtype=dtype)
903
+ layer_stacks[stack].append(out)
904
+
905
+ if i < self._num_hierarchies - 1:
906
+ # Build convolutional layers that contract the height/width dimensions
907
+ # on the downward pass between each set of layers (e.g. contracting from
908
+ # 32x32 to 16x16 in Figure 2 of [1]).
909
+ for stack in ["vertical", "horizontal"]:
910
+ # Define kernel dimensions/masking to maintain the autoregressive property.
911
+ x = layer_stacks[stack][-1]
912
+ h, w = kernel_valid_dims[stack]
913
+ kernel_height = 2 * h
914
+ kernel_width = w + 1 if stack == "vertical" else 2 * w
915
+ kernel_size = (kernel_height, kernel_width)
916
+ kernel_constraint = _make_kernel_constraint(kernel_size, (0, h), (0, w))
917
+ x = Conv2D(
918
+ strides=(2, 2),
919
+ kernel_size=kernel_size,
920
+ kernel_constraint=kernel_constraint,
921
+ )(x)
922
+ layer_stacks[stack].append(x)
923
+
924
+ # Upward pass of the U-net (right-hand half of Figure 2 of [1]). We stored
925
+ # the layers of the downward pass in a list, in order to access them to make
926
+ # skip-connections to the upward pass. For the upward pass, we need to keep
927
+ # track of only the current layer, so we maintain a reference to the
928
+ # current layer of the horizontal/vertical stack in the `upward_pass` dict.
929
+ # The upward pass begins with the last layer of the downward pass.
930
+ upward_pass = {key: stack.pop() for key, stack in layer_stacks.items()}
931
+
932
+ # As with the downward pass, each `i` iteration builds a highest level block
933
+ # in Figure 2 [1], and the `_` iterations build individual layers within the
934
+ # block.
935
+ for i in range(self._num_hierarchies):
936
+ num_resnet = self._num_resnet if i == 0 else self._num_resnet + 1
937
+
938
+ for _ in range(num_resnet):
939
+ # Build a layer as shown in Figure 2 of [2], with a skip-connection
940
+ # from the symmetric layer in the downward pass.
941
+ for stack in ["vertical", "horizontal"]:
942
+ input_x = upward_pass[stack]
943
+ x_symmetric = layer_stacks[stack].pop()
944
+
945
+ x = activation(input_x)
946
+ x = Conv2D(
947
+ kernel_size=kernel_sizes[stack],
948
+ kernel_constraint=kernel_constraints[stack],
949
+ )(x)
950
+
951
+ # Include the vertical-stack layer of the upward pass in the layers
952
+ # to be added to the horizontal layer.
953
+ if stack == "horizontal":
954
+ x_symmetric = keras.layers.Concatenate(axis=-1, dtype=dtype)(
955
+ [upward_pass["vertical"], x_symmetric]
956
+ )
957
+
958
+ # Add a skip-connection from the symmetric layer in the downward
959
+ # pass to the layer `x` in the upward pass.
960
+ h = activation(x_symmetric)
961
+ h = Dense(self._num_filters)(h)
962
+ x = keras.layers.add([h, x], dtype=dtype)
963
+
964
+ x = activation(x)
965
+ x = keras.layers.Dropout(self._dropout_p, dtype=dtype)(x)
966
+ x = Conv2D(
967
+ filters=2 * self._num_filters,
968
+ kernel_size=kernel_sizes[stack],
969
+ kernel_constraint=kernel_constraints[stack],
970
+ )(x)
971
+
972
+ if conditional_input is not None:
973
+ h_projection = _build_and_apply_h_projection(conditional_input, self._num_filters, dtype=dtype)
974
+ x = keras.layers.add([x, h_projection], dtype=dtype)
975
+
976
+ x = _apply_sigmoid_gating(x)
977
+ upward_pass[stack] = keras.layers.add([input_x, x], dtype=dtype)
978
+
979
+ # Define deconvolutional layers that expand height/width dimensions on the
980
+ # upward pass (e.g. expanding from 8x8 to 16x16 in Figure 2 of [1]), with
981
+ # the correct kernel dimensions/masking to maintain the autoregressive
982
+ # property.
983
+ if i < self._num_hierarchies - 1:
984
+ for stack in ["vertical", "horizontal"]:
985
+ h, w = kernel_valid_dims[stack]
986
+ kernel_height = 2 * h - 2
987
+ if stack == "vertical":
988
+ kernel_width = w + 1
989
+ kernel_constraint = _make_kernel_constraint(
990
+ (kernel_height, kernel_width),
991
+ (h - 2, kernel_height),
992
+ (0, w),
993
+ )
994
+ else:
995
+ kernel_width = 2 * w - 2
996
+ kernel_constraint = _make_kernel_constraint(
997
+ (kernel_height, kernel_width),
998
+ (h - 2, kernel_height),
999
+ (w - 2, kernel_width),
1000
+ )
1001
+
1002
+ x = upward_pass[stack]
1003
+ x = Conv2DTranspose(
1004
+ kernel_size=(kernel_height, kernel_width),
1005
+ kernel_constraint=kernel_constraint,
1006
+ )(x)
1007
+ upward_pass[stack] = x
1008
+
1009
+ x_out = keras.layers.ELU(dtype=dtype)(upward_pass["horizontal"])
1010
+
1011
+ # Build final Dense/Reshape layers to output the correct number of
1012
+ # parameters per pixel.
1013
+ num_channels = tensorshape_util.as_list(image_shape)[-1]
1014
+ num_coeffs = num_channels * (num_channels - 1) // 2 # alpha, beta, gamma in eq.3 of paper
1015
+ num_out = num_channels * 2 + num_coeffs + 1 # mu, s + alpha, beta, gamma + 1 (mixture weight)
1016
+ num_out_total = num_out * self._num_logistic_mix
1017
+ params = Dense(num_out_total)(x_out)
1018
+ params = tf.reshape(
1019
+ params,
1020
+ prefer_static.concat( # [-1,H,W,nb mixtures, params per mixture]
1021
+ [[-1], image_shape[:-1], [self._num_logistic_mix, num_out]], axis=0
1022
+ ),
1023
+ )
1024
+
1025
+ # If there is one color channel, split the parameters into a list of three
1026
+ # output `Tensor`s: (1) component logits for the Quantized Logistic mixture
1027
+ # distribution, (2) location parameters for each component, and (3) scale
1028
+ # parameters for each component. If there is more than one color channel,
1029
+ # return a fourth `Tensor` for the coefficients for the linear dependence
1030
+ # among color channels (e.g. alpha, beta, gamma).
1031
+ # [logits, mu, s, linear dependence]
1032
+ splits = 3 if num_channels == 1 else [1, num_channels, num_channels, num_coeffs]
1033
+ outputs = tf.split(params, splits, axis=-1)
1034
+
1035
+ # Squeeze singleton dimension from component logits
1036
+ outputs[0] = tf.squeeze(outputs[0], axis=-1)
1037
+
1038
+ # Ensure scales are positive and do not collapse to near-zero
1039
+ outputs[2] = tf.nn.softplus(outputs[2]) + tf.cast(tf.exp(-7.0), self.dtype)
1040
+
1041
+ inputs = image_input if conditional_input is None else [image_input, conditional_input]
1042
+ self._network = keras.Model(inputs=inputs, outputs=outputs)
1043
+ super().build(input_shape)
1044
+
1045
+ def call(self, inputs, training=None):
1046
+ """Call the Pixel CNN network model.
1047
+
1048
+ Parameters
1049
+ ----------
1050
+ inputs
1051
+ 4D `Tensor` of image data with dimensions [batch size, height,
1052
+ width, channels] or a 2-element `list`. If `list`, the first element is
1053
+ the 4D image `Tensor` and the second element is a `Tensor` with
1054
+ conditional input data (e.g. VAE encodings or class labels) with the
1055
+ same leading batch dimension as the image `Tensor`.
1056
+ training
1057
+ `bool` or `None`. If `bool`, it controls the dropout layer,
1058
+ where `True` implies dropout is active. If `None`, it it defaults to
1059
+ `keras.backend.learning_phase()`
1060
+
1061
+ Returns
1062
+ -------
1063
+ outputs
1064
+ a 3- or 4-element `list` of `Tensor`s in the following order: \
1065
+ component_logits: 4D `Tensor` of logits for the Categorical distribution \
1066
+ over Quantized Logistic mixture components. Dimensions are \
1067
+ `[batch_size, height, width, num_logistic_mix]`.
1068
+ locs
1069
+ 4D `Tensor` of location parameters for the Quantized Logistic \
1070
+ mixture components. Dimensions are `[batch_size, height, width, \
1071
+ num_logistic_mix, num_channels]`.
1072
+ scales
1073
+ 4D `Tensor` of location parameters for the Quantized Logistic \
1074
+ mixture components. Dimensions are `[batch_size, height, width, \
1075
+ num_logistic_mix, num_channels]`.
1076
+ coeffs
1077
+ 4D `Tensor` of coefficients for the linear dependence among \
1078
+ color channels, included only if the image has more than one channel. \
1079
+ Dimensions are `[batch_size, height, width, num_logistic_mix, \
1080
+ num_coeffs]`, where `num_coeffs = num_channels * (num_channels - 1) // 2`.
1081
+ """
1082
+ return self._network(inputs, training=training)
1083
+
1084
+
1085
+ def _make_kernel_constraint(kernel_size, valid_rows, valid_columns):
1086
+ """Make the masking function for layer kernels."""
1087
+ mask = np.zeros(kernel_size)
1088
+ lower, upper = valid_rows
1089
+ left, right = valid_columns
1090
+ mask[lower:upper, left:right] = 1.0
1091
+ mask = mask[:, :, np.newaxis, np.newaxis]
1092
+ return lambda x: x * mask
1093
+
1094
+
1095
+ def _build_and_apply_h_projection(h, num_filters, dtype):
1096
+ """Project the conditional input."""
1097
+ h = keras.layers.Flatten(dtype=dtype)(h)
1098
+ h_projection = keras.layers.Dense(2 * num_filters, kernel_initializer="random_normal", dtype=dtype)(h)
1099
+ return h_projection[..., tf.newaxis, tf.newaxis, :]
1100
+
1101
+
1102
+ def _apply_sigmoid_gating(x):
1103
+ """Apply the sigmoid gating in Figure 2 of [2]."""
1104
+ activation_tensor, gate_tensor = tf.split(x, 2, axis=-1)
1105
+ sigmoid_gate = tf.sigmoid(gate_tensor)
1106
+ return keras.layers.multiply([sigmoid_gate, activation_tensor], dtype=x.dtype)