keras-hub-nightly 0.16.1.dev202409250340__py3-none-any.whl → 0.16.1.dev202409260340__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. keras_hub/api/layers/__init__.py +3 -0
  2. keras_hub/api/models/__init__.py +16 -0
  3. keras_hub/api/tokenizers/__init__.py +1 -0
  4. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_encoder_block.py +8 -2
  5. keras_hub/src/models/clip/clip_preprocessor.py +147 -0
  6. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_text_encoder.py +60 -57
  7. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_tokenizer.py +69 -30
  8. keras_hub/src/models/densenet/__init__.py +6 -0
  9. keras_hub/src/models/densenet/densenet_backbone.py +11 -8
  10. keras_hub/src/models/densenet/densenet_image_classifier.py +27 -4
  11. keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py +27 -0
  12. keras_hub/src/models/densenet/densenet_image_converter.py +23 -0
  13. keras_hub/src/models/densenet/densenet_presets.py +56 -0
  14. keras_hub/src/models/stable_diffusion_3/__init__.py +13 -0
  15. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +93 -0
  16. keras_hub/src/models/{stable_diffusion_v3 → stable_diffusion_3}/mmdit.py +351 -26
  17. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +630 -0
  18. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +151 -0
  19. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +77 -0
  20. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py} +7 -7
  21. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +333 -0
  22. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_preprocessor.py → t5/t5_preprocessor.py} +12 -3
  23. keras_hub/src/models/text_to_image.py +295 -0
  24. keras_hub/src/utils/timm/convert_densenet.py +107 -0
  25. keras_hub/src/utils/timm/preset_loader.py +3 -0
  26. keras_hub/src/version_utils.py +1 -1
  27. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/METADATA +1 -1
  28. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/RECORD +31 -23
  29. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +0 -93
  30. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +0 -317
  31. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +0 -126
  32. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +0 -186
  33. /keras_hub/src/models/{stable_diffusion_v3 → clip}/__init__.py +0 -0
  34. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/WHEEL +0 -0
  35. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,8 @@ from keras import models
19
19
  from keras import ops
20
20
 
21
21
  from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
22
- from keras_hub.src.models.stable_diffusion_v3.mmdit_block import MMDiTBlock
22
+ from keras_hub.src.models.backbone import Backbone
23
+ from keras_hub.src.utils.keras_utils import gelu_approximate
23
24
  from keras_hub.src.utils.keras_utils import standardize_data_format
24
25
 
25
26
 
@@ -79,8 +80,8 @@ class AdjustablePositionEmbedding(PositionEmbedding):
79
80
  width = width or self.width
80
81
  shape = ops.shape(inputs)
81
82
  feature_length = shape[-1]
82
- top = ops.floor_divide(self.height - height, 2)
83
- left = ops.floor_divide(self.width - width, 2)
83
+ top = ops.cast(ops.floor_divide(self.height - height, 2), "int32")
84
+ left = ops.cast(ops.floor_divide(self.width - width, 2), "int32")
84
85
  position_embedding = ops.convert_to_tensor(self.position_embeddings)
