keras-hub-nightly 0.16.1.dev202410200345__py3-none-any.whl → 0.19.0.dev202412070351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/api/layers/__init__.py +12 -0
  2. keras_hub/api/models/__init__.py +32 -0
  3. keras_hub/src/bounding_box/__init__.py +2 -0
  4. keras_hub/src/bounding_box/converters.py +102 -12
  5. keras_hub/src/layers/modeling/rms_normalization.py +34 -0
  6. keras_hub/src/layers/modeling/transformer_encoder.py +27 -7
  7. keras_hub/src/layers/preprocessing/image_converter.py +5 -0
  8. keras_hub/src/models/albert/albert_presets.py +0 -8
  9. keras_hub/src/models/bart/bart_presets.py +0 -6
  10. keras_hub/src/models/bert/bert_presets.py +0 -20
  11. keras_hub/src/models/bloom/bloom_presets.py +0 -16
  12. keras_hub/src/models/clip/__init__.py +5 -0
  13. keras_hub/src/models/clip/clip_backbone.py +286 -0
  14. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  15. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  16. keras_hub/src/models/clip/clip_presets.py +93 -0
  17. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  18. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  19. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  20. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  21. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -10
  22. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +0 -2
  23. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +5 -3
  24. keras_hub/src/models/densenet/densenet_backbone.py +1 -1
  25. keras_hub/src/models/densenet/densenet_presets.py +0 -6
  26. keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -6
  27. keras_hub/src/models/efficientnet/__init__.py +9 -0
  28. keras_hub/src/models/efficientnet/cba.py +141 -0
  29. keras_hub/src/models/efficientnet/efficientnet_backbone.py +139 -56
  30. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  31. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  32. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  33. keras_hub/src/models/efficientnet/efficientnet_presets.py +192 -0
  34. keras_hub/src/models/efficientnet/fusedmbconv.py +81 -36
  35. keras_hub/src/models/efficientnet/mbconv.py +52 -21
  36. keras_hub/src/models/electra/electra_presets.py +0 -12
  37. keras_hub/src/models/f_net/f_net_presets.py +0 -4
  38. keras_hub/src/models/falcon/falcon_presets.py +0 -2
  39. keras_hub/src/models/flux/__init__.py +5 -0
  40. keras_hub/src/models/flux/flux_layers.py +494 -0
  41. keras_hub/src/models/flux/flux_maths.py +218 -0
  42. keras_hub/src/models/flux/flux_model.py +231 -0
  43. keras_hub/src/models/flux/flux_presets.py +14 -0
  44. keras_hub/src/models/flux/flux_text_to_image.py +142 -0
  45. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  46. keras_hub/src/models/gemma/gemma_presets.py +0 -40
  47. keras_hub/src/models/gpt2/gpt2_presets.py +0 -9
  48. keras_hub/src/models/image_object_detector.py +87 -0
  49. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  50. keras_hub/src/models/image_to_image.py +16 -10
  51. keras_hub/src/models/inpaint.py +20 -13
  52. keras_hub/src/models/llama/llama_backbone.py +1 -1
  53. keras_hub/src/models/llama/llama_presets.py +5 -15
  54. keras_hub/src/models/llama3/llama3_presets.py +0 -8
  55. keras_hub/src/models/mistral/mistral_presets.py +0 -6
  56. keras_hub/src/models/mit/mit_backbone.py +41 -27
  57. keras_hub/src/models/mit/mit_layers.py +9 -7
  58. keras_hub/src/models/mit/mit_presets.py +12 -24
  59. keras_hub/src/models/opt/opt_presets.py +0 -8
  60. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +61 -11
  61. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  62. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +166 -10
  63. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +12 -11
  64. keras_hub/src/models/phi3/phi3_presets.py +0 -4
  65. keras_hub/src/models/resnet/resnet_presets.py +10 -42
  66. keras_hub/src/models/retinanet/__init__.py +5 -0
  67. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  68. keras_hub/src/models/retinanet/feature_pyramid.py +99 -36
  69. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  70. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  71. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  72. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  73. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  74. keras_hub/src/models/retinanet/retinanet_object_detector.py +382 -0
  75. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  76. keras_hub/src/models/retinanet/retinanet_presets.py +15 -0
  77. keras_hub/src/models/roberta/roberta_presets.py +0 -4
  78. keras_hub/src/models/sam/sam_backbone.py +0 -1
  79. keras_hub/src/models/sam/sam_image_segmenter.py +9 -10
  80. keras_hub/src/models/sam/sam_presets.py +0 -6
  81. keras_hub/src/models/segformer/__init__.py +8 -0
  82. keras_hub/src/models/segformer/segformer_backbone.py +163 -0
  83. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  84. keras_hub/src/models/segformer/segformer_image_segmenter.py +171 -0
  85. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  86. keras_hub/src/models/segformer/segformer_presets.py +124 -0
  87. keras_hub/src/models/stable_diffusion_3/mmdit.py +41 -0
  88. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +38 -21
  89. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +3 -3
  90. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +3 -3
  91. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +28 -4
  92. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -1
  93. keras_hub/src/models/t5/t5_backbone.py +5 -4
  94. keras_hub/src/models/t5/t5_presets.py +41 -13
  95. keras_hub/src/models/text_to_image.py +13 -5
  96. keras_hub/src/models/vgg/vgg_backbone.py +1 -1
  97. keras_hub/src/models/vgg/vgg_presets.py +0 -8
  98. keras_hub/src/models/whisper/whisper_audio_converter.py +1 -1
  99. keras_hub/src/models/whisper/whisper_presets.py +0 -20
  100. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -4
  101. keras_hub/src/tests/test_case.py +25 -0
  102. keras_hub/src/utils/preset_utils.py +17 -4
  103. keras_hub/src/utils/timm/convert_efficientnet.py +449 -0
  104. keras_hub/src/utils/timm/preset_loader.py +3 -0
  105. keras_hub/src/version_utils.py +1 -1
  106. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/METADATA +15 -26
  107. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/RECORD +109 -76
  108. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/WHEEL +1 -1
  109. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/top_level.txt +0 -0
