keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,6 @@ import math
|
|
2
2
|
|
3
3
|
import keras
|
4
4
|
from keras import layers
|
5
|
-
from keras import models
|
6
5
|
from keras import ops
|
7
6
|
|
8
7
|
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
@@ -11,7 +10,216 @@ from keras_hub.src.utils.keras_utils import gelu_approximate
|
|
11
10
|
from keras_hub.src.utils.keras_utils import standardize_data_format
|
12
11
|
|
13
12
|
|
13
|
+
class AdaptiveLayerNormalization(layers.Layer):
|
14
|
+
"""Adaptive layer normalization.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
embedding_dim: int. The size of each embedding vector.
|
18
|
+
num_modulations: int. The number of the modulation parameters. The
|
19
|
+
available values are `2`, `6` and `9`. Defaults to `2`.
|
20
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
21
|
+
including `name`, `dtype` etc.
|
22
|
+
|
23
|
+
References:
|
24
|
+
- [FiLM: Visual Reasoning with a General Conditioning Layer](
|
25
|
+
https://arxiv.org/abs/1709.07871).
|
26
|
+
- [Scalable Diffusion Models with Transformers](
|
27
|
+
https://arxiv.org/abs/2212.09748).
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self, hidden_dim, num_modulations=2, **kwargs):
|
31
|
+
super().__init__(**kwargs)
|
32
|
+
hidden_dim = int(hidden_dim)
|
33
|
+
num_modulations = int(num_modulations)
|
34
|
+
if num_modulations not in (2, 6, 9):
|
35
|
+
raise ValueError(
|
36
|
+
"`num_modulations` must be `2`, `6` or `9`. "
|
37
|
+
f"Received: num_modulations={num_modulations}"
|
38
|
+
)
|
39
|
+
self.hidden_dim = hidden_dim
|
40
|
+
self.num_modulations = num_modulations
|
41
|
+
|
42
|
+
self.silu = layers.Activation("silu", dtype=self.dtype_policy)
|
43
|
+
self.dense = layers.Dense(
|
44
|
+
num_modulations * hidden_dim, dtype=self.dtype_policy, name="dense"
|
45
|
+
)
|
46
|
+
self.norm = layers.LayerNormalization(
|
47
|
+
epsilon=1e-6,
|
48
|
+
center=False,
|
49
|
+
scale=False,
|
50
|
+
dtype="float32",
|
51
|
+
name="norm",
|
52
|
+
)
|
53
|
+
|
54
|
+
def build(self, inputs_shape, embeddings_shape):
|
55
|
+
self.silu.build(embeddings_shape)
|
56
|
+
self.dense.build(embeddings_shape)
|
57
|
+
self.norm.build(inputs_shape)
|
58
|
+
|
59
|
+
def call(self, inputs, embeddings, training=None):
|
60
|
+
hidden_states = inputs
|
61
|
+
emb = self.dense(self.silu(embeddings), training=training)
|
62
|
+
if self.num_modulations == 9:
|
63
|
+
(
|
64
|
+
shift_msa,
|
65
|
+
scale_msa,
|
66
|
+
gate_msa,
|
67
|
+
shift_mlp,
|
68
|
+
scale_mlp,
|
69
|
+
gate_mlp,
|
70
|
+
shift_msa2,
|
71
|
+
scale_msa2,
|
72
|
+
gate_msa2,
|
73
|
+
) = ops.split(emb, self.num_modulations, axis=1)
|
74
|
+
elif self.num_modulations == 6:
|
75
|
+
(
|
76
|
+
shift_msa,
|
77
|
+
scale_msa,
|
78
|
+
gate_msa,
|
79
|
+
shift_mlp,
|
80
|
+
scale_mlp,
|
81
|
+
gate_mlp,
|
82
|
+
) = ops.split(emb, self.num_modulations, axis=1)
|
83
|
+
else:
|
84
|
+
shift_msa, scale_msa = ops.split(emb, self.num_modulations, axis=1)
|
85
|
+
|
86
|
+
scale_msa = ops.expand_dims(scale_msa, axis=1)
|
87
|
+
shift_msa = ops.expand_dims(shift_msa, axis=1)
|
88
|
+
norm_hidden_states = ops.cast(
|
89
|
+
self.norm(hidden_states, training=training), scale_msa.dtype
|
90
|
+
)
|
91
|
+
hidden_states = ops.add(
|
92
|
+
ops.multiply(norm_hidden_states, ops.add(1.0, scale_msa)), shift_msa
|
93
|
+
)
|
94
|
+
|
95
|
+
if self.num_modulations == 9:
|
96
|
+
scale_msa2 = ops.expand_dims(scale_msa2, axis=1)
|
97
|
+
shift_msa2 = ops.expand_dims(shift_msa2, axis=1)
|
98
|
+
hidden_states2 = ops.add(
|
99
|
+
ops.multiply(norm_hidden_states, ops.add(1.0, scale_msa2)),
|
100
|
+
shift_msa2,
|
101
|
+
)
|
102
|
+
return (
|
103
|
+
hidden_states,
|
104
|
+
gate_msa,
|
105
|
+
shift_mlp,
|
106
|
+
scale_mlp,
|
107
|
+
gate_mlp,
|
108
|
+
hidden_states2,
|
109
|
+
gate_msa2,
|
110
|
+
)
|
111
|
+
elif self.num_modulations == 6:
|
112
|
+
return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
113
|
+
else:
|
114
|
+
return hidden_states
|
115
|
+
|
116
|
+
def get_config(self):
|
117
|
+
config = super().get_config()
|
118
|
+
config.update(
|
119
|
+
{
|
120
|
+
"hidden_dim": self.hidden_dim,
|
121
|
+
"num_modulations": self.num_modulations,
|
122
|
+
}
|
123
|
+
)
|
124
|
+
return config
|
125
|
+
|
126
|
+
def compute_output_shape(self, inputs_shape, embeddings_shape):
|
127
|
+
if self.num_modulations == 9:
|
128
|
+
return (
|
129
|
+
inputs_shape,
|
130
|
+
embeddings_shape,
|
131
|
+
embeddings_shape,
|
132
|
+
embeddings_shape,
|
133
|
+
embeddings_shape,
|
134
|
+
inputs_shape,
|
135
|
+
embeddings_shape,
|
136
|
+
)
|
137
|
+
elif self.num_modulations == 6:
|
138
|
+
return (
|
139
|
+
inputs_shape,
|
140
|
+
embeddings_shape,
|
141
|
+
embeddings_shape,
|
142
|
+
embeddings_shape,
|
143
|
+
embeddings_shape,
|
144
|
+
)
|
145
|
+
else:
|
146
|
+
return inputs_shape
|
147
|
+
|
148
|
+
|
149
|
+
class MLP(layers.Layer):
|
150
|
+
"""A MLP block with architecture.
|
151
|
+
|
152
|
+
Args:
|
153
|
+
hidden_dim: int. The number of units in the hidden layers.
|
154
|
+
output_dim: int. The number of units in the output layer.
|
155
|
+
activation: str of callable. Activation to use in the hidden layers.
|
156
|
+
Default to `None`.
|
157
|
+
"""
|
158
|
+
|
159
|
+
def __init__(self, hidden_dim, output_dim, activation=None, **kwargs):
|
160
|
+
super().__init__(**kwargs)
|
161
|
+
self.hidden_dim = int(hidden_dim)
|
162
|
+
self.output_dim = int(output_dim)
|
163
|
+
self.activation = keras.activations.get(activation)
|
164
|
+
|
165
|
+
self.dense1 = layers.Dense(
|
166
|
+
hidden_dim,
|
167
|
+
activation=self.activation,
|
168
|
+
dtype=self.dtype_policy,
|
169
|
+
name="dense1",
|
170
|
+
)
|
171
|
+
self.dense2 = layers.Dense(
|
172
|
+
output_dim,
|
173
|
+
activation=None,
|
174
|
+
dtype=self.dtype_policy,
|
175
|
+
name="dense2",
|
176
|
+
)
|
177
|
+
|
178
|
+
def build(self, inputs_shape):
|
179
|
+
self.dense1.build(inputs_shape)
|
180
|
+
inputs_shape = self.dense1.compute_output_shape(inputs_shape)
|
181
|
+
self.dense2.build(inputs_shape)
|
182
|
+
|
183
|
+
def call(self, inputs, training=None):
|
184
|
+
x = self.dense1(inputs, training=training)
|
185
|
+
return self.dense2(x, training=training)
|
186
|
+
|
187
|
+
def get_config(self):
|
188
|
+
config = super().get_config()
|
189
|
+
config.update(
|
190
|
+
{
|
191
|
+
"hidden_dim": self.hidden_dim,
|
192
|
+
"output_dim": self.output_dim,
|
193
|
+
"activation": keras.activations.serialize(self.activation),
|
194
|
+
}
|
195
|
+
)
|
196
|
+
return config
|
197
|
+
|
198
|
+
def compute_output_shape(self, inputs_shape):
|
199
|
+
outputs_shape = list(inputs_shape)
|
200
|
+
outputs_shape[-1] = self.output_dim
|
201
|
+
return outputs_shape
|
202
|
+
|
203
|
+
|
14
204
|
class PatchEmbedding(layers.Layer):
|
205
|
+
"""A layer that converts images into patches.
|
206
|
+
|
207
|
+
Args:
|
208
|
+
patch_size: int. The size of one side of each patch.
|
209
|
+
hidden_dim: int. The number of units in the hidden layers.
|
210
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
211
|
+
`"channels_first"`. The ordering of the dimensions in the
|
212
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
213
|
+
`(batch_size, height, width, channels)`
|
214
|
+
while `"channels_first"` corresponds to inputs with shape
|
215
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
216
|
+
`image_data_format` value found in your Keras config file at
|
217
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
218
|
+
`"channels_last"`.
|
219
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
220
|
+
including `name`, `dtype` etc.
|
221
|
+
"""
|
222
|
+
|
15
223
|
def __init__(self, patch_size, hidden_dim, data_format=None, **kwargs):
|
16
224
|
super().__init__(**kwargs)
|
17
225
|
self.patch_size = int(patch_size)
|
@@ -48,6 +256,15 @@ class PatchEmbedding(layers.Layer):
|
|
48
256
|
|
49
257
|
|
50
258
|
class AdjustablePositionEmbedding(PositionEmbedding):
|
259
|
+
"""A position embedding layer with adjustable height and width.
|
260
|
+
|
261
|
+
The embedding will be cropped to match the input dimensions.
|
262
|
+
|
263
|
+
Args:
|
264
|
+
height: int. The maximum height of the embedding.
|
265
|
+
width: int. The maximum width of the embedding.
|
266
|
+
"""
|
267
|
+
|
51
268
|
def __init__(
|
52
269
|
self,
|
53
270
|
height,
|
@@ -84,11 +301,36 @@ class AdjustablePositionEmbedding(PositionEmbedding):
|
|
84
301
|
position_embedding = ops.expand_dims(position_embedding, axis=0)
|
85
302
|
return position_embedding
|
86
303
|
|
304
|
+
def get_config(self):
|
305
|
+
config = super().get_config()
|
306
|
+
del config["sequence_length"]
|
307
|
+
config.update(
|
308
|
+
{
|
309
|
+
"height": self.height,
|
310
|
+
"width": self.width,
|
311
|
+
}
|
312
|
+
)
|
313
|
+
return config
|
314
|
+
|
87
315
|
def compute_output_shape(self, input_shape):
|
88
316
|
return input_shape
|
89
317
|
|
90
318
|
|
91
319
|
class TimestepEmbedding(layers.Layer):
|
320
|
+
"""A layer which learns embedding for input timesteps.
|
321
|
+
|
322
|
+
Args:
|
323
|
+
embedding_dim: int. The size of the embedding.
|
324
|
+
frequency_dim: int. The size of the frequency.
|
325
|
+
max_period: int. Controls the maximum frequency of the embeddings.
|
326
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
327
|
+
including `name`, `dtype` etc.
|
328
|
+
|
329
|
+
Reference:
|
330
|
+
- [Denoising Diffusion Probabilistic Models](
|
331
|
+
https://arxiv.org/abs/2006.11239).
|
332
|
+
"""
|
333
|
+
|
92
334
|
def __init__(
|
93
335
|
self, embedding_dim, frequency_dim=256, max_period=10000, **kwargs
|
94
336
|
):
|
@@ -96,17 +338,23 @@ class TimestepEmbedding(layers.Layer):
|
|
96
338
|
self.embedding_dim = int(embedding_dim)
|
97
339
|
self.frequency_dim = int(frequency_dim)
|
98
340
|
self.max_period = float(max_period)
|
99
|
-
|
100
|
-
|
101
|
-
self.
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
layers.Dense(
|
107
|
-
embedding_dim, activation=None, dtype=self.dtype_policy
|
341
|
+
# Precomputed `freq`.
|
342
|
+
half_frequency_dim = frequency_dim // 2
|
343
|
+
self.freq = ops.exp(
|
344
|
+
ops.divide(
|
345
|
+
ops.multiply(
|
346
|
+
-math.log(max_period),
|
347
|
+
ops.arange(0, half_frequency_dim, dtype="float32"),
|
108
348
|
),
|
109
|
-
|
349
|
+
half_frequency_dim,
|
350
|
+
)
|
351
|
+
)
|
352
|
+
|
353
|
+
self.mlp = MLP(
|
354
|
+
embedding_dim,
|
355
|
+
embedding_dim,
|
356
|
+
"silu",
|
357
|
+
dtype=self.dtype_policy,
|
110
358
|
name="mlp",
|
111
359
|
)
|
112
360
|
|
@@ -118,16 +366,7 @@ class TimestepEmbedding(layers.Layer):
|
|
118
366
|
def _create_timestep_embedding(self, inputs):
|
119
367
|
compute_dtype = keras.backend.result_type(self.compute_dtype, "float32")
|
120
368
|
x = ops.cast(inputs, compute_dtype)
|
121
|
-
freqs = ops.
|
122
|
-
ops.divide(
|
123
|
-
ops.multiply(
|
124
|
-
-math.log(self.max_period),
|
125
|
-
ops.arange(0, self.half_frequency_dim, dtype="float32"),
|
126
|
-
),
|
127
|
-
self.half_frequency_dim,
|
128
|
-
)
|
129
|
-
)
|
130
|
-
freqs = ops.cast(freqs, compute_dtype)
|
369
|
+
freqs = ops.cast(self.freq, compute_dtype)
|
131
370
|
x = ops.multiply(x, ops.expand_dims(freqs, axis=0))
|
132
371
|
embedding = ops.concatenate([ops.cos(x), ops.sin(x)], axis=-1)
|
133
372
|
if self.frequency_dim % 2 != 0:
|
@@ -143,6 +382,7 @@ class TimestepEmbedding(layers.Layer):
|
|
143
382
|
config.update(
|
144
383
|
{
|
145
384
|
"embedding_dim": self.embedding_dim,
|
385
|
+
"frequency_dim": self.frequency_dim,
|
146
386
|
"max_period": self.max_period,
|
147
387
|
}
|
148
388
|
)
|
@@ -154,13 +394,52 @@ class TimestepEmbedding(layers.Layer):
|
|
154
394
|
return output_shape
|
155
395
|
|
156
396
|
|
397
|
+
def get_qk_norm(qk_norm=None, q_norm_name="q_norm", k_norm_name="k_norm"):
|
398
|
+
"""Helper function to instantiate `LayerNormalization` layers."""
|
399
|
+
q_norm = None
|
400
|
+
k_norm = None
|
401
|
+
if qk_norm is None:
|
402
|
+
pass
|
403
|
+
elif qk_norm == "rms_norm":
|
404
|
+
q_norm = layers.LayerNormalization(
|
405
|
+
epsilon=1e-6, rms_scaling=True, dtype="float32", name=q_norm_name
|
406
|
+
)
|
407
|
+
k_norm = layers.LayerNormalization(
|
408
|
+
epsilon=1e-6, rms_scaling=True, dtype="float32", name=k_norm_name
|
409
|
+
)
|
410
|
+
else:
|
411
|
+
raise NotImplementedError(
|
412
|
+
"Supported `qk_norm` are `'rms_norm'` and `None`. "
|
413
|
+
f"Received: qk_norm={qk_norm}."
|
414
|
+
)
|
415
|
+
return q_norm, k_norm
|
416
|
+
|
417
|
+
|
157
418
|
class DismantledBlock(layers.Layer):
|
419
|
+
"""A dismantled block used to compute pre- and post-attention.
|
420
|
+
|
421
|
+
Args:
|
422
|
+
num_heads: int. Number of attention heads.
|
423
|
+
hidden_dim: int. The number of units in the hidden layers.
|
424
|
+
mlp_ratio: float. The expansion ratio of `MLP`.
|
425
|
+
use_projection: bool. Whether to use an attention projection layer at
|
426
|
+
the end of the block.
|
427
|
+
qk_norm: Optional str. Whether to normalize the query and key tensors.
|
428
|
+
Available options are `None` and `"rms_norm"`. Defaults to `None`.
|
429
|
+
use_dual_attention: bool. Whether to use a dual attention in the
|
430
|
+
block. Defaults to `False`.
|
431
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
432
|
+
including `name`, `dtype` etc.
|
433
|
+
"""
|
434
|
+
|
158
435
|
def __init__(
|
159
436
|
self,
|
160
437
|
num_heads,
|
161
438
|
hidden_dim,
|
162
439
|
mlp_ratio=4.0,
|
163
440
|
use_projection=True,
|
441
|
+
qk_norm=None,
|
442
|
+
use_dual_attention=False,
|
164
443
|
**kwargs,
|
165
444
|
):
|
166
445
|
super().__init__(**kwargs)
|
@@ -168,33 +447,32 @@ class DismantledBlock(layers.Layer):
|
|
168
447
|
self.hidden_dim = hidden_dim
|
169
448
|
self.mlp_ratio = mlp_ratio
|
170
449
|
self.use_projection = use_projection
|
450
|
+
self.qk_norm = qk_norm
|
451
|
+
self.use_dual_attention = use_dual_attention
|
171
452
|
|
172
453
|
head_dim = hidden_dim // num_heads
|
173
454
|
self.head_dim = head_dim
|
174
455
|
mlp_hidden_dim = int(hidden_dim * mlp_ratio)
|
175
456
|
self.mlp_hidden_dim = mlp_hidden_dim
|
176
|
-
num_modulations = 6 if use_projection else 2
|
177
|
-
self.num_modulations = num_modulations
|
178
457
|
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
center=False,
|
191
|
-
scale=False,
|
192
|
-
dtype="float32",
|
193
|
-
name="norm1",
|
194
|
-
)
|
458
|
+
if use_projection:
|
459
|
+
self.ada_layer_norm = AdaptiveLayerNormalization(
|
460
|
+
hidden_dim,
|
461
|
+
num_modulations=9 if use_dual_attention else 6,
|
462
|
+
dtype=self.dtype_policy,
|
463
|
+
name="ada_layer_norm",
|
464
|
+
)
|
465
|
+
else:
|
466
|
+
self.ada_layer_norm = AdaptiveLayerNormalization(
|
467
|
+
hidden_dim, dtype=self.dtype_policy, name="ada_layer_norm"
|
468
|
+
)
|
195
469
|
self.attention_qkv = layers.Dense(
|
196
470
|
hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv"
|
197
471
|
)
|
472
|
+
q_norm, k_norm = get_qk_norm(qk_norm)
|
473
|
+
if q_norm is not None:
|
474
|
+
self.q_norm = q_norm
|
475
|
+
self.k_norm = k_norm
|
198
476
|
if use_projection:
|
199
477
|
self.attention_proj = layers.Dense(
|
200
478
|
hidden_dim, dtype=self.dtype_policy, name="attention_proj"
|
@@ -206,89 +484,165 @@ class DismantledBlock(layers.Layer):
|
|
206
484
|
dtype="float32",
|
207
485
|
name="norm2",
|
208
486
|
)
|
209
|
-
self.mlp =
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
dtype=self.dtype_policy,
|
215
|
-
),
|
216
|
-
layers.Dense(
|
217
|
-
hidden_dim,
|
218
|
-
dtype=self.dtype_policy,
|
219
|
-
),
|
220
|
-
],
|
487
|
+
self.mlp = MLP(
|
488
|
+
mlp_hidden_dim,
|
489
|
+
hidden_dim,
|
490
|
+
gelu_approximate,
|
491
|
+
dtype=self.dtype_policy,
|
221
492
|
name="mlp",
|
222
493
|
)
|
223
494
|
|
495
|
+
if use_dual_attention:
|
496
|
+
self.attention_qkv2 = layers.Dense(
|
497
|
+
hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv2"
|
498
|
+
)
|
499
|
+
q_norm2, k_norm2 = get_qk_norm(qk_norm, "q_norm2", "k_norm2")
|
500
|
+
if q_norm is not None:
|
501
|
+
self.q_norm2 = q_norm2
|
502
|
+
self.k_norm2 = k_norm2
|
503
|
+
if use_projection:
|
504
|
+
self.attention_proj2 = layers.Dense(
|
505
|
+
hidden_dim, dtype=self.dtype_policy, name="attention_proj2"
|
506
|
+
)
|
507
|
+
|
224
508
|
def build(self, inputs_shape, timestep_embedding):
|
225
|
-
self.
|
509
|
+
self.ada_layer_norm.build(inputs_shape, timestep_embedding)
|
226
510
|
self.attention_qkv.build(inputs_shape)
|
227
|
-
self.
|
511
|
+
if self.qk_norm is not None:
|
512
|
+
# [batch_size, sequence_length, num_heads, head_dim]
|
513
|
+
self.q_norm.build([None, None, self.num_heads, self.head_dim])
|
514
|
+
self.k_norm.build([None, None, self.num_heads, self.head_dim])
|
228
515
|
if self.use_projection:
|
229
516
|
self.attention_proj.build(inputs_shape)
|
230
517
|
self.norm2.build(inputs_shape)
|
231
518
|
self.mlp.build(inputs_shape)
|
519
|
+
if self.use_dual_attention:
|
520
|
+
self.attention_qkv2.build(inputs_shape)
|
521
|
+
if self.qk_norm is not None:
|
522
|
+
self.q_norm2.build([None, None, self.num_heads, self.head_dim])
|
523
|
+
self.k_norm2.build([None, None, self.num_heads, self.head_dim])
|
524
|
+
if self.use_projection:
|
525
|
+
self.attention_proj2.build(inputs_shape)
|
232
526
|
|
233
527
|
def _modulate(self, inputs, shift, scale):
|
234
|
-
|
235
|
-
|
528
|
+
inputs = ops.cast(inputs, self.compute_dtype)
|
529
|
+
shift = ops.cast(shift, self.compute_dtype)
|
530
|
+
scale = ops.cast(scale, self.compute_dtype)
|
236
531
|
return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
|
237
532
|
|
238
533
|
def _compute_pre_attention(self, inputs, timestep_embedding, training=None):
|
239
534
|
batch_size = ops.shape(inputs)[0]
|
240
535
|
if self.use_projection:
|
241
|
-
|
242
|
-
timestep_embedding, training=training
|
243
|
-
)
|
244
|
-
modulation = ops.reshape(
|
245
|
-
modulation, (batch_size, 6, self.hidden_dim)
|
246
|
-
)
|
247
|
-
(
|
248
|
-
shift_msa,
|
249
|
-
scale_msa,
|
250
|
-
gate_msa,
|
251
|
-
shift_mlp,
|
252
|
-
scale_mlp,
|
253
|
-
gate_mlp,
|
254
|
-
) = ops.unstack(modulation, 6, axis=1)
|
255
|
-
qkv = self.attention_qkv(
|
256
|
-
self._modulate(self.norm1(inputs), shift_msa, scale_msa),
|
257
|
-
training=training,
|
536
|
+
x, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.ada_layer_norm(
|
537
|
+
inputs, timestep_embedding, training=training
|
258
538
|
)
|
539
|
+
qkv = self.attention_qkv(x, training=training)
|
259
540
|
qkv = ops.reshape(
|
260
541
|
qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
|
261
542
|
)
|
262
543
|
q, k, v = ops.unstack(qkv, 3, axis=2)
|
544
|
+
if self.qk_norm is not None:
|
545
|
+
q = ops.cast(
|
546
|
+
self.q_norm(q, training=training), self.compute_dtype
|
547
|
+
)
|
548
|
+
k = ops.cast(
|
549
|
+
self.k_norm(k, training=training), self.compute_dtype
|
550
|
+
)
|
263
551
|
return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp)
|
264
552
|
else:
|
265
|
-
|
266
|
-
timestep_embedding, training=training
|
267
|
-
)
|
268
|
-
modulation = ops.reshape(
|
269
|
-
modulation, (batch_size, 2, self.hidden_dim)
|
270
|
-
)
|
271
|
-
shift_msa, scale_msa = ops.unstack(modulation, 2, axis=1)
|
272
|
-
qkv = self.attention_qkv(
|
273
|
-
self._modulate(self.norm1(inputs), shift_msa, scale_msa),
|
274
|
-
training=training,
|
553
|
+
x = self.ada_layer_norm(
|
554
|
+
inputs, timestep_embedding, training=training
|
275
555
|
)
|
556
|
+
qkv = self.attention_qkv(x, training=training)
|
276
557
|
qkv = ops.reshape(
|
277
558
|
qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
|
278
559
|
)
|
279
560
|
q, k, v = ops.unstack(qkv, 3, axis=2)
|
561
|
+
if self.qk_norm is not None:
|
562
|
+
q = ops.cast(
|
563
|
+
self.q_norm(q, training=training), self.compute_dtype
|
564
|
+
)
|
565
|
+
k = ops.cast(
|
566
|
+
self.k_norm(k, training=training), self.compute_dtype
|
567
|
+
)
|
280
568
|
return (q, k, v)
|
281
569
|
|
282
570
|
def _compute_post_attention(
|
283
571
|
self, inputs, inputs_intermediates, training=None
|
284
572
|
):
|
285
573
|
x, gate_msa, shift_mlp, scale_mlp, gate_mlp = inputs_intermediates
|
574
|
+
gate_msa = ops.expand_dims(gate_msa, axis=1)
|
575
|
+
shift_mlp = ops.expand_dims(shift_mlp, axis=1)
|
576
|
+
scale_mlp = ops.expand_dims(scale_mlp, axis=1)
|
577
|
+
gate_mlp = ops.expand_dims(gate_mlp, axis=1)
|
286
578
|
attn = self.attention_proj(inputs, training=training)
|
287
|
-
x = ops.add(x, ops.multiply(
|
579
|
+
x = ops.add(x, ops.multiply(gate_msa, attn))
|
288
580
|
x = ops.add(
|
289
581
|
x,
|
290
582
|
ops.multiply(
|
291
|
-
|
583
|
+
gate_mlp,
|
584
|
+
self.mlp(
|
585
|
+
self._modulate(self.norm2(x), shift_mlp, scale_mlp),
|
586
|
+
training=training,
|
587
|
+
),
|
588
|
+
),
|
589
|
+
)
|
590
|
+
return x
|
591
|
+
|
592
|
+
def _compute_pre_attention_with_dual_attention(
|
593
|
+
self, inputs, timestep_embedding, training=None
|
594
|
+
):
|
595
|
+
batch_size = ops.shape(inputs)[0]
|
596
|
+
x, gate_msa, shift_mlp, scale_mlp, gate_mlp, x2, gate_msa2 = (
|
597
|
+
self.ada_layer_norm(inputs, timestep_embedding, training=training)
|
598
|
+
)
|
599
|
+
# Compute the main attention
|
600
|
+
qkv = self.attention_qkv(x, training=training)
|
601
|
+
qkv = ops.reshape(
|
602
|
+
qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
|
603
|
+
)
|
604
|
+
q, k, v = ops.unstack(qkv, 3, axis=2)
|
605
|
+
if self.qk_norm is not None:
|
606
|
+
q = ops.cast(self.q_norm(q, training=training), self.compute_dtype)
|
607
|
+
k = ops.cast(self.k_norm(k, training=training), self.compute_dtype)
|
608
|
+
# Compute the dual attention
|
609
|
+
qkv2 = self.attention_qkv2(x2, training=training)
|
610
|
+
qkv2 = ops.reshape(
|
611
|
+
qkv2, (batch_size, -1, 3, self.num_heads, self.head_dim)
|
612
|
+
)
|
613
|
+
q2, k2, v2 = ops.unstack(qkv2, 3, axis=2)
|
614
|
+
if self.qk_norm is not None:
|
615
|
+
q2 = ops.cast(
|
616
|
+
self.q_norm2(q2, training=training), self.compute_dtype
|
617
|
+
)
|
618
|
+
k2 = ops.cast(
|
619
|
+
self.k_norm2(k2, training=training), self.compute_dtype
|
620
|
+
)
|
621
|
+
return (
|
622
|
+
(q, k, v),
|
623
|
+
(q2, k2, v2),
|
624
|
+
(inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2),
|
625
|
+
)
|
626
|
+
|
627
|
+
def _compute_post_attention_with_dual_attention(
|
628
|
+
self, inputs, inputs2, inputs_intermediates, training=None
|
629
|
+
):
|
630
|
+
x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2 = (
|
631
|
+
inputs_intermediates
|
632
|
+
)
|
633
|
+
gate_msa = ops.expand_dims(gate_msa, axis=1)
|
634
|
+
shift_mlp = ops.expand_dims(shift_mlp, axis=1)
|
635
|
+
scale_mlp = ops.expand_dims(scale_mlp, axis=1)
|
636
|
+
gate_mlp = ops.expand_dims(gate_mlp, axis=1)
|
637
|
+
gate_msa2 = ops.expand_dims(gate_msa2, axis=1)
|
638
|
+
attn = self.attention_proj(inputs, training=training)
|
639
|
+
x = ops.add(x, ops.multiply(gate_msa, attn))
|
640
|
+
attn2 = self.attention_proj2(inputs2, training=training)
|
641
|
+
x = ops.add(x, ops.multiply(gate_msa2, attn2))
|
642
|
+
x = ops.add(
|
643
|
+
x,
|
644
|
+
ops.multiply(
|
645
|
+
gate_mlp,
|
292
646
|
self.mlp(
|
293
647
|
self._modulate(self.norm2(x), shift_mlp, scale_mlp),
|
294
648
|
training=training,
|
@@ -302,17 +656,28 @@ class DismantledBlock(layers.Layer):
|
|
302
656
|
inputs,
|
303
657
|
timestep_embedding=None,
|
304
658
|
inputs_intermediates=None,
|
659
|
+
inputs2=None, # For the dual attention.
|
305
660
|
pre_attention=True,
|
306
661
|
training=None,
|
307
662
|
):
|
308
663
|
if pre_attention:
|
309
|
-
|
310
|
-
|
311
|
-
|
664
|
+
if self.use_dual_attention:
|
665
|
+
return self._compute_pre_attention_with_dual_attention(
|
666
|
+
inputs, timestep_embedding, training=training
|
667
|
+
)
|
668
|
+
else:
|
669
|
+
return self._compute_pre_attention(
|
670
|
+
inputs, timestep_embedding, training=training
|
671
|
+
)
|
312
672
|
else:
|
313
|
-
|
314
|
-
|
315
|
-
|
673
|
+
if self.use_dual_attention:
|
674
|
+
return self._compute_post_attention_with_dual_attention(
|
675
|
+
inputs, inputs2, inputs_intermediates, training=training
|
676
|
+
)
|
677
|
+
else:
|
678
|
+
return self._compute_post_attention(
|
679
|
+
inputs, inputs_intermediates, training=training
|
680
|
+
)
|
316
681
|
|
317
682
|
def get_config(self):
|
318
683
|
config = super().get_config()
|
@@ -322,18 +687,47 @@ class DismantledBlock(layers.Layer):
|
|
322
687
|
"hidden_dim": self.hidden_dim,
|
323
688
|
"mlp_ratio": self.mlp_ratio,
|
324
689
|
"use_projection": self.use_projection,
|
690
|
+
"qk_norm": self.qk_norm,
|
691
|
+
"use_dual_attention": self.use_dual_attention,
|
325
692
|
}
|
326
693
|
)
|
327
694
|
return config
|
328
695
|
|
329
696
|
|
330
697
|
class MMDiTBlock(layers.Layer):
|
698
|
+
"""A MMDiT block consisting of two `DismantledBlock` layers.
|
699
|
+
|
700
|
+
One `DismantledBlock` processes the input latents, and the other processes
|
701
|
+
the context embedding. This block integrates two modalities within the
|
702
|
+
attention operation, allowing each representation to operate in its own
|
703
|
+
space while considering the other.
|
704
|
+
|
705
|
+
Args:
|
706
|
+
num_heads: int. Number of attention heads.
|
707
|
+
hidden_dim: int. The number of units in the hidden layers.
|
708
|
+
mlp_ratio: float. The expansion ratio of `MLP`.
|
709
|
+
use_context_projection: bool. Whether to use an attention projection
|
710
|
+
layer at the end of the context block.
|
711
|
+
qk_norm: Optional str. Whether to normalize the query and key tensors.
|
712
|
+
Available options are `None` and `"rms_norm"`. Defaults to `None`.
|
713
|
+
use_dual_attention: bool. Whether to use a dual attention in the
|
714
|
+
block. Defaults to `False`.
|
715
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
716
|
+
including `name`, `dtype` etc.
|
717
|
+
|
718
|
+
Reference:
|
719
|
+
- [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](
|
720
|
+
https://arxiv.org/abs/2403.03206)
|
721
|
+
"""
|
722
|
+
|
331
723
|
def __init__(
|
332
724
|
self,
|
333
725
|
num_heads,
|
334
726
|
hidden_dim,
|
335
727
|
mlp_ratio=4.0,
|
336
728
|
use_context_projection=True,
|
729
|
+
qk_norm=None,
|
730
|
+
use_dual_attention=False,
|
337
731
|
**kwargs,
|
338
732
|
):
|
339
733
|
super().__init__(**kwargs)
|
@@ -341,18 +735,20 @@ class MMDiTBlock(layers.Layer):
|
|
341
735
|
self.hidden_dim = hidden_dim
|
342
736
|
self.mlp_ratio = mlp_ratio
|
343
737
|
self.use_context_projection = use_context_projection
|
738
|
+
self.qk_norm = qk_norm
|
739
|
+
self.use_dual_attention = use_dual_attention
|
344
740
|
|
345
741
|
head_dim = hidden_dim // num_heads
|
346
742
|
self.head_dim = head_dim
|
347
743
|
self._inverse_sqrt_key_dim = 1.0 / math.sqrt(head_dim)
|
348
|
-
self._dot_product_equation = "aecd,abcd->acbe"
|
349
|
-
self._combine_equation = "acbe,aecd->abcd"
|
350
744
|
|
351
745
|
self.x_block = DismantledBlock(
|
352
746
|
num_heads=num_heads,
|
353
747
|
hidden_dim=hidden_dim,
|
354
748
|
mlp_ratio=mlp_ratio,
|
355
749
|
use_projection=True,
|
750
|
+
qk_norm=qk_norm,
|
751
|
+
use_dual_attention=use_dual_attention,
|
356
752
|
dtype=self.dtype_policy,
|
357
753
|
name="x_block",
|
358
754
|
)
|
@@ -361,6 +757,7 @@ class MMDiTBlock(layers.Layer):
|
|
361
757
|
hidden_dim=hidden_dim,
|
362
758
|
mlp_ratio=mlp_ratio,
|
363
759
|
use_projection=use_context_projection,
|
760
|
+
qk_norm=qk_norm,
|
364
761
|
dtype=self.dtype_policy,
|
365
762
|
name="context_block",
|
366
763
|
)
|
@@ -371,20 +768,35 @@ class MMDiTBlock(layers.Layer):
|
|
371
768
|
self.context_block.build(context_shape, timestep_embedding_shape)
|
372
769
|
|
373
770
|
def _compute_attention(self, query, key, value):
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
771
|
+
batch_size = ops.shape(query)[0]
|
772
|
+
|
773
|
+
# Use the fast path when `ops.dot_product_attention` and flash attention
|
774
|
+
# are available.
|
775
|
+
if hasattr(ops, "dot_product_attention") and hasattr(
|
776
|
+
keras.config, "is_flash_attention_enabled"
|
777
|
+
):
|
778
|
+
encoded = ops.dot_product_attention(
|
779
|
+
query,
|
780
|
+
key,
|
781
|
+
value,
|
782
|
+
scale=self._inverse_sqrt_key_dim,
|
783
|
+
flash_attention=keras.config.is_flash_attention_enabled(),
|
784
|
+
)
|
785
|
+
return ops.reshape(
|
786
|
+
encoded, (batch_size, -1, self.num_heads * self.head_dim)
|
787
|
+
)
|
788
|
+
|
789
|
+
# Ref: jax.nn.dot_product_attention
|
790
|
+
# https://github.com/jax-ml/jax/blob/db89c245ac66911c98f265a05956fdfa4bc79d83/jax/_src/nn/functions.py#L846
|
791
|
+
logits = ops.einsum("BTNH,BSNH->BNTS", query, key)
|
792
|
+
logits = ops.multiply(logits, self._inverse_sqrt_key_dim)
|
793
|
+
probs = self.softmax(logits)
|
794
|
+
probs = ops.cast(probs, self.compute_dtype)
|
795
|
+
encoded = ops.einsum("BNTS,BSNH->BTNH", probs, value)
|
796
|
+
encoded = ops.reshape(
|
797
|
+
encoded, (batch_size, -1, self.num_heads * self.head_dim)
|
386
798
|
)
|
387
|
-
return
|
799
|
+
return encoded
|
388
800
|
|
389
801
|
def call(self, inputs, context, timestep_embedding, training=None):
|
390
802
|
# Compute pre-attention.
|
@@ -402,9 +814,14 @@ class MMDiTBlock(layers.Layer):
|
|
402
814
|
training=training,
|
403
815
|
)
|
404
816
|
context_len = ops.shape(context_qkv[0])[1]
|
405
|
-
|
406
|
-
|
407
|
-
|
817
|
+
if self.x_block.use_dual_attention:
|
818
|
+
x_qkv, x_qkv2, x_intermediates = self.x_block(
|
819
|
+
x, timestep_embedding=timestep_embedding, training=training
|
820
|
+
)
|
821
|
+
else:
|
822
|
+
x_qkv, x_intermediates = self.x_block(
|
823
|
+
x, timestep_embedding=timestep_embedding, training=training
|
824
|
+
)
|
408
825
|
q = ops.concatenate([context_qkv[0], x_qkv[0]], axis=1)
|
409
826
|
k = ops.concatenate([context_qkv[1], x_qkv[1]], axis=1)
|
410
827
|
v = ops.concatenate([context_qkv[2], x_qkv[2]], axis=1)
|
@@ -415,12 +832,23 @@ class MMDiTBlock(layers.Layer):
|
|
415
832
|
x_attention = attention[:, context_len:]
|
416
833
|
|
417
834
|
# Compute post-attention.
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
835
|
+
if self.x_block.use_dual_attention:
|
836
|
+
q2, k2, v2 = x_qkv2
|
837
|
+
x_attention2 = self._compute_attention(q2, k2, v2)
|
838
|
+
x = self.x_block(
|
839
|
+
x_attention,
|
840
|
+
inputs_intermediates=x_intermediates,
|
841
|
+
inputs2=x_attention2,
|
842
|
+
pre_attention=False,
|
843
|
+
training=training,
|
844
|
+
)
|
845
|
+
else:
|
846
|
+
x = self.x_block(
|
847
|
+
x_attention,
|
848
|
+
inputs_intermediates=x_intermediates,
|
849
|
+
pre_attention=False,
|
850
|
+
training=training,
|
851
|
+
)
|
424
852
|
if self.use_context_projection:
|
425
853
|
context = self.context_block(
|
426
854
|
context_attention,
|
@@ -440,6 +868,8 @@ class MMDiTBlock(layers.Layer):
|
|
440
868
|
"hidden_dim": self.hidden_dim,
|
441
869
|
"mlp_ratio": self.mlp_ratio,
|
442
870
|
"use_context_projection": self.use_context_projection,
|
871
|
+
"qk_norm": self.qk_norm,
|
872
|
+
"use_dual_attention": self.use_dual_attention,
|
443
873
|
}
|
444
874
|
)
|
445
875
|
return config
|
@@ -453,74 +883,16 @@ class MMDiTBlock(layers.Layer):
|
|
453
883
|
return inputs_shape
|
454
884
|
|
455
885
|
|
456
|
-
class
|
457
|
-
|
458
|
-
super().__init__(**kwargs)
|
459
|
-
self.hidden_dim = hidden_dim
|
460
|
-
self.output_dim = output_dim
|
461
|
-
num_modulation = 2
|
462
|
-
|
463
|
-
self.adaptive_norm_modulation = models.Sequential(
|
464
|
-
[
|
465
|
-
layers.Activation("silu", dtype=self.dtype_policy),
|
466
|
-
layers.Dense(
|
467
|
-
num_modulation * hidden_dim, dtype=self.dtype_policy
|
468
|
-
),
|
469
|
-
],
|
470
|
-
name="adaptive_norm_modulation",
|
471
|
-
)
|
472
|
-
self.norm = layers.LayerNormalization(
|
473
|
-
epsilon=1e-6,
|
474
|
-
center=False,
|
475
|
-
scale=False,
|
476
|
-
dtype="float32",
|
477
|
-
name="norm",
|
478
|
-
)
|
479
|
-
self.output_dense = layers.Dense(
|
480
|
-
output_dim,
|
481
|
-
use_bias=True,
|
482
|
-
dtype=self.dtype_policy,
|
483
|
-
name="output_dense",
|
484
|
-
)
|
485
|
-
|
486
|
-
def build(self, inputs_shape, timestep_embedding_shape):
|
487
|
-
self.adaptive_norm_modulation.build(timestep_embedding_shape)
|
488
|
-
self.norm.build(inputs_shape)
|
489
|
-
self.output_dense.build(inputs_shape)
|
490
|
-
|
491
|
-
def _modulate(self, inputs, shift, scale):
|
492
|
-
shift = ops.expand_dims(shift, axis=1)
|
493
|
-
scale = ops.expand_dims(scale, axis=1)
|
494
|
-
return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
|
495
|
-
|
496
|
-
def call(self, inputs, timestep_embedding, training=None):
|
497
|
-
x = inputs
|
498
|
-
modulation = self.adaptive_norm_modulation(
|
499
|
-
timestep_embedding, training=training
|
500
|
-
)
|
501
|
-
modulation = ops.reshape(modulation, (-1, 2, self.hidden_dim))
|
502
|
-
shift, scale = ops.unstack(modulation, 2, axis=1)
|
503
|
-
x = self._modulate(self.norm(x), shift, scale)
|
504
|
-
x = self.output_dense(x, training=training)
|
505
|
-
return x
|
506
|
-
|
507
|
-
def get_config(self):
|
508
|
-
config = super().get_config()
|
509
|
-
config.update(
|
510
|
-
{
|
511
|
-
"hidden_dim": self.hidden_dim,
|
512
|
-
"output_dim": self.output_dim,
|
513
|
-
}
|
514
|
-
)
|
515
|
-
return config
|
516
|
-
|
517
|
-
def compute_output_shape(self, inputs_shape):
|
518
|
-
outputs_shape = list(inputs_shape)
|
519
|
-
outputs_shape[-1] = self.output_dim
|
520
|
-
return outputs_shape
|
886
|
+
class Unpatch(layers.Layer):
|
887
|
+
"""A layer that reconstructs the image from hidden patches.
|
521
888
|
|
889
|
+
Args:
|
890
|
+
patch_size: int. The size of each square patch in the input image.
|
891
|
+
output_dim: int. The number of units in the output layer.
|
892
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
893
|
+
including `name`, `dtype` etc.
|
894
|
+
"""
|
522
895
|
|
523
|
-
class Unpatch(layers.Layer):
|
524
896
|
def __init__(self, patch_size, output_dim, **kwargs):
|
525
897
|
super().__init__(**kwargs)
|
526
898
|
self.patch_size = int(patch_size)
|
@@ -556,7 +928,7 @@ class Unpatch(layers.Layer):
|
|
556
928
|
|
557
929
|
|
558
930
|
class MMDiT(Backbone):
|
559
|
-
"""Multimodal Diffusion Transformer (MMDiT) model
|
931
|
+
"""A Multimodal Diffusion Transformer (MMDiT) model.
|
560
932
|
|
561
933
|
MMDiT is introduced in [
|
562
934
|
Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](
|
@@ -574,6 +946,12 @@ class MMDiT(Backbone):
|
|
574
946
|
latent_shape: tuple. The shape of the latent image.
|
575
947
|
context_shape: tuple. The shape of the context.
|
576
948
|
pooled_projection_shape: tuple. The shape of the pooled projection.
|
949
|
+
qk_norm: Optional str. Whether to normalize the query and key tensors in
|
950
|
+
the intermediate blocks. Available options are `None` and
|
951
|
+
`"rms_norm"`. Defaults to `None`.
|
952
|
+
dual_attention_indices: Optional tuple. Specifies the indices of
|
953
|
+
the blocks that serve as dual attention blocks. Typically, this is
|
954
|
+
for 3.5 version. Defaults to `None`.
|
577
955
|
data_format: `None` or str. If specified, either `"channels_last"` or
|
578
956
|
`"channels_first"`. The ordering of the dimensions in the
|
579
957
|
inputs. `"channels_last"` corresponds to inputs with shape
|
@@ -598,6 +976,8 @@ class MMDiT(Backbone):
|
|
598
976
|
latent_shape=(64, 64, 16),
|
599
977
|
context_shape=(None, 4096),
|
600
978
|
pooled_projection_shape=(2048,),
|
979
|
+
qk_norm=None,
|
980
|
+
dual_attention_indices=None,
|
601
981
|
data_format=None,
|
602
982
|
dtype=None,
|
603
983
|
**kwargs,
|
@@ -611,6 +991,7 @@ class MMDiT(Backbone):
|
|
611
991
|
image_width = latent_shape[1] // patch_size
|
612
992
|
output_dim = latent_shape[-1]
|
613
993
|
output_dim_in_final = patch_size**2 * output_dim
|
994
|
+
dual_attention_indices = dual_attention_indices or ()
|
614
995
|
data_format = standardize_data_format(data_format)
|
615
996
|
if data_format != "channels_last":
|
616
997
|
raise NotImplementedError(
|
@@ -636,12 +1017,8 @@ class MMDiT(Backbone):
|
|
636
1017
|
dtype=dtype,
|
637
1018
|
name="context_embedding",
|
638
1019
|
)
|
639
|
-
self.vector_embedding =
|
640
|
-
|
641
|
-
layers.Dense(hidden_dim, activation="silu", dtype=dtype),
|
642
|
-
layers.Dense(hidden_dim, activation=None, dtype=dtype),
|
643
|
-
],
|
644
|
-
name="vector_embedding",
|
1020
|
+
self.vector_embedding = MLP(
|
1021
|
+
hidden_dim, hidden_dim, "silu", dtype=dtype, name="vector_embedding"
|
645
1022
|
)
|
646
1023
|
self.vector_embedding_add = layers.Add(
|
647
1024
|
dtype=dtype, name="vector_embedding_add"
|
@@ -655,13 +1032,18 @@ class MMDiT(Backbone):
|
|
655
1032
|
hidden_dim,
|
656
1033
|
mlp_ratio,
|
657
1034
|
use_context_projection=not (i == num_layers - 1),
|
1035
|
+
qk_norm=qk_norm,
|
1036
|
+
use_dual_attention=i in dual_attention_indices,
|
658
1037
|
dtype=dtype,
|
659
1038
|
name=f"joint_block_{i}",
|
660
1039
|
)
|
661
1040
|
for i in range(num_layers)
|
662
1041
|
]
|
663
|
-
self.
|
664
|
-
hidden_dim,
|
1042
|
+
self.output_ada_layer_norm = AdaptiveLayerNormalization(
|
1043
|
+
hidden_dim, dtype=dtype, name="output_ada_layer_norm"
|
1044
|
+
)
|
1045
|
+
self.output_dense = layers.Dense(
|
1046
|
+
output_dim_in_final, dtype=dtype, name="output_dense"
|
665
1047
|
)
|
666
1048
|
self.unpatch = Unpatch(
|
667
1049
|
patch_size, output_dim, dtype=dtype, name="unpatch"
|
@@ -696,7 +1078,8 @@ class MMDiT(Backbone):
|
|
696
1078
|
x = block(x, context, timestep_embedding)
|
697
1079
|
|
698
1080
|
# Output layer.
|
699
|
-
x = self.
|
1081
|
+
x = self.output_ada_layer_norm(x, timestep_embedding)
|
1082
|
+
x = self.output_dense(x)
|
700
1083
|
outputs = self.unpatch(x, height=image_height, width=image_width)
|
701
1084
|
|
702
1085
|
super().__init__(
|
@@ -720,6 +1103,8 @@ class MMDiT(Backbone):
|
|
720
1103
|
self.latent_shape = latent_shape
|
721
1104
|
self.context_shape = context_shape
|
722
1105
|
self.pooled_projection_shape = pooled_projection_shape
|
1106
|
+
self.qk_norm = qk_norm
|
1107
|
+
self.dual_attention_indices = dual_attention_indices
|
723
1108
|
|
724
1109
|
def get_config(self):
|
725
1110
|
config = super().get_config()
|
@@ -734,6 +1119,8 @@ class MMDiT(Backbone):
|
|
734
1119
|
"latent_shape": self.latent_shape,
|
735
1120
|
"context_shape": self.context_shape,
|
736
1121
|
"pooled_projection_shape": self.pooled_projection_shape,
|
1122
|
+
"qk_norm": self.qk_norm,
|
1123
|
+
"dual_attention_indices": self.dual_attention_indices,
|
737
1124
|
}
|
738
1125
|
)
|
739
1126
|
return config
|