85
86
  position_embedding = ops.reshape(
86
87
  position_embedding, (self.height, self.width, feature_length)
@@ -166,6 +167,305 @@ class TimestepEmbedding(layers.Layer):
166
167
  return output_shape
167
168
 
168
169
 
170
+ class DismantledBlock(layers.Layer):
171
+ def __init__(
172
+ self,
173
+ num_heads,
174
+ hidden_dim,
175
+ mlp_ratio=4.0,
176
+ use_projection=True,
177
+ **kwargs,
178
+ ):
179
+ super().__init__(**kwargs)
180
+ self.num_heads = num_heads
181
+ self.hidden_dim = hidden_dim
182
+ self.mlp_ratio = mlp_ratio
183
+ self.use_projection = use_projection
184
+
185
+ head_dim = hidden_dim // num_heads
186
+ self.head_dim = head_dim
187
+ mlp_hidden_dim = int(hidden_dim * mlp_ratio)
188
+ self.mlp_hidden_dim = mlp_hidden_dim
189
+ num_modulations = 6 if use_projection else 2
190
+ self.num_modulations = num_modulations
191
+
192
+ self.adaptive_norm_modulation = models.Sequential(
193
+ [
194
+ layers.Activation("silu", dtype=self.dtype_policy),
195
+ layers.Dense(
196
+ num_modulations * hidden_dim, dtype=self.dtype_policy
197
+ ),
198
+ ],
199
+ name="adaptive_norm_modulation",
200
+ )
201
+ self.norm1 = layers.LayerNormalization(
202
+ epsilon=1e-6,
203
+ center=False,
204
+ scale=False,
205
+ dtype="float32",
206
+ name="norm1",
207
+ )
208
+ self.attention_qkv = layers.Dense(
209
+ hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv"
210
+ )
211
+ if use_projection:
212
+ self.attention_proj = layers.Dense(
213
+ hidden_dim, dtype=self.dtype_policy, name="attention_proj"
214
+ )
215
+ self.norm2 = layers.LayerNormalization(
216
+ epsilon=1e-6,
217
+ center=False,
218
+ scale=False,
219
+ dtype="float32",
220
+ name="norm2",
221
+ )
222
+ self.mlp = models.Sequential(
223
+ [
224
+ layers.Dense(
225
+ mlp_hidden_dim,
226
+ activation=gelu_approximate,
227
+ dtype=self.dtype_policy,
228
+ ),
229
+ layers.Dense(
230
+ hidden_dim,
231
+ dtype=self.dtype_policy,
232
+ ),
233
+ ],
234
+ name="mlp",
235
+ )
236
+
237
+ def build(self, inputs_shape, timestep_embedding):
238
+ self.adaptive_norm_modulation.build(timestep_embedding)
239
+ self.attention_qkv.build(inputs_shape)
240
+ self.norm1.build(inputs_shape)
241
+ if self.use_projection:
242
+ self.attention_proj.build(inputs_shape)
243
+ self.norm2.build(inputs_shape)
244
+ self.mlp.build(inputs_shape)
245
+
246
+ def _modulate(self, inputs, shift, scale):
247
+ shift = ops.expand_dims(shift, axis=1)
248
+ scale = ops.expand_dims(scale, axis=1)
249
+ return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
250
+
251
+ def _compute_pre_attention(self, inputs, timestep_embedding, training=None):
252
+ batch_size = ops.shape(inputs)[0]
253
+ if self.use_projection:
254
+ modulation = self.adaptive_norm_modulation(
255
+ timestep_embedding, training=training
256
+ )
257
+ modulation = ops.reshape(
258
+ modulation, (batch_size, 6, self.hidden_dim)
259
+ )
260
+ (
261
+ shift_msa,
262
+ scale_msa,
263
+ gate_msa,
264
+ shift_mlp,
265
+ scale_mlp,
266
+ gate_mlp,
267
+ ) = ops.unstack(modulation, 6, axis=1)
268
+ qkv = self.attention_qkv(
269
+ self._modulate(self.norm1(inputs), shift_msa, scale_msa),
270
+ training=training,
271
+ )
272
+ qkv = ops.reshape(
273
+ qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
274
+ )
275
+ q, k, v = ops.unstack(qkv, 3, axis=2)
276
+ return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp)
277
+ else:
278
+ modulation = self.adaptive_norm_modulation(
279
+ timestep_embedding, training=training
280
+ )
281
+ modulation = ops.reshape(
282
+ modulation, (batch_size, 2, self.hidden_dim)
283
+ )
284
+ shift_msa, scale_msa = ops.unstack(modulation, 2, axis=1)
285
+ qkv = self.attention_qkv(
286
+ self._modulate(self.norm1(inputs), shift_msa, scale_msa),
287
+ training=training,
288
+ )
289
+ qkv = ops.reshape(
290
+ qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
291
+ )
292
+ q, k, v = ops.unstack(qkv, 3, axis=2)
293
+ return (q, k, v)
294
+
295
+ def _compute_post_attention(
296
+ self, inputs, inputs_intermediates, training=None
297
+ ):
298
+ x, gate_msa, shift_mlp, scale_mlp, gate_mlp = inputs_intermediates
299
+ attn = self.attention_proj(inputs, training=training)
300
+ x = ops.add(x, ops.multiply(ops.expand_dims(gate_msa, axis=1), attn))
301
+ x = ops.add(
302
+ x,
303
+ ops.multiply(
304
+ ops.expand_dims(gate_mlp, axis=1),
305
+ self.mlp(
306
+ self._modulate(self.norm2(x), shift_mlp, scale_mlp),
307
+ training=training,
308
+ ),
309
+ ),
310
+ )
311
+ return x
312
+
313
+ def call(
314
+ self,
315
+ inputs,
316
+ timestep_embedding=None,
317
+ inputs_intermediates=None,
318
+ pre_attention=True,
319
+ training=None,
320
+ ):
321
+ if pre_attention:
322
+ return self._compute_pre_attention(
323
+ inputs, timestep_embedding, training=training
324
+ )
325
+ else:
326
+ return self._compute_post_attention(
327
+ inputs, inputs_intermediates, training=training
328
+ )
329
+
330
+ def get_config(self):
331
+ config = super().get_config()
332
+ config.update(
333
+ {
334
+ "num_heads": self.num_heads,
335
+ "hidden_dim": self.hidden_dim,
336
+ "mlp_ratio": self.mlp_ratio,
337
+ "use_projection": self.use_projection,
338
+ }
339
+ )
340
+ return config
341
+
342
+
343
+ class MMDiTBlock(layers.Layer):
344
+ def __init__(
345
+ self,
346
+ num_heads,
347
+ hidden_dim,
348
+ mlp_ratio=4.0,
349
+ use_context_projection=True,
350
+ **kwargs,
351
+ ):
352
+ super().__init__(**kwargs)
353
+ self.num_heads = num_heads
354
+ self.hidden_dim = hidden_dim
355
+ self.mlp_ratio = mlp_ratio
356
+ self.use_context_projection = use_context_projection
357
+
358
+ head_dim = hidden_dim // num_heads
359
+ self.head_dim = head_dim
360
+ self._inverse_sqrt_key_dim = 1.0 / math.sqrt(head_dim)
361
+ self._dot_product_equation = "aecd,abcd->acbe"
362
+ self._combine_equation = "acbe,aecd->abcd"
363
+
364
+ self.x_block = DismantledBlock(
365
+ num_heads=num_heads,
366
+ hidden_dim=hidden_dim,
367
+ mlp_ratio=mlp_ratio,
368
+ use_projection=True,
369
+ dtype=self.dtype_policy,
370
+ name="x_block",
371
+ )
372
+ self.context_block = DismantledBlock(
373
+ num_heads=num_heads,
374
+ hidden_dim=hidden_dim,
375
+ mlp_ratio=mlp_ratio,
376
+ use_projection=use_context_projection,
377
+ dtype=self.dtype_policy,
378
+ name="context_block",
379
+ )
380
+ self.softmax = layers.Softmax(dtype="float32")
381
+
382
+ def build(self, inputs_shape, context_shape, timestep_embedding_shape):
383
+ self.x_block.build(inputs_shape, timestep_embedding_shape)
384
+ self.context_block.build(context_shape, timestep_embedding_shape)
385
+
386
+ def _compute_attention(self, query, key, value):
387
+ query = ops.multiply(
388
+ query, ops.cast(self._inverse_sqrt_key_dim, query.dtype)
389
+ )
390
+ attention_scores = ops.einsum(self._dot_product_equation, key, query)
391
+ attention_scores = self.softmax(attention_scores)
392
+ attention_scores = ops.cast(attention_scores, self.compute_dtype)
393
+ attention_output = ops.einsum(
394
+ self._combine_equation, attention_scores, value
395
+ )
396
+ batch_size = ops.shape(attention_output)[0]
397
+ attention_output = ops.reshape(
398
+ attention_output, (batch_size, -1, self.num_heads * self.head_dim)
399
+ )
400
+ return attention_output
401
+
402
+ def call(self, inputs, context, timestep_embedding, training=None):
403
+ # Compute pre-attention.
404
+ x = inputs
405
+ if self.use_context_projection:
406
+ context_qkv, context_intermediates = self.context_block(
407
+ context,
408
+ timestep_embedding=timestep_embedding,
409
+ training=training,
410
+ )
411
+ else:
412
+ context_qkv = self.context_block(
413
+ context,
414
+ timestep_embedding=timestep_embedding,
415
+ training=training,
416
+ )
417
+ context_len = ops.shape(context_qkv[0])[1]
418
+ x_qkv, x_intermediates = self.x_block(
419
+ x, timestep_embedding=timestep_embedding, training=training
420
+ )
421
+ q = ops.concatenate([context_qkv[0], x_qkv[0]], axis=1)
422
+ k = ops.concatenate([context_qkv[1], x_qkv[1]], axis=1)
423
+ v = ops.concatenate([context_qkv[2], x_qkv[2]], axis=1)
424
+
425
+ # Compute attention.
426
+ attention = self._compute_attention(q, k, v)
427
+ context_attention = attention[:, :context_len]
428
+ x_attention = attention[:, context_len:]
429
+
430
+ # Compute post-attention.
431
+ x = self.x_block(
432
+ x_attention,
433
+ inputs_intermediates=x_intermediates,
434
+ pre_attention=False,
435
+ training=training,
436
+ )
437
+ if self.use_context_projection:
438
+ context = self.context_block(
439
+ context_attention,
440
+ inputs_intermediates=context_intermediates,
441
+ pre_attention=False,
442
+ training=training,
443
+ )
444
+ return x, context
445
+ else:
446
+ return x
447
+
448
+ def get_config(self):
449
+ config = super().get_config()
450
+ config.update(
451
+ {
452
+ "num_heads": self.num_heads,
453
+ "hidden_dim": self.hidden_dim,
454
+ "mlp_ratio": self.mlp_ratio,
455
+ "use_context_projection": self.use_context_projection,
456
+ }
457
+ )
458
+ return config
459
+
460
+ def compute_output_shape(
461
+ self, inputs_shape, context_shape, timestep_embedding_shape
462
+ ):
463
+ if self.use_context_projection:
464
+ return inputs_shape, context_shape
465
+ else:
466
+ return inputs_shape
467
+
468
+
169
469
  class OutputLayer(layers.Layer):
