keras-hub-nightly 0.16.1.dev202410030339__py3-none-any.whl → 0.16.1.dev202410050339__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. keras_hub/api/layers/__init__.py +3 -0
  2. keras_hub/api/models/__init__.py +9 -0
  3. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  4. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +196 -0
  5. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  6. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  7. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  8. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +4 -0
  9. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +109 -0
  10. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  11. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  12. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +57 -93
  13. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
  14. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +5 -3
  15. keras_hub/src/models/task.py +20 -15
  16. keras_hub/src/models/vae/__init__.py +1 -0
  17. keras_hub/src/models/vae/vae_backbone.py +172 -0
  18. keras_hub/src/models/vae/vae_layers.py +740 -0
  19. keras_hub/src/version_utils.py +1 -1
  20. {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410050339.dist-info}/METADATA +1 -1
  21. {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410050339.dist-info}/RECORD +23 -14
  22. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  23. {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410050339.dist-info}/WHEEL +0 -0
  24. {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410050339.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,740 @@
1
+ import math
2
+
3
+ import keras
4
+ from keras import ops
5
+
6
+ from keras_hub.src.utils.keras_utils import standardize_data_format
7
+
8
+
9
+ class Conv2DMultiHeadAttention(keras.layers.Layer):
10
+ """A MultiHeadAttention layer utilizing `Conv2D` and `GroupNormalization`.
11
+
12
+ Args:
13
+ filters: int. The number of the filters for the convolutional layers.
14
+ groups: int. The number of the groups for the group normalization
15
+ layers. Defaults to `32`.
16
+ data_format: `None` or str. If specified, either `"channels_last"` or
17
+ `"channels_first"`. The ordering of the dimensions in the
18
+ inputs. `"channels_last"` corresponds to inputs with shape
19
+ `(batch_size, height, width, channels)`
20
+ while `"channels_first"` corresponds to inputs with shape
21
+ `(batch_size, channels, height, width)`. It defaults to the
22
+ `image_data_format` value found in your Keras config file at
23
+ `~/.keras/keras.json`. If you never set it, then it will be
24
+ `"channels_last"`.
25
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
26
+ including `name`, `dtype` etc.
27
+ """
28
+
29
+ def __init__(self, filters, groups=32, data_format=None, **kwargs):
30
+ super().__init__(**kwargs)
31
+ data_format = standardize_data_format(data_format)
32
+ channel_axis = -1 if data_format == "channels_last" else 1
33
+ self.filters = int(filters)
34
+ self.groups = int(groups)
35
+ self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))
36
+ self.data_format = data_format
37
+
38
+ self.group_norm = keras.layers.GroupNormalization(
39
+ groups=groups,
40
+ axis=channel_axis,
41
+ epsilon=1e-6,
42
+ dtype=self.dtype_policy,
43
+ name="group_norm",
44
+ )
45
+ self.query_conv2d = keras.layers.Conv2D(
46
+ filters,
47
+ 1,
48
+ 1,
49
+ data_format=data_format,
50
+ dtype=self.dtype_policy,
51
+ name="query_conv2d",
52
+ )
53
+ self.key_conv2d = keras.layers.Conv2D(
54
+ filters,
55
+ 1,
56
+ 1,
57
+ data_format=data_format,
58
+ dtype=self.dtype_policy,
59
+ name="key_conv2d",
60
+ )
61
+ self.value_conv2d = keras.layers.Conv2D(
62
+ filters,
63
+ 1,
64
+ 1,
65
+ data_format=data_format,
66
+ dtype=self.dtype_policy,
67
+ name="value_conv2d",
68
+ )
69
+ self.softmax = keras.layers.Softmax(dtype="float32")
70
+ self.output_conv2d = keras.layers.Conv2D(
71
+ filters,
72
+ 1,
73
+ 1,
74
+ data_format=data_format,
75
+ dtype=self.dtype_policy,
76
+ name="output_conv2d",
77
+ )
78
+
79
+ def build(self, input_shape):
80
+ self.group_norm.build(input_shape)
81
+ self.query_conv2d.build(input_shape)
82
+ self.key_conv2d.build(input_shape)
83
+ self.value_conv2d.build(input_shape)
84
+ self.output_conv2d.build(input_shape)
85
+
86
+ def call(self, inputs, training=None):
87
+ x = self.group_norm(inputs, training=training)
88
+ query = self.query_conv2d(x, training=training)
89
+ key = self.key_conv2d(x, training=training)
90
+ value = self.value_conv2d(x, training=training)
91
+
92
+ if self.data_format == "channels_first":
93
+ query = ops.transpose(query, (0, 2, 3, 1))
94
+ key = ops.transpose(key, (0, 2, 3, 1))
95
+ value = ops.transpose(value, (0, 2, 3, 1))
96
+ shape = ops.shape(inputs)
97
+ b = shape[0]
98
+ query = ops.reshape(query, (b, -1, self.filters))
99
+ key = ops.reshape(key, (b, -1, self.filters))
100
+ value = ops.reshape(value, (b, -1, self.filters))
101
+
102
+ # Compute attention.
103
+ query = ops.multiply(
104
+ query, ops.cast(self._inverse_sqrt_filters, query.dtype)
105
+ )
106
+ # [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
107
+ attention_scores = ops.einsum("abc,adc->abd", query, key)
108
+ attention_scores = ops.cast(
109
+ self.softmax(attention_scores), self.compute_dtype
110
+ )
111
+ # [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
112
+ attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
113
+ x = ops.reshape(attention_output, shape)
114
+
115
+ x = self.output_conv2d(x, training=training)
116
+ if self.data_format == "channels_first":
117
+ x = ops.transpose(x, (0, 3, 1, 2))
118
+ x = ops.add(x, inputs)
119
+ return x
120
+
121
+ def get_config(self):
122
+ config = super().get_config()
123
+ config.update(
124
+ {
125
+ "filters": self.filters,
126
+ "groups": self.groups,
127
+ }
128
+ )
129
+ return config
130
+
131
+ def compute_output_shape(self, input_shape):
132
+ return input_shape
133
+
134
+
135
+ class ResNetBlock(keras.layers.Layer):
136
+ """A ResNet block utilizing `GroupNormalization` and SiLU activation.
137
+
138
+ Args:
139
+ filters: The number of filters in the block.
140
+ has_residual_projection: Whether to add a projection layer for the
141
+ residual connection. Defaults to `False`.
142
+ data_format: `None` or str. If specified, either `"channels_last"` or
143
+ `"channels_first"`. The ordering of the dimensions in the
144
+ inputs. `"channels_last"` corresponds to inputs with shape
145
+ `(batch_size, height, width, channels)`
146
+ while `"channels_first"` corresponds to inputs with shape
147
+ `(batch_size, channels, height, width)`. It defaults to the
148
+ `image_data_format` value found in your Keras config file at
149
+ `~/.keras/keras.json`. If you never set it, then it will be
150
+ `"channels_last"`.
151
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
152
+ including `name`, `dtype` etc.
153
+ """
154
+
155
+ def __init__(
156
+ self,
157
+ filters,
158
+ has_residual_projection=False,
159
+ data_format=None,
160
+ **kwargs,
161
+ ):
162
+
163
+ super().__init__(**kwargs)
164
+ data_format = standardize_data_format(data_format)
165
+ channel_axis = -1 if data_format == "channels_last" else 1
166
+ self.filters = int(filters)
167
+ self.has_residual_projection = bool(has_residual_projection)
168
+
169
+ # === Layers ===
170
+ self.norm1 = keras.layers.GroupNormalization(
171
+ groups=32,
172
+ axis=channel_axis,
173
+ epsilon=1e-6,
174
+ dtype=self.dtype_policy,
175
+ name="norm1",
176
+ )
177
+ self.act1 = keras.layers.Activation("silu", dtype=self.dtype_policy)
178
+ self.conv1 = keras.layers.Conv2D(
179
+ filters,
180
+ 3,
181
+ 1,
182
+ padding="same",
183
+ data_format=data_format,
184
+ dtype=self.dtype_policy,
185
+ name="conv1",
186
+ )
187
+ self.norm2 = keras.layers.GroupNormalization(
188
+ groups=32,
189
+ axis=channel_axis,
190
+ epsilon=1e-6,
191
+ dtype=self.dtype_policy,
192
+ name="norm2",
193
+ )
194
+ self.act2 = keras.layers.Activation("silu", dtype=self.dtype_policy)
195
+ self.conv2 = keras.layers.Conv2D(
196
+ filters,
197
+ 3,
198
+ 1,
199
+ padding="same",
200
+ data_format=data_format,
201
+ dtype=self.dtype_policy,
202
+ name="conv2",
203
+ )
204
+ if self.has_residual_projection:
205
+ self.residual_projection = keras.layers.Conv2D(
206
+ filters,
207
+ 1,
208
+ 1,
209
+ data_format=data_format,
210
+ dtype=self.dtype_policy,
211
+ name="residual_projection",
212
+ )
213
+ self.add = keras.layers.Add(dtype=self.dtype_policy)
214
+
215
+ def build(self, input_shape):
216
+ residual_shape = list(input_shape)
217
+ self.norm1.build(input_shape)
218
+ self.act1.build(input_shape)
219
+ self.conv1.build(input_shape)
220
+ input_shape = self.conv1.compute_output_shape(input_shape)
221
+ self.norm2.build(input_shape)
222
+ self.act2.build(input_shape)
223
+ self.conv2.build(input_shape)
224
+ input_shape = self.conv2.compute_output_shape(input_shape)
225
+ if self.has_residual_projection:
226
+ self.residual_projection.build(residual_shape)
227
+ self.add.build([input_shape, input_shape])
228
+
229
+ def call(self, inputs, training=None):
230
+ x = inputs
231
+ residual = x
232
+ x = self.norm1(x, training=training)
233
+ x = self.act1(x, training=training)
234
+ x = self.conv1(x, training=training)
235
+ x = self.norm2(x, training=training)
236
+ x = self.act2(x, training=training)
237
+ x = self.conv2(x, training=training)
238
+ if self.has_residual_projection:
239
+ residual = self.residual_projection(residual, training=training)
240
+ x = self.add([residual, x])
241
+ return x
242
+
243
+ def get_config(self):
244
+ config = super().get_config()
245
+ config.update(
246
+ {
247
+ "filters": self.filters,
248
+ "has_residual_projection": self.has_residual_projection,
249
+ }
250
+ )
251
+ return config
252
+
253
+ def compute_output_shape(self, input_shape):
254
+ outputs_shape = list(input_shape)
255
+ if self.has_residual_projection:
256
+ outputs_shape = self.residual_projection.compute_output_shape(
257
+ outputs_shape
258
+ )
259
+ return outputs_shape
260
+
261
+
262
+ class VAEEncoder(keras.layers.Layer):
263
+ """The encoder layer of VAE.
264
+
265
+ Args:
266
+ stackwise_num_filters: list of ints. The number of filters for each
267
+ stack.
268
+ stackwise_num_blocks: list of ints. The number of blocks for each stack.
269
+ output_channels: int. The number of channels in the output. Defaults to
270
+ `32`.
271
+ data_format: `None` or str. If specified, either `"channels_last"` or
272
+ `"channels_first"`. The ordering of the dimensions in the
273
+ inputs. `"channels_last"` corresponds to inputs with shape
274
+ `(batch_size, height, width, channels)`
275
+ while `"channels_first"` corresponds to inputs with shape
276
+ `(batch_size, channels, height, width)`. It defaults to the
277
+ `image_data_format` value found in your Keras config file at
278
+ `~/.keras/keras.json`. If you never set it, then it will be
279
+ `"channels_last"`.
280
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
281
+ including `name`, `dtype` etc.
282
+ """
283
+
284
+ def __init__(
285
+ self,
286
+ stackwise_num_filters,
287
+ stackwise_num_blocks,
288
+ output_channels=32,
289
+ data_format=None,
290
+ **kwargs,
291
+ ):
292
+ super().__init__(**kwargs)
293
+ data_format = standardize_data_format(data_format)
294
+ channel_axis = -1 if data_format == "channels_last" else 1
295
+ self.stackwise_num_filters = stackwise_num_filters
296
+ self.stackwise_num_blocks = stackwise_num_blocks
297
+ self.output_channels = int(output_channels)
298
+ self.data_format = data_format
299
+
300
+ # === Layers ===
301
+ self.input_projection = keras.layers.Conv2D(
302
+ stackwise_num_filters[0],
303
+ 3,
304
+ 1,
305
+ padding="same",
306
+ data_format=data_format,
307
+ dtype=self.dtype_policy,
308
+ name="input_projection",
309
+ )
310
+
311
+ # Blocks.
312
+ input_filters = stackwise_num_filters[0]
313
+ self.blocks = []
314
+ self.downsamples = []
315
+ for i, filters in enumerate(stackwise_num_filters):
316
+ for j in range(stackwise_num_blocks[i]):
317
+ self.blocks.append(
318
+ ResNetBlock(
319
+ filters,
320
+ has_residual_projection=input_filters != filters,
321
+ data_format=data_format,
322
+ dtype=self.dtype_policy,
323
+ name=f"block_{i}_{j}",
324
+ )
325
+ )
326
+ input_filters = filters
327
+ # No downsample in the last block.
328
+ if i != len(stackwise_num_filters) - 1:
329
+ self.downsamples.append(
330
+ keras.layers.ZeroPadding2D(
331
+ padding=((0, 1), (0, 1)),
332
+ data_format=data_format,
333
+ dtype=self.dtype_policy,
334
+ name=f"downsample_{i}_pad",
335
+ )
336
+ )
337
+ self.downsamples.append(
338
+ keras.layers.Conv2D(
339
+ filters,
340
+ 3,
341
+ 2,
342
+ data_format=data_format,
343
+ dtype=self.dtype_policy,
344
+ name=f"downsample_{i}_conv",
345
+ )
346
+ )
347
+
348
+ # Mid block.
349
+ self.mid_block_0 = ResNetBlock(
350
+ stackwise_num_filters[-1],
351
+ has_residual_projection=False,
352
+ data_format=data_format,
353
+ dtype=self.dtype_policy,
354
+ name="mid_block_0",
355
+ )
356
+ self.mid_attention = Conv2DMultiHeadAttention(
357
+ stackwise_num_filters[-1],
358
+ data_format=data_format,
359
+ dtype=self.dtype_policy,
360
+ name="mid_attention",
361
+ )
362
+ self.mid_block_1 = ResNetBlock(
363
+ stackwise_num_filters[-1],
364
+ has_residual_projection=False,
365
+ data_format=data_format,
366
+ dtype=self.dtype_policy,
367
+ name="mid_block_1",
368
+ )
369
+
370
+ # Output layers.
371
+ self.output_norm = keras.layers.GroupNormalization(
372
+ groups=32,
373
+ axis=channel_axis,
374
+ epsilon=1e-6,
375
+ dtype=self.dtype_policy,
376
+ name="output_norm",
377
+ )
378
+ self.output_act = keras.layers.Activation(
379
+ "swish", dtype=self.dtype_policy
380
+ )
381
+ self.output_projection = keras.layers.Conv2D(
382
+ output_channels,
383
+ 3,
384
+ 1,
385
+ padding="same",
386
+ data_format=data_format,
387
+ dtype=self.dtype_policy,
388
+ name="output_projection",
389
+ )
390
+
391
+ def build(self, input_shape):
392
+ self.input_projection.build(input_shape)
393
+ input_shape = self.input_projection.compute_output_shape(input_shape)
394
+ blocks_idx = 0
395
+ downsamples_idx = 0
396
+ for i, _ in enumerate(self.stackwise_num_filters):
397
+ for _ in range(self.stackwise_num_blocks[i]):
398
+ self.blocks[blocks_idx].build(input_shape)
399
+ input_shape = self.blocks[blocks_idx].compute_output_shape(
400
+ input_shape
401
+ )
402
+ blocks_idx += 1
403
+ if i != len(self.stackwise_num_filters) - 1:
404
+ self.downsamples[downsamples_idx].build(input_shape)
405
+ input_shape = self.downsamples[
406
+ downsamples_idx
407
+ ].compute_output_shape(input_shape)
408
+ downsamples_idx += 1
409
+ self.downsamples[downsamples_idx].build(input_shape)
410
+ input_shape = self.downsamples[
411
+ downsamples_idx
412
+ ].compute_output_shape(input_shape)
413
+ downsamples_idx += 1
414
+ self.mid_block_0.build(input_shape)
415
+ input_shape = self.mid_block_0.compute_output_shape(input_shape)
416
+ self.mid_attention.build(input_shape)
417
+ input_shape = self.mid_attention.compute_output_shape(input_shape)
418
+ self.mid_block_1.build(input_shape)
419
+ input_shape = self.mid_block_1.compute_output_shape(input_shape)
420
+ self.output_norm.build(input_shape)
421
+ self.output_act.build(input_shape)
422
+ self.output_projection.build(input_shape)
423
+
424
+ def call(self, inputs, training=None):
425
+ x = inputs
426
+ x = self.input_projection(x, training=training)
427
+ blocks_idx = 0
428
+ upsamples_idx = 0
429
+ for i, _ in enumerate(self.stackwise_num_filters):
430
+ for _ in range(self.stackwise_num_blocks[i]):
431
+ x = self.blocks[blocks_idx](x, training=training)
432
+ blocks_idx += 1
433
+ if i != len(self.stackwise_num_filters) - 1:
434
+ x = self.downsamples[upsamples_idx](x, training=training)
435
+ x = self.downsamples[upsamples_idx + 1](x, training=training)
436
+ upsamples_idx += 2
437
+ x = self.mid_block_0(x, training=training)
438
+ x = self.mid_attention(x, training=training)
439
+ x = self.mid_block_1(x, training=training)
440
+ x = self.output_norm(x, training=training)
441
+ x = self.output_act(x, training=training)
442
+ x = self.output_projection(x, training=training)
443
+ return x
444
+
445
+ def get_config(self):
446
+ config = super().get_config()
447
+ config.update(
448
+ {
449
+ "stackwise_num_filters": self.stackwise_num_filters,
450
+ "stackwise_num_blocks": self.stackwise_num_blocks,
451
+ "output_channels": self.output_channels,
452
+ }
453
+ )
454
+ return config
455
+
456
+ def compute_output_shape(self, input_shape):
457
+ if self.data_format == "channels_last":
458
+ h_axis, w_axis, c_axis = 1, 2, 3
459
+ else:
460
+ c_axis, h_axis, w_axis = 1, 2, 3
461
+ scale_factor = 2 ** (len(self.stackwise_num_filters) - 1)
462
+ outputs_shape = list(input_shape)
463
+ if (
464
+ outputs_shape[h_axis] is not None
465
+ and outputs_shape[w_axis] is not None
466
+ ):
467
+ outputs_shape[h_axis] = outputs_shape[h_axis] // scale_factor
468
+ outputs_shape[w_axis] = outputs_shape[w_axis] // scale_factor
469
+ outputs_shape[c_axis] = self.output_channels
470
+ return outputs_shape
471
+
472
+
473
+ class VAEDecoder(keras.layers.Layer):
474
+ """The decoder layer of VAE.
475
+
476
+ Args:
477
+ stackwise_num_filters: list of ints. The number of filters for each
478
+ stack.
479
+ stackwise_num_blocks: list of ints. The number of blocks for each stack.
480
+ output_channels: int. The number of channels in the output. Defaults to
481
+ `3`.
482
+ data_format: `None` or str. If specified, either `"channels_last"` or
483
+ `"channels_first"`. The ordering of the dimensions in the
484
+ inputs. `"channels_last"` corresponds to inputs with shape
485
+ `(batch_size, height, width, channels)`
486
+ while `"channels_first"` corresponds to inputs with shape
487
+ `(batch_size, channels, height, width)`. It defaults to the
488
+ `image_data_format` value found in your Keras config file at
489
+ `~/.keras/keras.json`. If you never set it, then it will be
490
+ `"channels_last"`.
491
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
492
+ including `name`, `dtype` etc.
493
+ """
494
+
495
+ def __init__(
496
+ self,
497
+ stackwise_num_filters,
498
+ stackwise_num_blocks,
499
+ output_channels=3,
500
+ data_format=None,
501
+ **kwargs,
502
+ ):
503
+ super().__init__(**kwargs)
504
+ data_format = standardize_data_format(data_format)
505
+ channel_axis = -1 if data_format == "channels_last" else 1
506
+ self.stackwise_num_filters = stackwise_num_filters
507
+ self.stackwise_num_blocks = stackwise_num_blocks
508
+ self.output_channels = int(output_channels)
509
+ self.data_format = data_format
510
+
511
+ # === Layers ===
512
+ self.input_projection = keras.layers.Conv2D(
513
+ stackwise_num_filters[0],
514
+ 3,
515
+ 1,
516
+ padding="same",
517
+ data_format=data_format,
518
+ dtype=self.dtype_policy,
519
+ name="input_projection",
520
+ )
521
+
522
+ # Mid block.
523
+ self.mid_block_0 = ResNetBlock(
524
+ stackwise_num_filters[0],
525
+ data_format=data_format,
526
+ dtype=self.dtype_policy,
527
+ name="mid_block_0",
528
+ )
529
+ self.mid_attention = Conv2DMultiHeadAttention(
530
+ stackwise_num_filters[0],
531
+ data_format=data_format,
532
+ dtype=self.dtype_policy,
533
+ name="mid_attention",
534
+ )
535
+ self.mid_block_1 = ResNetBlock(
536
+ stackwise_num_filters[0],
537
+ data_format=data_format,
538
+ dtype=self.dtype_policy,
539
+ name="mid_block_1",
540
+ )
541
+
542
+ # Blocks.
543
+ input_filters = stackwise_num_filters[0]
544
+ self.blocks = []
545
+ self.upsamples = []
546
+ for i, filters in enumerate(stackwise_num_filters):
547
+ for j in range(stackwise_num_blocks[i]):
548
+ self.blocks.append(
549
+ ResNetBlock(
550
+ filters,
551
+ has_residual_projection=input_filters != filters,
552
+ data_format=data_format,
553
+ dtype=self.dtype_policy,
554
+ name=f"block_{i}_{j}",
555
+ )
556
+ )
557
+ input_filters = filters
558
+ # No upsample in the last block.
559
+ if i != len(stackwise_num_filters) - 1:
560
+ self.upsamples.append(
561
+ keras.layers.UpSampling2D(
562
+ 2,
563
+ data_format=data_format,
564
+ dtype=self.dtype_policy,
565
+ name=f"upsample_{i}",
566
+ )
567
+ )
568
+ self.upsamples.append(
569
+ keras.layers.Conv2D(
570
+ filters,
571
+ 3,
572
+ 1,
573
+ padding="same",
574
+ data_format=data_format,
575
+ dtype=self.dtype_policy,
576
+ name=f"upsample_{i}_conv",
577
+ )
578
+ )
579
+
580
+ # Output layers.
581
+ self.output_norm = keras.layers.GroupNormalization(
582
+ groups=32,
583
+ axis=channel_axis,
584
+ epsilon=1e-6,
585
+ dtype=self.dtype_policy,
586
+ name="output_norm",
587
+ )
588
+ self.output_act = keras.layers.Activation(
589
+ "swish", dtype=self.dtype_policy
590
+ )
591
+ self.output_projection = keras.layers.Conv2D(
592
+ output_channels,
593
+ 3,
594
+ 1,
595
+ padding="same",
596
+ data_format=data_format,
597
+ dtype=self.dtype_policy,
598
+ name="output_projection",
599
+ )
600
+
601
+ def build(self, input_shape):
602
+ self.input_projection.build(input_shape)
603
+ input_shape = self.input_projection.compute_output_shape(input_shape)
604
+ self.mid_block_0.build(input_shape)
605
+ input_shape = self.mid_block_0.compute_output_shape(input_shape)
606
+ self.mid_attention.build(input_shape)
607
+ input_shape = self.mid_attention.compute_output_shape(input_shape)
608
+ self.mid_block_1.build(input_shape)
609
+ input_shape = self.mid_block_1.compute_output_shape(input_shape)
610
+ blocks_idx = 0
611
+ upsamples_idx = 0
612
+ for i, _ in enumerate(self.stackwise_num_filters):
613
+ for _ in range(self.stackwise_num_blocks[i]):
614
+ self.blocks[blocks_idx].build(input_shape)
615
+ input_shape = self.blocks[blocks_idx].compute_output_shape(
616
+ input_shape
617
+ )
618
+ blocks_idx += 1
619
+ if i != len(self.stackwise_num_filters) - 1:
620
+ self.upsamples[upsamples_idx].build(input_shape)
621
+ input_shape = self.upsamples[
622
+ upsamples_idx
623
+ ].compute_output_shape(input_shape)
624
+ self.upsamples[upsamples_idx + 1].build(input_shape)
625
+ input_shape = self.upsamples[
626
+ upsamples_idx + 1
627
+ ].compute_output_shape(input_shape)
628
+ upsamples_idx += 2
629
+ self.output_norm.build(input_shape)
630
+ self.output_act.build(input_shape)
631
+ self.output_projection.build(input_shape)
632
+
633
+ def call(self, inputs, training=None):
634
+ x = inputs
635
+ x = self.input_projection(x, training=training)
636
+ x = self.mid_block_0(x, training=training)
637
+ x = self.mid_attention(x, training=training)
638
+ x = self.mid_block_1(x, training=training)
639
+ blocks_idx = 0
640
+ upsamples_idx = 0
641
+ for i, _ in enumerate(self.stackwise_num_filters):
642
+ for _ in range(self.stackwise_num_blocks[i]):
643
+ x = self.blocks[blocks_idx](x, training=training)
644
+ blocks_idx += 1
645
+ if i != len(self.stackwise_num_filters) - 1:
646
+ x = self.upsamples[upsamples_idx](x, training=training)
647
+ x = self.upsamples[upsamples_idx + 1](x, training=training)
648
+ upsamples_idx += 2
649
+ x = self.output_norm(x, training=training)
650
+ x = self.output_act(x, training=training)
651
+ x = self.output_projection(x, training=training)
652
+ return x
653
+
654
+ def get_config(self):
655
+ config = super().get_config()
656
+ config.update(
657
+ {
658
+ "stackwise_num_filters": self.stackwise_num_filters,
659
+ "stackwise_num_blocks": self.stackwise_num_blocks,
660
+ "output_channels": self.output_channels,
661
+ }
662
+ )
663
+ return config
664
+
665
+ def compute_output_shape(self, input_shape):
666
+ if self.data_format == "channels_last":
667
+ h_axis, w_axis, c_axis = 1, 2, 3
668
+ else:
669
+ c_axis, h_axis, w_axis = 1, 2, 3
670
+ scale_factor = 2 ** (len(self.stackwise_num_filters) - 1)
671
+ outputs_shape = list(input_shape)
672
+ if (
673
+ outputs_shape[h_axis] is not None
674
+ and outputs_shape[w_axis] is not None
675
+ ):
676
+ outputs_shape[h_axis] = outputs_shape[h_axis] * scale_factor
677
+ outputs_shape[w_axis] = outputs_shape[w_axis] * scale_factor
678
+ outputs_shape[c_axis] = self.output_channels
679
+ return outputs_shape
680
+
681
+
682
+ class DiagonalGaussianDistributionSampler(keras.layers.Layer):
683
+ """A sampler for a diagonal Gaussian distribution.
684
+
685
+ This layer samples latent variables from a diagonal Gaussian distribution.
686
+
687
+ Args:
688
+ method: str. The method used to sample from the distribution. Available
689
+ methods are `"sample"` and `"mode"`. `"sample"` draws from the
690
+ distribution using both the mean and log variance. `"mode"` draws
691
+ from the distribution using the mean only.
692
+ axis: int. The axis along which to split the mean and log variance.
693
+ Defaults to `-1`.
694
+ seed: optional int. Used as a random seed.
695
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
696
+ including `name`, `dtype` etc.
697
+ """
698
+
699
+ def __init__(self, method, axis=-1, seed=None, **kwargs):
700
+ super().__init__(**kwargs)
701
+ # TODO: Support `kl` and `nll` modes.
702
+ valid_methods = ("sample", "mode")
703
+ if method not in valid_methods:
704
+ raise ValueError(
705
+ f"Invalid method {method}. Valid methods are "
706
+ f"{list(valid_methods)}."
707
+ )
708
+ self.method = method
709
+ self.axis = axis
710
+ self.seed = seed
711
+ self.seed_generator = keras.random.SeedGenerator(seed)
712
+
713
+ def call(self, inputs):
714
+ x = inputs
715
+ if self.method == "sample":
716
+ x_mean, x_logvar = ops.split(x, 2, axis=self.axis)
717
+ x_logvar = ops.clip(x_logvar, -30.0, 20.0)
718
+ x_std = ops.exp(ops.multiply(0.5, x_logvar))
719
+ sample = keras.random.normal(
720
+ ops.shape(x_mean), dtype=x_mean.dtype, seed=self.seed_generator
721
+ )
722
+ x = ops.add(x_mean, ops.multiply(x_std, sample))
723
+ else:
724
+ x, _ = ops.split(x, 2, axis=self.axis)
725
+ return x
726
+
727
+ def get_config(self):
728
+ config = super().get_config()
729
+ config.update(
730
+ {
731
+ "axis": self.axis,
732
+ "seed": self.seed,
733
+ }
734
+ )
735
+ return config
736
+
737
+ def compute_output_shape(self, input_shape):
738
+ output_shape = list(input_shape)
739
+ output_shape[self.axis] = output_shape[self.axis] // 2
740
+ return output_shape