keras-hub-nightly 0.19.0.dev202501080345__py3-none-any.whl → 0.19.0.dev202501090358__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/src/models/stable_diffusion_3/mmdit.py +254 -58
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +15 -2
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +12 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.19.0.dev202501080345.dist-info → keras_hub_nightly-0.19.0.dev202501090358.dist-info}/METADATA +13 -2
- {keras_hub_nightly-0.19.0.dev202501080345.dist-info → keras_hub_nightly-0.19.0.dev202501090358.dist-info}/RECORD +8 -8
- {keras_hub_nightly-0.19.0.dev202501080345.dist-info → keras_hub_nightly-0.19.0.dev202501090358.dist-info}/WHEEL +1 -1
- {keras_hub_nightly-0.19.0.dev202501080345.dist-info → keras_hub_nightly-0.19.0.dev202501090358.dist-info}/top_level.txt +0 -0
@@ -15,9 +15,8 @@ class AdaptiveLayerNormalization(layers.Layer):
|
|
15
15
|
|
16
16
|
Args:
|
17
17
|
embedding_dim: int. The size of each embedding vector.
|
18
|
-
|
19
|
-
|
20
|
-
transformers. Defaults to `False`.
|
18
|
+
num_modulations: int. The number of the modulation parameters. The
|
19
|
+
available values are `2`, `6` and `9`. Defaults to `2`.
|
21
20
|
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
22
21
|
including `name`, `dtype` etc.
|
23
22
|
|
@@ -28,11 +27,17 @@ class AdaptiveLayerNormalization(layers.Layer):
|
|
28
27
|
https://arxiv.org/abs/2212.09748).
|
29
28
|
"""
|
30
29
|
|
31
|
-
def __init__(self, hidden_dim,
|
30
|
+
def __init__(self, hidden_dim, num_modulations=2, **kwargs):
|
32
31
|
super().__init__(**kwargs)
|
33
|
-
|
34
|
-
|
35
|
-
num_modulations
|
32
|
+
hidden_dim = int(hidden_dim)
|
33
|
+
num_modulations = int(num_modulations)
|
34
|
+
if num_modulations not in (2, 6, 9):
|
35
|
+
raise ValueError(
|
36
|
+
"`num_modulations` must be `2`, `6` or `9`. "
|
37
|
+
f"Received: num_modulations={num_modulations}"
|
38
|
+
)
|
39
|
+
self.hidden_dim = hidden_dim
|
40
|
+
self.num_modulations = num_modulations
|
36
41
|
|
37
42
|
self.silu = layers.Activation("silu", dtype=self.dtype_policy)
|
38
43
|
self.dense = layers.Dense(
|
@@ -52,40 +57,84 @@ class AdaptiveLayerNormalization(layers.Layer):
|
|
52
57
|
self.norm.build(inputs_shape)
|
53
58
|
|
54
59
|
def call(self, inputs, embeddings, training=None):
|
55
|
-
|
60
|
+
hidden_states = inputs
|
56
61
|
emb = self.dense(self.silu(embeddings), training=training)
|
57
|
-
if self.
|
58
|
-
|
59
|
-
|
60
|
-
|
62
|
+
if self.num_modulations == 9:
|
63
|
+
(
|
64
|
+
shift_msa,
|
65
|
+
scale_msa,
|
66
|
+
gate_msa,
|
67
|
+
shift_mlp,
|
68
|
+
scale_mlp,
|
69
|
+
gate_mlp,
|
70
|
+
shift_msa2,
|
71
|
+
scale_msa2,
|
72
|
+
gate_msa2,
|
73
|
+
) = ops.split(emb, self.num_modulations, axis=1)
|
74
|
+
elif self.num_modulations == 6:
|
75
|
+
(
|
76
|
+
shift_msa,
|
77
|
+
scale_msa,
|
78
|
+
gate_msa,
|
79
|
+
shift_mlp,
|
80
|
+
scale_mlp,
|
81
|
+
gate_mlp,
|
82
|
+
) = ops.split(emb, self.num_modulations, axis=1)
|
61
83
|
else:
|
62
|
-
shift_msa, scale_msa = ops.split(emb,
|
84
|
+
shift_msa, scale_msa = ops.split(emb, self.num_modulations, axis=1)
|
85
|
+
|
63
86
|
scale_msa = ops.expand_dims(scale_msa, axis=1)
|
64
87
|
shift_msa = ops.expand_dims(shift_msa, axis=1)
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
),
|
70
|
-
shift_msa,
|
88
|
+
norm_hidden_states = ops.cast(
|
89
|
+
self.norm(hidden_states, training=training), scale_msa.dtype
|
90
|
+
)
|
91
|
+
hidden_states = ops.add(
|
92
|
+
ops.multiply(norm_hidden_states, ops.add(1.0, scale_msa)), shift_msa
|
71
93
|
)
|
72
|
-
|
73
|
-
|
94
|
+
|
95
|
+
if self.num_modulations == 9:
|
96
|
+
scale_msa2 = ops.expand_dims(scale_msa2, axis=1)
|
97
|
+
shift_msa2 = ops.expand_dims(shift_msa2, axis=1)
|
98
|
+
hidden_states2 = ops.add(
|
99
|
+
ops.multiply(norm_hidden_states, ops.add(1.0, scale_msa2)),
|
100
|
+
shift_msa2,
|
101
|
+
)
|
102
|
+
return (
|
103
|
+
hidden_states,
|
104
|
+
gate_msa,
|
105
|
+
shift_mlp,
|
106
|
+
scale_mlp,
|
107
|
+
gate_mlp,
|
108
|
+
hidden_states2,
|
109
|
+
gate_msa2,
|
110
|
+
)
|
111
|
+
elif self.num_modulations == 6:
|
112
|
+
return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
74
113
|
else:
|
75
|
-
return
|
114
|
+
return hidden_states
|
76
115
|
|
77
116
|
def get_config(self):
|
78
117
|
config = super().get_config()
|
79
118
|
config.update(
|
80
119
|
{
|
81
120
|
"hidden_dim": self.hidden_dim,
|
82
|
-
"
|
121
|
+
"num_modulations": self.num_modulations,
|
83
122
|
}
|
84
123
|
)
|
85
124
|
return config
|
86
125
|
|
87
126
|
def compute_output_shape(self, inputs_shape, embeddings_shape):
|
88
|
-
if self.
|
127
|
+
if self.num_modulations == 9:
|
128
|
+
return (
|
129
|
+
inputs_shape,
|
130
|
+
embeddings_shape,
|
131
|
+
embeddings_shape,
|
132
|
+
embeddings_shape,
|
133
|
+
embeddings_shape,
|
134
|
+
inputs_shape,
|
135
|
+
embeddings_shape,
|
136
|
+
)
|
137
|
+
elif self.num_modulations == 6:
|
89
138
|
return (
|
90
139
|
inputs_shape,
|
91
140
|
embeddings_shape,
|
@@ -345,6 +394,27 @@ class TimestepEmbedding(layers.Layer):
|
|
345
394
|
return output_shape
|
346
395
|
|
347
396
|
|
397
|
+
def get_qk_norm(qk_norm=None, q_norm_name="q_norm", k_norm_name="k_norm"):
|
398
|
+
"""Helper function to instantiate `LayerNormalization` layers."""
|
399
|
+
q_norm = None
|
400
|
+
k_norm = None
|
401
|
+
if qk_norm is None:
|
402
|
+
pass
|
403
|
+
elif qk_norm == "rms_norm":
|
404
|
+
q_norm = layers.LayerNormalization(
|
405
|
+
epsilon=1e-6, rms_scaling=True, dtype="float32", name=q_norm_name
|
406
|
+
)
|
407
|
+
k_norm = layers.LayerNormalization(
|
408
|
+
epsilon=1e-6, rms_scaling=True, dtype="float32", name=k_norm_name
|
409
|
+
)
|
410
|
+
else:
|
411
|
+
raise NotImplementedError(
|
412
|
+
"Supported `qk_norm` are `'rms_norm'` and `None`. "
|
413
|
+
f"Received: qk_norm={qk_norm}."
|
414
|
+
)
|
415
|
+
return q_norm, k_norm
|
416
|
+
|
417
|
+
|
348
418
|
class DismantledBlock(layers.Layer):
|
349
419
|
"""A dismantled block used to compute pre- and post-attention.
|
350
420
|
|
@@ -356,6 +426,8 @@ class DismantledBlock(layers.Layer):
|
|
356
426
|
the end of the block.
|
357
427
|
qk_norm: Optional str. Whether to normalize the query and key tensors.
|
358
428
|
Available options are `None` and `"rms_norm"`. Defaults to `None`.
|
429
|
+
use_dual_attention: bool. Whether to use a dual attention in the
|
430
|
+
block. Defaults to `False`.
|
359
431
|
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
360
432
|
including `name`, `dtype` etc.
|
361
433
|
"""
|
@@ -367,6 +439,7 @@ class DismantledBlock(layers.Layer):
|
|
367
439
|
mlp_ratio=4.0,
|
368
440
|
use_projection=True,
|
369
441
|
qk_norm=None,
|
442
|
+
use_dual_attention=False,
|
370
443
|
**kwargs,
|
371
444
|
):
|
372
445
|
super().__init__(**kwargs)
|
@@ -375,6 +448,7 @@ class DismantledBlock(layers.Layer):
|
|
375
448
|
self.mlp_ratio = mlp_ratio
|
376
449
|
self.use_projection = use_projection
|
377
450
|
self.qk_norm = qk_norm
|
451
|
+
self.use_dual_attention = use_dual_attention
|
378
452
|
|
379
453
|
head_dim = hidden_dim // num_heads
|
380
454
|
self.head_dim = head_dim
|
@@ -384,7 +458,7 @@ class DismantledBlock(layers.Layer):
|
|
384
458
|
if use_projection:
|
385
459
|
self.ada_layer_norm = AdaptiveLayerNormalization(
|
386
460
|
hidden_dim,
|
387
|
-
|
461
|
+
num_modulations=9 if use_dual_attention else 6,
|
388
462
|
dtype=self.dtype_policy,
|
389
463
|
name="ada_layer_norm",
|
390
464
|
)
|
@@ -395,18 +469,10 @@ class DismantledBlock(layers.Layer):
|
|
395
469
|
self.attention_qkv = layers.Dense(
|
396
470
|
hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv"
|
397
471
|
)
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
self.k_norm = layers.LayerNormalization(
|
403
|
-
epsilon=1e-6, rms_scaling=True, dtype="float32", name="q_norm"
|
404
|
-
)
|
405
|
-
elif qk_norm is not None:
|
406
|
-
raise NotImplementedError(
|
407
|
-
"Supported `qk_norm` are `'rms_norm'` and `None`. "
|
408
|
-
f"Received: qk_norm={qk_norm}."
|
409
|
-
)
|
472
|
+
q_norm, k_norm = get_qk_norm(qk_norm)
|
473
|
+
if q_norm is not None:
|
474
|
+
self.q_norm = q_norm
|
475
|
+
self.k_norm = k_norm
|
410
476
|
if use_projection:
|
411
477
|
self.attention_proj = layers.Dense(
|
412
478
|
hidden_dim, dtype=self.dtype_policy, name="attention_proj"
|
@@ -426,6 +492,19 @@ class DismantledBlock(layers.Layer):
|
|
426
492
|
name="mlp",
|
427
493
|
)
|
428
494
|
|
495
|
+
if use_dual_attention:
|
496
|
+
self.attention_qkv2 = layers.Dense(
|
497
|
+
hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv2"
|
498
|
+
)
|
499
|
+
q_norm2, k_norm2 = get_qk_norm(qk_norm, "q_norm2", "k_norm2")
|
500
|
+
if q_norm is not None:
|
501
|
+
self.q_norm2 = q_norm2
|
502
|
+
self.k_norm2 = k_norm2
|
503
|
+
if use_projection:
|
504
|
+
self.attention_proj2 = layers.Dense(
|
505
|
+
hidden_dim, dtype=self.dtype_policy, name="attention_proj2"
|
506
|
+
)
|
507
|
+
|
429
508
|
def build(self, inputs_shape, timestep_embedding):
|
430
509
|
self.ada_layer_norm.build(inputs_shape, timestep_embedding)
|
431
510
|
self.attention_qkv.build(inputs_shape)
|
@@ -437,6 +516,13 @@ class DismantledBlock(layers.Layer):
|
|
437
516
|
self.attention_proj.build(inputs_shape)
|
438
517
|
self.norm2.build(inputs_shape)
|
439
518
|
self.mlp.build(inputs_shape)
|
519
|
+
if self.use_dual_attention:
|
520
|
+
self.attention_qkv2.build(inputs_shape)
|
521
|
+
if self.qk_norm is not None:
|
522
|
+
self.q_norm2.build([None, None, self.num_heads, self.head_dim])
|
523
|
+
self.k_norm2.build([None, None, self.num_heads, self.head_dim])
|
524
|
+
if self.use_projection:
|
525
|
+
self.attention_proj2.build(inputs_shape)
|
440
526
|
|
441
527
|
def _modulate(self, inputs, shift, scale):
|
442
528
|
inputs = ops.cast(inputs, self.compute_dtype)
|
@@ -456,8 +542,12 @@ class DismantledBlock(layers.Layer):
|
|
456
542
|
)
|
457
543
|
q, k, v = ops.unstack(qkv, 3, axis=2)
|
458
544
|
if self.qk_norm is not None:
|
459
|
-
q =
|
460
|
-
|
545
|
+
q = ops.cast(
|
546
|
+
self.q_norm(q, training=training), self.compute_dtype
|
547
|
+
)
|
548
|
+
k = ops.cast(
|
549
|
+
self.k_norm(k, training=training), self.compute_dtype
|
550
|
+
)
|
461
551
|
return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp)
|
462
552
|
else:
|
463
553
|
x = self.ada_layer_norm(
|
@@ -469,8 +559,12 @@ class DismantledBlock(layers.Layer):
|
|
469
559
|
)
|
470
560
|
q, k, v = ops.unstack(qkv, 3, axis=2)
|
471
561
|
if self.qk_norm is not None:
|
472
|
-
q =
|
473
|
-
|
562
|
+
q = ops.cast(
|
563
|
+
self.q_norm(q, training=training), self.compute_dtype
|
564
|
+
)
|
565
|
+
k = ops.cast(
|
566
|
+
self.k_norm(k, training=training), self.compute_dtype
|
567
|
+
)
|
474
568
|
return (q, k, v)
|
475
569
|
|
476
570
|
def _compute_post_attention(
|
@@ -495,22 +589,95 @@ class DismantledBlock(layers.Layer):
|
|
495
589
|
)
|
496
590
|
return x
|
497
591
|
|
592
|
+
def _compute_pre_attention_with_dual_attention(
|
593
|
+
self, inputs, timestep_embedding, training=None
|
594
|
+
):
|
595
|
+
batch_size = ops.shape(inputs)[0]
|
596
|
+
x, gate_msa, shift_mlp, scale_mlp, gate_mlp, x2, gate_msa2 = (
|
597
|
+
self.ada_layer_norm(inputs, timestep_embedding, training=training)
|
598
|
+
)
|
599
|
+
# Compute the main attention
|
600
|
+
qkv = self.attention_qkv(x, training=training)
|
601
|
+
qkv = ops.reshape(
|
602
|
+
qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
|
603
|
+
)
|
604
|
+
q, k, v = ops.unstack(qkv, 3, axis=2)
|
605
|
+
if self.qk_norm is not None:
|
606
|
+
q = ops.cast(self.q_norm(q, training=training), self.compute_dtype)
|
607
|
+
k = ops.cast(self.k_norm(k, training=training), self.compute_dtype)
|
608
|
+
# Compute the dual attention
|
609
|
+
qkv2 = self.attention_qkv2(x2, training=training)
|
610
|
+
qkv2 = ops.reshape(
|
611
|
+
qkv2, (batch_size, -1, 3, self.num_heads, self.head_dim)
|
612
|
+
)
|
613
|
+
q2, k2, v2 = ops.unstack(qkv2, 3, axis=2)
|
614
|
+
if self.qk_norm is not None:
|
615
|
+
q2 = ops.cast(
|
616
|
+
self.q_norm2(q2, training=training), self.compute_dtype
|
617
|
+
)
|
618
|
+
k2 = ops.cast(
|
619
|
+
self.k_norm2(k2, training=training), self.compute_dtype
|
620
|
+
)
|
621
|
+
return (
|
622
|
+
(q, k, v),
|
623
|
+
(q2, k2, v2),
|
624
|
+
(inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2),
|
625
|
+
)
|
626
|
+
|
627
|
+
def _compute_post_attention_with_dual_attention(
|
628
|
+
self, inputs, inputs2, inputs_intermediates, training=None
|
629
|
+
):
|
630
|
+
x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2 = (
|
631
|
+
inputs_intermediates
|
632
|
+
)
|
633
|
+
gate_msa = ops.expand_dims(gate_msa, axis=1)
|
634
|
+
shift_mlp = ops.expand_dims(shift_mlp, axis=1)
|
635
|
+
scale_mlp = ops.expand_dims(scale_mlp, axis=1)
|
636
|
+
gate_mlp = ops.expand_dims(gate_mlp, axis=1)
|
637
|
+
gate_msa2 = ops.expand_dims(gate_msa2, axis=1)
|
638
|
+
attn = self.attention_proj(inputs, training=training)
|
639
|
+
x = ops.add(x, ops.multiply(gate_msa, attn))
|
640
|
+
attn2 = self.attention_proj2(inputs2, training=training)
|
641
|
+
x = ops.add(x, ops.multiply(gate_msa2, attn2))
|
642
|
+
x = ops.add(
|
643
|
+
x,
|
644
|
+
ops.multiply(
|
645
|
+
gate_mlp,
|
646
|
+
self.mlp(
|
647
|
+
self._modulate(self.norm2(x), shift_mlp, scale_mlp),
|
648
|
+
training=training,
|
649
|
+
),
|
650
|
+
),
|
651
|
+
)
|
652
|
+
return x
|
653
|
+
|
498
654
|
def call(
|
499
655
|
self,
|
500
656
|
inputs,
|
501
657
|
timestep_embedding=None,
|
502
658
|
inputs_intermediates=None,
|
659
|
+
inputs2=None, # For the dual attention.
|
503
660
|
pre_attention=True,
|
504
661
|
training=None,
|
505
662
|
):
|
506
663
|
if pre_attention:
|
507
|
-
|
508
|
-
|
509
|
-
|
664
|
+
if self.use_dual_attention:
|
665
|
+
return self._compute_pre_attention_with_dual_attention(
|
666
|
+
inputs, timestep_embedding, training=training
|
667
|
+
)
|
668
|
+
else:
|
669
|
+
return self._compute_pre_attention(
|
670
|
+
inputs, timestep_embedding, training=training
|
671
|
+
)
|
510
672
|
else:
|
511
|
-
|
512
|
-
|
513
|
-
|
673
|
+
if self.use_dual_attention:
|
674
|
+
return self._compute_post_attention_with_dual_attention(
|
675
|
+
inputs, inputs2, inputs_intermediates, training=training
|
676
|
+
)
|
677
|
+
else:
|
678
|
+
return self._compute_post_attention(
|
679
|
+
inputs, inputs_intermediates, training=training
|
680
|
+
)
|
514
681
|
|
515
682
|
def get_config(self):
|
516
683
|
config = super().get_config()
|
@@ -521,6 +688,7 @@ class DismantledBlock(layers.Layer):
|
|
521
688
|
"mlp_ratio": self.mlp_ratio,
|
522
689
|
"use_projection": self.use_projection,
|
523
690
|
"qk_norm": self.qk_norm,
|
691
|
+
"use_dual_attention": self.use_dual_attention,
|
524
692
|
}
|
525
693
|
)
|
526
694
|
return config
|
@@ -542,6 +710,8 @@ class MMDiTBlock(layers.Layer):
|
|
542
710
|
layer at the end of the context block.
|
543
711
|
qk_norm: Optional str. Whether to normalize the query and key tensors.
|
544
712
|
Available options are `None` and `"rms_norm"`. Defaults to `None`.
|
713
|
+
use_dual_attention: bool. Whether to use a dual attention in the
|
714
|
+
block. Defaults to `False`.
|
545
715
|
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
546
716
|
including `name`, `dtype` etc.
|
547
717
|
|
@@ -557,6 +727,7 @@ class MMDiTBlock(layers.Layer):
|
|
557
727
|
mlp_ratio=4.0,
|
558
728
|
use_context_projection=True,
|
559
729
|
qk_norm=None,
|
730
|
+
use_dual_attention=False,
|
560
731
|
**kwargs,
|
561
732
|
):
|
562
733
|
super().__init__(**kwargs)
|
@@ -565,6 +736,7 @@ class MMDiTBlock(layers.Layer):
|
|
565
736
|
self.mlp_ratio = mlp_ratio
|
566
737
|
self.use_context_projection = use_context_projection
|
567
738
|
self.qk_norm = qk_norm
|
739
|
+
self.use_dual_attention = use_dual_attention
|
568
740
|
|
569
741
|
head_dim = hidden_dim // num_heads
|
570
742
|
self.head_dim = head_dim
|
@@ -576,6 +748,7 @@ class MMDiTBlock(layers.Layer):
|
|
576
748
|
mlp_ratio=mlp_ratio,
|
577
749
|
use_projection=True,
|
578
750
|
qk_norm=qk_norm,
|
751
|
+
use_dual_attention=use_dual_attention,
|
579
752
|
dtype=self.dtype_policy,
|
580
753
|
name="x_block",
|
581
754
|
)
|
@@ -602,8 +775,6 @@ class MMDiTBlock(layers.Layer):
|
|
602
775
|
if hasattr(ops, "dot_product_attention") and hasattr(
|
603
776
|
keras.config, "is_flash_attention_enabled"
|
604
777
|
):
|
605
|
-
# `ops.dot_product_attention` is slower than the vanilla
|
606
|
-
# implementation in the tensorflow backend.
|
607
778
|
encoded = ops.dot_product_attention(
|
608
779
|
query,
|
609
780
|
key,
|
@@ -643,9 +814,14 @@ class MMDiTBlock(layers.Layer):
|
|
643
814
|
training=training,
|
644
815
|
)
|
645
816
|
context_len = ops.shape(context_qkv[0])[1]
|
646
|
-
|
647
|
-
|
648
|
-
|
817
|
+
if self.x_block.use_dual_attention:
|
818
|
+
x_qkv, x_qkv2, x_intermediates = self.x_block(
|
819
|
+
x, timestep_embedding=timestep_embedding, training=training
|
820
|
+
)
|
821
|
+
else:
|
822
|
+
x_qkv, x_intermediates = self.x_block(
|
823
|
+
x, timestep_embedding=timestep_embedding, training=training
|
824
|
+
)
|
649
825
|
q = ops.concatenate([context_qkv[0], x_qkv[0]], axis=1)
|
650
826
|
k = ops.concatenate([context_qkv[1], x_qkv[1]], axis=1)
|
651
827
|
v = ops.concatenate([context_qkv[2], x_qkv[2]], axis=1)
|
@@ -656,12 +832,23 @@ class MMDiTBlock(layers.Layer):
|
|
656
832
|
x_attention = attention[:, context_len:]
|
657
833
|
|
658
834
|
# Compute post-attention.
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
835
|
+
if self.x_block.use_dual_attention:
|
836
|
+
q2, k2, v2 = x_qkv2
|
837
|
+
x_attention2 = self._compute_attention(q2, k2, v2)
|
838
|
+
x = self.x_block(
|
839
|
+
x_attention,
|
840
|
+
inputs_intermediates=x_intermediates,
|
841
|
+
inputs2=x_attention2,
|
842
|
+
pre_attention=False,
|
843
|
+
training=training,
|
844
|
+
)
|
845
|
+
else:
|
846
|
+
x = self.x_block(
|
847
|
+
x_attention,
|
848
|
+
inputs_intermediates=x_intermediates,
|
849
|
+
pre_attention=False,
|
850
|
+
training=training,
|
851
|
+
)
|
665
852
|
if self.use_context_projection:
|
666
853
|
context = self.context_block(
|
667
854
|
context_attention,
|
@@ -682,6 +869,7 @@ class MMDiTBlock(layers.Layer):
|
|
682
869
|
"mlp_ratio": self.mlp_ratio,
|
683
870
|
"use_context_projection": self.use_context_projection,
|
684
871
|
"qk_norm": self.qk_norm,
|
872
|
+
"use_dual_attention": self.use_dual_attention,
|
685
873
|
}
|
686
874
|
)
|
687
875
|
return config
|
@@ -761,6 +949,9 @@ class MMDiT(Backbone):
|
|
761
949
|
qk_norm: Optional str. Whether to normalize the query and key tensors in
|
762
950
|
the intermediate blocks. Available options are `None` and
|
763
951
|
`"rms_norm"`. Defaults to `None`.
|
952
|
+
dual_attention_indices: Optional tuple. Specifies the indices of
|
953
|
+
the blocks that serve as dual attention blocks. Typically, this is
|
954
|
+
for 3.5 version. Defaults to `None`.
|
764
955
|
data_format: `None` or str. If specified, either `"channels_last"` or
|
765
956
|
`"channels_first"`. The ordering of the dimensions in the
|
766
957
|
inputs. `"channels_last"` corresponds to inputs with shape
|
@@ -786,6 +977,7 @@ class MMDiT(Backbone):
|
|
786
977
|
context_shape=(None, 4096),
|
787
978
|
pooled_projection_shape=(2048,),
|
788
979
|
qk_norm=None,
|
980
|
+
dual_attention_indices=None,
|
789
981
|
data_format=None,
|
790
982
|
dtype=None,
|
791
983
|
**kwargs,
|
@@ -799,6 +991,7 @@ class MMDiT(Backbone):
|
|
799
991
|
image_width = latent_shape[1] // patch_size
|
800
992
|
output_dim = latent_shape[-1]
|
801
993
|
output_dim_in_final = patch_size**2 * output_dim
|
994
|
+
dual_attention_indices = dual_attention_indices or ()
|
802
995
|
data_format = standardize_data_format(data_format)
|
803
996
|
if data_format != "channels_last":
|
804
997
|
raise NotImplementedError(
|
@@ -840,6 +1033,7 @@ class MMDiT(Backbone):
|
|
840
1033
|
mlp_ratio,
|
841
1034
|
use_context_projection=not (i == num_layers - 1),
|
842
1035
|
qk_norm=qk_norm,
|
1036
|
+
use_dual_attention=i in dual_attention_indices,
|
843
1037
|
dtype=dtype,
|
844
1038
|
name=f"joint_block_{i}",
|
845
1039
|
)
|
@@ -910,6 +1104,7 @@ class MMDiT(Backbone):
|
|
910
1104
|
self.context_shape = context_shape
|
911
1105
|
self.pooled_projection_shape = pooled_projection_shape
|
912
1106
|
self.qk_norm = qk_norm
|
1107
|
+
self.dual_attention_indices = dual_attention_indices
|
913
1108
|
|
914
1109
|
def get_config(self):
|
915
1110
|
config = super().get_config()
|
@@ -925,6 +1120,7 @@ class MMDiT(Backbone):
|
|
925
1120
|
"context_shape": self.context_shape,
|
926
1121
|
"pooled_projection_shape": self.pooled_projection_shape,
|
927
1122
|
"qk_norm": self.qk_norm,
|
1123
|
+
"dual_attention_indices": self.dual_attention_indices,
|
928
1124
|
}
|
929
1125
|
)
|
930
1126
|
return config
|
@@ -205,7 +205,10 @@ class StableDiffusion3Backbone(Backbone):
|
|
205
205
|
mmdit_qk_norm: Optional str. Whether to normalize the query and key
|
206
206
|
tensors for each transformer in MMDiT. Available options are `None`
|
207
207
|
and `"rms_norm"`. Typically, this is set to `None` for 3.0 version
|
208
|
-
and to `"rms_norm" for 3.5 version.
|
208
|
+
and to `"rms_norm"` for 3.5 version.
|
209
|
+
mmdit_dual_attention_indices: Optional tuple. Specifies the indices of
|
210
|
+
the blocks that serve as dual attention blocks. Typically, this is
|
211
|
+
for 3.5 version. Defaults to `None`.
|
209
212
|
vae: The VAE used for transformations between pixel space and latent
|
210
213
|
space.
|
211
214
|
clip_l: The CLIP text encoder for encoding the inputs.
|
@@ -253,6 +256,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
253
256
|
mmdit_depth=4,
|
254
257
|
mmdit_position_size=192,
|
255
258
|
mmdit_qk_norm=None,
|
259
|
+
mmdit_dual_attention_indices=None,
|
256
260
|
vae=vae,
|
257
261
|
clip_l=clip_l,
|
258
262
|
clip_g=clip_g,
|
@@ -268,6 +272,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
268
272
|
mmdit_num_heads,
|
269
273
|
mmdit_position_size,
|
270
274
|
mmdit_qk_norm,
|
275
|
+
mmdit_dual_attention_indices,
|
271
276
|
vae,
|
272
277
|
clip_l,
|
273
278
|
clip_g,
|
@@ -319,6 +324,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
319
324
|
context_shape=context_shape,
|
320
325
|
pooled_projection_shape=pooled_projection_shape,
|
321
326
|
qk_norm=mmdit_qk_norm,
|
327
|
+
dual_attention_indices=mmdit_dual_attention_indices,
|
322
328
|
data_format=data_format,
|
323
329
|
dtype=dtype,
|
324
330
|
name="diffuser",
|
@@ -454,6 +460,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
454
460
|
self.mmdit_num_heads = mmdit_num_heads
|
455
461
|
self.mmdit_position_size = mmdit_position_size
|
456
462
|
self.mmdit_qk_norm = mmdit_qk_norm
|
463
|
+
self.mmdit_dual_attention_indices = mmdit_dual_attention_indices
|
457
464
|
self.latent_channels = latent_channels
|
458
465
|
self.output_channels = output_channels
|
459
466
|
self.num_train_timesteps = num_train_timesteps
|
@@ -590,6 +597,9 @@ class StableDiffusion3Backbone(Backbone):
|
|
590
597
|
"mmdit_num_heads": self.mmdit_num_heads,
|
591
598
|
"mmdit_position_size": self.mmdit_position_size,
|
592
599
|
"mmdit_qk_norm": self.mmdit_qk_norm,
|
600
|
+
"mmdit_dual_attention_indices": (
|
601
|
+
self.mmdit_dual_attention_indices
|
602
|
+
),
|
593
603
|
"vae": layers.serialize(self.vae),
|
594
604
|
"clip_l": layers.serialize(self.clip_l),
|
595
605
|
"clip_g": layers.serialize(self.clip_g),
|
@@ -638,7 +648,10 @@ class StableDiffusion3Backbone(Backbone):
|
|
638
648
|
)
|
639
649
|
|
640
650
|
# To maintain backward compatibility, we need to ensure that
|
641
|
-
# `mmdit_qk_norm` is included in the
|
651
|
+
# `mmdit_qk_norm` and `mmdit_dual_attention_indices` is included in the
|
652
|
+
# config.
|
642
653
|
if "mmdit_qk_norm" not in config:
|
643
654
|
config["mmdit_qk_norm"] = None
|
655
|
+
if "mmdit_dual_attention_indices" not in config:
|
656
|
+
config["mmdit_dual_attention_indices"] = None
|
644
657
|
return cls(**config)
|
@@ -13,6 +13,18 @@ backbone_presets = {
|
|
13
13
|
},
|
14
14
|
"kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/4",
|
15
15
|
},
|
16
|
+
"stable_diffusion_3.5_medium": {
|
17
|
+
"metadata": {
|
18
|
+
"description": (
|
19
|
+
"3 billion parameter, including CLIP L and CLIP G text "
|
20
|
+
"encoders, MMDiT-X generative model, and VAE autoencoder. "
|
21
|
+
"Developed by Stability AI."
|
22
|
+
),
|
23
|
+
"params": 3371793763,
|
24
|
+
"path": "stable_diffusion_3",
|
25
|
+
},
|
26
|
+
"kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3.5_medium/1",
|
27
|
+
},
|
16
28
|
"stable_diffusion_3.5_large": {
|
17
29
|
"metadata": {
|
18
30
|
"description": (
|
keras_hub/src/version_utils.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: keras-hub-nightly
|
3
|
-
Version: 0.19.0.
|
3
|
+
Version: 0.19.0.dev202501090358
|
4
4
|
Summary: Industry-strength Natural Language Processing extensions for Keras.
|
5
5
|
Home-page: https://github.com/keras-team/keras-hub
|
6
6
|
Author: Keras team
|
@@ -31,6 +31,17 @@ Requires-Dist: tensorflow-text
|
|
31
31
|
Provides-Extra: extras
|
32
32
|
Requires-Dist: rouge-score; extra == "extras"
|
33
33
|
Requires-Dist: sentencepiece; extra == "extras"
|
34
|
+
Dynamic: author
|
35
|
+
Dynamic: author-email
|
36
|
+
Dynamic: classifier
|
37
|
+
Dynamic: description
|
38
|
+
Dynamic: description-content-type
|
39
|
+
Dynamic: home-page
|
40
|
+
Dynamic: license
|
41
|
+
Dynamic: provides-extra
|
42
|
+
Dynamic: requires-dist
|
43
|
+
Dynamic: requires-python
|
44
|
+
Dynamic: summary
|
34
45
|
|
35
46
|
# KerasHub: Multi-framework Pretrained Models
|
36
47
|
[](https://github.com/keras-team/keras-hub/actions?query=workflow%3ATests+branch%3Amaster)
|
@@ -9,7 +9,7 @@ keras_hub/api/tokenizers/__init__.py,sha256=mtJgQy1spfQnPAkeLoeinsT_W9iCWHlJXwzc
|
|
9
9
|
keras_hub/api/utils/__init__.py,sha256=Gp1E6gG-RtKQS3PBEQEOz9PQvXkXaJ0ySGMqZ7myN7A,215
|
10
10
|
keras_hub/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
11
|
keras_hub/src/api_export.py,sha256=9pQZK27JObxWZ96QPLBp1OBsjWigh1iuV6RglPGMRk0,1499
|
12
|
-
keras_hub/src/version_utils.py,sha256=
|
12
|
+
keras_hub/src/version_utils.py,sha256=1azFcDsz3l9ou6a0Z5UvGUYOSfX9oH5PRsuYxd9JBI8,222
|
13
13
|
keras_hub/src/bounding_box/__init__.py,sha256=7i6KnGupN4AVivR_dFjQyuuTbI0GkHy8d-aMXeqZdU8,95
|
14
14
|
keras_hub/src/bounding_box/converters.py,sha256=UUp1hwegpDZyIo8sh9TLNy1v6JjwmvwzL6wmHFMAtbk,21916
|
15
15
|
keras_hub/src/bounding_box/formats.py,sha256=YmskOz2BOSat7NaE__J9VfpSNGPJJR0znSzA4lp8MMI,3868
|
@@ -314,11 +314,11 @@ keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py,sha256=
|
|
314
314
|
keras_hub/src/models/segformer/segformer_presets.py,sha256=ET39ospixkTaCsjoMLdJrr3wlGvTAQu5prleVC5lMZI,4793
|
315
315
|
keras_hub/src/models/stable_diffusion_3/__init__.py,sha256=ZKYQuaRObyhKq8GVAHmoRvlXp6FpU8ChvutVCHyXKuc,343
|
316
316
|
keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py,sha256=vtVhieAv277mAiZj7Kvvqg_Ba7klfQxZVk4PPxNNQ0s,3062
|
317
|
-
keras_hub/src/models/stable_diffusion_3/mmdit.py,sha256=
|
318
|
-
keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py,sha256=
|
317
|
+
keras_hub/src/models/stable_diffusion_3/mmdit.py,sha256=0gq2tcIqcbiGKKDDj3vrRsF67U3qE9g706XPs2BfCOY,40979
|
318
|
+
keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py,sha256=w8lsMampk34M9xQi96mEnXmkaKQqFQtoFTW8zP7ilEA,24078
|
319
319
|
keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py,sha256=oQcVCWOwrdUTrr_JNekoMqdSlKYMGz5tG6v8uD25lTc,5479
|
320
320
|
keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py,sha256=aZMIC-GYjLhdU_yM7fJEznApCo1zwRAgwQbW0tCW0xY,6399
|
321
|
-
keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py,sha256=
|
321
|
+
keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py,sha256=x7Ez4L955MJE4ABtBy-63YpU9XpR0Ro8QWPzYYJs1yE,2167
|
322
322
|
keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py,sha256=Yt-UIatVKANjjKFCFEj1rIHhOrt8hqefKKQJIAWcTLc,4567
|
323
323
|
keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py,sha256=m5PdVSgTcYuqd7jOQ8wD4PAnMa7wY2WdhwpK3hdydhM,2756
|
324
324
|
keras_hub/src/models/stable_diffusion_3/t5_encoder.py,sha256=oV7P1uwCKdGiD93zXq7kmqX0elMZQU4UvBa8wg6P1hs,5113
|
@@ -417,7 +417,7 @@ keras_hub/src/utils/transformers/convert_pali_gemma.py,sha256=B1leeDw96Yvu81hYum
|
|
417
417
|
keras_hub/src/utils/transformers/convert_vit.py,sha256=9SUZ9utNJhW_5cj3acMn9cRy47u2eIcDsrhmzj77o9k,5187
|
418
418
|
keras_hub/src/utils/transformers/preset_loader.py,sha256=DgGJXbTSB9Na8FIR-YWWVqQPOFxHwWrGm41EwcS_EFs,3797
|
419
419
|
keras_hub/src/utils/transformers/safetensor_utils.py,sha256=CYUHyA4y-B61r7NDnCsFb4t_UmSwZ1k9L-8gzEd6KRg,3339
|
420
|
-
keras_hub_nightly-0.19.0.
|
421
|
-
keras_hub_nightly-0.19.0.
|
422
|
-
keras_hub_nightly-0.19.0.
|
423
|
-
keras_hub_nightly-0.19.0.
|
420
|
+
keras_hub_nightly-0.19.0.dev202501090358.dist-info/METADATA,sha256=ywWExWZy14kzevtOFQZcdFDiqRJ2I72oWaeiFbjpZZE,7498
|
421
|
+
keras_hub_nightly-0.19.0.dev202501090358.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
422
|
+
keras_hub_nightly-0.19.0.dev202501090358.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
|
423
|
+
keras_hub_nightly-0.19.0.dev202501090358.dist-info/RECORD,,
|