170
470
  def __init__(self, hidden_dim, output_dim, **kwargs):
171
471
  super().__init__(**kwargs)
@@ -186,11 +486,11 @@ class OutputLayer(layers.Layer):
186
486
  epsilon=1e-6,
187
487
  center=False,
188
488
  scale=False,
189
- dtype=self.dtype_policy,
489
+ dtype="float32",
190
490
  name="norm",
191
491
  )
192
492
  self.output_dense = layers.Dense(
193
- output_dim, # patch_size ** 2 * input_channels
493
+ output_dim,
194
494
  use_bias=True,
195
495
  dtype=self.dtype_policy,
196
496
  name="output_dense",
@@ -227,6 +527,11 @@ class OutputLayer(layers.Layer):
227
527
  )
228
528
  return config
229
529
 
530
+ def compute_output_shape(self, inputs_shape):
531
+ outputs_shape = list(inputs_shape)
532
+ outputs_shape[-1] = self.output_dim
533
+ return outputs_shape
534
+
230
535
 
231
536
  class Unpatch(layers.Layer):
232
537
  def __init__(self, patch_size, output_dim, **kwargs):
@@ -263,18 +568,48 @@ class Unpatch(layers.Layer):
263
568
  return [inputs_shape[0], None, None, self.output_dim]
