keras-hub-nightly 0.16.1.dev202410030339__py3-none-any.whl → 0.16.1.dev202410050339__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +3 -0
- keras_hub/api/models/__init__.py +9 -0
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +196 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +4 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +109 -0
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +57 -93
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +5 -3
- keras_hub/src/models/task.py +20 -15
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +172 -0
- keras_hub/src/models/vae/vae_layers.py +740 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410050339.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410050339.dist-info}/RECORD +23 -14
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410050339.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410050339.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,740 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
import keras
|
4
|
+
from keras import ops
|
5
|
+
|
6
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
7
|
+
|
8
|
+
|
9
|
+
class Conv2DMultiHeadAttention(keras.layers.Layer):
|
10
|
+
"""A MultiHeadAttention layer utilizing `Conv2D` and `GroupNormalization`.
|
11
|
+
|
12
|
+
Args:
|
13
|
+
filters: int. The number of the filters for the convolutional layers.
|
14
|
+
groups: int. The number of the groups for the group normalization
|
15
|
+
layers. Defaults to `32`.
|
16
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
17
|
+
`"channels_first"`. The ordering of the dimensions in the
|
18
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
19
|
+
`(batch_size, height, width, channels)`
|
20
|
+
while `"channels_first"` corresponds to inputs with shape
|
21
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
22
|
+
`image_data_format` value found in your Keras config file at
|
23
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
24
|
+
`"channels_last"`.
|
25
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
26
|
+
including `name`, `dtype` etc.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(self, filters, groups=32, data_format=None, **kwargs):
|
30
|
+
super().__init__(**kwargs)
|
31
|
+
data_format = standardize_data_format(data_format)
|
32
|
+
channel_axis = -1 if data_format == "channels_last" else 1
|
33
|
+
self.filters = int(filters)
|
34
|
+
self.groups = int(groups)
|
35
|
+
self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))
|
36
|
+
self.data_format = data_format
|
37
|
+
|
38
|
+
self.group_norm = keras.layers.GroupNormalization(
|
39
|
+
groups=groups,
|
40
|
+
axis=channel_axis,
|
41
|
+
epsilon=1e-6,
|
42
|
+
dtype=self.dtype_policy,
|
43
|
+
name="group_norm",
|
44
|
+
)
|
45
|
+
self.query_conv2d = keras.layers.Conv2D(
|
46
|
+
filters,
|
47
|
+
1,
|
48
|
+
1,
|
49
|
+
data_format=data_format,
|
50
|
+
dtype=self.dtype_policy,
|
51
|
+
name="query_conv2d",
|
52
|
+
)
|
53
|
+
self.key_conv2d = keras.layers.Conv2D(
|
54
|
+
filters,
|
55
|
+
1,
|
56
|
+
1,
|
57
|
+
data_format=data_format,
|
58
|
+
dtype=self.dtype_policy,
|
59
|
+
name="key_conv2d",
|
60
|
+
)
|
61
|
+
self.value_conv2d = keras.layers.Conv2D(
|
62
|
+
filters,
|
63
|
+
1,
|
64
|
+
1,
|
65
|
+
data_format=data_format,
|
66
|
+
dtype=self.dtype_policy,
|
67
|
+
name="value_conv2d",
|
68
|
+
)
|
69
|
+
self.softmax = keras.layers.Softmax(dtype="float32")
|
70
|
+
self.output_conv2d = keras.layers.Conv2D(
|
71
|
+
filters,
|
72
|
+
1,
|
73
|
+
1,
|
74
|
+
data_format=data_format,
|
75
|
+
dtype=self.dtype_policy,
|
76
|
+
name="output_conv2d",
|
77
|
+
)
|
78
|
+
|
79
|
+
def build(self, input_shape):
|
80
|
+
self.group_norm.build(input_shape)
|
81
|
+
self.query_conv2d.build(input_shape)
|
82
|
+
self.key_conv2d.build(input_shape)
|
83
|
+
self.value_conv2d.build(input_shape)
|
84
|
+
self.output_conv2d.build(input_shape)
|
85
|
+
|
86
|
+
def call(self, inputs, training=None):
|
87
|
+
x = self.group_norm(inputs, training=training)
|
88
|
+
query = self.query_conv2d(x, training=training)
|
89
|
+
key = self.key_conv2d(x, training=training)
|
90
|
+
value = self.value_conv2d(x, training=training)
|
91
|
+
|
92
|
+
if self.data_format == "channels_first":
|
93
|
+
query = ops.transpose(query, (0, 2, 3, 1))
|
94
|
+
key = ops.transpose(key, (0, 2, 3, 1))
|
95
|
+
value = ops.transpose(value, (0, 2, 3, 1))
|
96
|
+
shape = ops.shape(inputs)
|
97
|
+
b = shape[0]
|
98
|
+
query = ops.reshape(query, (b, -1, self.filters))
|
99
|
+
key = ops.reshape(key, (b, -1, self.filters))
|
100
|
+
value = ops.reshape(value, (b, -1, self.filters))
|
101
|
+
|
102
|
+
# Compute attention.
|
103
|
+
query = ops.multiply(
|
104
|
+
query, ops.cast(self._inverse_sqrt_filters, query.dtype)
|
105
|
+
)
|
106
|
+
# [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
|
107
|
+
attention_scores = ops.einsum("abc,adc->abd", query, key)
|
108
|
+
attention_scores = ops.cast(
|
109
|
+
self.softmax(attention_scores), self.compute_dtype
|
110
|
+
)
|
111
|
+
# [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
|
112
|
+
attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
|
113
|
+
x = ops.reshape(attention_output, shape)
|
114
|
+
|
115
|
+
x = self.output_conv2d(x, training=training)
|
116
|
+
if self.data_format == "channels_first":
|
117
|
+
x = ops.transpose(x, (0, 3, 1, 2))
|
118
|
+
x = ops.add(x, inputs)
|
119
|
+
return x
|
120
|
+
|
121
|
+
def get_config(self):
|
122
|
+
config = super().get_config()
|
123
|
+
config.update(
|
124
|
+
{
|
125
|
+
"filters": self.filters,
|
126
|
+
"groups": self.groups,
|
127
|
+
}
|
128
|
+
)
|
129
|
+
return config
|
130
|
+
|
131
|
+
def compute_output_shape(self, input_shape):
|
132
|
+
return input_shape
|
133
|
+
|
134
|
+
|
135
|
+
class ResNetBlock(keras.layers.Layer):
|
136
|
+
"""A ResNet block utilizing `GroupNormalization` and SiLU activation.
|
137
|
+
|
138
|
+
Args:
|
139
|
+
filters: The number of filters in the block.
|
140
|
+
has_residual_projection: Whether to add a projection layer for the
|
141
|
+
residual connection. Defaults to `False`.
|
142
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
143
|
+
`"channels_first"`. The ordering of the dimensions in the
|
144
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
145
|
+
`(batch_size, height, width, channels)`
|
146
|
+
while `"channels_first"` corresponds to inputs with shape
|
147
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
148
|
+
`image_data_format` value found in your Keras config file at
|
149
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
150
|
+
`"channels_last"`.
|
151
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
152
|
+
including `name`, `dtype` etc.
|
153
|
+
"""
|
154
|
+
|
155
|
+
def __init__(
|
156
|
+
self,
|
157
|
+
filters,
|
158
|
+
has_residual_projection=False,
|
159
|
+
data_format=None,
|
160
|
+
**kwargs,
|
161
|
+
):
|
162
|
+
|
163
|
+
super().__init__(**kwargs)
|
164
|
+
data_format = standardize_data_format(data_format)
|
165
|
+
channel_axis = -1 if data_format == "channels_last" else 1
|
166
|
+
self.filters = int(filters)
|
167
|
+
self.has_residual_projection = bool(has_residual_projection)
|
168
|
+
|
169
|
+
# === Layers ===
|
170
|
+
self.norm1 = keras.layers.GroupNormalization(
|
171
|
+
groups=32,
|
172
|
+
axis=channel_axis,
|
173
|
+
epsilon=1e-6,
|
174
|
+
dtype=self.dtype_policy,
|
175
|
+
name="norm1",
|
176
|
+
)
|
177
|
+
self.act1 = keras.layers.Activation("silu", dtype=self.dtype_policy)
|
178
|
+
self.conv1 = keras.layers.Conv2D(
|
179
|
+
filters,
|
180
|
+
3,
|
181
|
+
1,
|
182
|
+
padding="same",
|
183
|
+
data_format=data_format,
|
184
|
+
dtype=self.dtype_policy,
|
185
|
+
name="conv1",
|
186
|
+
)
|
187
|
+
self.norm2 = keras.layers.GroupNormalization(
|
188
|
+
groups=32,
|
189
|
+
axis=channel_axis,
|
190
|
+
epsilon=1e-6,
|
191
|
+
dtype=self.dtype_policy,
|
192
|
+
name="norm2",
|
193
|
+
)
|
194
|
+
self.act2 = keras.layers.Activation("silu", dtype=self.dtype_policy)
|
195
|
+
self.conv2 = keras.layers.Conv2D(
|
196
|
+
filters,
|
197
|
+
3,
|
198
|
+
1,
|
199
|
+
padding="same",
|
200
|
+
data_format=data_format,
|
201
|
+
dtype=self.dtype_policy,
|
202
|
+
name="conv2",
|
203
|
+
)
|
204
|
+
if self.has_residual_projection:
|
205
|
+
self.residual_projection = keras.layers.Conv2D(
|
206
|
+
filters,
|
207
|
+
1,
|
208
|
+
1,
|
209
|
+
data_format=data_format,
|
210
|
+
dtype=self.dtype_policy,
|
211
|
+
name="residual_projection",
|
212
|
+
)
|
213
|
+
self.add = keras.layers.Add(dtype=self.dtype_policy)
|
214
|
+
|
215
|
+
def build(self, input_shape):
|
216
|
+
residual_shape = list(input_shape)
|
217
|
+
self.norm1.build(input_shape)
|
218
|
+
self.act1.build(input_shape)
|
219
|
+
self.conv1.build(input_shape)
|
220
|
+
input_shape = self.conv1.compute_output_shape(input_shape)
|
221
|
+
self.norm2.build(input_shape)
|
222
|
+
self.act2.build(input_shape)
|
223
|
+
self.conv2.build(input_shape)
|
224
|
+
input_shape = self.conv2.compute_output_shape(input_shape)
|
225
|
+
if self.has_residual_projection:
|
226
|
+
self.residual_projection.build(residual_shape)
|
227
|
+
self.add.build([input_shape, input_shape])
|
228
|
+
|
229
|
+
def call(self, inputs, training=None):
|
230
|
+
x = inputs
|
231
|
+
residual = x
|
232
|
+
x = self.norm1(x, training=training)
|
233
|
+
x = self.act1(x, training=training)
|
234
|
+
x = self.conv1(x, training=training)
|
235
|
+
x = self.norm2(x, training=training)
|
236
|
+
x = self.act2(x, training=training)
|
237
|
+
x = self.conv2(x, training=training)
|
238
|
+
if self.has_residual_projection:
|
239
|
+
residual = self.residual_projection(residual, training=training)
|
240
|
+
x = self.add([residual, x])
|
241
|
+
return x
|
242
|
+
|
243
|
+
def get_config(self):
|
244
|
+
config = super().get_config()
|
245
|
+
config.update(
|
246
|
+
{
|
247
|
+
"filters": self.filters,
|
248
|
+
"has_residual_projection": self.has_residual_projection,
|
249
|
+
}
|
250
|
+
)
|
251
|
+
return config
|
252
|
+
|
253
|
+
def compute_output_shape(self, input_shape):
|
254
|
+
outputs_shape = list(input_shape)
|
255
|
+
if self.has_residual_projection:
|
256
|
+
outputs_shape = self.residual_projection.compute_output_shape(
|
257
|
+
outputs_shape
|
258
|
+
)
|
259
|
+
return outputs_shape
|
260
|
+
|
261
|
+
|
262
|
+
class VAEEncoder(keras.layers.Layer):
|
263
|
+
"""The encoder layer of VAE.
|
264
|
+
|
265
|
+
Args:
|
266
|
+
stackwise_num_filters: list of ints. The number of filters for each
|
267
|
+
stack.
|
268
|
+
stackwise_num_blocks: list of ints. The number of blocks for each stack.
|
269
|
+
output_channels: int. The number of channels in the output. Defaults to
|
270
|
+
`32`.
|
271
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
272
|
+
`"channels_first"`. The ordering of the dimensions in the
|
273
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
274
|
+
`(batch_size, height, width, channels)`
|
275
|
+
while `"channels_first"` corresponds to inputs with shape
|
276
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
277
|
+
`image_data_format` value found in your Keras config file at
|
278
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
279
|
+
`"channels_last"`.
|
280
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
281
|
+
including `name`, `dtype` etc.
|
282
|
+
"""
|
283
|
+
|
284
|
+
def __init__(
|
285
|
+
self,
|
286
|
+
stackwise_num_filters,
|
287
|
+
stackwise_num_blocks,
|
288
|
+
output_channels=32,
|
289
|
+
data_format=None,
|
290
|
+
**kwargs,
|
291
|
+
):
|
292
|
+
super().__init__(**kwargs)
|
293
|
+
data_format = standardize_data_format(data_format)
|
294
|
+
channel_axis = -1 if data_format == "channels_last" else 1
|
295
|
+
self.stackwise_num_filters = stackwise_num_filters
|
296
|
+
self.stackwise_num_blocks = stackwise_num_blocks
|
297
|
+
self.output_channels = int(output_channels)
|
298
|
+
self.data_format = data_format
|
299
|
+
|
300
|
+
# === Layers ===
|
301
|
+
self.input_projection = keras.layers.Conv2D(
|
302
|
+
stackwise_num_filters[0],
|
303
|
+
3,
|
304
|
+
1,
|
305
|
+
padding="same",
|
306
|
+
data_format=data_format,
|
307
|
+
dtype=self.dtype_policy,
|
308
|
+
name="input_projection",
|
309
|
+
)
|
310
|
+
|
311
|
+
# Blocks.
|
312
|
+
input_filters = stackwise_num_filters[0]
|
313
|
+
self.blocks = []
|
314
|
+
self.downsamples = []
|
315
|
+
for i, filters in enumerate(stackwise_num_filters):
|
316
|
+
for j in range(stackwise_num_blocks[i]):
|
317
|
+
self.blocks.append(
|
318
|
+
ResNetBlock(
|
319
|
+
filters,
|
320
|
+
has_residual_projection=input_filters != filters,
|
321
|
+
data_format=data_format,
|
322
|
+
dtype=self.dtype_policy,
|
323
|
+
name=f"block_{i}_{j}",
|
324
|
+
)
|
325
|
+
)
|
326
|
+
input_filters = filters
|
327
|
+
# No downsample in the last block.
|
328
|
+
if i != len(stackwise_num_filters) - 1:
|
329
|
+
self.downsamples.append(
|
330
|
+
keras.layers.ZeroPadding2D(
|
331
|
+
padding=((0, 1), (0, 1)),
|
332
|
+
data_format=data_format,
|
333
|
+
dtype=self.dtype_policy,
|
334
|
+
name=f"downsample_{i}_pad",
|
335
|
+
)
|
336
|
+
)
|
337
|
+
self.downsamples.append(
|
338
|
+
keras.layers.Conv2D(
|
339
|
+
filters,
|
340
|
+
3,
|
341
|
+
2,
|
342
|
+
data_format=data_format,
|
343
|
+
dtype=self.dtype_policy,
|
344
|
+
name=f"downsample_{i}_conv",
|
345
|
+
)
|
346
|
+
)
|
347
|
+
|
348
|
+
# Mid block.
|
349
|
+
self.mid_block_0 = ResNetBlock(
|
350
|
+
stackwise_num_filters[-1],
|
351
|
+
has_residual_projection=False,
|
352
|
+
data_format=data_format,
|
353
|
+
dtype=self.dtype_policy,
|
354
|
+
name="mid_block_0",
|
355
|
+
)
|
356
|
+
self.mid_attention = Conv2DMultiHeadAttention(
|
357
|
+
stackwise_num_filters[-1],
|
358
|
+
data_format=data_format,
|
359
|
+
dtype=self.dtype_policy,
|
360
|
+
name="mid_attention",
|
361
|
+
)
|
362
|
+
self.mid_block_1 = ResNetBlock(
|
363
|
+
stackwise_num_filters[-1],
|
364
|
+
has_residual_projection=False,
|
365
|
+
data_format=data_format,
|
366
|
+
dtype=self.dtype_policy,
|
367
|
+
name="mid_block_1",
|
368
|
+
)
|
369
|
+
|
370
|
+
# Output layers.
|
371
|
+
self.output_norm = keras.layers.GroupNormalization(
|
372
|
+
groups=32,
|
373
|
+
axis=channel_axis,
|
374
|
+
epsilon=1e-6,
|
375
|
+
dtype=self.dtype_policy,
|
376
|
+
name="output_norm",
|
377
|
+
)
|
378
|
+
self.output_act = keras.layers.Activation(
|
379
|
+
"swish", dtype=self.dtype_policy
|
380
|
+
)
|
381
|
+
self.output_projection = keras.layers.Conv2D(
|
382
|
+
output_channels,
|
383
|
+
3,
|
384
|
+
1,
|
385
|
+
padding="same",
|
386
|
+
data_format=data_format,
|
387
|
+
dtype=self.dtype_policy,
|
388
|
+
name="output_projection",
|
389
|
+
)
|
390
|
+
|
391
|
+
def build(self, input_shape):
|
392
|
+
self.input_projection.build(input_shape)
|
393
|
+
input_shape = self.input_projection.compute_output_shape(input_shape)
|
394
|
+
blocks_idx = 0
|
395
|
+
downsamples_idx = 0
|
396
|
+
for i, _ in enumerate(self.stackwise_num_filters):
|
397
|
+
for _ in range(self.stackwise_num_blocks[i]):
|
398
|
+
self.blocks[blocks_idx].build(input_shape)
|
399
|
+
input_shape = self.blocks[blocks_idx].compute_output_shape(
|
400
|
+
input_shape
|
401
|
+
)
|
402
|
+
blocks_idx += 1
|
403
|
+
if i != len(self.stackwise_num_filters) - 1:
|
404
|
+
self.downsamples[downsamples_idx].build(input_shape)
|
405
|
+
input_shape = self.downsamples[
|
406
|
+
downsamples_idx
|
407
|
+
].compute_output_shape(input_shape)
|
408
|
+
downsamples_idx += 1
|
409
|
+
self.downsamples[downsamples_idx].build(input_shape)
|
410
|
+
input_shape = self.downsamples[
|
411
|
+
downsamples_idx
|
412
|
+
].compute_output_shape(input_shape)
|
413
|
+
downsamples_idx += 1
|
414
|
+
self.mid_block_0.build(input_shape)
|
415
|
+
input_shape = self.mid_block_0.compute_output_shape(input_shape)
|
416
|
+
self.mid_attention.build(input_shape)
|
417
|
+
input_shape = self.mid_attention.compute_output_shape(input_shape)
|
418
|
+
self.mid_block_1.build(input_shape)
|
419
|
+
input_shape = self.mid_block_1.compute_output_shape(input_shape)
|
420
|
+
self.output_norm.build(input_shape)
|
421
|
+
self.output_act.build(input_shape)
|
422
|
+
self.output_projection.build(input_shape)
|
423
|
+
|
424
|
+
def call(self, inputs, training=None):
|
425
|
+
x = inputs
|
426
|
+
x = self.input_projection(x, training=training)
|
427
|
+
blocks_idx = 0
|
428
|
+
upsamples_idx = 0
|
429
|
+
for i, _ in enumerate(self.stackwise_num_filters):
|
430
|
+
for _ in range(self.stackwise_num_blocks[i]):
|
431
|
+
x = self.blocks[blocks_idx](x, training=training)
|
432
|
+
blocks_idx += 1
|
433
|
+
if i != len(self.stackwise_num_filters) - 1:
|
434
|
+
x = self.downsamples[upsamples_idx](x, training=training)
|
435
|
+
x = self.downsamples[upsamples_idx + 1](x, training=training)
|
436
|
+
upsamples_idx += 2
|
437
|
+
x = self.mid_block_0(x, training=training)
|
438
|
+
x = self.mid_attention(x, training=training)
|
439
|
+
x = self.mid_block_1(x, training=training)
|
440
|
+
x = self.output_norm(x, training=training)
|
441
|
+
x = self.output_act(x, training=training)
|
442
|
+
x = self.output_projection(x, training=training)
|
443
|
+
return x
|
444
|
+
|
445
|
+
def get_config(self):
|
446
|
+
config = super().get_config()
|
447
|
+
config.update(
|
448
|
+
{
|
449
|
+
"stackwise_num_filters": self.stackwise_num_filters,
|
450
|
+
"stackwise_num_blocks": self.stackwise_num_blocks,
|
451
|
+
"output_channels": self.output_channels,
|
452
|
+
}
|
453
|
+
)
|
454
|
+
return config
|
455
|
+
|
456
|
+
def compute_output_shape(self, input_shape):
|
457
|
+
if self.data_format == "channels_last":
|
458
|
+
h_axis, w_axis, c_axis = 1, 2, 3
|
459
|
+
else:
|
460
|
+
c_axis, h_axis, w_axis = 1, 2, 3
|
461
|
+
scale_factor = 2 ** (len(self.stackwise_num_filters) - 1)
|
462
|
+
outputs_shape = list(input_shape)
|
463
|
+
if (
|
464
|
+
outputs_shape[h_axis] is not None
|
465
|
+
and outputs_shape[w_axis] is not None
|
466
|
+
):
|
467
|
+
outputs_shape[h_axis] = outputs_shape[h_axis] // scale_factor
|
468
|
+
outputs_shape[w_axis] = outputs_shape[w_axis] // scale_factor
|
469
|
+
outputs_shape[c_axis] = self.output_channels
|
470
|
+
return outputs_shape
|
471
|
+
|
472
|
+
|
473
|
+
class VAEDecoder(keras.layers.Layer):
|
474
|
+
"""The decoder layer of VAE.
|
475
|
+
|
476
|
+
Args:
|
477
|
+
stackwise_num_filters: list of ints. The number of filters for each
|
478
|
+
stack.
|
479
|
+
stackwise_num_blocks: list of ints. The number of blocks for each stack.
|
480
|
+
output_channels: int. The number of channels in the output. Defaults to
|
481
|
+
`3`.
|
482
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
483
|
+
`"channels_first"`. The ordering of the dimensions in the
|
484
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
485
|
+
`(batch_size, height, width, channels)`
|
486
|
+
while `"channels_first"` corresponds to inputs with shape
|
487
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
488
|
+
`image_data_format` value found in your Keras config file at
|
489
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
490
|
+
`"channels_last"`.
|
491
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
492
|
+
including `name`, `dtype` etc.
|
493
|
+
"""
|
494
|
+
|
495
|
+
def __init__(
|
496
|
+
self,
|
497
|
+
stackwise_num_filters,
|
498
|
+
stackwise_num_blocks,
|
499
|
+
output_channels=3,
|
500
|
+
data_format=None,
|
501
|
+
**kwargs,
|
502
|
+
):
|
503
|
+
super().__init__(**kwargs)
|
504
|
+
data_format = standardize_data_format(data_format)
|
505
|
+
channel_axis = -1 if data_format == "channels_last" else 1
|
506
|
+
self.stackwise_num_filters = stackwise_num_filters
|
507
|
+
self.stackwise_num_blocks = stackwise_num_blocks
|
508
|
+
self.output_channels = int(output_channels)
|
509
|
+
self.data_format = data_format
|
510
|
+
|
511
|
+
# === Layers ===
|
512
|
+
self.input_projection = keras.layers.Conv2D(
|
513
|
+
stackwise_num_filters[0],
|
514
|
+
3,
|
515
|
+
1,
|
516
|
+
padding="same",
|
517
|
+
data_format=data_format,
|
518
|
+
dtype=self.dtype_policy,
|
519
|
+
name="input_projection",
|
520
|
+
)
|
521
|
+
|
522
|
+
# Mid block.
|
523
|
+
self.mid_block_0 = ResNetBlock(
|
524
|
+
stackwise_num_filters[0],
|
525
|
+
data_format=data_format,
|
526
|
+
dtype=self.dtype_policy,
|
527
|
+
name="mid_block_0",
|
528
|
+
)
|
529
|
+
self.mid_attention = Conv2DMultiHeadAttention(
|
530
|
+
stackwise_num_filters[0],
|
531
|
+
data_format=data_format,
|
532
|
+
dtype=self.dtype_policy,
|
533
|
+
name="mid_attention",
|
534
|
+
)
|
535
|
+
self.mid_block_1 = ResNetBlock(
|
536
|
+
stackwise_num_filters[0],
|
537
|
+
data_format=data_format,
|
538
|
+
dtype=self.dtype_policy,
|
539
|
+
name="mid_block_1",
|
540
|
+
)
|
541
|
+
|
542
|
+
# Blocks.
|
543
|
+
input_filters = stackwise_num_filters[0]
|
544
|
+
self.blocks = []
|
545
|
+
self.upsamples = []
|
546
|
+
for i, filters in enumerate(stackwise_num_filters):
|
547
|
+
for j in range(stackwise_num_blocks[i]):
|
548
|
+
self.blocks.append(
|
549
|
+
ResNetBlock(
|
550
|
+
filters,
|
551
|
+
has_residual_projection=input_filters != filters,
|
552
|
+
data_format=data_format,
|
553
|
+
dtype=self.dtype_policy,
|
554
|
+
name=f"block_{i}_{j}",
|
555
|
+
)
|
556
|
+
)
|
557
|
+
input_filters = filters
|
558
|
+
# No upsample in the last block.
|
559
|
+
if i != len(stackwise_num_filters) - 1:
|
560
|
+
self.upsamples.append(
|
561
|
+
keras.layers.UpSampling2D(
|
562
|
+
2,
|
563
|
+
data_format=data_format,
|
564
|
+
dtype=self.dtype_policy,
|
565
|
+
name=f"upsample_{i}",
|
566
|
+
)
|
567
|
+
)
|
568
|
+
self.upsamples.append(
|
569
|
+
keras.layers.Conv2D(
|
570
|
+
filters,
|
571
|
+
3,
|
572
|
+
1,
|
573
|
+
padding="same",
|
574
|
+
data_format=data_format,
|
575
|
+
dtype=self.dtype_policy,
|
576
|
+
name=f"upsample_{i}_conv",
|
577
|
+
)
|
578
|
+
)
|
579
|
+
|
580
|
+
# Output layers.
|
581
|
+
self.output_norm = keras.layers.GroupNormalization(
|
582
|
+
groups=32,
|
583
|
+
axis=channel_axis,
|
584
|
+
epsilon=1e-6,
|
585
|
+
dtype=self.dtype_policy,
|
586
|
+
name="output_norm",
|
587
|
+
)
|
588
|
+
self.output_act = keras.layers.Activation(
|
589
|
+
"swish", dtype=self.dtype_policy
|
590
|
+
)
|
591
|
+
self.output_projection = keras.layers.Conv2D(
|
592
|
+
output_channels,
|
593
|
+
3,
|
594
|
+
1,
|
595
|
+
padding="same",
|
596
|
+
data_format=data_format,
|
597
|
+
dtype=self.dtype_policy,
|
598
|
+
name="output_projection",
|
599
|
+
)
|
600
|
+
|
601
|
+
def build(self, input_shape):
|
602
|
+
self.input_projection.build(input_shape)
|
603
|
+
input_shape = self.input_projection.compute_output_shape(input_shape)
|
604
|
+
self.mid_block_0.build(input_shape)
|
605
|
+
input_shape = self.mid_block_0.compute_output_shape(input_shape)
|
606
|
+
self.mid_attention.build(input_shape)
|
607
|
+
input_shape = self.mid_attention.compute_output_shape(input_shape)
|
608
|
+
self.mid_block_1.build(input_shape)
|
609
|
+
input_shape = self.mid_block_1.compute_output_shape(input_shape)
|
610
|
+
blocks_idx = 0
|
611
|
+
upsamples_idx = 0
|
612
|
+
for i, _ in enumerate(self.stackwise_num_filters):
|
613
|
+
for _ in range(self.stackwise_num_blocks[i]):
|
614
|
+
self.blocks[blocks_idx].build(input_shape)
|
615
|
+
input_shape = self.blocks[blocks_idx].compute_output_shape(
|
616
|
+
input_shape
|
617
|
+
)
|
618
|
+
blocks_idx += 1
|
619
|
+
if i != len(self.stackwise_num_filters) - 1:
|
620
|
+
self.upsamples[upsamples_idx].build(input_shape)
|
621
|
+
input_shape = self.upsamples[
|
622
|
+
upsamples_idx
|
623
|
+
].compute_output_shape(input_shape)
|
624
|
+
self.upsamples[upsamples_idx + 1].build(input_shape)
|
625
|
+
input_shape = self.upsamples[
|
626
|
+
upsamples_idx + 1
|
627
|
+
].compute_output_shape(input_shape)
|
628
|
+
upsamples_idx += 2
|
629
|
+
self.output_norm.build(input_shape)
|
630
|
+
self.output_act.build(input_shape)
|
631
|
+
self.output_projection.build(input_shape)
|
632
|
+
|
633
|
+
def call(self, inputs, training=None):
|
634
|
+
x = inputs
|
635
|
+
x = self.input_projection(x, training=training)
|
636
|
+
x = self.mid_block_0(x, training=training)
|
637
|
+
x = self.mid_attention(x, training=training)
|
638
|
+
x = self.mid_block_1(x, training=training)
|
639
|
+
blocks_idx = 0
|
640
|
+
upsamples_idx = 0
|
641
|
+
for i, _ in enumerate(self.stackwise_num_filters):
|
642
|
+
for _ in range(self.stackwise_num_blocks[i]):
|
643
|
+
x = self.blocks[blocks_idx](x, training=training)
|
644
|
+
blocks_idx += 1
|
645
|
+
if i != len(self.stackwise_num_filters) - 1:
|
646
|
+
x = self.upsamples[upsamples_idx](x, training=training)
|
647
|
+
x = self.upsamples[upsamples_idx + 1](x, training=training)
|
648
|
+
upsamples_idx += 2
|
649
|
+
x = self.output_norm(x, training=training)
|
650
|
+
x = self.output_act(x, training=training)
|
651
|
+
x = self.output_projection(x, training=training)
|
652
|
+
return x
|
653
|
+
|
654
|
+
def get_config(self):
|
655
|
+
config = super().get_config()
|
656
|
+
config.update(
|
657
|
+
{
|
658
|
+
"stackwise_num_filters": self.stackwise_num_filters,
|
659
|
+
"stackwise_num_blocks": self.stackwise_num_blocks,
|
660
|
+
"output_channels": self.output_channels,
|
661
|
+
}
|
662
|
+
)
|
663
|
+
return config
|
664
|
+
|
665
|
+
def compute_output_shape(self, input_shape):
|
666
|
+
if self.data_format == "channels_last":
|
667
|
+
h_axis, w_axis, c_axis = 1, 2, 3
|
668
|
+
else:
|
669
|
+
c_axis, h_axis, w_axis = 1, 2, 3
|
670
|
+
scale_factor = 2 ** (len(self.stackwise_num_filters) - 1)
|
671
|
+
outputs_shape = list(input_shape)
|
672
|
+
if (
|
673
|
+
outputs_shape[h_axis] is not None
|
674
|
+
and outputs_shape[w_axis] is not None
|
675
|
+
):
|
676
|
+
outputs_shape[h_axis] = outputs_shape[h_axis] * scale_factor
|
677
|
+
outputs_shape[w_axis] = outputs_shape[w_axis] * scale_factor
|
678
|
+
outputs_shape[c_axis] = self.output_channels
|
679
|
+
return outputs_shape
|
680
|
+
|
681
|
+
|
682
|
+
class DiagonalGaussianDistributionSampler(keras.layers.Layer):
|
683
|
+
"""A sampler for a diagonal Gaussian distribution.
|
684
|
+
|
685
|
+
This layer samples latent variables from a diagonal Gaussian distribution.
|
686
|
+
|
687
|
+
Args:
|
688
|
+
method: str. The method used to sample from the distribution. Available
|
689
|
+
methods are `"sample"` and `"mode"`. `"sample"` draws from the
|
690
|
+
distribution using both the mean and log variance. `"mode"` draws
|
691
|
+
from the distribution using the mean only.
|
692
|
+
axis: int. The axis along which to split the mean and log variance.
|
693
|
+
Defaults to `-1`.
|
694
|
+
seed: optional int. Used as a random seed.
|
695
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
696
|
+
including `name`, `dtype` etc.
|
697
|
+
"""
|
698
|
+
|
699
|
+
def __init__(self, method, axis=-1, seed=None, **kwargs):
|
700
|
+
super().__init__(**kwargs)
|
701
|
+
# TODO: Support `kl` and `nll` modes.
|
702
|
+
valid_methods = ("sample", "mode")
|
703
|
+
if method not in valid_methods:
|
704
|
+
raise ValueError(
|
705
|
+
f"Invalid method {method}. Valid methods are "
|
706
|
+
f"{list(valid_methods)}."
|
707
|
+
)
|
708
|
+
self.method = method
|
709
|
+
self.axis = axis
|
710
|
+
self.seed = seed
|
711
|
+
self.seed_generator = keras.random.SeedGenerator(seed)
|
712
|
+
|
713
|
+
def call(self, inputs):
|
714
|
+
x = inputs
|
715
|
+
if self.method == "sample":
|
716
|
+
x_mean, x_logvar = ops.split(x, 2, axis=self.axis)
|
717
|
+
x_logvar = ops.clip(x_logvar, -30.0, 20.0)
|
718
|
+
x_std = ops.exp(ops.multiply(0.5, x_logvar))
|
719
|
+
sample = keras.random.normal(
|
720
|
+
ops.shape(x_mean), dtype=x_mean.dtype, seed=self.seed_generator
|
721
|
+
)
|
722
|
+
x = ops.add(x_mean, ops.multiply(x_std, sample))
|
723
|
+
else:
|
724
|
+
x, _ = ops.split(x, 2, axis=self.axis)
|
725
|
+
return x
|
726
|
+
|
727
|
+
def get_config(self):
|
728
|
+
config = super().get_config()
|
729
|
+
config.update(
|
730
|
+
{
|
731
|
+
"axis": self.axis,
|
732
|
+
"seed": self.seed,
|
733
|
+
}
|
734
|
+
)
|
735
|
+
return config
|
736
|
+
|
737
|
+
def compute_output_shape(self, input_shape):
|
738
|
+
output_shape = list(input_shape)
|
739
|
+
output_shape[self.axis] = output_shape[self.axis] // 2
|
740
|
+
return output_shape
|