keras-hub-nightly 0.19.0.dev202501080345__py3-none-any.whl → 0.19.0.dev202501090358__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,9 +15,8 @@ class AdaptiveLayerNormalization(layers.Layer):
15
15
 
16
16
  Args:
17
17
  embedding_dim: int. The size of each embedding vector.
18
- residual_modulation: bool. Whether to output the modulation parameters
19
- of the residual connection within the block of the diffusion
20
- transformers. Defaults to `False`.
18
+ num_modulations: int. The number of the modulation parameters. The
19
+ available values are `2`, `6` and `9`. Defaults to `2`.
21
20
  **kwargs: other keyword arguments passed to `keras.layers.Layer`,
22
21
  including `name`, `dtype` etc.
23
22
 
@@ -28,11 +27,17 @@ class AdaptiveLayerNormalization(layers.Layer):
28
27
  https://arxiv.org/abs/2212.09748).
29
28
  """
30
29
 
31
- def __init__(self, hidden_dim, residual_modulation=False, **kwargs):
30
+ def __init__(self, hidden_dim, num_modulations=2, **kwargs):
32
31
  super().__init__(**kwargs)
33
- self.hidden_dim = int(hidden_dim)
34
- self.residual_modulation = bool(residual_modulation)
35
- num_modulations = 6 if self.residual_modulation else 2
32
+ hidden_dim = int(hidden_dim)
33
+ num_modulations = int(num_modulations)
34
+ if num_modulations not in (2, 6, 9):
35
+ raise ValueError(
36
+ "`num_modulations` must be `2`, `6` or `9`. "
37
+ f"Received: num_modulations={num_modulations}"
38
+ )
39
+ self.hidden_dim = hidden_dim
40
+ self.num_modulations = num_modulations
36
41
 
37
42
  self.silu = layers.Activation("silu", dtype=self.dtype_policy)
38
43
  self.dense = layers.Dense(
@@ -52,40 +57,84 @@ class AdaptiveLayerNormalization(layers.Layer):
52
57
  self.norm.build(inputs_shape)
53
58
 
54
59
  def call(self, inputs, embeddings, training=None):
55
- x = inputs
60
+ hidden_states = inputs
56
61
  emb = self.dense(self.silu(embeddings), training=training)
57
- if self.residual_modulation:
58
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
59
- ops.split(emb, 6, axis=1)
60
- )
62
+ if self.num_modulations == 9:
63
+ (
64
+ shift_msa,
65
+ scale_msa,
66
+ gate_msa,
67
+ shift_mlp,
68
+ scale_mlp,
69
+ gate_mlp,
70
+ shift_msa2,
71
+ scale_msa2,
72
+ gate_msa2,
73
+ ) = ops.split(emb, self.num_modulations, axis=1)
74
+ elif self.num_modulations == 6:
75
+ (
76
+ shift_msa,
77
+ scale_msa,
78
+ gate_msa,
79
+ shift_mlp,
80
+ scale_mlp,
81
+ gate_mlp,
82
+ ) = ops.split(emb, self.num_modulations, axis=1)
61
83
  else:
62
- shift_msa, scale_msa = ops.split(emb, 2, axis=1)
84
+ shift_msa, scale_msa = ops.split(emb, self.num_modulations, axis=1)
85
+
63
86
  scale_msa = ops.expand_dims(scale_msa, axis=1)
64
87
  shift_msa = ops.expand_dims(shift_msa, axis=1)
65
- x = ops.add(
66
- ops.multiply(
67
- self.norm(x, training=training),
68
- ops.add(1.0, scale_msa),
69
- ),
70
- shift_msa,
88
+ norm_hidden_states = ops.cast(
89
+ self.norm(hidden_states, training=training), scale_msa.dtype
90
+ )
91
+ hidden_states = ops.add(
92
+ ops.multiply(norm_hidden_states, ops.add(1.0, scale_msa)), shift_msa
71
93
  )
72
- if self.residual_modulation:
73
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
94
+
95
+ if self.num_modulations == 9:
96
+ scale_msa2 = ops.expand_dims(scale_msa2, axis=1)
97
+ shift_msa2 = ops.expand_dims(shift_msa2, axis=1)
98
+ hidden_states2 = ops.add(
99
+ ops.multiply(norm_hidden_states, ops.add(1.0, scale_msa2)),
100
+ shift_msa2,
101
+ )
102
+ return (
103
+ hidden_states,
104
+ gate_msa,
105
+ shift_mlp,
106
+ scale_mlp,
107
+ gate_mlp,
108
+ hidden_states2,
109
+ gate_msa2,
110
+ )
111
+ elif self.num_modulations == 6:
112
+ return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp
74
113
  else:
75
- return x
114
+ return hidden_states
76
115
 
77
116
  def get_config(self):
78
117
  config = super().get_config()
79
118
  config.update(
80
119
  {
81
120
  "hidden_dim": self.hidden_dim,
82
- "residual_modulation": self.residual_modulation,
121
+ "num_modulations": self.num_modulations,
83
122
  }
84
123
  )
85
124
  return config
86
125
 
87
126
  def compute_output_shape(self, inputs_shape, embeddings_shape):
88
- if self.residual_modulation:
127
+ if self.num_modulations == 9:
128
+ return (
129
+ inputs_shape,
130
+ embeddings_shape,
131
+ embeddings_shape,
132
+ embeddings_shape,
133
+ embeddings_shape,
134
+ inputs_shape,
135
+ embeddings_shape,
136
+ )
137
+ elif self.num_modulations == 6:
89
138
  return (
90
139
  inputs_shape,
91
140
  embeddings_shape,
@@ -345,6 +394,27 @@ class TimestepEmbedding(layers.Layer):
345
394
  return output_shape
346
395
 
347
396
 
397
+ def get_qk_norm(qk_norm=None, q_norm_name="q_norm", k_norm_name="k_norm"):
398
+ """Helper function to instantiate `LayerNormalization` layers."""
399
+ q_norm = None
400
+ k_norm = None
401
+ if qk_norm is None:
402
+ pass
403
+ elif qk_norm == "rms_norm":
404
+ q_norm = layers.LayerNormalization(
405
+ epsilon=1e-6, rms_scaling=True, dtype="float32", name=q_norm_name
406
+ )
407
+ k_norm = layers.LayerNormalization(
408
+ epsilon=1e-6, rms_scaling=True, dtype="float32", name=k_norm_name
409
+ )
410
+ else:
411
+ raise NotImplementedError(
412
+ "Supported `qk_norm` are `'rms_norm'` and `None`. "
413
+ f"Received: qk_norm={qk_norm}."
414
+ )
415
+ return q_norm, k_norm
416
+
417
+
348
418
  class DismantledBlock(layers.Layer):
349
419
  """A dismantled block used to compute pre- and post-attention.