264
569
 
265
570
 
266
- class MMDiT(keras.Model):
571
+ class MMDiT(Backbone):
572
+ """Multimodal Diffusion Transformer (MMDiT) model for Stable Diffusion 3.
573
+
574
+ MMDiT is introduced in [
575
+ Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](
576
+ https://arxiv.org/abs/2403.03206).
577
+
578
+ Args:
579
+ patch_size: int. The size of each square patch in the input image.
580
+ hidden_dim: int. The size of the transformer hidden state at the end
581
+ of each transformer layer.
582
+ num_layers: int. The number of transformer layers.
583
+ num_heads: int. The number of attention heads for each transformer.
584
+ position_size: int. The size of the height and width for the position
585
+ embedding.
586
+ mlp_ratio: float. The ratio of the mlp hidden dim to the transformer
587
+ latent_shape: tuple. The shape of the latent image.
588
+ context_shape: tuple. The shape of the context.
589
+ pooled_projection_shape: tuple. The shape of the pooled projection.
590
+ data_format: `None` or str. If specified, either `"channels_last"` or
591
+ `"channels_first"`. The ordering of the dimensions in the
592
+ inputs. `"channels_last"` corresponds to inputs with shape
593
+ `(batch_size, height, width, channels)`
594
+ while `"channels_first"` corresponds to inputs with shape
595
+ `(batch_size, channels, height, width)`. It defaults to the
596
+ `image_data_format` value found in your Keras config file at
597
+ `~/.keras/keras.json`. If you never set it, then it will be
598
+ `"channels_last"`.
599
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
600
+ to use for the model's computations and weights.
601
+ """
602
+
267
603
  def __init__(
268
604
  self,
269
605
  patch_size,
270
- num_heads,
271
606
  hidden_dim,
272
- depth,
607
+ num_layers,
608
+ num_heads,
273
609
  position_size,
274
- output_dim,
275
610
  mlp_ratio=4.0,
276
611
  latent_shape=(64, 64, 16),
277
- context_shape=(1024, 4096),
612
+ context_shape=(None, 4096),
278
613
  pooled_projection_shape=(2048,),
279
614
  data_format=None,
280
615
  dtype=None,
@@ -287,6 +622,7 @@ class MMDiT(keras.Model):
287
622
  )