@@ -354,6 +354,8 @@ class DismantledBlock(layers.Layer):
354
354
  mlp_ratio: float. The expansion ratio of `MLP`.
355
355
  use_projection: bool. Whether to use an attention projection layer at
356
356
  the end of the block.
357
+ qk_norm: Optional str. Whether to normalize the query and key tensors.
358
+ Available options are `None` and `"rms_norm"`. Defaults to `None`.
357
359
  **kwargs: other keyword arguments passed to `keras.layers.Layer`,
358
360
  including `name`, `dtype` etc.
359
361
  """
@@ -364,6 +366,7 @@ class DismantledBlock(layers.Layer):
364
366
  hidden_dim,
365
367
  mlp_ratio=4.0,
366
368
  use_projection=True,
369
+ qk_norm=None,
367
370
  **kwargs,
368
371
  ):
369
372
  super().__init__(**kwargs)
@@ -371,6 +374,7 @@ class DismantledBlock(layers.Layer):
371
374
  self.hidden_dim = hidden_dim
372
375
  self.mlp_ratio = mlp_ratio
373
376
  self.use_projection = use_projection
377
+ self.qk_norm = qk_norm
374
378
 
375
379
  head_dim = hidden_dim // num_heads
376
380
  self.head_dim = head_dim
@@ -391,6 +395,18 @@ class DismantledBlock(layers.Layer):
391
395
  self.attention_qkv = layers.Dense(
392
396
  hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv"
393
397
  )
398
+ if qk_norm is not None and qk_norm == "rms_norm":
399
+ self.q_norm = layers.LayerNormalization(
400
+ epsilon=1e-6, rms_scaling=True, dtype="float32", name="q_norm"
401
+ )
402
+ self.k_norm = layers.LayerNormalization(
403
+ epsilon=1e-6, rms_scaling=True, dtype="float32", name="q_norm"
404
+ )
405
+ elif qk_norm is not None:
406
+ raise NotImplementedError(
407
+ "Supported `qk_norm` are `'rms_norm'` and `None`. "
408
+ f"Received: qk_norm={qk_norm}."
409
+ )
394
410
  if use_projection:
395
411
  self.attention_proj = layers.Dense(
396
412
  hidden_dim, dtype=self.dtype_policy, name="attention_proj"
@@ -413,6 +429,10 @@ class DismantledBlock(layers.Layer):
413
429
  def build(self, inputs_shape, timestep_embedding):
414
430
  self.ada_layer_norm.build(inputs_shape, timestep_embedding)
415
431
  self.attention_qkv.build(inputs_shape)
432
+ if self.qk_norm is not None:
433
+ # [batch_size, sequence_length, num_heads, head_dim]
434
+ self.q_norm.build([None, None, self.num_heads, self.head_dim])
435
+ self.k_norm.build([None, None, self.num_heads, self.head_dim])
416
436
  if self.use_projection:
417
437
  self.attention_proj.build(inputs_shape)
418
438
  self.norm2.build(inputs_shape)
@@ -435,6 +455,9 @@ class DismantledBlock(layers.Layer):
435
455
  qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
436
456
  )
437
457
  q, k, v = ops.unstack(qkv, 3, axis=2)
458
+ if self.qk_norm is not None:
459
+ q = self.q_norm(q, training=training)
460
+ k = self.k_norm(k, training=training)
438
461
  return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp)
439
462
  else:
440
463
  x = self.ada_layer_norm(
@@ -445,6 +468,9 @@ class DismantledBlock(layers.Layer):
445
468
  qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
446
469
  )
447
470
  q, k, v = ops.unstack(qkv, 3, axis=2)
471
+ if self.qk_norm is not None:
472
+ q = self.q_norm(q, training=training)
473
+ k = self.k_norm(k, training=training)
448
474
  return (q, k, v)
449
475
 
450
476
  def _compute_post_attention(
@@ -494,6 +520,7 @@ class DismantledBlock(layers.Layer):
494
520
  "hidden_dim": self.hidden_dim,
495
521
  "mlp_ratio": self.mlp_ratio,
496
522
  "use_projection": self.use_projection,
523
+ "qk_norm": self.qk_norm,
497
524
  }
498
525
  )
499
526
  return config
@@ -513,6 +540,8 @@ class MMDiTBlock(layers.Layer):
513
540
  mlp_ratio: float. The expansion ratio of `MLP`.
514
541
  use_context_projection: bool. Whether to use an attention projection
515
542
  layer at the end of the context block.
543
+ qk_norm: Optional str. Whether to normalize the query and key tensors.
544
+ Available options are `None` and `"rms_norm"`. Defaults to `None`.
516
545
  **kwargs: other keyword arguments passed to `keras.layers.Layer`,
517
546
  including `name`, `dtype` etc.
518
547
 
@@ -527,6 +556,7 @@ class MMDiTBlock(layers.Layer):
527
556
  hidden_dim,
528
557
  mlp_ratio=4.0,
529
558
  use_context_projection=True,
559
+ qk_norm=None,
530
560
  **kwargs,
531
561
  ):
532
562
  super().__init__(**kwargs)
@@ -534,6 +564,7 @@ class MMDiTBlock(layers.Layer):
534
564
  self.hidden_dim = hidden_dim
535
565
  self.mlp_ratio = mlp_ratio
536
566
  self.use_context_projection = use_context_projection
567
+ self.qk_norm = qk_norm
537
568
 
538
569
  head_dim = hidden_dim // num_heads
539
570
  self.head_dim = head_dim
@@ -544,6 +575,7 @@ class MMDiTBlock(layers.Layer):
544
575
  hidden_dim=hidden_dim,
545
576
  mlp_ratio=mlp_ratio,
546
577
  use_projection=True,
578
+ qk_norm=qk_norm,
547
579
  dtype=self.dtype_policy,
548
580
  name="x_block",
549
581
  )
@@ -552,6 +584,7 @@ class MMDiTBlock(layers.Layer):
552
584
  hidden_dim=hidden_dim,
553
585
  mlp_ratio=mlp_ratio,
554
586
  use_projection=use_context_projection,
587
+ qk_norm=qk_norm,
555
588
  dtype=self.dtype_policy,
556
589
  name="context_block",
557
590
  )
@@ -629,6 +662,7 @@ class MMDiTBlock(layers.Layer):
629
662
  "hidden_dim": self.hidden_dim,
630
663
  "mlp_ratio": self.mlp_ratio,
631
664
  "use_context_projection": self.use_context_projection,
665
+ "qk_norm": self.qk_norm,
632
666
  }
633
667
  )
634
668
  return config
@@ -705,6 +739,9 @@ class MMDiT(Backbone):
705
739
  latent_shape: tuple. The shape of the latent image.
706
740
  context_shape: tuple. The shape of the context.
707
741
  pooled_projection_shape: tuple. The shape of the pooled projection.
742
+ qk_norm: Optional str. Whether to normalize the query and key tensors in
743
+ the intermediate blocks. Available options are `None` and
744
+ `"rms_norm"`. Defaults to `None`.
708
745
  data_format: `None` or str. If specified, either `"channels_last"` or
709
746
  `"channels_first"`. The ordering of the dimensions in the
710
747
  inputs. `"channels_last"` corresponds to inputs with shape
@@ -729,6 +766,7 @@ class MMDiT(Backbone):
729
766
  latent_shape=(64, 64, 16),
730
767
  context_shape=(None, 4096),
731
768
  pooled_projection_shape=(2048,),
769
+ qk_norm=None,
732
770
  data_format=None,
733
771
  dtype=None,
734
772
  **kwargs,
@@ -782,6 +820,7 @@ class MMDiT(Backbone):
782
820
  hidden_dim,
783
821
  mlp_ratio,
784
822
  use_context_projection=not (i == num_layers - 1),
823
+ qk_norm=qk_norm,
785
824
  dtype=dtype,
786
825
  name=f"joint_block_{i}",
787
826
  )
@@ -851,6 +890,7 @@ class MMDiT(Backbone):
851
890
  self.latent_shape = latent_shape
852
891
  self.context_shape = context_shape
853
892
  self.pooled_projection_shape = pooled_projection_shape
893
+ self.qk_norm = qk_norm
854
894
 
855
895
  def get_config(self):
856
896
  config = super().get_config()
@@ -865,6 +905,7 @@ class MMDiT(Backbone):
865
905
  "latent_shape": self.latent_shape,
866
906
  "context_shape": self.context_shape,
867
907
  "pooled_projection_shape": self.pooled_projection_shape,
908
+ "qk_norm": self.qk_norm,
868
909
  }
869
910
  )
870
911
  return config
@@ -202,6 +202,10 @@ class StableDiffusion3Backbone(Backbone):
202
202
  transformer in MMDiT.
203
203
  mmdit_position_size: int. The size of the height and width for the
204
204
  position embedding in MMDiT.
205
+ mmdit_qk_norm: Optional str. Whether to normalize the query and key
206
+ tensors for each transformer in MMDiT. Available options are `None`
207
+ and `"rms_norm"`. Typically, this is set to `None` for 3.0 version
208
+ and to `"rms_norm" for 3.5 version.
205
209
  vae: The VAE used for transformations between pixel space and latent
206
210
  space.
207
211
  clip_l: The CLIP text encoder for encoding the inputs.
@@ -215,8 +219,8 @@ class StableDiffusion3Backbone(Backbone):
215
219
  model. Defaults to `1000`.
216
220
  shift: float. The shift value for the timestep schedule. Defaults to
217
221
  `3.0`.
218
- height: optional int. The output height of the image.
219
- width: optional int. The output width of the image.
222
+ image_shape: tuple. The input shape without the batch size. Defaults to
223
+ `(1024, 1024, 3)`.
220
224
  data_format: `None` or str. If specified, either `"channels_last"` or
221
225
  `"channels_first"`. The ordering of the dimensions in the
222
226
  inputs. `"channels_last"` corresponds to inputs with shape
@@ -248,6 +252,7 @@ class StableDiffusion3Backbone(Backbone):
248
252
  mmdit_hidden_dim=256,
249
253
  mmdit_depth=4,
250
254
  mmdit_position_size=192,
255
+ mmdit_qk_norm=None,
251
256
  vae=vae,
252
257
  clip_l=clip_l,
253
258
  clip_g=clip_g,
@@ -262,6 +267,7 @@ class StableDiffusion3Backbone(Backbone):
262
267
  mmdit_num_layers,
263
268
  mmdit_num_heads,
264
269
  mmdit_position_size,
270
+ mmdit_qk_norm,
265
271
  vae,
266
272
  clip_l,
267
273
  clip_g,
@@ -270,23 +276,21 @@ class StableDiffusion3Backbone(Backbone):
270
276
  output_channels=3,
271
277
  num_train_timesteps=1000,
272
278
  shift=3.0,
273
- height=None,
274
- width=None,
279
+ image_shape=(1024, 1024, 3),
275
280
  data_format=None,
276
281
  dtype=None,
277
282
  **kwargs,
278
283
  ):
279
- height = int(height or 1024)
280
- width = int(width or 1024)
281
- if height % 8 != 0 or width % 8 != 0:
282
- raise ValueError(
283
- "`height` and `width` must be divisible by 8. "
284
- f"Received: height={height}, width={width}"
285
- )
286
284
  data_format = standardize_data_format(data_format)
287
285
  if data_format != "channels_last":
288
286
  raise NotImplementedError
289
- image_shape = (height, width, int(vae.input_channels))
287
+ height = image_shape[0]
288
+ width = image_shape[1]
289
+ if height % 8 != 0 or width % 8 != 0:
290
+ raise ValueError(
291
+ "height and width in `image_shape` must be divisible by 8. "
292
+ f"Received: image_shape={image_shape}"
293
+ )
290
294
  latent_shape = (height // 8, width // 8, int(latent_channels))
291
295
  context_shape = (None, 4096 if t5 is None else t5.hidden_dim)
292
296
  pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,)
@@ -314,6 +318,7 @@ class StableDiffusion3Backbone(Backbone):
314
318
  latent_shape=latent_shape,
315
319
  context_shape=context_shape,
316
320
  pooled_projection_shape=pooled_projection_shape,
321
+ qk_norm=mmdit_qk_norm,
317
322
  data_format=data_format,
318
323
  dtype=dtype,
319
324
  name="diffuser",
@@ -448,12 +453,12 @@ class StableDiffusion3Backbone(Backbone):
448
453
  self.mmdit_num_layers = mmdit_num_layers
449
454
  self.mmdit_num_heads = mmdit_num_heads
450
455
  self.mmdit_position_size = mmdit_position_size
456
+ self.mmdit_qk_norm = mmdit_qk_norm
451
457
  self.latent_channels = latent_channels
452
458
  self.output_channels = output_channels
453
459
  self.num_train_timesteps = num_train_timesteps
454
460
  self.shift = shift
455
- self.height = height
456
- self.width = width
461
+ self.image_shape = image_shape
457
462
 
458
463
  @property
459
464
  def latent_shape(self):
@@ -535,7 +540,7 @@ class StableDiffusion3Backbone(Backbone):
535
540
  embeddings,
536
541
  step,
537
542
  num_steps,
538
- guidance_scale,
543
+ guidance_scale=None,
539
544
  ):
540
545
  step = ops.convert_to_tensor(step)
541
546
  next_step = ops.add(step, 1)
@@ -543,9 +548,15 @@ class StableDiffusion3Backbone(Backbone):
543
548
  next_sigma, _ = self.scheduler(next_step, num_steps)
544
549
 
545
550
  # Concatenation for classifier-free guidance.
546
- concated_latents, contexts, pooled_projs, timesteps = self.cfg_concat(
547
- latents, *embeddings, timestep
548
- )
551
+ if guidance_scale is not None:
552
+ concated_latents, contexts, pooled_projs, timesteps = (
553
+ self.cfg_concat(latents, *embeddings, timestep)
554
+ )
555
+ else:
556
+ timesteps = ops.broadcast_to(timestep, ops.shape(latents)[:1])
557
+ concated_latents = latents
558
+ contexts = embeddings[0]
559
+ pooled_projs = embeddings[2]
549
560
 
550
561
  # Diffusion.
551
562
  predicted_noise = self.diffuser(
@@ -559,7 +570,8 @@ class StableDiffusion3Backbone(Backbone):
559
570
  )
560
571
 
561
572
  # Classifier-free guidance.
562
- predicted_noise = self.cfg(predicted_noise, guidance_scale)
573
+ if guidance_scale is not None:
574
+ predicted_noise = self.cfg(predicted_noise, guidance_scale)
563
575
 
564
576
  # Euler step.
565
577
  return self.euler_step(latents, predicted_noise, sigma, next_sigma)
@@ -577,6 +589,7 @@ class StableDiffusion3Backbone(Backbone):
577
589
  "mmdit_num_layers": self.mmdit_num_layers,
578
590
  "mmdit_num_heads": self.mmdit_num_heads,
579
591
  "mmdit_position_size": self.mmdit_position_size,
592
+ "mmdit_qk_norm": self.mmdit_qk_norm,
580
593
  "vae": layers.serialize(self.vae),
581
594
  "clip_l": layers.serialize(self.clip_l),
582
595
  "clip_g": layers.serialize(self.clip_g),
@@ -585,8 +598,7 @@ class StableDiffusion3Backbone(Backbone):
585
598
  "output_channels": self.output_channels,
586
599
  "num_train_timesteps": self.num_train_timesteps,
587
600
  "shift": self.shift,
588
- "height": self.height,
589
- "width": self.width,
601
+ "image_shape": self.image_shape,
590
602
  }
591
603
  )
592
604
  return config
@@ -624,4 +636,9 @@ class StableDiffusion3Backbone(Backbone):
624
636
  config["t5"] = layers.deserialize(
625
637
  config["t5"], custom_objects=custom_objects
626
638
  )
639
+
640
+ # To maintain backward compatibility, we need to ensure that
641
+ # `mmdit_qk_norm` is included in the config.
642
+ if "mmdit_qk_norm" not in config:
643
+ config["mmdit_qk_norm"] = None
627
644
  return cls(**config)
@@ -27,7 +27,7 @@ class StableDiffusion3ImageToImage(ImageToImage):
27
27
  Use `generate()` to do image generation.
28
28
  ```python
29
29
  image_to_image = keras_hub.models.StableDiffusion3ImageToImage.from_preset(
30
- "stable_diffusion_3_medium", height=512, width=512
30
+ "stable_diffusion_3_medium", image_shape=(512, 512, 3)
31
31
  )
32
32
  image_to_image.generate(
33
33
  {
@@ -158,14 +158,14 @@ class StableDiffusion3ImageToImage(ImageToImage):
158
158
  self,
159
159
  inputs,
160
160
  num_steps=50,
161
- guidance_scale=7.0,
162
161
  strength=0.8,
162
+ guidance_scale=7.0,
163
163
  seed=None,
164
164
  ):
165
165
  return super().generate(
166
166
  inputs,
167
167
  num_steps=num_steps,
168
- guidance_scale=guidance_scale,
169
168
  strength=strength,
169
+ guidance_scale=guidance_scale,
170
170
  seed=seed,
171
171
  )
@@ -29,7 +29,7 @@ class StableDiffusion3Inpaint(Inpaint):
29
29
  reference_image = np.ones((1024, 1024, 3), dtype="float32")
30
30
  reference_mask = np.ones((1024, 1024), dtype="float32")
31
31
  inpaint = keras_hub.models.StableDiffusion3Inpaint.from_preset(
32
- "stable_diffusion_3_medium", height=512, width=512
32
+ "stable_diffusion_3_medium", image_shape=(512, 512, 3)
33
33
  )
34
34
  inpaint.generate(
35
35
  reference_image,
@@ -181,14 +181,14 @@ class StableDiffusion3Inpaint(Inpaint):
181
181
  self,
182
182
  inputs,
183
183
  num_steps=50,
184
- guidance_scale=7.0,
185
184
  strength=0.6,
185
+ guidance_scale=7.0,
186
186
  seed=None,
187
187
  ):
188
188
  return super().generate(
189
189
  inputs,
190
190
  num_steps=num_steps,
191
- guidance_scale=guidance_scale,
192
191
  strength=strength,
192
+ guidance_scale=guidance_scale,
193
193
  seed=seed,
194
194
  )
@@ -9,10 +9,34 @@ backbone_presets = {
9
9
  "Developed by Stability AI."
10
10
  ),
11
11
  "params": 2987080931,
12
- "official_name": "StableDiffusion3",
13
12
  "path": "stable_diffusion_3",
14
- "model_card": "https://arxiv.org/abs/2110.00476",
15
13
  },
16
- "kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/1",
17
- }
14
+ "kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/3",
15
+ },
16
+ "stable_diffusion_3.5_large": {
17
+ "metadata": {
18
+ "description": (
19
+ "9 billion parameter, including CLIP L and CLIP G text "
20
+ "encoders, MMDiT generative model, and VAE autoencoder. "
21
+ "Developed by Stability AI."
22
+ ),
23
+ "params": 9048410595,
24
+ "path": "stable_diffusion_3",
25
+ },
26
+ "kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large/1",
27
+ },
28
+ "stable_diffusion_3.5_large_turbo": {
29
+ "metadata": {
30
+ "description": (
31
+ "9 billion parameter, including CLIP L and CLIP G text "
32
+ "encoders, MMDiT generative model, and VAE autoencoder. "
33
+ "A timestep-distilled version that eliminates classifier-free "
34
+ "guidance and uses fewer steps for generation. "
35
+ "Developed by Stability AI."
36
+ ),
37
+ "params": 9048410595,
38
+ "path": "stable_diffusion_3",
39
+ },
40
+ "kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large_turbo/1",
41
+ },
18
42
  }
@@ -27,7 +27,7 @@ class StableDiffusion3TextToImage(TextToImage):
27
27
  Use `generate()` to do image generation.
28
28
  ```python
29
29
  text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset(
30
- "stable_diffusion_3_medium", height=512, width=512
30
+ "stable_diffusion_3_medium", image_shape=(512, 512, 3)
31
31
  )
32
32
  text_to_image.generate(
33
33
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
@@ -42,11 +42,12 @@ class T5Backbone(Backbone):
42
42
  projections in the multi-head attention layers. Defaults to
43
43
  hidden_dim / num_heads.
44
44
  dropout: float. Dropout probability for the Transformer layers.
45
- activation: activation function (or activation string name). The
46
- activation to be used in the inner dense blocks of the
47
- Transformer layers. Defaults to `"relu"`.
45
+ activation: string. The activation function to use in the dense blocks
46
+ of the Transformer Layers.
48
47
  use_gated_activation: boolean. Whether to use activation gating in
49
- the inner dense blocks of the Transformer layers.
48
+ the inner dense blocks of the Transformer layers. When used with
49
+ the GELU activation function, this is referred to as GEGLU
50
+ (gated GLU) from https://arxiv.org/pdf/2002.05202.
50
51
  The original T5 architecture didn't use gating, but more
51
52
  recent versions do. Defaults to `True`.
52
53
  layer_norm_epsilon: float. Epsilon factor to be used in the
@@ -1,4 +1,4 @@
1
- """XLM-RoBERTa model preset configurations."""
1
+ """T5 model preset configurations."""
2
2
 
3
3
  backbone_presets = {
4
4
  "t5_small_multi": {
@@ -8,12 +8,18 @@ backbone_presets = {
8
8
  "Corpus (C4)."
9
9
  ),
10
10
  "params": 0,
11
- "official_name": "T5",
12
11
  "path": "t5",
13
- "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
14
12
  },
15
13
  "kaggle_handle": "kaggle://keras/t5/keras/t5_small_multi/2",
16
14
  },
15
+ "t5_1.1_small": {
16
+ "metadata": {
17
+ "description": (""),
18
+ "params": 60511616,
19
+ "path": "t5",
20
+ },
21
+ "kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_small/1",
22
+ },
17
23
  "t5_base_multi": {
18
24
  "metadata": {
19
25
  "description": (
@@ -21,12 +27,18 @@ backbone_presets = {
21
27
  "Corpus (C4)."
22
28
  ),
23
29
  "params": 0,
24
- "official_name": "T5",
25
30
  "path": "t5",
26
- "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
27
31
  },
28
32
  "kaggle_handle": "kaggle://keras/t5/keras/t5_base_multi/2",
29
33
  },
34
+ "t5_1.1_base": {
35
+ "metadata": {
36
+ "description": (""),
37
+ "params": 247577856,
38
+ "path": "t5",
39
+ },
40
+ "kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_base/1",
41
+ },
30
42
  "t5_large_multi": {
31
43
  "metadata": {
32
44
  "description": (
@@ -34,12 +46,34 @@ backbone_presets = {
34
46
  "Corpus (C4)."
35
47
  ),
36
48
  "params": 0,
37
- "official_name": "T5",
38
49
  "path": "t5",
39
- "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
40
50
  },
41
51
  "kaggle_handle": "kaggle://keras/t5/keras/t5_large_multi/2",
42
52
  },
53
+ "t5_1.1_large": {
54
+ "metadata": {
55
+ "description": (""),
56
+ "params": 750251008,
57
+ "path": "t5",
58
+ },
59
+ "kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_large/1",
60
+ },
61
+ "t5_1.1_xl": {
62
+ "metadata": {
63
+ "description": (""),
64
+ "params": 2849757184,
65
+ "path": "t5",
66
+ },
67
+ "kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_xl/1",
68
+ },
69
+ "t5_1.1_xxl": {
70
+ "metadata": {
71
+ "description": (""),
72
+ "params": 11135332352,
73
+ "path": "t5",
74
+ },
75
+ "kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_xxl/1",
76
+ },
43
77
  "flan_small_multi": {
44
78
  "metadata": {
45
79
  "description": (
@@ -47,9 +81,7 @@ backbone_presets = {
47
81
  "Corpus (C4)."
48
82
  ),
49
83
  "params": 0,
50
- "official_name": "T5",
51
84
  "path": "t5",
52
- "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
53
85
  },
54
86
  "kaggle_handle": "kaggle://keras/t5/keras/flan_small_multi/2",
55
87
  },
@@ -60,9 +92,7 @@ backbone_presets = {
60
92
  "Corpus (C4)."
61
93
  ),
62
94
  "params": 0,
63
- "official_name": "T5",
64
95
  "path": "t5",
65
- "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
66
96
  },
67
97
  "kaggle_handle": "kaggle://keras/t5/keras/flan_base_multi/2",
68
98
  },
@@ -73,9 +103,7 @@ backbone_presets = {
73
103
  "Corpus (C4)."
74
104
  ),
75
105
  "params": 0,
76
- "official_name": "T5",
77
106
  "path": "t5",
78
- "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
79
107
  },
80
108
  "kaggle_handle": "kaggle://keras/t5/keras/flan_large_multi/2",
81
109
  },
@@ -249,7 +249,7 @@ class TextToImage(Task):
249
249
  self,
250
250
  inputs,
251
251
  num_steps,
252
- guidance_scale,
252
+ guidance_scale=None,
253
253
  seed=None,
254
254
  ):
255
255
  """Generate image based on the provided `inputs`.
