keras-hub-nightly 0.16.1.dev202410080341__py3-none-any.whl → 0.16.1.dev202410100339__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. keras_hub/api/layers/__init__.py +3 -0
  2. keras_hub/api/models/__init__.py +11 -0
  3. keras_hub/src/layers/preprocessing/image_converter.py +2 -1
  4. keras_hub/src/models/image_to_image.py +411 -0
  5. keras_hub/src/models/inpaint.py +513 -0
  6. keras_hub/src/models/mix_transformer/__init__.py +12 -0
  7. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +4 -0
  8. keras_hub/src/models/mix_transformer/mix_transformer_classifier_preprocessor.py +16 -0
  9. keras_hub/src/models/mix_transformer/mix_transformer_image_converter.py +8 -0
  10. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +9 -5
  11. keras_hub/src/models/mix_transformer/mix_transformer_presets.py +151 -0
  12. keras_hub/src/models/preprocessor.py +4 -4
  13. keras_hub/src/models/stable_diffusion_3/mmdit.py +308 -177
  14. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +87 -55
  15. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +171 -0
  16. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +194 -0
  17. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +1 -1
  18. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +13 -8
  19. keras_hub/src/models/task.py +1 -1
  20. keras_hub/src/models/text_to_image.py +89 -36
  21. keras_hub/src/tests/test_case.py +3 -1
  22. keras_hub/src/tokenizers/tokenizer.py +7 -7
  23. keras_hub/src/utils/preset_utils.py +7 -7
  24. keras_hub/src/utils/timm/preset_loader.py +1 -3
  25. keras_hub/src/version_utils.py +1 -1
  26. {keras_hub_nightly-0.16.1.dev202410080341.dist-info → keras_hub_nightly-0.16.1.dev202410100339.dist-info}/METADATA +1 -1
  27. {keras_hub_nightly-0.16.1.dev202410080341.dist-info → keras_hub_nightly-0.16.1.dev202410100339.dist-info}/RECORD +29 -22
  28. {keras_hub_nightly-0.16.1.dev202410080341.dist-info → keras_hub_nightly-0.16.1.dev202410100339.dist-info}/WHEEL +0 -0
  29. {keras_hub_nightly-0.16.1.dev202410080341.dist-info → keras_hub_nightly-0.16.1.dev202410100339.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,6 @@ import math
2
2
 
3
3
  import keras
4
4
  from keras import layers
5
- from keras import models
6
5
  from keras import ops
7
6
 
8
7
  from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
@@ -11,7 +10,167 @@ from keras_hub.src.utils.keras_utils import gelu_approximate
11
10
  from keras_hub.src.utils.keras_utils import standardize_data_format
12
11
 
13
12
 
13
+ class AdaptiveLayerNormalization(layers.Layer):
14
+ """Adaptive layer normalization.
15
+
16
+ Args:
17
+ embedding_dim: int. The size of each embedding vector.
18
+ residual_modulation: bool. Whether to output the modulation parameters
19
+ of the residual connection within the block of the diffusion
20
+ transformers. Defaults to `False`.
21
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
22
+ including `name`, `dtype` etc.
23
+
24
+ References:
25
+ - [FiLM: Visual Reasoning with a General Conditioning Layer](
26
+ https://arxiv.org/abs/1709.07871).
27
+ - [Scalable Diffusion Models with Transformers](
28
+ https://arxiv.org/abs/2212.09748).
29
+ """
30
+
31
+ def __init__(self, hidden_dim, residual_modulation=False, **kwargs):
32
+ super().__init__(**kwargs)
33
+ self.hidden_dim = int(hidden_dim)
34
+ self.residual_modulation = bool(residual_modulation)
35
+ num_modulations = 6 if self.residual_modulation else 2
36
+
37
+ self.silu = layers.Activation("silu", dtype=self.dtype_policy)
38
+ self.dense = layers.Dense(
39
+ num_modulations * hidden_dim, dtype=self.dtype_policy, name="dense"
40
+ )
41
+ self.norm = layers.LayerNormalization(
42
+ epsilon=1e-6,
43
+ center=False,
44
+ scale=False,
45
+ dtype="float32",
46
+ name="norm",
47
+ )
48
+
49
+ def build(self, inputs_shape, embeddings_shape):
50
+ self.silu.build(embeddings_shape)
51
+ self.dense.build(embeddings_shape)
52
+ self.norm.build(inputs_shape)
53
+
54
+ def call(self, inputs, embeddings, training=None):
55
+ x = inputs
56
+ emb = self.dense(self.silu(embeddings), training=training)
57
+ if self.residual_modulation:
58
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
59
+ ops.split(emb, 6, axis=1)
60
+ )
61
+ else:
62
+ shift_msa, scale_msa = ops.split(emb, 2, axis=1)
63
+ scale_msa = ops.expand_dims(scale_msa, axis=1)
64
+ shift_msa = ops.expand_dims(shift_msa, axis=1)
65
+ x = ops.add(
66
+ ops.multiply(
67
+ self.norm(x, training=training),
68
+ ops.add(1.0, scale_msa),
69
+ ),
70
+ shift_msa,
71
+ )
72
+ if self.residual_modulation:
73
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
74
+ else:
75
+ return x
76
+
77
+ def get_config(self):
78
+ config = super().get_config()
79
+ config.update(
80
+ {
81
+ "hidden_dim": self.hidden_dim,
82
+ "residual_modulation": self.residual_modulation,
83
+ }
84
+ )
85
+ return config
86
+
87
+ def compute_output_shape(self, inputs_shape, embeddings_shape):
88
+ if self.residual_modulation:
89
+ return (
90
+ inputs_shape,
91
+ embeddings_shape,
92
+ embeddings_shape,
93
+ embeddings_shape,
94
+ embeddings_shape,
95
+ )
96
+ else:
97
+ return inputs_shape
98
+
99
+
100
+ class MLP(layers.Layer):
101
+ """A MLP block with architecture.
102
+
103
+ Args:
104
+ hidden_dim: int. The number of units in the hidden layers.
105
+ output_dim: int. The number of units in the output layer.
106
+ activation: str of callable. Activation to use in the hidden layers.
107
+ Default to `None`.
108
+ """
109
+
110
+ def __init__(self, hidden_dim, output_dim, activation=None, **kwargs):
111
+ super().__init__(**kwargs)
112
+ self.hidden_dim = int(hidden_dim)
113
+ self.output_dim = int(output_dim)
114
+ self.activation = keras.activations.get(activation)
115
+
116
+ self.dense1 = layers.Dense(
117
+ hidden_dim,
118
+ activation=self.activation,
119
+ dtype=self.dtype_policy,
120
+ name="dense1",
121
+ )
122
+ self.dense2 = layers.Dense(
123
+ output_dim,
124
+ activation=None,
125
+ dtype=self.dtype_policy,
126
+ name="dense2",
127
+ )
128
+
129
+ def build(self, inputs_shape):
130
+ self.dense1.build(inputs_shape)
131
+ inputs_shape = self.dense1.compute_output_shape(inputs_shape)
132
+ self.dense2.build(inputs_shape)
133
+
134
+ def call(self, inputs, training=None):
135
+ x = self.dense1(inputs, training=training)
136
+ return self.dense2(x, training=training)
137
+
138
+ def get_config(self):
139
+ config = super().get_config()
140
+ config.update(
141
+ {
142
+ "hidden_dim": self.hidden_dim,
143
+ "output_dim": self.output_dim,
144
+ "activation": keras.activations.serialize(self.activation),
145
+ }
146
+ )
147
+ return config
148
+
149
+ def compute_output_shape(self, inputs_shape):
150
+ outputs_shape = list(inputs_shape)
151
+ outputs_shape[-1] = self.output_dim
152
+ return outputs_shape
153
+
154
+
14
155
  class PatchEmbedding(layers.Layer):