288
623
  image_height = latent_shape[0] // patch_size
289
624
  image_width = latent_shape[1] // patch_size
625
+ output_dim = latent_shape[-1]
290
626
  output_dim_in_final = patch_size**2 * output_dim
291
627
  data_format = standardize_data_format(data_format)
292
628
  if data_format != "channels_last":
@@ -331,11 +667,11 @@ class MMDiT(keras.Model):
331
667
  num_heads,
332
668
  hidden_dim,
333
669
  mlp_ratio,
334
- use_context_projection=not (i == depth - 1),
670
+ use_context_projection=not (i == num_layers - 1),
335
671
  dtype=dtype,
336
672
  name=f"joint_block_{i}",
337
673
  )
338
- for i in range(depth)
674
+ for i in range(num_layers)
339
675
  ]
340
676
  self.output_layer = OutputLayer(
341
677
  hidden_dim, output_dim_in_final, dtype=dtype, name="output_layer"
@@ -391,33 +727,22 @@ class MMDiT(keras.Model):
391
727
  self.patch_size = patch_size
392
728
  self.num_heads = num_heads
393
729
  self.hidden_dim = hidden_dim
394
- self.depth = depth
730
+ self.num_layers = num_layers
395
731
  self.position_size = position_size
396
- self.output_dim = output_dim
397
732
  self.mlp_ratio = mlp_ratio
398
733
  self.latent_shape = latent_shape
399
734
  self.context_shape = context_shape
400
735
  self.pooled_projection_shape = pooled_projection_shape
401
736
 
402
- if dtype is not None:
403
- try:
404
- self.dtype_policy = keras.dtype_policies.get(dtype)
405
- # Before Keras 3.2, there is no `keras.dtype_policies.get`.
406
- except AttributeError:
407
- if isinstance(dtype, keras.DTypePolicy):
408
- dtype = dtype.name
409
- self.dtype_policy = keras.DTypePolicy(dtype)
410
-
411
737
  def get_config(self):
412
738
  config = super().get_config()
413
739
  config.update(
414
740
  {
415
741
  "patch_size": self.patch_size,
416
- "num_heads": self.num_heads,
417
742
  "hidden_dim": self.hidden_dim,
418
- "depth": self.depth,
743
+ "num_layers": self.num_layers,
744
+ "num_heads": self.num_heads,
419
745
  "position_size": self.position_size,
420
- "output_dim": self.output_dim,
421
746
  "mlp_ratio": self.mlp_ratio,
422
747
  "latent_shape": self.latent_shape,
423
748
  "context_shape": self.context_shape,