@@ -283,15 +283,23 @@ class TextToImage(Task):
283
283
  - A `tf.data.Dataset` with "prompts" and/or "negative_prompts"
284
284
  keys
285
285
  num_steps: int. The number of diffusion steps to take.
286
- guidance_scale: float. The classifier free guidance scale defined in
287
- [Classifier-Free Diffusion Guidance](
286
+ guidance_scale: Optional float. The classifier free guidance scale
287
+ defined in [Classifier-Free Diffusion Guidance](
288
288
  https://arxiv.org/abs/2207.12598). A higher scale encourages
289
289
  generating images more closely related to the prompts, typically
290
- at the cost of lower image quality.
290
+ at the cost of lower image quality. Note that some models don't
291
+ utilize classifier-free guidance.
291
292
  seed: optional int. Used as a random seed.
292
293
  """
294
+ num_steps = int(num_steps)
295
+ guidance_scale = (
296
+ float(guidance_scale) if guidance_scale is not None else None
297
+ )
293
298
  num_steps = ops.convert_to_tensor(num_steps, "int32")
294
- guidance_scale = ops.convert_to_tensor(guidance_scale)
299
+ if guidance_scale is not None and guidance_scale > 1.0:
300
+ guidance_scale = ops.convert_to_tensor(guidance_scale)
301
+ else:
302
+ guidance_scale = None
295
303
 
296
304
  # Setup our three main passes.
297
305
  # 1. Preprocessing strings to dense integer tensors.
@@ -27,7 +27,7 @@ class VGGBackbone(Backbone):
27
27
  input_data = np.ones((2, 224, 224, 3), dtype="float32")
28
28
 
29
29
  # Pretrained VGG backbone.
30
- model = keras_hub.models.VGGBackbone.from_preset("vgg16")
30
+ model = keras_hub.models.VGGBackbone.from_preset("vgg_16_imagenet")
31
31
  model(input_data)
32
32
 
33
33
  # Randomly initialized VGG backbone with a custom config.
@@ -8,9 +8,7 @@ backbone_presets = {
8
8
  "at a 224x224 resolution."
9
9
  ),
10
10
  "params": 9220480,
11
- "official_name": "vgg",
12
11
  "path": "vgg",
13
- "model_card": "https://arxiv.org/abs/1409.1556",
14
12
  },
15
13
  "kaggle_handle": "kaggle://keras/vgg/keras/vgg_11_imagenet/1",
16
14
  },
@@ -21,9 +19,7 @@ backbone_presets = {
21
19
  "at a 224x224 resolution."
22
20
  ),
23
21
  "params": 9404992,
24
- "official_name": "vgg",
25
22
  "path": "vgg",
26
- "model_card": "https://arxiv.org/abs/1409.1556",
27
23
  },
28
24
  "kaggle_handle": "kaggle://keras/vgg/keras/vgg_13_imagenet/1",
29
25
  },
@@ -34,9 +30,7 @@ backbone_presets = {
34
30
  "at a 224x224 resolution."
35
31
  ),
36
32
  "params": 14714688,
37
- "official_name": "vgg",
38
33
  "path": "vgg",
39
- "model_card": "https://arxiv.org/abs/1409.1556",
40
34
  },
41
35
  "kaggle_handle": "kaggle://keras/vgg/keras/vgg_16_imagenet/1",
42
36
  },
@@ -47,9 +41,7 @@ backbone_presets = {
47
41
  "at a 224x224 resolution."
48
42
  ),
49
43
  "params": 20024384,
50
- "official_name": "vgg",
51
44
  "path": "vgg",
52
- "model_card": "https://arxiv.org/abs/1409.1556",
53
45
  },
54
46
  "kaggle_handle": "kaggle://keras/vgg/keras/vgg_19_imagenet/1",
55
47
  },
@@ -39,7 +39,7 @@ class WhisperAudioConverter(AudioConverter):
39
39
  audio_tensor = tf.ones((8000,), dtype="float32")
40
40
 
41
41
  # Compute the log-mel spectrogram.
42
- audio_converter = keras_hub.models.WhisperAudioConverter.from_preset(
42
+ audio_converter = keras_hub.layers.WhisperAudioConverter.from_preset(
43
43
  "whisper_base_en",
44
44
  )
45
45
  audio_converter(audio_tensor)