156
+ """A layer that converts images into patches.
157
+
158
+ Args:
159
+ patch_size: int. The size of one side of each patch.
160
+ hidden_dim: int. The number of units in the hidden layers.
161
+ data_format: `None` or str. If specified, either `"channels_last"` or
162
+ `"channels_first"`. The ordering of the dimensions in the
163
+ inputs. `"channels_last"` corresponds to inputs with shape
164
+ `(batch_size, height, width, channels)`
165
+ while `"channels_first"` corresponds to inputs with shape
166
+ `(batch_size, channels, height, width)`. It defaults to the
167
+ `image_data_format` value found in your Keras config file at
168
+ `~/.keras/keras.json`. If you never set it, then it will be
169
+ `"channels_last"`.
170
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
171
+ including `name`, `dtype` etc.
172
+ """
173
+
15
174
  def __init__(self, patch_size, hidden_dim, data_format=None, **kwargs):
16
175
  super().__init__(**kwargs)
17
176
  self.patch_size = int(patch_size)
@@ -48,6 +207,15 @@ class PatchEmbedding(layers.Layer):
48
207
 
49
208
 
50
209
  class AdjustablePositionEmbedding(PositionEmbedding):
210
+ """A position embedding layer with adjustable height and width.
211
+
212
+ The embedding will be cropped to match the input dimensions.
213
+
214
+ Args:
215
+ height: int. The maximum height of the embedding.
216
+ width: int. The maximum width of the embedding.
217
+ """
218
+
51
219
  def __init__(
52
220
  self,
53
221
  height,
@@ -84,11 +252,36 @@ class AdjustablePositionEmbedding(PositionEmbedding):
84
252
  position_embedding = ops.expand_dims(position_embedding, axis=0)
85
253
  return position_embedding
86
254
 
255
+ def get_config(self):
256
+ config = super().get_config()
257
+ del config["sequence_length"]
258
+ config.update(
259
+ {
260
+ "height": self.height,
261
+ "width": self.width,
262
+ }
263
+ )
264
+ return config
265
+
87
266
  def compute_output_shape(self, input_shape):
88
267
  return input_shape
89
268
 
90
269
 
91
270
  class TimestepEmbedding(layers.Layer):
271
+ """A layer which learns embedding for input timesteps.
272
+
273
+ Args:
274
+ embedding_dim: int. The size of the embedding.
275
+ frequency_dim: int. The size of the frequency.
276
+ max_period: int. Controls the maximum frequency of the embeddings.
277
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
278
+ including `name`, `dtype` etc.
279
+
280
+ Reference:
281
+ - [Denoising Diffusion Probabilistic Models](
282
+ https://arxiv.org/abs/2006.11239).
283
+ """
284
+
92
285
  def __init__(
93
286
  self, embedding_dim, frequency_dim=256, max_period=10000, **kwargs
94
287
  ):
@@ -96,17 +289,23 @@ class TimestepEmbedding(layers.Layer):
96
289
  self.embedding_dim = int(embedding_dim)
97
290
  self.frequency_dim = int(frequency_dim)
98
291
  self.max_period = float(max_period)
99
- self.half_frequency_dim = self.frequency_dim // 2
100
-
101
- self.mlp = models.Sequential(
102
- [
103
- layers.Dense(
104
- embedding_dim, activation="silu", dtype=self.dtype_policy
105
- ),
106
- layers.Dense(
107
- embedding_dim, activation=None, dtype=self.dtype_policy
292
+ # Precomputed `freq`.
293
+ half_frequency_dim = frequency_dim // 2
294
+ self.freq = ops.exp(
295
+ ops.divide(
296
+ ops.multiply(
297
+ -math.log(max_period),
298
+ ops.arange(0, half_frequency_dim, dtype="float32"),
108
299
  ),
109
- ],
300
+ half_frequency_dim,
301
+ )
302
+ )
303
+
304
+ self.mlp = MLP(
305
+ embedding_dim,
306
+ embedding_dim,
307
+ "silu",
308
+ dtype=self.dtype_policy,
110
309
  name="mlp",
111
310
  )
112
311
 
@@ -118,16 +317,7 @@ class TimestepEmbedding(layers.Layer):
118
317
  def _create_timestep_embedding(self, inputs):
119
318
  compute_dtype = keras.backend.result_type(self.compute_dtype, "float32")
120
319
  x = ops.cast(inputs, compute_dtype)
121
- freqs = ops.exp(
122
- ops.divide(
123
- ops.multiply(
124
- -math.log(self.max_period),
125
- ops.arange(0, self.half_frequency_dim, dtype="float32"),
126
- ),
127
- self.half_frequency_dim,
128
- )
129
- )
130
- freqs = ops.cast(freqs, compute_dtype)
320
+ freqs = ops.cast(self.freq, compute_dtype)
131
321
  x = ops.multiply(x, ops.expand_dims(freqs, axis=0))
132
322
  embedding = ops.concatenate([ops.cos(x), ops.sin(x)], axis=-1)
133
323
  if self.frequency_dim % 2 != 0:
@@ -143,6 +333,7 @@ class TimestepEmbedding(layers.Layer):
143
333
  config.update(
144
334
  {
145
335
  "embedding_dim": self.embedding_dim,
336
+ "frequency_dim": self.frequency_dim,
146
337
  "max_period": self.max_period,
147
338
  }
148
339
  )
@@ -155,6 +346,18 @@ class TimestepEmbedding(layers.Layer):
155
346
 
156
347
 
157
348
  class DismantledBlock(layers.Layer):
349
+ """A dismantled block used to compute pre- and post-attention.
350
+
351
+ Args:
352
+ num_heads: int. Number of attention heads.
353
+ hidden_dim: int. The number of units in the hidden layers.
354
+ mlp_ratio: float. The expansion ratio of `MLP`.
355
+ use_projection: bool. Whether to use an attention projection layer at
356
+ the end of the block.
357
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
358
+ including `name`, `dtype` etc.
359
+ """
360
+
158
361
  def __init__(
159
362
  self,
160
363
  num_heads,
@@ -173,25 +376,18 @@ class DismantledBlock(layers.Layer):
173
376
  self.head_dim = head_dim
174
377
  mlp_hidden_dim = int(hidden_dim * mlp_ratio)
175
378
  self.mlp_hidden_dim = mlp_hidden_dim
176
- num_modulations = 6 if use_projection else 2
177
- self.num_modulations = num_modulations
178
-
179
- self.adaptive_norm_modulation = models.Sequential(
180
- [
181
- layers.Activation("silu", dtype=self.dtype_policy),
182
- layers.Dense(
183
- num_modulations * hidden_dim, dtype=self.dtype_policy
184
- ),
185
- ],
186
- name="adaptive_norm_modulation",
187
- )
188
- self.norm1 = layers.LayerNormalization(
189
- epsilon=1e-6,
190
- center=False,
191
- scale=False,
192
- dtype="float32",
193
- name="norm1",
194
- )
379
+
380
+ if use_projection:
381
+ self.ada_layer_norm = AdaptiveLayerNormalization(
382
+ hidden_dim,
383
+ residual_modulation=True,
384
+ dtype=self.dtype_policy,
385
+ name="ada_layer_norm",
386
+ )
387
+ else:
388
+ self.ada_layer_norm = AdaptiveLayerNormalization(
389
+ hidden_dim, dtype=self.dtype_policy, name="ada_layer_norm"
390
+ )
195
391
  self.attention_qkv = layers.Dense(
196
392
  hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv"
197
393
  )
@@ -206,73 +402,45 @@ class DismantledBlock(layers.Layer):
206
402
  dtype="float32",
207
403
  name="norm2",
208
404
  )
209
- self.mlp = models.Sequential(
210
- [
211
- layers.Dense(
212
- mlp_hidden_dim,
213
- activation=gelu_approximate,
214
- dtype=self.dtype_policy,
215
- ),
216
- layers.Dense(
217
- hidden_dim,
218
- dtype=self.dtype_policy,
219
- ),
220
- ],
405
+ self.mlp = MLP(
406
+ mlp_hidden_dim,
407
+ hidden_dim,
408
+ gelu_approximate,
409
+ dtype=self.dtype_policy,
221
410
  name="mlp",
222
411
  )
223
412
 
224
413
  def build(self, inputs_shape, timestep_embedding):
225
- self.adaptive_norm_modulation.build(timestep_embedding)
414
+ self.ada_layer_norm.build(inputs_shape, timestep_embedding)
226
415
  self.attention_qkv.build(inputs_shape)
227
- self.norm1.build(inputs_shape)
228
416
  if self.use_projection:
229
417
  self.attention_proj.build(inputs_shape)
230
418
  self.norm2.build(inputs_shape)
231
419
  self.mlp.build(inputs_shape)
232
420
 
233
421
  def _modulate(self, inputs, shift, scale):
234
- shift = ops.expand_dims(shift, axis=1)
235
- scale = ops.expand_dims(scale, axis=1)
422
+ inputs = ops.cast(inputs, self.compute_dtype)
423
+ shift = ops.cast(shift, self.compute_dtype)
424
+ scale = ops.cast(scale, self.compute_dtype)
236
425
  return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
237
426
 
238
427
  def _compute_pre_attention(self, inputs, timestep_embedding, training=None):
239
428
  batch_size = ops.shape(inputs)[0]
240
429
  if self.use_projection:
241
- modulation = self.adaptive_norm_modulation(
242
- timestep_embedding, training=training
243
- )
244
- modulation = ops.reshape(
245
- modulation, (batch_size, 6, self.hidden_dim)
246
- )
247
- (
248
- shift_msa,
249
- scale_msa,
250
- gate_msa,
251
- shift_mlp,
252
- scale_mlp,
253
- gate_mlp,
254
- ) = ops.unstack(modulation, 6, axis=1)
255
- qkv = self.attention_qkv(
256
- self._modulate(self.norm1(inputs), shift_msa, scale_msa),
257
- training=training,
430
+ x, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.ada_layer_norm(
431
+ inputs, timestep_embedding, training=training
258
432
  )
433
+ qkv = self.attention_qkv(x, training=training)
259
434
  qkv = ops.reshape(
260
435
  qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
261
436
  )
262
437
  q, k, v = ops.unstack(qkv, 3, axis=2)
263
438
  return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp)
264
439
  else:
265
- modulation = self.adaptive_norm_modulation(
266
- timestep_embedding, training=training
267
- )
268
- modulation = ops.reshape(
269
- modulation, (batch_size, 2, self.hidden_dim)
270
- )
271
- shift_msa, scale_msa = ops.unstack(modulation, 2, axis=1)
272
- qkv = self.attention_qkv(
273
- self._modulate(self.norm1(inputs), shift_msa, scale_msa),
274
- training=training,
440
+ x = self.ada_layer_norm(
441
+ inputs, timestep_embedding, training=training
275
442
  )
443
+ qkv = self.attention_qkv(x, training=training)
276
444
  qkv = ops.reshape(
277
445
  qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
278
446
  )
@@ -283,12 +451,16 @@ class DismantledBlock(layers.Layer):
283
451
  self, inputs, inputs_intermediates, training=None
284
452
  ):
285
453
  x, gate_msa, shift_mlp, scale_mlp, gate_mlp = inputs_intermediates
454
+ gate_msa = ops.expand_dims(gate_msa, axis=1)
455
+ shift_mlp = ops.expand_dims(shift_mlp, axis=1)
456
+ scale_mlp = ops.expand_dims(scale_mlp, axis=1)
457
+ gate_mlp = ops.expand_dims(gate_mlp, axis=1)
286
458
  attn = self.attention_proj(inputs, training=training)
287
- x = ops.add(x, ops.multiply(ops.expand_dims(gate_msa, axis=1), attn))
459
+ x = ops.add(x, ops.multiply(gate_msa, attn))
288
460
  x = ops.add(
289
461
  x,
290
462
  ops.multiply(
291
- ops.expand_dims(gate_mlp, axis=1),
463
+ gate_mlp,
292
464
  self.mlp(
293
465
  self._modulate(self.norm2(x), shift_mlp, scale_mlp),
294
466
  training=training,
@@ -328,6 +500,27 @@ class DismantledBlock(layers.Layer):
328
500
 
329
501
 
330
502
  class MMDiTBlock(layers.Layer):
503
+ """A MMDiT block consisting of two `DismantledBlock` layers.
504
+
505
+ One `DismantledBlock` processes the input latents, and the other processes
506
+ the context embedding. This block integrates two modalities within the
507
+ attention operation, allowing each representation to operate in its own
508
+ space while considering the other.
509
+
510
+ Args:
511
+ num_heads: int. Number of attention heads.
512
+ hidden_dim: int. The number of units in the hidden layers.
513
+ mlp_ratio: float. The expansion ratio of `MLP`.
514
+ use_context_projection: bool. Whether to use an attention projection
515
+ layer at the end of the context block.
516
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
517
+ including `name`, `dtype` etc.
518
+
519
+ Reference:
520
+ - [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](
521
+ https://arxiv.org/abs/2403.03206)
522
+ """
523
+
331
524
  def __init__(
332
525
  self,
333
526
  num_heads,
@@ -345,8 +538,6 @@ class MMDiTBlock(layers.Layer):
345
538
  head_dim = hidden_dim // num_heads
346
539
  self.head_dim = head_dim
347
540
  self._inverse_sqrt_key_dim = 1.0 / math.sqrt(head_dim)
348
- self._dot_product_equation = "aecd,abcd->acbe"
349
- self._combine_equation = "acbe,aecd->abcd"
350
541
 
351
542
  self.x_block = DismantledBlock(
352
543
  num_heads=num_heads,
@@ -371,20 +562,18 @@ class MMDiTBlock(layers.Layer):
371
562
  self.context_block.build(context_shape, timestep_embedding_shape)
372
563
 
373
564
  def _compute_attention(self, query, key, value):
374
- query = ops.multiply(
375
- query, ops.cast(self._inverse_sqrt_key_dim, query.dtype)
376
- )
377
- attention_scores = ops.einsum(self._dot_product_equation, key, query)
378
- attention_scores = self.softmax(attention_scores)
379
- attention_scores = ops.cast(attention_scores, self.compute_dtype)
380
- attention_output = ops.einsum(
381
- self._combine_equation, attention_scores, value
382
- )
383
- batch_size = ops.shape(attention_output)[0]
384
- attention_output = ops.reshape(
385
- attention_output, (batch_size, -1, self.num_heads * self.head_dim)
386
- )
387
- return attention_output
565
+ # Ref: jax.nn.dot_product_attention
566
+ # https://github.com/jax-ml/jax/blob/db89c245ac66911c98f265a05956fdfa4bc79d83/jax/_src/nn/functions.py#L846
567
+ batch_size = ops.shape(query)[0]
568
+ logits = ops.einsum("BTNH,BSNH->BNTS", query, key)
569
+ logits = ops.multiply(logits, self._inverse_sqrt_key_dim)
570
+ probs = self.softmax(logits)
571
+ probs = ops.cast(probs, self.compute_dtype)
572
+ encoded = ops.einsum("BNTS,BSNH->BTNH", probs, value)
573
+ encoded = ops.reshape(
574
+ encoded, (batch_size, -1, self.num_heads * self.head_dim)
575
+ )
576
+ return encoded
388
577
 
389
578
  def call(self, inputs, context, timestep_embedding, training=None):
390
579
  # Compute pre-attention.
@@ -453,74 +642,16 @@ class MMDiTBlock(layers.Layer):
453
642
  return inputs_shape
454
643
 
455
644
 
456
- class OutputLayer(layers.Layer):
457
- def __init__(self, hidden_dim, output_dim, **kwargs):
458
- super().__init__(**kwargs)
459
- self.hidden_dim = hidden_dim
460
- self.output_dim = output_dim
461
- num_modulation = 2
462
-
463
- self.adaptive_norm_modulation = models.Sequential(
464
- [
465
- layers.Activation("silu", dtype=self.dtype_policy),
466
- layers.Dense(
467
- num_modulation * hidden_dim, dtype=self.dtype_policy
468
- ),
469
- ],
470
- name="adaptive_norm_modulation",
471
- )
472
- self.norm = layers.LayerNormalization(
473
- epsilon=1e-6,
474
- center=False,
475
- scale=False,
476
- dtype="float32",
477
- name="norm",
478
- )
479
- self.output_dense = layers.Dense(
480
- output_dim,
481
- use_bias=True,
482
- dtype=self.dtype_policy,
483
- name="output_dense",
484
- )
485
-
486
- def build(self, inputs_shape, timestep_embedding_shape):
487
- self.adaptive_norm_modulation.build(timestep_embedding_shape)
488
- self.norm.build(inputs_shape)
489
- self.output_dense.build(inputs_shape)
490
-
491
- def _modulate(self, inputs, shift, scale):
492
- shift = ops.expand_dims(shift, axis=1)
493
- scale = ops.expand_dims(scale, axis=1)
494
- return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
495
-
496
- def call(self, inputs, timestep_embedding, training=None):
497
- x = inputs
498
- modulation = self.adaptive_norm_modulation(
499
- timestep_embedding, training=training
500
- )
501
- modulation = ops.reshape(modulation, (-1, 2, self.hidden_dim))
502
- shift, scale = ops.unstack(modulation, 2, axis=1)
503
- x = self._modulate(self.norm(x), shift, scale)
504
- x = self.output_dense(x, training=training)
505
- return x
506
-
507
- def get_config(self):
508
- config = super().get_config()
509
- config.update(
510
- {
511
- "hidden_dim": self.hidden_dim,
512
- "output_dim": self.output_dim,
513
- }
514
- )
515
- return config
516
-
517
- def compute_output_shape(self, inputs_shape):
518
- outputs_shape = list(inputs_shape)
519
- outputs_shape[-1] = self.output_dim
520
- return outputs_shape
645
+ class Unpatch(layers.Layer):
646
+ """A layer that reconstructs the image from hidden patches.
521
647
 
648
+ Args:
649
+ patch_size: int. The size of each square patch in the input image.
650
+ output_dim: int. The number of units in the output layer.
651
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
652
+ including `name`, `dtype` etc.
653
+ """
522
654
 
523
- class Unpatch(layers.Layer):
524
655
  def __init__(self, patch_size, output_dim, **kwargs):
525
656
  super().__init__(**kwargs)
526
657
  self.patch_size = int(patch_size)
@@ -556,7 +687,7 @@ class Unpatch(layers.Layer):
556
687
 
557
688
 
558
689
  class MMDiT(Backbone):
559
- """Multimodal Diffusion Transformer (MMDiT) model for Stable Diffusion 3.
690
+ """A Multimodal Diffusion Transformer (MMDiT) model.
560
691
 
561
692
  MMDiT is introduced in [
562
693
  Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](
@@ -636,12 +767,8 @@ class MMDiT(Backbone):
636
767
  dtype=dtype,
637
768
  name="context_embedding",
638
769
  )
639
- self.vector_embedding = models.Sequential(
640
- [
641
- layers.Dense(hidden_dim, activation="silu", dtype=dtype),
642
- layers.Dense(hidden_dim, activation=None, dtype=dtype),
643
- ],
644
- name="vector_embedding",
770
+ self.vector_embedding = MLP(
771
+ hidden_dim, hidden_dim, "silu", dtype=dtype, name="vector_embedding"
645
772
  )
646
773
  self.vector_embedding_add = layers.Add(
647
774
  dtype=dtype, name="vector_embedding_add"
@@ -660,8 +787,11 @@ class MMDiT(Backbone):
660
787
  )
661
788
  for i in range(num_layers)
662
789
  ]
663
- self.output_layer = OutputLayer(
664
- hidden_dim, output_dim_in_final, dtype=dtype, name="output_layer"
790
+ self.output_ada_layer_norm = AdaptiveLayerNormalization(
791
+ hidden_dim, dtype=dtype, name="output_ada_layer_norm"
792
+ )
793
+ self.output_dense = layers.Dense(
794
+ output_dim_in_final, dtype=dtype, name="output_dense"
665
795
  )
666
796
  self.unpatch = Unpatch(
667
797
  patch_size, output_dim, dtype=dtype, name="unpatch"
@@ -696,7 +826,8 @@ class MMDiT(Backbone):
696
826
  x = block(x, context, timestep_embedding)
697
827
 
698
828
  # Output layer.
699
- x = self.output_layer(x, timestep_embedding)
829
+ x = self.output_ada_layer_norm(x, timestep_embedding)
830
+ x = self.output_dense(x)
700
831
  outputs = self.unpatch(x, height=image_height, width=image_width)
701
832
 
702
833
  super().__init__(