350
420
 
@@ -356,6 +426,8 @@ class DismantledBlock(layers.Layer):
356
426
  the end of the block.
357
427
  qk_norm: Optional str. Whether to normalize the query and key tensors.
358
428
  Available options are `None` and `"rms_norm"`. Defaults to `None`.
429
+ use_dual_attention: bool. Whether to use a dual attention in the
430
+ block. Defaults to `False`.
359
431
  **kwargs: other keyword arguments passed to `keras.layers.Layer`,
360
432
  including `name`, `dtype` etc.
361
433
  """
@@ -367,6 +439,7 @@ class DismantledBlock(layers.Layer):
367
439
  mlp_ratio=4.0,
368
440
  use_projection=True,
369
441
  qk_norm=None,
442
+ use_dual_attention=False,
370
443
  **kwargs,
371
444
  ):
372
445
  super().__init__(**kwargs)
@@ -375,6 +448,7 @@ class DismantledBlock(layers.Layer):
375
448
  self.mlp_ratio = mlp_ratio
376
449
  self.use_projection = use_projection
377
450
  self.qk_norm = qk_norm
451
+ self.use_dual_attention = use_dual_attention
378
452
 
379
453
  head_dim = hidden_dim // num_heads
380
454
  self.head_dim = head_dim
@@ -384,7 +458,7 @@ class DismantledBlock(layers.Layer):
384
458
  if use_projection:
385
459
  self.ada_layer_norm = AdaptiveLayerNormalization(
386
460
  hidden_dim,
387
- residual_modulation=True,
461
+ num_modulations=9 if use_dual_attention else 6,
388
462
  dtype=self.dtype_policy,
389
463
  name="ada_layer_norm",
390
464
  )
@@ -395,18 +469,10 @@ class DismantledBlock(layers.Layer):
395
469
  self.attention_qkv = layers.Dense(
396
470
  hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv"
397
471
  )
398
- if qk_norm is not None and qk_norm == "rms_norm":
399
- self.q_norm = layers.LayerNormalization(
400
- epsilon=1e-6, rms_scaling=True, dtype="float32", name="q_norm"
401
- )
402
- self.k_norm = layers.LayerNormalization(
403
- epsilon=1e-6, rms_scaling=True, dtype="float32", name="q_norm"
404
- )
405
- elif qk_norm is not None:
406
- raise NotImplementedError(
407
- "Supported `qk_norm` are `'rms_norm'` and `None`. "
408
- f"Received: qk_norm={qk_norm}."
409
- )
472
+ q_norm, k_norm = get_qk_norm(qk_norm)
473
+ if q_norm is not None:
474
+ self.q_norm = q_norm
475
+ self.k_norm = k_norm
410
476
  if use_projection:
411
477
  self.attention_proj = layers.Dense(
412
478
  hidden_dim, dtype=self.dtype_policy, name="attention_proj"
@@ -426,6 +492,19 @@ class DismantledBlock(layers.Layer):
426
492
  name="mlp",
427
493
  )
428
494
 
495
+ if use_dual_attention:
496
+ self.attention_qkv2 = layers.Dense(
497
+ hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv2"
498
+ )
499
+ q_norm2, k_norm2 = get_qk_norm(qk_norm, "q_norm2", "k_norm2")
500
+ if q_norm is not None:
501
+ self.q_norm2 = q_norm2
502
+ self.k_norm2 = k_norm2
503
+ if use_projection:
504
+ self.attention_proj2 = layers.Dense(
505
+ hidden_dim, dtype=self.dtype_policy, name="attention_proj2"
506
+ )
507
+
429
508
  def build(self, inputs_shape, timestep_embedding):
430
509
  self.ada_layer_norm.build(inputs_shape, timestep_embedding)
431
510
  self.attention_qkv.build(inputs_shape)
@@ -437,6 +516,13 @@ class DismantledBlock(layers.Layer):
437
516
  self.attention_proj.build(inputs_shape)
438
517
  self.norm2.build(inputs_shape)
439
518
  self.mlp.build(inputs_shape)
519
+ if self.use_dual_attention:
520
+ self.attention_qkv2.build(inputs_shape)
521
+ if self.qk_norm is not None:
522
+ self.q_norm2.build([None, None, self.num_heads, self.head_dim])
523
+ self.k_norm2.build([None, None, self.num_heads, self.head_dim])
524
+ if self.use_projection:
525
+ self.attention_proj2.build(inputs_shape)
440
526
 
441
527
  def _modulate(self, inputs, shift, scale):
442
528
  inputs = ops.cast(inputs, self.compute_dtype)
@@ -456,8 +542,12 @@ class DismantledBlock(layers.Layer):
456
542
  )
457
543
  q, k, v = ops.unstack(qkv, 3, axis=2)
458
544
  if self.qk_norm is not None:
459
- q = self.q_norm(q, training=training)
460
- k = self.k_norm(k, training=training)
545
+ q = ops.cast(
546
+ self.q_norm(q, training=training), self.compute_dtype
547
+ )
548
+ k = ops.cast(
549
+ self.k_norm(k, training=training), self.compute_dtype
550
+ )
461
551
  return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp)
462
552
  else:
463
553
  x = self.ada_layer_norm(
@@ -469,8 +559,12 @@ class DismantledBlock(layers.Layer):
469
559
  )
470
560
  q, k, v = ops.unstack(qkv, 3, axis=2)
471
561
  if self.qk_norm is not None:
472
- q = self.q_norm(q, training=training)
473
- k = self.k_norm(k, training=training)
562
+ q = ops.cast(
563
+ self.q_norm(q, training=training), self.compute_dtype
564
+ )
565
+ k = ops.cast(
566
+ self.k_norm(k, training=training), self.compute_dtype
567
+ )
474
568
  return (q, k, v)
475
569
 
476
570
  def _compute_post_attention(
@@ -495,22 +589,95 @@ class DismantledBlock(layers.Layer):
495
589
  )
496
590
  return x
497
591
 
592
+ def _compute_pre_attention_with_dual_attention(
593
+ self, inputs, timestep_embedding, training=None
594
+ ):
595
+ batch_size = ops.shape(inputs)[0]
596
+ x, gate_msa, shift_mlp, scale_mlp, gate_mlp, x2, gate_msa2 = (
597
+ self.ada_layer_norm(inputs, timestep_embedding, training=training)
598
+ )
599
+ # Compute the main attention
600
+ qkv = self.attention_qkv(x, training=training)
601
+ qkv = ops.reshape(
602
+ qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
603
+ )
604
+ q, k, v = ops.unstack(qkv, 3, axis=2)
605
+ if self.qk_norm is not None:
606
+ q = ops.cast(self.q_norm(q, training=training), self.compute_dtype)
607
+ k = ops.cast(self.k_norm(k, training=training), self.compute_dtype)
608
+ # Compute the dual attention
609
+ qkv2 = self.attention_qkv2(x2, training=training)
610
+ qkv2 = ops.reshape(
611
+ qkv2, (batch_size, -1, 3, self.num_heads, self.head_dim)
612
+ )
613
+ q2, k2, v2 = ops.unstack(qkv2, 3, axis=2)
614
+ if self.qk_norm is not None:
615
+ q2 = ops.cast(
616
+ self.q_norm2(q2, training=training), self.compute_dtype
617
+ )
618
+ k2 = ops.cast(
619
+ self.k_norm2(k2, training=training), self.compute_dtype
620
+ )
621
+ return (
622
+ (q, k, v),
623
+ (q2, k2, v2),
624
+ (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2),
625
+ )
626
+
627
+ def _compute_post_attention_with_dual_attention(
628
+ self, inputs, inputs2, inputs_intermediates, training=None
629
+ ):
630
+ x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2 = (
631
+ inputs_intermediates
632
+ )
633
+ gate_msa = ops.expand_dims(gate_msa, axis=1)
634
+ shift_mlp = ops.expand_dims(shift_mlp, axis=1)
635
+ scale_mlp = ops.expand_dims(scale_mlp, axis=1)
636
+ gate_mlp = ops.expand_dims(gate_mlp, axis=1)
637
+ gate_msa2 = ops.expand_dims(gate_msa2, axis=1)
638
+ attn = self.attention_proj(inputs, training=training)
639
+ x = ops.add(x, ops.multiply(gate_msa, attn))
640
+ attn2 = self.attention_proj2(inputs2, training=training)
641
+ x = ops.add(x, ops.multiply(gate_msa2, attn2))
642
+ x = ops.add(
643
+ x,
644
+ ops.multiply(
645
+ gate_mlp,
646
+ self.mlp(
647
+ self._modulate(self.norm2(x), shift_mlp, scale_mlp),
648
+ training=training,
649
+ ),
650
+ ),
651
+ )
652
+ return x
653
+
498
654
  def call(
499
655
  self,
500
656
  inputs,
501
657
  timestep_embedding=None,
502
658
  inputs_intermediates=None,
659
+ inputs2=None, # For the dual attention.
503
660
  pre_attention=True,
504
661
  training=None,
505
662
  ):
506
663
  if pre_attention:
507
- return self._compute_pre_attention(
508
- inputs, timestep_embedding, training=training
509
- )
664
+ if self.use_dual_attention:
665
+ return self._compute_pre_attention_with_dual_attention(
666
+ inputs, timestep_embedding, training=training
667
+ )
668
+ else:
669
+ return self._compute_pre_attention(
670
+ inputs, timestep_embedding, training=training
671
+ )
510
672
  else:
511
- return self._compute_post_attention(
512
- inputs, inputs_intermediates, training=training
513
- )
673
+ if self.use_dual_attention:
674
+ return self._compute_post_attention_with_dual_attention(
675
+ inputs, inputs2, inputs_intermediates, training=training
676
+ )
677
+ else:
678
+ return self._compute_post_attention(
679
+ inputs, inputs_intermediates, training=training
680
+ )
514
681
 
515
682
  def get_config(self):
516
683
  config = super().get_config()
@@ -521,6 +688,7 @@ class DismantledBlock(layers.Layer):
521
688
  "mlp_ratio": self.mlp_ratio,
522
689
  "use_projection": self.use_projection,
523
690
  "qk_norm": self.qk_norm,
691
+ "use_dual_attention": self.use_dual_attention,
524
692
  }
525
693
  )
526
694
  return config
@@ -542,6 +710,8 @@ class MMDiTBlock(layers.Layer):
542
710
  layer at the end of the context block.
543
711
  qk_norm: Optional str. Whether to normalize the query and key tensors.
544
712
  Available options are `None` and `"rms_norm"`. Defaults to `None`.
713
+ use_dual_attention: bool. Whether to use a dual attention in the
714
+ block. Defaults to `False`.
545
715
  **kwargs: other keyword arguments passed to `keras.layers.Layer`,
546
716
  including `name`, `dtype` etc.
547
717
 
@@ -557,6 +727,7 @@ class MMDiTBlock(layers.Layer):
557
727
  mlp_ratio=4.0,
558
728
  use_context_projection=True,
559
729
  qk_norm=None,
730
+ use_dual_attention=False,
560
731
  **kwargs,
561
732
  ):
562
733
  super().__init__(**kwargs)
@@ -565,6 +736,7 @@ class MMDiTBlock(layers.Layer):
565
736
  self.mlp_ratio = mlp_ratio
566
737
  self.use_context_projection = use_context_projection
567
738
  self.qk_norm = qk_norm
739
+ self.use_dual_attention = use_dual_attention
568
740
 
569
741
  head_dim = hidden_dim // num_heads
570
742
  self.head_dim = head_dim
@@ -576,6 +748,7 @@ class MMDiTBlock(layers.Layer):
576
748
  mlp_ratio=mlp_ratio,
577
749
  use_projection=True,
578
750
  qk_norm=qk_norm,
751
+ use_dual_attention=use_dual_attention,
579
752
  dtype=self.dtype_policy,
580
753
  name="x_block",
581
754
  )
@@ -602,8 +775,6 @@ class MMDiTBlock(layers.Layer):
602
775
  if hasattr(ops, "dot_product_attention") and hasattr(
603
776
  keras.config, "is_flash_attention_enabled"
604
777
  ):
605
- # `ops.dot_product_attention` is slower than the vanilla
606
- # implementation in the tensorflow backend.
607
778
  encoded = ops.dot_product_attention(
608
779
  query,
609
780
  key,
@@ -643,9 +814,14 @@ class MMDiTBlock(layers.Layer):
643
814
  training=training,
644
815
  )
645
816
  context_len = ops.shape(context_qkv[0])[1]
646
- x_qkv, x_intermediates = self.x_block(
647
- x, timestep_embedding=timestep_embedding, training=training
648
- )
817
+ if self.x_block.use_dual_attention:
818
+ x_qkv, x_qkv2, x_intermediates = self.x_block(
819
+ x, timestep_embedding=timestep_embedding, training=training
820
+ )
821
+ else:
822
+ x_qkv, x_intermediates = self.x_block(
823
+ x, timestep_embedding=timestep_embedding, training=training
824
+ )
649
825
  q = ops.concatenate([context_qkv[0], x_qkv[0]], axis=1)
650
826
  k = ops.concatenate([context_qkv[1], x_qkv[1]], axis=1)
651
827
  v = ops.concatenate([context_qkv[2], x_qkv[2]], axis=1)
@@ -656,12 +832,23 @@ class MMDiTBlock(layers.Layer):
656
832
  x_attention = attention[:, context_len:]
657
833
 
658
834
  # Compute post-attention.
659
- x = self.x_block(
660
- x_attention,
661
- inputs_intermediates=x_intermediates,
662
- pre_attention=False,
663
- training=training,
664
- )
835
+ if self.x_block.use_dual_attention:
836
+ q2, k2, v2 = x_qkv2
837
+ x_attention2 = self._compute_attention(q2, k2, v2)
838
+ x = self.x_block(
839
+ x_attention,
840
+ inputs_intermediates=x_intermediates,
841
+ inputs2=x_attention2,
842
+ pre_attention=False,
843
+ training=training,
844
+ )
845
+ else:
846
+ x = self.x_block(
847
+ x_attention,
848
+ inputs_intermediates=x_intermediates,
849
+ pre_attention=False,
850
+ training=training,
851
+ )
665
852
  if self.use_context_projection:
666
853
  context = self.context_block(
667
854
  context_attention,
@@ -682,6 +869,7 @@ class MMDiTBlock(layers.Layer):
682
869
  "mlp_ratio": self.mlp_ratio,
683
870
  "use_context_projection": self.use_context_projection,
684
871
  "qk_norm": self.qk_norm,
872
+ "use_dual_attention": self.use_dual_attention,
685
873
  }
686
874
  )
687
875
  return config
@@ -761,6 +949,9 @@ class MMDiT(Backbone):
761
949
  qk_norm: Optional str. Whether to normalize the query and key tensors in
762
950
  the intermediate blocks. Available options are `None` and
763
951
  `"rms_norm"`. Defaults to `None`.
952
+ dual_attention_indices: Optional tuple. Specifies the indices of
953
+ the blocks that serve as dual attention blocks. Typically, this is
954
+ for 3.5 version. Defaults to `None`.
764
955
  data_format: `None` or str. If specified, either `"channels_last"` or
765
956
  `"channels_first"`. The ordering of the dimensions in the
766
957
  inputs. `"channels_last"` corresponds to inputs with shape
@@ -786,6 +977,7 @@ class MMDiT(Backbone):
786
977
  context_shape=(None, 4096),
787
978
  pooled_projection_shape=(2048,),
788
979
  qk_norm=None,
980
+ dual_attention_indices=None,
789
981
  data_format=None,
790
982
  dtype=None,
791
983
  **kwargs,
@@ -799,6 +991,7 @@ class MMDiT(Backbone):
799
991
  image_width = latent_shape[1] // patch_size
800
992
  output_dim = latent_shape[-1]
801
993
  output_dim_in_final = patch_size**2 * output_dim
994
+ dual_attention_indices = dual_attention_indices or ()
802
995
  data_format = standardize_data_format(data_format)
803
996
  if data_format != "channels_last":
804
997
  raise NotImplementedError(
@@ -840,6 +1033,7 @@ class MMDiT(Backbone):
840
1033
  mlp_ratio,
841
1034
  use_context_projection=not (i == num_layers - 1),
842
1035
  qk_norm=qk_norm,
1036
+ use_dual_attention=i in dual_attention_indices,
843
1037
  dtype=dtype,
844
1038
  name=f"joint_block_{i}",
845
1039
  )
@@ -910,6 +1104,7 @@ class MMDiT(Backbone):
910
1104
  self.context_shape = context_shape
911
1105
  self.pooled_projection_shape = pooled_projection_shape
912
1106
  self.qk_norm = qk_norm
1107
+ self.dual_attention_indices = dual_attention_indices
913
1108
 
914
1109
  def get_config(self):
915
1110
  config = super().get_config()
@@ -925,6 +1120,7 @@ class MMDiT(Backbone):
925
1120
  "context_shape": self.context_shape,
926
1121
  "pooled_projection_shape": self.pooled_projection_shape,
927
1122
  "qk_norm": self.qk_norm,
1123
+ "dual_attention_indices": self.dual_attention_indices,
928
1124
  }
929
1125
  )
930
1126
  return config
@@ -205,7 +205,10 @@ class StableDiffusion3Backbone(Backbone):
205
205
  mmdit_qk_norm: Optional str. Whether to normalize the query and key
206
206
  tensors for each transformer in MMDiT. Available options are `None`
207
207
  and `"rms_norm"`. Typically, this is set to `None` for 3.0 version
208
- and to `"rms_norm" for 3.5 version.
208
+ and to `"rms_norm"` for 3.5 version.
209
+ mmdit_dual_attention_indices: Optional tuple. Specifies the indices of
210
+ the blocks that serve as dual attention blocks. Typically, this is
211
+ for 3.5 version. Defaults to `None`.
209
212
  vae: The VAE used for transformations between pixel space and latent
210
213
  space.
211
214
  clip_l: The CLIP text encoder for encoding the inputs.
@@ -253,6 +256,7 @@ class StableDiffusion3Backbone(Backbone):
253
256
  mmdit_depth=4,
254
257
  mmdit_position_size=192,
255
258
  mmdit_qk_norm=None,
259
+ mmdit_dual_attention_indices=None,
256
260
  vae=vae,
257
261
  clip_l=clip_l,
258
262
  clip_g=clip_g,
@@ -268,6 +272,7 @@ class StableDiffusion3Backbone(Backbone):
268
272
  mmdit_num_heads,
269
273
  mmdit_position_size,
270
274
  mmdit_qk_norm,
275
+ mmdit_dual_attention_indices,
271
276
  vae,
272
277
  clip_l,
273
278
  clip_g,
@@ -319,6 +324,7 @@ class StableDiffusion3Backbone(Backbone):
319
324
  context_shape=context_shape,
320
325
  pooled_projection_shape=pooled_projection_shape,
321
326
  qk_norm=mmdit_qk_norm,
327
+ dual_attention_indices=mmdit_dual_attention_indices,
322
328
  data_format=data_format,
323
329
  dtype=dtype,
324
330
  name="diffuser",
@@ -454,6 +460,7 @@ class StableDiffusion3Backbone(Backbone):
454
460
  self.mmdit_num_heads = mmdit_num_heads
455
461
  self.mmdit_position_size = mmdit_position_size
456
462
  self.mmdit_qk_norm = mmdit_qk_norm
463
+ self.mmdit_dual_attention_indices = mmdit_dual_attention_indices
457
464
  self.latent_channels = latent_channels
458
465
  self.output_channels = output_channels
459
466
  self.num_train_timesteps = num_train_timesteps
@@ -590,6 +597,9 @@ class StableDiffusion3Backbone(Backbone):
590
597
  "mmdit_num_heads": self.mmdit_num_heads,
591
598
  "mmdit_position_size": self.mmdit_position_size,
592
599
  "mmdit_qk_norm": self.mmdit_qk_norm,
600
+ "mmdit_dual_attention_indices": (
601
+ self.mmdit_dual_attention_indices
602
+ ),
593
603
  "vae": layers.serialize(self.vae),
594
604
  "clip_l": layers.serialize(self.clip_l),
595
605
  "clip_g": layers.serialize(self.clip_g),
@@ -638,7 +648,10 @@ class StableDiffusion3Backbone(Backbone):
638
648
  )
639
649
 
640
650
  # To maintain backward compatibility, we need to ensure that
641
- # `mmdit_qk_norm` is included in the config.
651
+ # `mmdit_qk_norm` and `mmdit_dual_attention_indices` is included in the
652
+ # config.
642
653
  if "mmdit_qk_norm" not in config:
643
654
  config["mmdit_qk_norm"] = None
655
+ if "mmdit_dual_attention_indices" not in config:
656
+ config["mmdit_dual_attention_indices"] = None
644
657
  return cls(**config)
@@ -13,6 +13,18 @@ backbone_presets = {
13
13
  },
14
14
  "kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/4",
15
15
  },
16
+ "stable_diffusion_3.5_medium": {
17
+ "metadata": {
18
+ "description": (
19
+ "3 billion parameter, including CLIP L and CLIP G text "
20
+ "encoders, MMDiT-X generative model, and VAE autoencoder. "
21
+ "Developed by Stability AI."
22
+ ),
23
+ "params": 3371793763,
24
+ "path": "stable_diffusion_3",
25
+ },
26
+ "kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3.5_medium/1",
27
+ },
16
28
  "stable_diffusion_3.5_large": {
17
29
  "metadata": {
18
30
  "description": (
@@ -1,7 +1,7 @@
1
1
  from keras_hub.src.api_export import keras_hub_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.19.0.dev202501080345"
4
+ __version__ = "0.19.0.dev202501090358"
5
5
 
6
6
 
7
7
  @keras_hub_export("keras_hub.version")
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: keras-hub-nightly
3
- Version: 0.19.0.dev202501080345
3
+ Version: 0.19.0.dev202501090358
4
4
  Summary: Industry-strength Natural Language Processing extensions for Keras.
5
5
  Home-page: https://github.com/keras-team/keras-hub
6
6
  Author: Keras team
@@ -31,6 +31,17 @@ Requires-Dist: tensorflow-text
31
31
  Provides-Extra: extras
32
32
  Requires-Dist: rouge-score; extra == "extras"
33
33
  Requires-Dist: sentencepiece; extra == "extras"
34
+ Dynamic: author
35
+ Dynamic: author-email
36
+ Dynamic: classifier
37
+ Dynamic: description
38
+ Dynamic: description-content-type
39
+ Dynamic: home-page
40
+ Dynamic: license
41
+ Dynamic: provides-extra
42
+ Dynamic: requires-dist
43
+ Dynamic: requires-python
44
+ Dynamic: summary
34
45
 
35
46
  # KerasHub: Multi-framework Pretrained Models
36
47
  [![](https://github.com/keras-team/keras-hub/workflows/Tests/badge.svg?branch=master)](https://github.com/keras-team/keras-hub/actions?query=workflow%3ATests+branch%3Amaster)
@@ -9,7 +9,7 @@ keras_hub/api/tokenizers/__init__.py,sha256=mtJgQy1spfQnPAkeLoeinsT_W9iCWHlJXwzc
9
9
  keras_hub/api/utils/__init__.py,sha256=Gp1E6gG-RtKQS3PBEQEOz9PQvXkXaJ0ySGMqZ7myN7A,215
10
10
  keras_hub/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  keras_hub/src/api_export.py,sha256=9pQZK27JObxWZ96QPLBp1OBsjWigh1iuV6RglPGMRk0,1499
12
- keras_hub/src/version_utils.py,sha256=PULLssHhM5UCqNcvlGVVyAQo9rDMkhLThc6DLU9Kz2g,222
12
+ keras_hub/src/version_utils.py,sha256=1azFcDsz3l9ou6a0Z5UvGUYOSfX9oH5PRsuYxd9JBI8,222
13
13
  keras_hub/src/bounding_box/__init__.py,sha256=7i6KnGupN4AVivR_dFjQyuuTbI0GkHy8d-aMXeqZdU8,95
14
14
  keras_hub/src/bounding_box/converters.py,sha256=UUp1hwegpDZyIo8sh9TLNy1v6JjwmvwzL6wmHFMAtbk,21916
15
15
  keras_hub/src/bounding_box/formats.py,sha256=YmskOz2BOSat7NaE__J9VfpSNGPJJR0znSzA4lp8MMI,3868
@@ -314,11 +314,11 @@ keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py,sha256=
314
314
  keras_hub/src/models/segformer/segformer_presets.py,sha256=ET39ospixkTaCsjoMLdJrr3wlGvTAQu5prleVC5lMZI,4793
315
315
  keras_hub/src/models/stable_diffusion_3/__init__.py,sha256=ZKYQuaRObyhKq8GVAHmoRvlXp6FpU8ChvutVCHyXKuc,343
316
316
  keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py,sha256=vtVhieAv277mAiZj7Kvvqg_Ba7klfQxZVk4PPxNNQ0s,3062
317
- keras_hub/src/models/stable_diffusion_3/mmdit.py,sha256=poJlz-xt06hgOtn_Bw5YQDxZtDBc9L4Vo0ahhGwPly4,33340
318
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py,sha256=u0Wwtbl5b-1z_vn07TRw4jpkVYrReZeHbWqQIrZjyCA,23368
317
+ keras_hub/src/models/stable_diffusion_3/mmdit.py,sha256=0gq2tcIqcbiGKKDDj3vrRsF67U3qE9g706XPs2BfCOY,40979
318
+ keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py,sha256=w8lsMampk34M9xQi96mEnXmkaKQqFQtoFTW8zP7ilEA,24078
319
319
  keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py,sha256=oQcVCWOwrdUTrr_JNekoMqdSlKYMGz5tG6v8uD25lTc,5479
320
320
  keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py,sha256=aZMIC-GYjLhdU_yM7fJEznApCo1zwRAgwQbW0tCW0xY,6399
321
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py,sha256=z6wrfv8rCqLBzn7_edRcKCIDQRTNUgLqyr-LLp55-IE,1680
321
+ keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py,sha256=x7Ez4L955MJE4ABtBy-63YpU9XpR0Ro8QWPzYYJs1yE,2167
322
322
  keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py,sha256=Yt-UIatVKANjjKFCFEj1rIHhOrt8hqefKKQJIAWcTLc,4567
323
323
  keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py,sha256=m5PdVSgTcYuqd7jOQ8wD4PAnMa7wY2WdhwpK3hdydhM,2756
324
324
  keras_hub/src/models/stable_diffusion_3/t5_encoder.py,sha256=oV7P1uwCKdGiD93zXq7kmqX0elMZQU4UvBa8wg6P1hs,5113
@@ -417,7 +417,7 @@ keras_hub/src/utils/transformers/convert_pali_gemma.py,sha256=B1leeDw96Yvu81hYum
417
417
  keras_hub/src/utils/transformers/convert_vit.py,sha256=9SUZ9utNJhW_5cj3acMn9cRy47u2eIcDsrhmzj77o9k,5187
418
418
  keras_hub/src/utils/transformers/preset_loader.py,sha256=DgGJXbTSB9Na8FIR-YWWVqQPOFxHwWrGm41EwcS_EFs,3797
419
419
  keras_hub/src/utils/transformers/safetensor_utils.py,sha256=CYUHyA4y-B61r7NDnCsFb4t_UmSwZ1k9L-8gzEd6KRg,3339
420
- keras_hub_nightly-0.19.0.dev202501080345.dist-info/METADATA,sha256=WWsYYpkd-P_ryoA3jId3bNDKaMQOJfy7eBeYQ7N_D6w,7260
421
- keras_hub_nightly-0.19.0.dev202501080345.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
422
- keras_hub_nightly-0.19.0.dev202501080345.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
423
- keras_hub_nightly-0.19.0.dev202501080345.dist-info/RECORD,,
420
+ keras_hub_nightly-0.19.0.dev202501090358.dist-info/METADATA,sha256=ywWExWZy14kzevtOFQZcdFDiqRJ2I72oWaeiFbjpZZE,7498
421
+ keras_hub_nightly-0.19.0.dev202501090358.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
422
+ keras_hub_nightly-0.19.0.dev202501090358.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
423
+ keras_hub_nightly-0.19.0.dev202501090358.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.7.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5