keras-hub-nightly 0.16.1.dev202410080341__py3-none-any.whl → 0.16.1.dev202410100339__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +3 -0
- keras_hub/api/models/__init__.py +11 -0
- keras_hub/src/layers/preprocessing/image_converter.py +2 -1
- keras_hub/src/models/image_to_image.py +411 -0
- keras_hub/src/models/inpaint.py +513 -0
- keras_hub/src/models/mix_transformer/__init__.py +12 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +4 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier_preprocessor.py +16 -0
- keras_hub/src/models/mix_transformer/mix_transformer_image_converter.py +8 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +9 -5
- keras_hub/src/models/mix_transformer/mix_transformer_presets.py +151 -0
- keras_hub/src/models/preprocessor.py +4 -4
- keras_hub/src/models/stable_diffusion_3/mmdit.py +308 -177
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +87 -55
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +171 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +194 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +1 -1
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +13 -8
- keras_hub/src/models/task.py +1 -1
- keras_hub/src/models/text_to_image.py +89 -36
- keras_hub/src/tests/test_case.py +3 -1
- keras_hub/src/tokenizers/tokenizer.py +7 -7
- keras_hub/src/utils/preset_utils.py +7 -7
- keras_hub/src/utils/timm/preset_loader.py +1 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410080341.dist-info → keras_hub_nightly-0.16.1.dev202410100339.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.16.1.dev202410080341.dist-info → keras_hub_nightly-0.16.1.dev202410100339.dist-info}/RECORD +29 -22
- {keras_hub_nightly-0.16.1.dev202410080341.dist-info → keras_hub_nightly-0.16.1.dev202410100339.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.16.1.dev202410080341.dist-info → keras_hub_nightly-0.16.1.dev202410100339.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,6 @@ import math
|
|
2
2
|
|
3
3
|
import keras
|
4
4
|
from keras import layers
|
5
|
-
from keras import models
|
6
5
|
from keras import ops
|
7
6
|
|
8
7
|
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
@@ -11,7 +10,167 @@ from keras_hub.src.utils.keras_utils import gelu_approximate
|
|
11
10
|
from keras_hub.src.utils.keras_utils import standardize_data_format
|
12
11
|
|
13
12
|
|
13
|
+
class AdaptiveLayerNormalization(layers.Layer):
|
14
|
+
"""Adaptive layer normalization.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
embedding_dim: int. The size of each embedding vector.
|
18
|
+
residual_modulation: bool. Whether to output the modulation parameters
|
19
|
+
of the residual connection within the block of the diffusion
|
20
|
+
transformers. Defaults to `False`.
|
21
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
22
|
+
including `name`, `dtype` etc.
|
23
|
+
|
24
|
+
References:
|
25
|
+
- [FiLM: Visual Reasoning with a General Conditioning Layer](
|
26
|
+
https://arxiv.org/abs/1709.07871).
|
27
|
+
- [Scalable Diffusion Models with Transformers](
|
28
|
+
https://arxiv.org/abs/2212.09748).
|
29
|
+
"""
|
30
|
+
|
31
|
+
def __init__(self, hidden_dim, residual_modulation=False, **kwargs):
|
32
|
+
super().__init__(**kwargs)
|
33
|
+
self.hidden_dim = int(hidden_dim)
|
34
|
+
self.residual_modulation = bool(residual_modulation)
|
35
|
+
num_modulations = 6 if self.residual_modulation else 2
|
36
|
+
|
37
|
+
self.silu = layers.Activation("silu", dtype=self.dtype_policy)
|
38
|
+
self.dense = layers.Dense(
|
39
|
+
num_modulations * hidden_dim, dtype=self.dtype_policy, name="dense"
|
40
|
+
)
|
41
|
+
self.norm = layers.LayerNormalization(
|
42
|
+
epsilon=1e-6,
|
43
|
+
center=False,
|
44
|
+
scale=False,
|
45
|
+
dtype="float32",
|
46
|
+
name="norm",
|
47
|
+
)
|
48
|
+
|
49
|
+
def build(self, inputs_shape, embeddings_shape):
|
50
|
+
self.silu.build(embeddings_shape)
|
51
|
+
self.dense.build(embeddings_shape)
|
52
|
+
self.norm.build(inputs_shape)
|
53
|
+
|
54
|
+
def call(self, inputs, embeddings, training=None):
|
55
|
+
x = inputs
|
56
|
+
emb = self.dense(self.silu(embeddings), training=training)
|
57
|
+
if self.residual_modulation:
|
58
|
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
59
|
+
ops.split(emb, 6, axis=1)
|
60
|
+
)
|
61
|
+
else:
|
62
|
+
shift_msa, scale_msa = ops.split(emb, 2, axis=1)
|
63
|
+
scale_msa = ops.expand_dims(scale_msa, axis=1)
|
64
|
+
shift_msa = ops.expand_dims(shift_msa, axis=1)
|
65
|
+
x = ops.add(
|
66
|
+
ops.multiply(
|
67
|
+
self.norm(x, training=training),
|
68
|
+
ops.add(1.0, scale_msa),
|
69
|
+
),
|
70
|
+
shift_msa,
|
71
|
+
)
|
72
|
+
if self.residual_modulation:
|
73
|
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
74
|
+
else:
|
75
|
+
return x
|
76
|
+
|
77
|
+
def get_config(self):
|
78
|
+
config = super().get_config()
|
79
|
+
config.update(
|
80
|
+
{
|
81
|
+
"hidden_dim": self.hidden_dim,
|
82
|
+
"residual_modulation": self.residual_modulation,
|
83
|
+
}
|
84
|
+
)
|
85
|
+
return config
|
86
|
+
|
87
|
+
def compute_output_shape(self, inputs_shape, embeddings_shape):
|
88
|
+
if self.residual_modulation:
|
89
|
+
return (
|
90
|
+
inputs_shape,
|
91
|
+
embeddings_shape,
|
92
|
+
embeddings_shape,
|
93
|
+
embeddings_shape,
|
94
|
+
embeddings_shape,
|
95
|
+
)
|
96
|
+
else:
|
97
|
+
return inputs_shape
|
98
|
+
|
99
|
+
|
100
|
+
class MLP(layers.Layer):
|
101
|
+
"""A MLP block with architecture.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
hidden_dim: int. The number of units in the hidden layers.
|
105
|
+
output_dim: int. The number of units in the output layer.
|
106
|
+
activation: str of callable. Activation to use in the hidden layers.
|
107
|
+
Default to `None`.
|
108
|
+
"""
|
109
|
+
|
110
|
+
def __init__(self, hidden_dim, output_dim, activation=None, **kwargs):
|
111
|
+
super().__init__(**kwargs)
|
112
|
+
self.hidden_dim = int(hidden_dim)
|
113
|
+
self.output_dim = int(output_dim)
|
114
|
+
self.activation = keras.activations.get(activation)
|
115
|
+
|
116
|
+
self.dense1 = layers.Dense(
|
117
|
+
hidden_dim,
|
118
|
+
activation=self.activation,
|
119
|
+
dtype=self.dtype_policy,
|
120
|
+
name="dense1",
|
121
|
+
)
|
122
|
+
self.dense2 = layers.Dense(
|
123
|
+
output_dim,
|
124
|
+
activation=None,
|
125
|
+
dtype=self.dtype_policy,
|
126
|
+
name="dense2",
|
127
|
+
)
|
128
|
+
|
129
|
+
def build(self, inputs_shape):
|
130
|
+
self.dense1.build(inputs_shape)
|
131
|
+
inputs_shape = self.dense1.compute_output_shape(inputs_shape)
|
132
|
+
self.dense2.build(inputs_shape)
|
133
|
+
|
134
|
+
def call(self, inputs, training=None):
|
135
|
+
x = self.dense1(inputs, training=training)
|
136
|
+
return self.dense2(x, training=training)
|
137
|
+
|
138
|
+
def get_config(self):
|
139
|
+
config = super().get_config()
|
140
|
+
config.update(
|
141
|
+
{
|
142
|
+
"hidden_dim": self.hidden_dim,
|
143
|
+
"output_dim": self.output_dim,
|
144
|
+
"activation": keras.activations.serialize(self.activation),
|
145
|
+
}
|
146
|
+
)
|
147
|
+
return config
|
148
|
+
|
149
|
+
def compute_output_shape(self, inputs_shape):
|
150
|
+
outputs_shape = list(inputs_shape)
|
151
|
+
outputs_shape[-1] = self.output_dim
|
152
|
+
return outputs_shape
|
153
|
+
|
154
|
+
|
14
155
|
class PatchEmbedding(layers.Layer):
|
156
|
+
"""A layer that converts images into patches.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
patch_size: int. The size of one side of each patch.
|
160
|
+
hidden_dim: int. The number of units in the hidden layers.
|
161
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
162
|
+
`"channels_first"`. The ordering of the dimensions in the
|
163
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
164
|
+
`(batch_size, height, width, channels)`
|
165
|
+
while `"channels_first"` corresponds to inputs with shape
|
166
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
167
|
+
`image_data_format` value found in your Keras config file at
|
168
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
169
|
+
`"channels_last"`.
|
170
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
171
|
+
including `name`, `dtype` etc.
|
172
|
+
"""
|
173
|
+
|
15
174
|
def __init__(self, patch_size, hidden_dim, data_format=None, **kwargs):
|
16
175
|
super().__init__(**kwargs)
|
17
176
|
self.patch_size = int(patch_size)
|
@@ -48,6 +207,15 @@ class PatchEmbedding(layers.Layer):
|
|
48
207
|
|
49
208
|
|
50
209
|
class AdjustablePositionEmbedding(PositionEmbedding):
|
210
|
+
"""A position embedding layer with adjustable height and width.
|
211
|
+
|
212
|
+
The embedding will be cropped to match the input dimensions.
|
213
|
+
|
214
|
+
Args:
|
215
|
+
height: int. The maximum height of the embedding.
|
216
|
+
width: int. The maximum width of the embedding.
|
217
|
+
"""
|
218
|
+
|
51
219
|
def __init__(
|
52
220
|
self,
|
53
221
|
height,
|
@@ -84,11 +252,36 @@ class AdjustablePositionEmbedding(PositionEmbedding):
|
|
84
252
|
position_embedding = ops.expand_dims(position_embedding, axis=0)
|
85
253
|
return position_embedding
|
86
254
|
|
255
|
+
def get_config(self):
|
256
|
+
config = super().get_config()
|
257
|
+
del config["sequence_length"]
|
258
|
+
config.update(
|
259
|
+
{
|
260
|
+
"height": self.height,
|
261
|
+
"width": self.width,
|
262
|
+
}
|
263
|
+
)
|
264
|
+
return config
|
265
|
+
|
87
266
|
def compute_output_shape(self, input_shape):
|
88
267
|
return input_shape
|
89
268
|
|
90
269
|
|
91
270
|
class TimestepEmbedding(layers.Layer):
|
271
|
+
"""A layer which learns embedding for input timesteps.
|
272
|
+
|
273
|
+
Args:
|
274
|
+
embedding_dim: int. The size of the embedding.
|
275
|
+
frequency_dim: int. The size of the frequency.
|
276
|
+
max_period: int. Controls the maximum frequency of the embeddings.
|
277
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
278
|
+
including `name`, `dtype` etc.
|
279
|
+
|
280
|
+
Reference:
|
281
|
+
- [Denoising Diffusion Probabilistic Models](
|
282
|
+
https://arxiv.org/abs/2006.11239).
|
283
|
+
"""
|
284
|
+
|
92
285
|
def __init__(
|
93
286
|
self, embedding_dim, frequency_dim=256, max_period=10000, **kwargs
|
94
287
|
):
|
@@ -96,17 +289,23 @@ class TimestepEmbedding(layers.Layer):
|
|
96
289
|
self.embedding_dim = int(embedding_dim)
|
97
290
|
self.frequency_dim = int(frequency_dim)
|
98
291
|
self.max_period = float(max_period)
|
99
|
-
|
100
|
-
|
101
|
-
self.
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
layers.Dense(
|
107
|
-
embedding_dim, activation=None, dtype=self.dtype_policy
|
292
|
+
# Precomputed `freq`.
|
293
|
+
half_frequency_dim = frequency_dim // 2
|
294
|
+
self.freq = ops.exp(
|
295
|
+
ops.divide(
|
296
|
+
ops.multiply(
|
297
|
+
-math.log(max_period),
|
298
|
+
ops.arange(0, half_frequency_dim, dtype="float32"),
|
108
299
|
),
|
109
|
-
|
300
|
+
half_frequency_dim,
|
301
|
+
)
|
302
|
+
)
|
303
|
+
|
304
|
+
self.mlp = MLP(
|
305
|
+
embedding_dim,
|
306
|
+
embedding_dim,
|
307
|
+
"silu",
|
308
|
+
dtype=self.dtype_policy,
|
110
309
|
name="mlp",
|
111
310
|
)
|
112
311
|
|
@@ -118,16 +317,7 @@ class TimestepEmbedding(layers.Layer):
|
|
118
317
|
def _create_timestep_embedding(self, inputs):
|
119
318
|
compute_dtype = keras.backend.result_type(self.compute_dtype, "float32")
|
120
319
|
x = ops.cast(inputs, compute_dtype)
|
121
|
-
freqs = ops.
|
122
|
-
ops.divide(
|
123
|
-
ops.multiply(
|
124
|
-
-math.log(self.max_period),
|
125
|
-
ops.arange(0, self.half_frequency_dim, dtype="float32"),
|
126
|
-
),
|
127
|
-
self.half_frequency_dim,
|
128
|
-
)
|
129
|
-
)
|
130
|
-
freqs = ops.cast(freqs, compute_dtype)
|
320
|
+
freqs = ops.cast(self.freq, compute_dtype)
|
131
321
|
x = ops.multiply(x, ops.expand_dims(freqs, axis=0))
|
132
322
|
embedding = ops.concatenate([ops.cos(x), ops.sin(x)], axis=-1)
|
133
323
|
if self.frequency_dim % 2 != 0:
|
@@ -143,6 +333,7 @@ class TimestepEmbedding(layers.Layer):
|
|
143
333
|
config.update(
|
144
334
|
{
|
145
335
|
"embedding_dim": self.embedding_dim,
|
336
|
+
"frequency_dim": self.frequency_dim,
|
146
337
|
"max_period": self.max_period,
|
147
338
|
}
|
148
339
|
)
|
@@ -155,6 +346,18 @@ class TimestepEmbedding(layers.Layer):
|
|
155
346
|
|
156
347
|
|
157
348
|
class DismantledBlock(layers.Layer):
|
349
|
+
"""A dismantled block used to compute pre- and post-attention.
|
350
|
+
|
351
|
+
Args:
|
352
|
+
num_heads: int. Number of attention heads.
|
353
|
+
hidden_dim: int. The number of units in the hidden layers.
|
354
|
+
mlp_ratio: float. The expansion ratio of `MLP`.
|
355
|
+
use_projection: bool. Whether to use an attention projection layer at
|
356
|
+
the end of the block.
|
357
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
358
|
+
including `name`, `dtype` etc.
|
359
|
+
"""
|
360
|
+
|
158
361
|
def __init__(
|
159
362
|
self,
|
160
363
|
num_heads,
|
@@ -173,25 +376,18 @@ class DismantledBlock(layers.Layer):
|
|
173
376
|
self.head_dim = head_dim
|
174
377
|
mlp_hidden_dim = int(hidden_dim * mlp_ratio)
|
175
378
|
self.mlp_hidden_dim = mlp_hidden_dim
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
self.norm1 = layers.LayerNormalization(
|
189
|
-
epsilon=1e-6,
|
190
|
-
center=False,
|
191
|
-
scale=False,
|
192
|
-
dtype="float32",
|
193
|
-
name="norm1",
|
194
|
-
)
|
379
|
+
|
380
|
+
if use_projection:
|
381
|
+
self.ada_layer_norm = AdaptiveLayerNormalization(
|
382
|
+
hidden_dim,
|
383
|
+
residual_modulation=True,
|
384
|
+
dtype=self.dtype_policy,
|
385
|
+
name="ada_layer_norm",
|
386
|
+
)
|
387
|
+
else:
|
388
|
+
self.ada_layer_norm = AdaptiveLayerNormalization(
|
389
|
+
hidden_dim, dtype=self.dtype_policy, name="ada_layer_norm"
|
390
|
+
)
|
195
391
|
self.attention_qkv = layers.Dense(
|
196
392
|
hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv"
|
197
393
|
)
|
@@ -206,73 +402,45 @@ class DismantledBlock(layers.Layer):
|
|
206
402
|
dtype="float32",
|
207
403
|
name="norm2",
|
208
404
|
)
|
209
|
-
self.mlp =
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
dtype=self.dtype_policy,
|
215
|
-
),
|
216
|
-
layers.Dense(
|
217
|
-
hidden_dim,
|
218
|
-
dtype=self.dtype_policy,
|
219
|
-
),
|
220
|
-
],
|
405
|
+
self.mlp = MLP(
|
406
|
+
mlp_hidden_dim,
|
407
|
+
hidden_dim,
|
408
|
+
gelu_approximate,
|
409
|
+
dtype=self.dtype_policy,
|
221
410
|
name="mlp",
|
222
411
|
)
|
223
412
|
|
224
413
|
def build(self, inputs_shape, timestep_embedding):
|
225
|
-
self.
|
414
|
+
self.ada_layer_norm.build(inputs_shape, timestep_embedding)
|
226
415
|
self.attention_qkv.build(inputs_shape)
|
227
|
-
self.norm1.build(inputs_shape)
|
228
416
|
if self.use_projection:
|
229
417
|
self.attention_proj.build(inputs_shape)
|
230
418
|
self.norm2.build(inputs_shape)
|
231
419
|
self.mlp.build(inputs_shape)
|
232
420
|
|
233
421
|
def _modulate(self, inputs, shift, scale):
|
234
|
-
|
235
|
-
|
422
|
+
inputs = ops.cast(inputs, self.compute_dtype)
|
423
|
+
shift = ops.cast(shift, self.compute_dtype)
|
424
|
+
scale = ops.cast(scale, self.compute_dtype)
|
236
425
|
return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
|
237
426
|
|
238
427
|
def _compute_pre_attention(self, inputs, timestep_embedding, training=None):
|
239
428
|
batch_size = ops.shape(inputs)[0]
|
240
429
|
if self.use_projection:
|
241
|
-
|
242
|
-
timestep_embedding, training=training
|
243
|
-
)
|
244
|
-
modulation = ops.reshape(
|
245
|
-
modulation, (batch_size, 6, self.hidden_dim)
|
246
|
-
)
|
247
|
-
(
|
248
|
-
shift_msa,
|
249
|
-
scale_msa,
|
250
|
-
gate_msa,
|
251
|
-
shift_mlp,
|
252
|
-
scale_mlp,
|
253
|
-
gate_mlp,
|
254
|
-
) = ops.unstack(modulation, 6, axis=1)
|
255
|
-
qkv = self.attention_qkv(
|
256
|
-
self._modulate(self.norm1(inputs), shift_msa, scale_msa),
|
257
|
-
training=training,
|
430
|
+
x, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.ada_layer_norm(
|
431
|
+
inputs, timestep_embedding, training=training
|
258
432
|
)
|
433
|
+
qkv = self.attention_qkv(x, training=training)
|
259
434
|
qkv = ops.reshape(
|
260
435
|
qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
|
261
436
|
)
|
262
437
|
q, k, v = ops.unstack(qkv, 3, axis=2)
|
263
438
|
return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp)
|
264
439
|
else:
|
265
|
-
|
266
|
-
timestep_embedding, training=training
|
267
|
-
)
|
268
|
-
modulation = ops.reshape(
|
269
|
-
modulation, (batch_size, 2, self.hidden_dim)
|
270
|
-
)
|
271
|
-
shift_msa, scale_msa = ops.unstack(modulation, 2, axis=1)
|
272
|
-
qkv = self.attention_qkv(
|
273
|
-
self._modulate(self.norm1(inputs), shift_msa, scale_msa),
|
274
|
-
training=training,
|
440
|
+
x = self.ada_layer_norm(
|
441
|
+
inputs, timestep_embedding, training=training
|
275
442
|
)
|
443
|
+
qkv = self.attention_qkv(x, training=training)
|
276
444
|
qkv = ops.reshape(
|
277
445
|
qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
|
278
446
|
)
|
@@ -283,12 +451,16 @@ class DismantledBlock(layers.Layer):
|
|
283
451
|
self, inputs, inputs_intermediates, training=None
|
284
452
|
):
|
285
453
|
x, gate_msa, shift_mlp, scale_mlp, gate_mlp = inputs_intermediates
|
454
|
+
gate_msa = ops.expand_dims(gate_msa, axis=1)
|
455
|
+
shift_mlp = ops.expand_dims(shift_mlp, axis=1)
|
456
|
+
scale_mlp = ops.expand_dims(scale_mlp, axis=1)
|
457
|
+
gate_mlp = ops.expand_dims(gate_mlp, axis=1)
|
286
458
|
attn = self.attention_proj(inputs, training=training)
|
287
|
-
x = ops.add(x, ops.multiply(
|
459
|
+
x = ops.add(x, ops.multiply(gate_msa, attn))
|
288
460
|
x = ops.add(
|
289
461
|
x,
|
290
462
|
ops.multiply(
|
291
|
-
|
463
|
+
gate_mlp,
|
292
464
|
self.mlp(
|
293
465
|
self._modulate(self.norm2(x), shift_mlp, scale_mlp),
|
294
466
|
training=training,
|
@@ -328,6 +500,27 @@ class DismantledBlock(layers.Layer):
|
|
328
500
|
|
329
501
|
|
330
502
|
class MMDiTBlock(layers.Layer):
|
503
|
+
"""A MMDiT block consisting of two `DismantledBlock` layers.
|
504
|
+
|
505
|
+
One `DismantledBlock` processes the input latents, and the other processes
|
506
|
+
the context embedding. This block integrates two modalities within the
|
507
|
+
attention operation, allowing each representation to operate in its own
|
508
|
+
space while considering the other.
|
509
|
+
|
510
|
+
Args:
|
511
|
+
num_heads: int. Number of attention heads.
|
512
|
+
hidden_dim: int. The number of units in the hidden layers.
|
513
|
+
mlp_ratio: float. The expansion ratio of `MLP`.
|
514
|
+
use_context_projection: bool. Whether to use an attention projection
|
515
|
+
layer at the end of the context block.
|
516
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
517
|
+
including `name`, `dtype` etc.
|
518
|
+
|
519
|
+
Reference:
|
520
|
+
- [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](
|
521
|
+
https://arxiv.org/abs/2403.03206)
|
522
|
+
"""
|
523
|
+
|
331
524
|
def __init__(
|
332
525
|
self,
|
333
526
|
num_heads,
|
@@ -345,8 +538,6 @@ class MMDiTBlock(layers.Layer):
|
|
345
538
|
head_dim = hidden_dim // num_heads
|
346
539
|
self.head_dim = head_dim
|
347
540
|
self._inverse_sqrt_key_dim = 1.0 / math.sqrt(head_dim)
|
348
|
-
self._dot_product_equation = "aecd,abcd->acbe"
|
349
|
-
self._combine_equation = "acbe,aecd->abcd"
|
350
541
|
|
351
542
|
self.x_block = DismantledBlock(
|
352
543
|
num_heads=num_heads,
|
@@ -371,20 +562,18 @@ class MMDiTBlock(layers.Layer):
|
|
371
562
|
self.context_block.build(context_shape, timestep_embedding_shape)
|
372
563
|
|
373
564
|
def _compute_attention(self, query, key, value):
|
374
|
-
|
375
|
-
|
376
|
-
)
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
)
|
387
|
-
return attention_output
|
565
|
+
# Ref: jax.nn.dot_product_attention
|
566
|
+
# https://github.com/jax-ml/jax/blob/db89c245ac66911c98f265a05956fdfa4bc79d83/jax/_src/nn/functions.py#L846
|
567
|
+
batch_size = ops.shape(query)[0]
|
568
|
+
logits = ops.einsum("BTNH,BSNH->BNTS", query, key)
|
569
|
+
logits = ops.multiply(logits, self._inverse_sqrt_key_dim)
|
570
|
+
probs = self.softmax(logits)
|
571
|
+
probs = ops.cast(probs, self.compute_dtype)
|
572
|
+
encoded = ops.einsum("BNTS,BSNH->BTNH", probs, value)
|
573
|
+
encoded = ops.reshape(
|
574
|
+
encoded, (batch_size, -1, self.num_heads * self.head_dim)
|
575
|
+
)
|
576
|
+
return encoded
|
388
577
|
|
389
578
|
def call(self, inputs, context, timestep_embedding, training=None):
|
390
579
|
# Compute pre-attention.
|
@@ -453,74 +642,16 @@ class MMDiTBlock(layers.Layer):
|
|
453
642
|
return inputs_shape
|
454
643
|
|
455
644
|
|
456
|
-
class
|
457
|
-
|
458
|
-
super().__init__(**kwargs)
|
459
|
-
self.hidden_dim = hidden_dim
|
460
|
-
self.output_dim = output_dim
|
461
|
-
num_modulation = 2
|
462
|
-
|
463
|
-
self.adaptive_norm_modulation = models.Sequential(
|
464
|
-
[
|
465
|
-
layers.Activation("silu", dtype=self.dtype_policy),
|
466
|
-
layers.Dense(
|
467
|
-
num_modulation * hidden_dim, dtype=self.dtype_policy
|
468
|
-
),
|
469
|
-
],
|
470
|
-
name="adaptive_norm_modulation",
|
471
|
-
)
|
472
|
-
self.norm = layers.LayerNormalization(
|
473
|
-
epsilon=1e-6,
|
474
|
-
center=False,
|
475
|
-
scale=False,
|
476
|
-
dtype="float32",
|
477
|
-
name="norm",
|
478
|
-
)
|
479
|
-
self.output_dense = layers.Dense(
|
480
|
-
output_dim,
|
481
|
-
use_bias=True,
|
482
|
-
dtype=self.dtype_policy,
|
483
|
-
name="output_dense",
|
484
|
-
)
|
485
|
-
|
486
|
-
def build(self, inputs_shape, timestep_embedding_shape):
|
487
|
-
self.adaptive_norm_modulation.build(timestep_embedding_shape)
|
488
|
-
self.norm.build(inputs_shape)
|
489
|
-
self.output_dense.build(inputs_shape)
|
490
|
-
|
491
|
-
def _modulate(self, inputs, shift, scale):
|
492
|
-
shift = ops.expand_dims(shift, axis=1)
|
493
|
-
scale = ops.expand_dims(scale, axis=1)
|
494
|
-
return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
|
495
|
-
|
496
|
-
def call(self, inputs, timestep_embedding, training=None):
|
497
|
-
x = inputs
|
498
|
-
modulation = self.adaptive_norm_modulation(
|
499
|
-
timestep_embedding, training=training
|
500
|
-
)
|
501
|
-
modulation = ops.reshape(modulation, (-1, 2, self.hidden_dim))
|
502
|
-
shift, scale = ops.unstack(modulation, 2, axis=1)
|
503
|
-
x = self._modulate(self.norm(x), shift, scale)
|
504
|
-
x = self.output_dense(x, training=training)
|
505
|
-
return x
|
506
|
-
|
507
|
-
def get_config(self):
|
508
|
-
config = super().get_config()
|
509
|
-
config.update(
|
510
|
-
{
|
511
|
-
"hidden_dim": self.hidden_dim,
|
512
|
-
"output_dim": self.output_dim,
|
513
|
-
}
|
514
|
-
)
|
515
|
-
return config
|
516
|
-
|
517
|
-
def compute_output_shape(self, inputs_shape):
|
518
|
-
outputs_shape = list(inputs_shape)
|
519
|
-
outputs_shape[-1] = self.output_dim
|
520
|
-
return outputs_shape
|
645
|
+
class Unpatch(layers.Layer):
|
646
|
+
"""A layer that reconstructs the image from hidden patches.
|
521
647
|
|
648
|
+
Args:
|
649
|
+
patch_size: int. The size of each square patch in the input image.
|
650
|
+
output_dim: int. The number of units in the output layer.
|
651
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
652
|
+
including `name`, `dtype` etc.
|
653
|
+
"""
|
522
654
|
|
523
|
-
class Unpatch(layers.Layer):
|
524
655
|
def __init__(self, patch_size, output_dim, **kwargs):
|
525
656
|
super().__init__(**kwargs)
|
526
657
|
self.patch_size = int(patch_size)
|
@@ -556,7 +687,7 @@ class Unpatch(layers.Layer):
|
|
556
687
|
|
557
688
|
|
558
689
|
class MMDiT(Backbone):
|
559
|
-
"""Multimodal Diffusion Transformer (MMDiT) model
|
690
|
+
"""A Multimodal Diffusion Transformer (MMDiT) model.
|
560
691
|
|
561
692
|
MMDiT is introduced in [
|
562
693
|
Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](
|
@@ -636,12 +767,8 @@ class MMDiT(Backbone):
|
|
636
767
|
dtype=dtype,
|
637
768
|
name="context_embedding",
|
638
769
|
)
|
639
|
-
self.vector_embedding =
|
640
|
-
|
641
|
-
layers.Dense(hidden_dim, activation="silu", dtype=dtype),
|
642
|
-
layers.Dense(hidden_dim, activation=None, dtype=dtype),
|
643
|
-
],
|
644
|
-
name="vector_embedding",
|
770
|
+
self.vector_embedding = MLP(
|
771
|
+
hidden_dim, hidden_dim, "silu", dtype=dtype, name="vector_embedding"
|
645
772
|
)
|
646
773
|
self.vector_embedding_add = layers.Add(
|
647
774
|
dtype=dtype, name="vector_embedding_add"
|
@@ -660,8 +787,11 @@ class MMDiT(Backbone):
|
|
660
787
|
)
|
661
788
|
for i in range(num_layers)
|
662
789
|
]
|
663
|
-
self.
|
664
|
-
hidden_dim,
|
790
|
+
self.output_ada_layer_norm = AdaptiveLayerNormalization(
|
791
|
+
hidden_dim, dtype=dtype, name="output_ada_layer_norm"
|
792
|
+
)
|
793
|
+
self.output_dense = layers.Dense(
|
794
|
+
output_dim_in_final, dtype=dtype, name="output_dense"
|
665
795
|
)
|
666
796
|
self.unpatch = Unpatch(
|
667
797
|
patch_size, output_dim, dtype=dtype, name="unpatch"
|
@@ -696,7 +826,8 @@ class MMDiT(Backbone):
|
|
696
826
|
x = block(x, context, timestep_embedding)
|
697
827
|
|
698
828
|
# Output layer.
|
699
|
-
x = self.
|
829
|
+
x = self.output_ada_layer_norm(x, timestep_embedding)
|
830
|
+
x = self.output_dense(x)
|
700
831
|
outputs = self.unpatch(x, height=image_height, width=image_width)
|
701
832
|
|
702
833
|
super().__init__(
|