keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -4,13 +4,10 @@ from keras import ops
|
|
4
4
|
|
5
5
|
from keras_hub.src.api_export import keras_hub_export
|
6
6
|
from keras_hub.src.models.backbone import Backbone
|
7
|
-
from keras_hub.src.models.stable_diffusion_3.flow_match_euler_discrete_scheduler import (
|
7
|
+
from keras_hub.src.models.stable_diffusion_3.flow_match_euler_discrete_scheduler import ( # noqa: E501
|
8
8
|
FlowMatchEulerDiscreteScheduler,
|
9
9
|
)
|
10
10
|
from keras_hub.src.models.stable_diffusion_3.mmdit import MMDiT
|
11
|
-
from keras_hub.src.models.stable_diffusion_3.vae_image_decoder import (
|
12
|
-
VAEImageDecoder,
|
13
|
-
)
|
14
11
|
from keras_hub.src.utils.keras_utils import standardize_data_format
|
15
12
|
|
16
13
|
|
@@ -54,11 +51,52 @@ class CLIPProjection(layers.Layer):
|
|
54
51
|
return (inputs_shape[0], self.hidden_dim)
|
55
52
|
|
56
53
|
|
57
|
-
class
|
58
|
-
def
|
59
|
-
|
60
|
-
|
54
|
+
class CLIPConcatenate(layers.Layer):
|
55
|
+
def call(
|
56
|
+
self,
|
57
|
+
clip_l_projection,
|
58
|
+
clip_g_projection,
|
59
|
+
clip_l_intermediate_output,
|
60
|
+
clip_g_intermediate_output,
|
61
|
+
padding,
|
62
|
+
):
|
63
|
+
pooled_embeddings = ops.concatenate(
|
64
|
+
[clip_l_projection, clip_g_projection], axis=-1
|
65
|
+
)
|
66
|
+
embeddings = ops.concatenate(
|
67
|
+
[clip_l_intermediate_output, clip_g_intermediate_output], axis=-1
|
68
|
+
)
|
69
|
+
embeddings = ops.pad(embeddings, [[0, 0], [0, 0], [0, padding]])
|
70
|
+
return pooled_embeddings, embeddings
|
71
|
+
|
72
|
+
|
73
|
+
class ImageRescaling(layers.Rescaling):
|
74
|
+
"""Rescales inputs from image space to latent space.
|
75
|
+
|
76
|
+
The rescaling is performed using the formula: `(inputs - offset) * scale`.
|
77
|
+
"""
|
61
78
|
|
79
|
+
def call(self, inputs):
|
80
|
+
dtype = self.compute_dtype
|
81
|
+
scale = self.backend.cast(self.scale, dtype)
|
82
|
+
offset = self.backend.cast(self.offset, dtype)
|
83
|
+
return (self.backend.cast(inputs, dtype) - offset) * scale
|
84
|
+
|
85
|
+
|
86
|
+
class LatentRescaling(layers.Rescaling):
|
87
|
+
"""Rescales inputs from latent space to image space.
|
88
|
+
|
89
|
+
The rescaling is performed using the formula: `inputs / scale + offset`.
|
90
|
+
"""
|
91
|
+
|
92
|
+
def call(self, inputs):
|
93
|
+
dtype = self.compute_dtype
|
94
|
+
scale = self.backend.cast(self.scale, dtype)
|
95
|
+
offset = self.backend.cast(self.offset, dtype)
|
96
|
+
return (self.backend.cast(inputs, dtype) / scale) + offset
|
97
|
+
|
98
|
+
|
99
|
+
class ClassifierFreeGuidanceConcatenate(layers.Layer):
|
62
100
|
def call(
|
63
101
|
self,
|
64
102
|
latents,
|
@@ -69,20 +107,16 @@ class ClassifierFreeGuidanceConcatenate(layers.Layer):
|
|
69
107
|
timestep,
|
70
108
|
):
|
71
109
|
timestep = ops.broadcast_to(timestep, ops.shape(latents)[:1])
|
72
|
-
latents = ops.concatenate([latents, latents], axis=
|
110
|
+
latents = ops.concatenate([latents, latents], axis=0)
|
73
111
|
contexts = ops.concatenate(
|
74
|
-
[positive_contexts, negative_contexts], axis=
|
112
|
+
[positive_contexts, negative_contexts], axis=0
|
75
113
|
)
|
76
114
|
pooled_projections = ops.concatenate(
|
77
|
-
[positive_pooled_projections, negative_pooled_projections],
|
78
|
-
axis=self.axis,
|
115
|
+
[positive_pooled_projections, negative_pooled_projections], axis=0
|
79
116
|
)
|
80
|
-
timesteps = ops.concatenate([timestep, timestep], axis=
|
117
|
+
timesteps = ops.concatenate([timestep, timestep], axis=0)
|
81
118
|
return latents, contexts, pooled_projections, timesteps
|
82
119
|
|
83
|
-
def get_config(self):
|
84
|
-
return super().get_config()
|
85
|
-
|
86
120
|
|
87
121
|
class ClassifierFreeGuidance(layers.Layer):
|
88
122
|
"""Perform classifier free guidance.
|
@@ -103,9 +137,6 @@ class ClassifierFreeGuidance(layers.Layer):
|
|
103
137
|
- [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
104
138
|
"""
|
105
139
|
|
106
|
-
def __init__(self, **kwargs):
|
107
|
-
super().__init__(**kwargs)
|
108
|
-
|
109
140
|
def call(self, inputs, guidance_scale):
|
110
141
|
positive_noise, negative_noise = ops.split(inputs, 2, axis=0)
|
111
142
|
return ops.add(
|
@@ -115,9 +146,6 @@ class ClassifierFreeGuidance(layers.Layer):
|
|
115
146
|
),
|
116
147
|
)
|
117
148
|
|
118
|
-
def get_config(self):
|
119
|
-
return super().get_config()
|
120
|
-
|
121
149
|
def compute_output_shape(self, inputs_shape):
|
122
150
|
outputs_shape = list(inputs_shape)
|
123
151
|
if outputs_shape[0] is not None:
|
@@ -145,58 +173,10 @@ class EulerStep(layers.Layer):
|
|
145
173
|
https://arxiv.org/abs/2206.00364).
|
146
174
|
"""
|
147
175
|
|
148
|
-
def __init__(self, **kwargs):
|
149
|
-
super().__init__(**kwargs)
|
150
|
-
|
151
176
|
def call(self, latents, noise_residual, sigma, sigma_next):
|
152
177
|
sigma_diff = ops.subtract(sigma_next, sigma)
|
153
178
|
return ops.add(latents, ops.multiply(sigma_diff, noise_residual))
|
154
179
|
|
155
|
-
def get_config(self):
|
156
|
-
return super().get_config()
|
157
|
-
|
158
|
-
def compute_output_shape(self, latents_shape):
|
159
|
-
return latents_shape
|
160
|
-
|
161
|
-
|
162
|
-
class LatentSpaceDecoder(layers.Layer):
|
163
|
-
"""Decoder to transform the latent space back to the original image space.
|
164
|
-
|
165
|
-
During decoding, the latents are transformed back to the original image
|
166
|
-
space using the equation: `latents / scale + shift`.
|
167
|
-
|
168
|
-
Args:
|
169
|
-
scale: float. The scaling factor.
|
170
|
-
shift: float. The shift factor.
|
171
|
-
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
172
|
-
including `name`, `dtype` etc.
|
173
|
-
|
174
|
-
Call arguments:
|
175
|
-
latents: The latent tensor to be transformed.
|
176
|
-
|
177
|
-
Reference:
|
178
|
-
- [High-Resolution Image Synthesis with Latent Diffusion Models](
|
179
|
-
https://arxiv.org/abs/2112.10752).
|
180
|
-
"""
|
181
|
-
|
182
|
-
def __init__(self, scale, shift, **kwargs):
|
183
|
-
super().__init__(**kwargs)
|
184
|
-
self.scale = scale
|
185
|
-
self.shift = shift
|
186
|
-
|
187
|
-
def call(self, latents):
|
188
|
-
return ops.add(ops.divide(latents, self.scale), self.shift)
|
189
|
-
|
190
|
-
def get_config(self):
|
191
|
-
config = super().get_config()
|
192
|
-
config.update(
|
193
|
-
{
|
194
|
-
"scale": self.scale,
|
195
|
-
"shift": self.shift,
|
196
|
-
}
|
197
|
-
)
|
198
|
-
return config
|
199
|
-
|
200
180
|
def compute_output_shape(self, latents_shape):
|
201
181
|
return latents_shape
|
202
182
|
|
@@ -222,16 +202,18 @@ class StableDiffusion3Backbone(Backbone):
|
|
222
202
|
transformer in MMDiT.
|
223
203
|
mmdit_position_size: int. The size of the height and width for the
|
224
204
|
position embedding in MMDiT.
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
205
|
+
mmdit_qk_norm: Optional str. Whether to normalize the query and key
|
206
|
+
tensors for each transformer in MMDiT. Available options are `None`
|
207
|
+
and `"rms_norm"`. Typically, this is set to `None` for 3.0 version
|
208
|
+
and to `"rms_norm"` for 3.5 version.
|
209
|
+
mmdit_dual_attention_indices: Optional tuple. Specifies the indices of
|
210
|
+
the blocks that serve as dual attention blocks. Typically, this is
|
211
|
+
for 3.5 version. Defaults to `None`.
|
212
|
+
vae: The VAE used for transformations between pixel space and latent
|
213
|
+
space.
|
214
|
+
clip_l: The CLIP text encoder for encoding the inputs.
|
215
|
+
clip_g: The CLIP text encoder for encoding the inputs.
|
216
|
+
t5: optional The T5 text encoder for encoding the inputs.
|
235
217
|
latent_channels: int. The number of channels in the latent. Defaults to
|
236
218
|
`16`.
|
237
219
|
output_channels: int. The number of channels in the output. Defaults to
|
@@ -239,9 +221,9 @@ class StableDiffusion3Backbone(Backbone):
|
|
239
221
|
num_train_timesteps: int. The number of diffusion steps to train the
|
240
222
|
model. Defaults to `1000`.
|
241
223
|
shift: float. The shift value for the timestep schedule. Defaults to
|
242
|
-
`
|
243
|
-
|
244
|
-
|
224
|
+
`3.0`.
|
225
|
+
image_shape: tuple. The input shape without the batch size. Defaults to
|
226
|
+
`(1024, 1024, 3)`.
|
245
227
|
data_format: `None` or str. If specified, either `"channels_last"` or
|
246
228
|
`"channels_first"`. The ordering of the dimensions in the
|
247
229
|
inputs. `"channels_last"` corresponds to inputs with shape
|
@@ -264,6 +246,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
264
246
|
)
|
265
247
|
|
266
248
|
# Randomly initialized Stable Diffusion 3 model with custom config.
|
249
|
+
vae = keras_hub.models.VAEBackbone(...)
|
267
250
|
clip_l = keras_hub.models.CLIPTextEncoder(...)
|
268
251
|
clip_g = keras_hub.models.CLIPTextEncoder(...)
|
269
252
|
model = keras_hub.models.StableDiffusion3Backbone(
|
@@ -272,8 +255,9 @@ class StableDiffusion3Backbone(Backbone):
|
|
272
255
|
mmdit_hidden_dim=256,
|
273
256
|
mmdit_depth=4,
|
274
257
|
mmdit_position_size=192,
|
275
|
-
|
276
|
-
|
258
|
+
mmdit_qk_norm=None,
|
259
|
+
mmdit_dual_attention_indices=None,
|
260
|
+
vae=vae,
|
277
261
|
clip_l=clip_l,
|
278
262
|
clip_g=clip_g,
|
279
263
|
)
|
@@ -287,46 +271,48 @@ class StableDiffusion3Backbone(Backbone):
|
|
287
271
|
mmdit_num_layers,
|
288
272
|
mmdit_num_heads,
|
289
273
|
mmdit_position_size,
|
290
|
-
|
291
|
-
|
274
|
+
mmdit_qk_norm,
|
275
|
+
mmdit_dual_attention_indices,
|
276
|
+
vae,
|
292
277
|
clip_l,
|
293
278
|
clip_g,
|
294
279
|
t5=None,
|
295
280
|
latent_channels=16,
|
296
281
|
output_channels=3,
|
297
282
|
num_train_timesteps=1000,
|
298
|
-
shift=
|
299
|
-
|
300
|
-
width=None,
|
283
|
+
shift=3.0,
|
284
|
+
image_shape=(1024, 1024, 3),
|
301
285
|
data_format=None,
|
302
286
|
dtype=None,
|
303
287
|
**kwargs,
|
304
288
|
):
|
305
|
-
height = int(height or 1024)
|
306
|
-
width = int(width or 1024)
|
307
|
-
if height % 8 != 0 or width % 8 != 0:
|
308
|
-
raise ValueError(
|
309
|
-
"`height` and `width` must be divisible by 8. "
|
310
|
-
f"Received: height={height}, width={width}"
|
311
|
-
)
|
312
289
|
data_format = standardize_data_format(data_format)
|
313
290
|
if data_format != "channels_last":
|
314
291
|
raise NotImplementedError
|
315
|
-
|
292
|
+
height = image_shape[0]
|
293
|
+
width = image_shape[1]
|
294
|
+
if height % 8 != 0 or width % 8 != 0:
|
295
|
+
raise ValueError(
|
296
|
+
"height and width in `image_shape` must be divisible by 8. "
|
297
|
+
f"Received: image_shape={image_shape}"
|
298
|
+
)
|
299
|
+
latent_shape = (height // 8, width // 8, int(latent_channels))
|
316
300
|
context_shape = (None, 4096 if t5 is None else t5.hidden_dim)
|
317
301
|
pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,)
|
302
|
+
self._latent_shape = latent_shape
|
318
303
|
|
319
304
|
# === Layers ===
|
320
305
|
self.clip_l = clip_l
|
321
306
|
self.clip_l_projection = CLIPProjection(
|
322
307
|
clip_l.hidden_dim, dtype=dtype, name="clip_l_projection"
|
323
308
|
)
|
324
|
-
self.clip_l_projection.build([None, clip_l.hidden_dim], None)
|
325
309
|
self.clip_g = clip_g
|
326
310
|
self.clip_g_projection = CLIPProjection(
|
327
311
|
clip_g.hidden_dim, dtype=dtype, name="clip_g_projection"
|
328
312
|
)
|
329
|
-
self.
|
313
|
+
self.clip_concatenate = CLIPConcatenate(
|
314
|
+
dtype=dtype, name="clip_concatenate"
|
315
|
+
)
|
330
316
|
self.t5 = t5
|
331
317
|
self.diffuser = MMDiT(
|
332
318
|
mmdit_patch_size,
|
@@ -337,18 +323,18 @@ class StableDiffusion3Backbone(Backbone):
|
|
337
323
|
latent_shape=latent_shape,
|
338
324
|
context_shape=context_shape,
|
339
325
|
pooled_projection_shape=pooled_projection_shape,
|
326
|
+
qk_norm=mmdit_qk_norm,
|
327
|
+
dual_attention_indices=mmdit_dual_attention_indices,
|
340
328
|
data_format=data_format,
|
341
329
|
dtype=dtype,
|
342
330
|
name="diffuser",
|
343
331
|
)
|
344
|
-
self.
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
dtype=dtype,
|
351
|
-
name="decoder",
|
332
|
+
self.vae = vae
|
333
|
+
self.cfg_concat = ClassifierFreeGuidanceConcatenate(
|
334
|
+
dtype=dtype, name="classifier_free_guidance_concat"
|
335
|
+
)
|
336
|
+
self.cfg = ClassifierFreeGuidance(
|
337
|
+
dtype=dtype, name="classifier_free_guidance"
|
352
338
|
)
|
353
339
|
# Set `dtype="float32"` to ensure the high precision for the noise
|
354
340
|
# residual.
|
@@ -358,21 +344,25 @@ class StableDiffusion3Backbone(Backbone):
|
|
358
344
|
dtype="float32",
|
359
345
|
name="scheduler",
|
360
346
|
)
|
361
|
-
self.cfg_concat = ClassifierFreeGuidanceConcatenate(
|
362
|
-
dtype="float32", name="classifier_free_guidance_concat"
|
363
|
-
)
|
364
|
-
self.cfg = ClassifierFreeGuidance(
|
365
|
-
dtype="float32", name="classifier_free_guidance"
|
366
|
-
)
|
367
347
|
self.euler_step = EulerStep(dtype="float32", name="euler_step")
|
368
|
-
self.
|
369
|
-
scale=self.
|
370
|
-
|
371
|
-
dtype=
|
372
|
-
name="
|
348
|
+
self.image_rescaling = ImageRescaling(
|
349
|
+
scale=self.vae.scale,
|
350
|
+
offset=self.vae.shift,
|
351
|
+
dtype=dtype,
|
352
|
+
name="image_rescaling",
|
353
|
+
)
|
354
|
+
self.latent_rescaling = LatentRescaling(
|
355
|
+
scale=self.vae.scale,
|
356
|
+
offset=self.vae.shift,
|
357
|
+
dtype=dtype,
|
358
|
+
name="latent_rescaling",
|
373
359
|
)
|
374
360
|
|
375
361
|
# === Functional Model ===
|
362
|
+
image_input = keras.Input(
|
363
|
+
shape=image_shape,
|
364
|
+
name="images",
|
365
|
+
)
|
376
366
|
latent_input = keras.Input(
|
377
367
|
shape=latent_shape,
|
378
368
|
name="latents",
|
@@ -428,17 +418,19 @@ class StableDiffusion3Backbone(Backbone):
|
|
428
418
|
dtype="float32",
|
429
419
|
name="guidance_scale",
|
430
420
|
)
|
431
|
-
embeddings = self.
|
421
|
+
embeddings = self.encode_text_step(token_ids, negative_token_ids)
|
422
|
+
latents = self.encode_image_step(image_input)
|
432
423
|
# Use `steps=0` to define the functional model.
|
433
|
-
|
424
|
+
denoised_latents = self.denoise_step(
|
434
425
|
latent_input,
|
435
426
|
embeddings,
|
436
427
|
0,
|
437
428
|
num_step_input[0],
|
438
429
|
guidance_scale_input[0],
|
439
430
|
)
|
440
|
-
|
431
|
+
images = self.decode_step(denoised_latents)
|
441
432
|
inputs = {
|
433
|
+
"images": image_input,
|
442
434
|
"latents": latent_input,
|
443
435
|
"clip_l_token_ids": clip_l_token_id_input,
|
444
436
|
"clip_l_negative_token_ids": clip_l_negative_token_id_input,
|
@@ -447,6 +439,10 @@ class StableDiffusion3Backbone(Backbone):
|
|
447
439
|
"num_steps": num_step_input,
|
448
440
|
"guidance_scale": guidance_scale_input,
|
449
441
|
}
|
442
|
+
outputs = {
|
443
|
+
"latents": latents,
|
444
|
+
"images": images,
|
445
|
+
}
|
450
446
|
if self.t5 is not None:
|
451
447
|
inputs["t5_token_ids"] = t5_token_id_input
|
452
448
|
inputs["t5_negative_token_ids"] = t5_negative_token_id_input
|
@@ -463,18 +459,17 @@ class StableDiffusion3Backbone(Backbone):
|
|
463
459
|
self.mmdit_num_layers = mmdit_num_layers
|
464
460
|
self.mmdit_num_heads = mmdit_num_heads
|
465
461
|
self.mmdit_position_size = mmdit_position_size
|
466
|
-
self.
|
467
|
-
self.
|
462
|
+
self.mmdit_qk_norm = mmdit_qk_norm
|
463
|
+
self.mmdit_dual_attention_indices = mmdit_dual_attention_indices
|
468
464
|
self.latent_channels = latent_channels
|
469
465
|
self.output_channels = output_channels
|
470
466
|
self.num_train_timesteps = num_train_timesteps
|
471
467
|
self.shift = shift
|
472
|
-
self.
|
473
|
-
self.width = width
|
468
|
+
self.image_shape = image_shape
|
474
469
|
|
475
470
|
@property
|
476
471
|
def latent_shape(self):
|
477
|
-
return (None,) +
|
472
|
+
return (None,) + self._latent_shape
|
478
473
|
|
479
474
|
@property
|
480
475
|
def clip_hidden_dim(self):
|
@@ -484,13 +479,17 @@ class StableDiffusion3Backbone(Backbone):
|
|
484
479
|
def t5_hidden_dim(self):
|
485
480
|
return 4096 if self.t5 is None else self.t5.hidden_dim
|
486
481
|
|
487
|
-
def
|
482
|
+
def encode_text_step(self, token_ids, negative_token_ids):
|
488
483
|
clip_hidden_dim = self.clip_hidden_dim
|
489
484
|
t5_hidden_dim = self.t5_hidden_dim
|
490
485
|
|
491
486
|
def encode(token_ids):
|
492
|
-
clip_l_outputs = self.clip_l(
|
493
|
-
|
487
|
+
clip_l_outputs = self.clip_l(
|
488
|
+
{"token_ids": token_ids["clip_l"]}, training=False
|
489
|
+
)
|
490
|
+
clip_g_outputs = self.clip_g(
|
491
|
+
{"token_ids": token_ids["clip_g"]}, training=False
|
492
|
+
)
|
494
493
|
clip_l_projection = self.clip_l_projection(
|
495
494
|
clip_l_outputs["sequence_output"],
|
496
495
|
token_ids["clip_l"],
|
@@ -501,23 +500,21 @@ class StableDiffusion3Backbone(Backbone):
|
|
501
500
|
token_ids["clip_g"],
|
502
501
|
training=False,
|
503
502
|
)
|
504
|
-
pooled_embeddings =
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
clip_l_outputs["intermediate_output"],
|
511
|
-
clip_g_outputs["intermediate_output"],
|
512
|
-
],
|
513
|
-
axis=-1,
|
514
|
-
)
|
515
|
-
embeddings = ops.pad(
|
516
|
-
embeddings,
|
517
|
-
[[0, 0], [0, 0], [0, t5_hidden_dim - clip_hidden_dim]],
|
503
|
+
pooled_embeddings, embeddings = self.clip_concatenate(
|
504
|
+
clip_l_projection,
|
505
|
+
clip_g_projection,
|
506
|
+
clip_l_outputs["intermediate_output"],
|
507
|
+
clip_g_outputs["intermediate_output"],
|
508
|
+
padding=t5_hidden_dim - clip_hidden_dim,
|
518
509
|
)
|
519
510
|
if self.t5 is not None:
|
520
|
-
t5_outputs = self.t5(
|
511
|
+
t5_outputs = self.t5(
|
512
|
+
{
|
513
|
+
"token_ids": token_ids["t5"],
|
514
|
+
"padding_mask": ops.ones_like(token_ids["t5"]),
|
515
|
+
},
|
516
|
+
training=False,
|
517
|
+
)
|
521
518
|
embeddings = ops.concatenate([embeddings, t5_outputs], axis=-2)
|
522
519
|
else:
|
523
520
|
padded_size = self.clip_l.max_sequence_length
|
@@ -537,23 +534,36 @@ class StableDiffusion3Backbone(Backbone):
|
|
537
534
|
negative_pooled_embeddings,
|
538
535
|
)
|
539
536
|
|
537
|
+
def encode_image_step(self, images):
|
538
|
+
latents = self.vae.encode(images)
|
539
|
+
return self.image_rescaling(latents)
|
540
|
+
|
541
|
+
def add_noise_step(self, latents, noises, step, num_steps):
|
542
|
+
return self.scheduler.add_noise(latents, noises, step, num_steps)
|
543
|
+
|
540
544
|
def denoise_step(
|
541
545
|
self,
|
542
546
|
latents,
|
543
547
|
embeddings,
|
544
|
-
|
548
|
+
step,
|
545
549
|
num_steps,
|
546
|
-
guidance_scale,
|
550
|
+
guidance_scale=None,
|
547
551
|
):
|
548
|
-
|
549
|
-
|
550
|
-
sigma, timestep = self.scheduler(
|
551
|
-
|
552
|
+
step = ops.convert_to_tensor(step)
|
553
|
+
next_step = ops.add(step, 1)
|
554
|
+
sigma, timestep = self.scheduler(step, num_steps)
|
555
|
+
next_sigma, _ = self.scheduler(next_step, num_steps)
|
552
556
|
|
553
557
|
# Concatenation for classifier-free guidance.
|
554
|
-
|
555
|
-
|
556
|
-
|
558
|
+
if guidance_scale is not None:
|
559
|
+
concated_latents, contexts, pooled_projs, timesteps = (
|
560
|
+
self.cfg_concat(latents, *embeddings, timestep)
|
561
|
+
)
|
562
|
+
else:
|
563
|
+
timesteps = ops.broadcast_to(timestep, ops.shape(latents)[:1])
|
564
|
+
concated_latents = latents
|
565
|
+
contexts = embeddings[0]
|
566
|
+
pooled_projs = embeddings[2]
|
557
567
|
|
558
568
|
# Diffusion.
|
559
569
|
predicted_noise = self.diffuser(
|
@@ -567,14 +577,15 @@ class StableDiffusion3Backbone(Backbone):
|
|
567
577
|
)
|
568
578
|
|
569
579
|
# Classifier-free guidance.
|
570
|
-
|
580
|
+
if guidance_scale is not None:
|
581
|
+
predicted_noise = self.cfg(predicted_noise, guidance_scale)
|
571
582
|
|
572
583
|
# Euler step.
|
573
|
-
return self.euler_step(latents, predicted_noise, sigma,
|
584
|
+
return self.euler_step(latents, predicted_noise, sigma, next_sigma)
|
574
585
|
|
575
586
|
def decode_step(self, latents):
|
576
|
-
latents = self.
|
577
|
-
return self.
|
587
|
+
latents = self.latent_rescaling(latents)
|
588
|
+
return self.vae.decode(latents, training=False)
|
578
589
|
|
579
590
|
def get_config(self):
|
580
591
|
config = super().get_config()
|
@@ -585,8 +596,11 @@ class StableDiffusion3Backbone(Backbone):
|
|
585
596
|
"mmdit_num_layers": self.mmdit_num_layers,
|
586
597
|
"mmdit_num_heads": self.mmdit_num_heads,
|
587
598
|
"mmdit_position_size": self.mmdit_position_size,
|
588
|
-
"
|
589
|
-
"
|
599
|
+
"mmdit_qk_norm": self.mmdit_qk_norm,
|
600
|
+
"mmdit_dual_attention_indices": (
|
601
|
+
self.mmdit_dual_attention_indices
|
602
|
+
),
|
603
|
+
"vae": layers.serialize(self.vae),
|
590
604
|
"clip_l": layers.serialize(self.clip_l),
|
591
605
|
"clip_g": layers.serialize(self.clip_g),
|
592
606
|
"t5": layers.serialize(self.t5),
|
@@ -594,8 +608,7 @@ class StableDiffusion3Backbone(Backbone):
|
|
594
608
|
"output_channels": self.output_channels,
|
595
609
|
"num_train_timesteps": self.num_train_timesteps,
|
596
610
|
"shift": self.shift,
|
597
|
-
"
|
598
|
-
"width": self.width,
|
611
|
+
"image_shape": self.image_shape,
|
599
612
|
}
|
600
613
|
)
|
601
614
|
return config
|
@@ -607,6 +620,8 @@ class StableDiffusion3Backbone(Backbone):
|
|
607
620
|
# Propagate `dtype` to text encoders if needed.
|
608
621
|
if "dtype" in config and config["dtype"] is not None:
|
609
622
|
dtype_config = config["dtype"]
|
623
|
+
if "dtype" not in config["vae"]["config"]:
|
624
|
+
config["vae"]["config"]["dtype"] = dtype_config
|
610
625
|
if "dtype" not in config["clip_l"]["config"]:
|
611
626
|
config["clip_l"]["config"]["dtype"] = dtype_config
|
612
627
|
if "dtype" not in config["clip_g"]["config"]:
|
@@ -617,7 +632,10 @@ class StableDiffusion3Backbone(Backbone):
|
|
617
632
|
):
|
618
633
|
config["t5"]["config"]["dtype"] = dtype_config
|
619
634
|
|
620
|
-
# We expect `clip_l`, `clip_g` and/or `t5` to be instantiated.
|
635
|
+
# We expect `vae`, `clip_l`, `clip_g` and/or `t5` to be instantiated.
|
636
|
+
config["vae"] = layers.deserialize(
|
637
|
+
config["vae"], custom_objects=custom_objects
|
638
|
+
)
|
621
639
|
config["clip_l"] = layers.deserialize(
|
622
640
|
config["clip_l"], custom_objects=custom_objects
|
623
641
|
)
|
@@ -628,4 +646,12 @@ class StableDiffusion3Backbone(Backbone):
|
|
628
646
|
config["t5"] = layers.deserialize(
|
629
647
|
config["t5"], custom_objects=custom_objects
|
630
648
|
)
|
649
|
+
|
650
|
+
# To maintain backward compatibility, we need to ensure that
|
651
|
+
# `mmdit_qk_norm` and `mmdit_dual_attention_indices` is included in the
|
652
|
+
# config.
|
653
|
+
if "mmdit_qk_norm" not in config:
|
654
|
+
config["mmdit_qk_norm"] = None
|
655
|
+
if "mmdit_dual_attention_indices" not in config:
|
656
|
+
config["mmdit_dual_attention_indices"] = None
|
631
657
|
return cls(**config)
|