keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,6 @@ import math
2
2
 
3
3
  import keras
4
4
  from keras import layers
5
- from keras import models
6
5
  from keras import ops
7
6
 
8
7
  from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
@@ -11,7 +10,216 @@ from keras_hub.src.utils.keras_utils import gelu_approximate
11
10
  from keras_hub.src.utils.keras_utils import standardize_data_format
12
11
 
13
12
 
13
+ class AdaptiveLayerNormalization(layers.Layer):
14
+ """Adaptive layer normalization.
15
+
16
+ Args:
17
+ embedding_dim: int. The size of each embedding vector.
18
+ num_modulations: int. The number of the modulation parameters. The
19
+ available values are `2`, `6` and `9`. Defaults to `2`.
20
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
21
+ including `name`, `dtype` etc.
22
+
23
+ References:
24
+ - [FiLM: Visual Reasoning with a General Conditioning Layer](
25
+ https://arxiv.org/abs/1709.07871).
26
+ - [Scalable Diffusion Models with Transformers](
27
+ https://arxiv.org/abs/2212.09748).
28
+ """
29
+
30
+ def __init__(self, hidden_dim, num_modulations=2, **kwargs):
31
+ super().__init__(**kwargs)
32
+ hidden_dim = int(hidden_dim)
33
+ num_modulations = int(num_modulations)
34
+ if num_modulations not in (2, 6, 9):
35
+ raise ValueError(
36
+ "`num_modulations` must be `2`, `6` or `9`. "
37
+ f"Received: num_modulations={num_modulations}"
38
+ )
39
+ self.hidden_dim = hidden_dim
40
+ self.num_modulations = num_modulations
41
+
42
+ self.silu = layers.Activation("silu", dtype=self.dtype_policy)
43
+ self.dense = layers.Dense(
44
+ num_modulations * hidden_dim, dtype=self.dtype_policy, name="dense"
45
+ )
46
+ self.norm = layers.LayerNormalization(
47
+ epsilon=1e-6,
48
+ center=False,
49
+ scale=False,
50
+ dtype="float32",
51
+ name="norm",
52
+ )
53
+
54
+ def build(self, inputs_shape, embeddings_shape):
55
+ self.silu.build(embeddings_shape)
56
+ self.dense.build(embeddings_shape)
57
+ self.norm.build(inputs_shape)
58
+
59
+ def call(self, inputs, embeddings, training=None):
60
+ hidden_states = inputs
61
+ emb = self.dense(self.silu(embeddings), training=training)
62
+ if self.num_modulations == 9:
63
+ (
64
+ shift_msa,
65
+ scale_msa,
66
+ gate_msa,
67
+ shift_mlp,
68
+ scale_mlp,
69
+ gate_mlp,
70
+ shift_msa2,
71
+ scale_msa2,
72
+ gate_msa2,
73
+ ) = ops.split(emb, self.num_modulations, axis=1)
74
+ elif self.num_modulations == 6:
75
+ (
76
+ shift_msa,
77
+ scale_msa,
78
+ gate_msa,
79
+ shift_mlp,
80
+ scale_mlp,
81
+ gate_mlp,
82
+ ) = ops.split(emb, self.num_modulations, axis=1)
83
+ else:
84
+ shift_msa, scale_msa = ops.split(emb, self.num_modulations, axis=1)
85
+
86
+ scale_msa = ops.expand_dims(scale_msa, axis=1)
87
+ shift_msa = ops.expand_dims(shift_msa, axis=1)
88
+ norm_hidden_states = ops.cast(
89
+ self.norm(hidden_states, training=training), scale_msa.dtype
90
+ )
91
+ hidden_states = ops.add(
92
+ ops.multiply(norm_hidden_states, ops.add(1.0, scale_msa)), shift_msa
93
+ )
94
+
95
+ if self.num_modulations == 9:
96
+ scale_msa2 = ops.expand_dims(scale_msa2, axis=1)
97
+ shift_msa2 = ops.expand_dims(shift_msa2, axis=1)
98
+ hidden_states2 = ops.add(
99
+ ops.multiply(norm_hidden_states, ops.add(1.0, scale_msa2)),
100
+ shift_msa2,
101
+ )
102
+ return (
103
+ hidden_states,
104
+ gate_msa,
105
+ shift_mlp,
106
+ scale_mlp,
107
+ gate_mlp,
108
+ hidden_states2,
109
+ gate_msa2,
110
+ )
111
+ elif self.num_modulations == 6:
112
+ return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp
113
+ else:
114
+ return hidden_states
115
+
116
+ def get_config(self):
117
+ config = super().get_config()
118
+ config.update(
119
+ {
120
+ "hidden_dim": self.hidden_dim,
121
+ "num_modulations": self.num_modulations,
122
+ }
123
+ )
124
+ return config
125
+
126
+ def compute_output_shape(self, inputs_shape, embeddings_shape):
127
+ if self.num_modulations == 9:
128
+ return (
129
+ inputs_shape,
130
+ embeddings_shape,
131
+ embeddings_shape,
132
+ embeddings_shape,
133
+ embeddings_shape,
134
+ inputs_shape,
135
+ embeddings_shape,
136
+ )
137
+ elif self.num_modulations == 6:
138
+ return (
139
+ inputs_shape,
140
+ embeddings_shape,
141
+ embeddings_shape,
142
+ embeddings_shape,
143
+ embeddings_shape,
144
+ )
145
+ else:
146
+ return inputs_shape
147
+
148
+
149
+ class MLP(layers.Layer):
150
+ """A MLP block with architecture.
151
+
152
+ Args:
153
+ hidden_dim: int. The number of units in the hidden layers.
154
+ output_dim: int. The number of units in the output layer.
155
+ activation: str of callable. Activation to use in the hidden layers.
156
+ Default to `None`.
157
+ """
158
+
159
+ def __init__(self, hidden_dim, output_dim, activation=None, **kwargs):
160
+ super().__init__(**kwargs)
161
+ self.hidden_dim = int(hidden_dim)
162
+ self.output_dim = int(output_dim)
163
+ self.activation = keras.activations.get(activation)
164
+
165
+ self.dense1 = layers.Dense(
166
+ hidden_dim,
167
+ activation=self.activation,
168
+ dtype=self.dtype_policy,
169
+ name="dense1",
170
+ )
171
+ self.dense2 = layers.Dense(
172
+ output_dim,
173
+ activation=None,
174
+ dtype=self.dtype_policy,
175
+ name="dense2",
176
+ )
177
+
178
+ def build(self, inputs_shape):
179
+ self.dense1.build(inputs_shape)
180
+ inputs_shape = self.dense1.compute_output_shape(inputs_shape)
181
+ self.dense2.build(inputs_shape)
182
+
183
+ def call(self, inputs, training=None):
184
+ x = self.dense1(inputs, training=training)
185
+ return self.dense2(x, training=training)
186
+
187
+ def get_config(self):
188
+ config = super().get_config()
189
+ config.update(
190
+ {
191
+ "hidden_dim": self.hidden_dim,
192
+ "output_dim": self.output_dim,
193
+ "activation": keras.activations.serialize(self.activation),
194
+ }
195
+ )
196
+ return config
197
+
198
+ def compute_output_shape(self, inputs_shape):
199
+ outputs_shape = list(inputs_shape)
200
+ outputs_shape[-1] = self.output_dim
201
+ return outputs_shape
202
+
203
+
14
204
  class PatchEmbedding(layers.Layer):
205
+ """A layer that converts images into patches.
206
+
207
+ Args:
208
+ patch_size: int. The size of one side of each patch.
209
+ hidden_dim: int. The number of units in the hidden layers.
210
+ data_format: `None` or str. If specified, either `"channels_last"` or
211
+ `"channels_first"`. The ordering of the dimensions in the
212
+ inputs. `"channels_last"` corresponds to inputs with shape
213
+ `(batch_size, height, width, channels)`
214
+ while `"channels_first"` corresponds to inputs with shape
215
+ `(batch_size, channels, height, width)`. It defaults to the
216
+ `image_data_format` value found in your Keras config file at
217
+ `~/.keras/keras.json`. If you never set it, then it will be
218
+ `"channels_last"`.
219
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
220
+ including `name`, `dtype` etc.
221
+ """
222
+
15
223
  def __init__(self, patch_size, hidden_dim, data_format=None, **kwargs):
16
224
  super().__init__(**kwargs)
17
225
  self.patch_size = int(patch_size)
@@ -48,6 +256,15 @@ class PatchEmbedding(layers.Layer):
48
256
 
49
257
 
50
258
  class AdjustablePositionEmbedding(PositionEmbedding):
259
+ """A position embedding layer with adjustable height and width.
260
+
261
+ The embedding will be cropped to match the input dimensions.
262
+
263
+ Args:
264
+ height: int. The maximum height of the embedding.
265
+ width: int. The maximum width of the embedding.
266
+ """
267
+
51
268
  def __init__(
52
269
  self,
53
270
  height,
@@ -84,11 +301,36 @@ class AdjustablePositionEmbedding(PositionEmbedding):
84
301
  position_embedding = ops.expand_dims(position_embedding, axis=0)
85
302
  return position_embedding
86
303
 
304
+ def get_config(self):
305
+ config = super().get_config()
306
+ del config["sequence_length"]
307
+ config.update(
308
+ {
309
+ "height": self.height,
310
+ "width": self.width,
311
+ }
312
+ )
313
+ return config
314
+
87
315
  def compute_output_shape(self, input_shape):
88
316
  return input_shape
89
317
 
90
318
 
91
319
  class TimestepEmbedding(layers.Layer):
320
+ """A layer which learns embedding for input timesteps.
321
+
322
+ Args:
323
+ embedding_dim: int. The size of the embedding.
324
+ frequency_dim: int. The size of the frequency.
325
+ max_period: int. Controls the maximum frequency of the embeddings.
326
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
327
+ including `name`, `dtype` etc.
328
+
329
+ Reference:
330
+ - [Denoising Diffusion Probabilistic Models](
331
+ https://arxiv.org/abs/2006.11239).
332
+ """
333
+
92
334
  def __init__(
93
335
  self, embedding_dim, frequency_dim=256, max_period=10000, **kwargs
94
336
  ):
@@ -96,17 +338,23 @@ class TimestepEmbedding(layers.Layer):
96
338
  self.embedding_dim = int(embedding_dim)
97
339
  self.frequency_dim = int(frequency_dim)
98
340
  self.max_period = float(max_period)
99
- self.half_frequency_dim = self.frequency_dim // 2
100
-
101
- self.mlp = models.Sequential(
102
- [
103
- layers.Dense(
104
- embedding_dim, activation="silu", dtype=self.dtype_policy
105
- ),
106
- layers.Dense(
107
- embedding_dim, activation=None, dtype=self.dtype_policy
341
+ # Precomputed `freq`.
342
+ half_frequency_dim = frequency_dim // 2
343
+ self.freq = ops.exp(
344
+ ops.divide(
345
+ ops.multiply(
346
+ -math.log(max_period),
347
+ ops.arange(0, half_frequency_dim, dtype="float32"),
108
348
  ),
109
- ],
349
+ half_frequency_dim,
350
+ )
351
+ )
352
+
353
+ self.mlp = MLP(
354
+ embedding_dim,
355
+ embedding_dim,
356
+ "silu",
357
+ dtype=self.dtype_policy,
110
358
  name="mlp",
111
359
  )
112
360
 
@@ -118,16 +366,7 @@ class TimestepEmbedding(layers.Layer):
118
366
  def _create_timestep_embedding(self, inputs):
119
367
  compute_dtype = keras.backend.result_type(self.compute_dtype, "float32")
120
368
  x = ops.cast(inputs, compute_dtype)
121
- freqs = ops.exp(
122
- ops.divide(
123
- ops.multiply(
124
- -math.log(self.max_period),
125
- ops.arange(0, self.half_frequency_dim, dtype="float32"),
126
- ),
127
- self.half_frequency_dim,
128
- )
129
- )
130
- freqs = ops.cast(freqs, compute_dtype)
369
+ freqs = ops.cast(self.freq, compute_dtype)
131
370
  x = ops.multiply(x, ops.expand_dims(freqs, axis=0))
132
371
  embedding = ops.concatenate([ops.cos(x), ops.sin(x)], axis=-1)
133
372
  if self.frequency_dim % 2 != 0:
@@ -143,6 +382,7 @@ class TimestepEmbedding(layers.Layer):
143
382
  config.update(
144
383
  {
145
384
  "embedding_dim": self.embedding_dim,
385
+ "frequency_dim": self.frequency_dim,
146
386
  "max_period": self.max_period,
147
387
  }
148
388
  )
@@ -154,13 +394,52 @@ class TimestepEmbedding(layers.Layer):
154
394
  return output_shape
155
395
 
156
396
 
397
+ def get_qk_norm(qk_norm=None, q_norm_name="q_norm", k_norm_name="k_norm"):
398
+ """Helper function to instantiate `LayerNormalization` layers."""
399
+ q_norm = None
400
+ k_norm = None
401
+ if qk_norm is None:
402
+ pass
403
+ elif qk_norm == "rms_norm":
404
+ q_norm = layers.LayerNormalization(
405
+ epsilon=1e-6, rms_scaling=True, dtype="float32", name=q_norm_name
406
+ )
407
+ k_norm = layers.LayerNormalization(
408
+ epsilon=1e-6, rms_scaling=True, dtype="float32", name=k_norm_name
409
+ )
410
+ else:
411
+ raise NotImplementedError(
412
+ "Supported `qk_norm` are `'rms_norm'` and `None`. "
413
+ f"Received: qk_norm={qk_norm}."
414
+ )
415
+ return q_norm, k_norm
416
+
417
+
157
418
  class DismantledBlock(layers.Layer):
419
+ """A dismantled block used to compute pre- and post-attention.
420
+
421
+ Args:
422
+ num_heads: int. Number of attention heads.
423
+ hidden_dim: int. The number of units in the hidden layers.
424
+ mlp_ratio: float. The expansion ratio of `MLP`.
425
+ use_projection: bool. Whether to use an attention projection layer at
426
+ the end of the block.
427
+ qk_norm: Optional str. Whether to normalize the query and key tensors.
428
+ Available options are `None` and `"rms_norm"`. Defaults to `None`.
429
+ use_dual_attention: bool. Whether to use a dual attention in the
430
+ block. Defaults to `False`.
431
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
432
+ including `name`, `dtype` etc.
433
+ """
434
+
158
435
  def __init__(
159
436
  self,
160
437
  num_heads,
161
438
  hidden_dim,
162
439
  mlp_ratio=4.0,
163
440
  use_projection=True,
441
+ qk_norm=None,
442
+ use_dual_attention=False,
164
443
  **kwargs,
165
444
  ):
166
445
  super().__init__(**kwargs)
@@ -168,33 +447,32 @@ class DismantledBlock(layers.Layer):
168
447
  self.hidden_dim = hidden_dim
169
448
  self.mlp_ratio = mlp_ratio
170
449
  self.use_projection = use_projection
450
+ self.qk_norm = qk_norm
451
+ self.use_dual_attention = use_dual_attention
171
452
 
172
453
  head_dim = hidden_dim // num_heads
173
454
  self.head_dim = head_dim
174
455
  mlp_hidden_dim = int(hidden_dim * mlp_ratio)
175
456
  self.mlp_hidden_dim = mlp_hidden_dim
176
- num_modulations = 6 if use_projection else 2
177
- self.num_modulations = num_modulations
178
457
 
179
- self.adaptive_norm_modulation = models.Sequential(
180
- [
181
- layers.Activation("silu", dtype=self.dtype_policy),
182
- layers.Dense(
183
- num_modulations * hidden_dim, dtype=self.dtype_policy
184
- ),
185
- ],
186
- name="adaptive_norm_modulation",
187
- )
188
- self.norm1 = layers.LayerNormalization(
189
- epsilon=1e-6,
190
- center=False,
191
- scale=False,
192
- dtype="float32",
193
- name="norm1",
194
- )
458
+ if use_projection:
459
+ self.ada_layer_norm = AdaptiveLayerNormalization(
460
+ hidden_dim,
461
+ num_modulations=9 if use_dual_attention else 6,
462
+ dtype=self.dtype_policy,
463
+ name="ada_layer_norm",
464
+ )
465
+ else:
466
+ self.ada_layer_norm = AdaptiveLayerNormalization(
467
+ hidden_dim, dtype=self.dtype_policy, name="ada_layer_norm"
468
+ )
195
469
  self.attention_qkv = layers.Dense(
196
470
  hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv"
197
471
  )
472
+ q_norm, k_norm = get_qk_norm(qk_norm)
473
+ if q_norm is not None:
474
+ self.q_norm = q_norm
475
+ self.k_norm = k_norm
198
476
  if use_projection:
199
477
  self.attention_proj = layers.Dense(
200
478
  hidden_dim, dtype=self.dtype_policy, name="attention_proj"
@@ -206,89 +484,165 @@ class DismantledBlock(layers.Layer):
206
484
  dtype="float32",
207
485
  name="norm2",
208
486
  )
209
- self.mlp = models.Sequential(
210
- [
211
- layers.Dense(
212
- mlp_hidden_dim,
213
- activation=gelu_approximate,
214
- dtype=self.dtype_policy,
215
- ),
216
- layers.Dense(
217
- hidden_dim,
218
- dtype=self.dtype_policy,
219
- ),
220
- ],
487
+ self.mlp = MLP(
488
+ mlp_hidden_dim,
489
+ hidden_dim,
490
+ gelu_approximate,
491
+ dtype=self.dtype_policy,
221
492
  name="mlp",
222
493
  )
223
494
 
495
+ if use_dual_attention:
496
+ self.attention_qkv2 = layers.Dense(
497
+ hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv2"
498
+ )
499
+ q_norm2, k_norm2 = get_qk_norm(qk_norm, "q_norm2", "k_norm2")
500
+ if q_norm is not None:
501
+ self.q_norm2 = q_norm2
502
+ self.k_norm2 = k_norm2
503
+ if use_projection:
504
+ self.attention_proj2 = layers.Dense(
505
+ hidden_dim, dtype=self.dtype_policy, name="attention_proj2"
506
+ )
507
+
224
508
  def build(self, inputs_shape, timestep_embedding):
225
- self.adaptive_norm_modulation.build(timestep_embedding)
509
+ self.ada_layer_norm.build(inputs_shape, timestep_embedding)
226
510
  self.attention_qkv.build(inputs_shape)
227
- self.norm1.build(inputs_shape)
511
+ if self.qk_norm is not None:
512
+ # [batch_size, sequence_length, num_heads, head_dim]
513
+ self.q_norm.build([None, None, self.num_heads, self.head_dim])
514
+ self.k_norm.build([None, None, self.num_heads, self.head_dim])
228
515
  if self.use_projection:
229
516
  self.attention_proj.build(inputs_shape)
230
517
  self.norm2.build(inputs_shape)
231
518
  self.mlp.build(inputs_shape)
519
+ if self.use_dual_attention:
520
+ self.attention_qkv2.build(inputs_shape)
521
+ if self.qk_norm is not None:
522
+ self.q_norm2.build([None, None, self.num_heads, self.head_dim])
523
+ self.k_norm2.build([None, None, self.num_heads, self.head_dim])
524
+ if self.use_projection:
525
+ self.attention_proj2.build(inputs_shape)
232
526
 
233
527
  def _modulate(self, inputs, shift, scale):
234
- shift = ops.expand_dims(shift, axis=1)
235
- scale = ops.expand_dims(scale, axis=1)
528
+ inputs = ops.cast(inputs, self.compute_dtype)
529
+ shift = ops.cast(shift, self.compute_dtype)
530
+ scale = ops.cast(scale, self.compute_dtype)
236
531
  return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
237
532
 
238
533
  def _compute_pre_attention(self, inputs, timestep_embedding, training=None):
239
534
  batch_size = ops.shape(inputs)[0]
240
535
  if self.use_projection:
241
- modulation = self.adaptive_norm_modulation(
242
- timestep_embedding, training=training
243
- )
244
- modulation = ops.reshape(
245
- modulation, (batch_size, 6, self.hidden_dim)
246
- )
247
- (
248
- shift_msa,
249
- scale_msa,
250
- gate_msa,
251
- shift_mlp,
252
- scale_mlp,
253
- gate_mlp,
254
- ) = ops.unstack(modulation, 6, axis=1)
255
- qkv = self.attention_qkv(
256
- self._modulate(self.norm1(inputs), shift_msa, scale_msa),
257
- training=training,
536
+ x, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.ada_layer_norm(
537
+ inputs, timestep_embedding, training=training
258
538
  )
539
+ qkv = self.attention_qkv(x, training=training)
259
540
  qkv = ops.reshape(
260
541
  qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
261
542
  )
262
543
  q, k, v = ops.unstack(qkv, 3, axis=2)
544
+ if self.qk_norm is not None:
545
+ q = ops.cast(
546
+ self.q_norm(q, training=training), self.compute_dtype
547
+ )
548
+ k = ops.cast(
549
+ self.k_norm(k, training=training), self.compute_dtype
550
+ )
263
551
  return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp)
264
552
  else:
265
- modulation = self.adaptive_norm_modulation(
266
- timestep_embedding, training=training
267
- )
268
- modulation = ops.reshape(
269
- modulation, (batch_size, 2, self.hidden_dim)
270
- )
271
- shift_msa, scale_msa = ops.unstack(modulation, 2, axis=1)
272
- qkv = self.attention_qkv(
273
- self._modulate(self.norm1(inputs), shift_msa, scale_msa),
274
- training=training,
553
+ x = self.ada_layer_norm(
554
+ inputs, timestep_embedding, training=training
275
555
  )
556
+ qkv = self.attention_qkv(x, training=training)
276
557
  qkv = ops.reshape(
277
558
  qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
278
559
  )
279
560
  q, k, v = ops.unstack(qkv, 3, axis=2)
561
+ if self.qk_norm is not None:
562
+ q = ops.cast(
563
+ self.q_norm(q, training=training), self.compute_dtype
564
+ )
565
+ k = ops.cast(
566
+ self.k_norm(k, training=training), self.compute_dtype
567
+ )
280
568
  return (q, k, v)
281
569
 
282
570
  def _compute_post_attention(
283
571
  self, inputs, inputs_intermediates, training=None
284
572
  ):
285
573
  x, gate_msa, shift_mlp, scale_mlp, gate_mlp = inputs_intermediates
574
+ gate_msa = ops.expand_dims(gate_msa, axis=1)
575
+ shift_mlp = ops.expand_dims(shift_mlp, axis=1)
576
+ scale_mlp = ops.expand_dims(scale_mlp, axis=1)
577
+ gate_mlp = ops.expand_dims(gate_mlp, axis=1)
286
578
  attn = self.attention_proj(inputs, training=training)
287
- x = ops.add(x, ops.multiply(ops.expand_dims(gate_msa, axis=1), attn))
579
+ x = ops.add(x, ops.multiply(gate_msa, attn))
288
580
  x = ops.add(
289
581
  x,
290
582
  ops.multiply(
291
- ops.expand_dims(gate_mlp, axis=1),
583
+ gate_mlp,
584
+ self.mlp(
585
+ self._modulate(self.norm2(x), shift_mlp, scale_mlp),
586
+ training=training,
587
+ ),
588
+ ),
589
+ )
590
+ return x
591
+
592
+ def _compute_pre_attention_with_dual_attention(
593
+ self, inputs, timestep_embedding, training=None
594
+ ):
595
+ batch_size = ops.shape(inputs)[0]
596
+ x, gate_msa, shift_mlp, scale_mlp, gate_mlp, x2, gate_msa2 = (
597
+ self.ada_layer_norm(inputs, timestep_embedding, training=training)
598
+ )
599
+ # Compute the main attention
600
+ qkv = self.attention_qkv(x, training=training)
601
+ qkv = ops.reshape(
602
+ qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
603
+ )
604
+ q, k, v = ops.unstack(qkv, 3, axis=2)
605
+ if self.qk_norm is not None:
606
+ q = ops.cast(self.q_norm(q, training=training), self.compute_dtype)
607
+ k = ops.cast(self.k_norm(k, training=training), self.compute_dtype)
608
+ # Compute the dual attention
609
+ qkv2 = self.attention_qkv2(x2, training=training)
610
+ qkv2 = ops.reshape(
611
+ qkv2, (batch_size, -1, 3, self.num_heads, self.head_dim)
612
+ )
613
+ q2, k2, v2 = ops.unstack(qkv2, 3, axis=2)
614
+ if self.qk_norm is not None:
615
+ q2 = ops.cast(
616
+ self.q_norm2(q2, training=training), self.compute_dtype
617
+ )
618
+ k2 = ops.cast(
619
+ self.k_norm2(k2, training=training), self.compute_dtype
620
+ )
621
+ return (
622
+ (q, k, v),
623
+ (q2, k2, v2),
624
+ (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2),
625
+ )
626
+
627
+ def _compute_post_attention_with_dual_attention(
628
+ self, inputs, inputs2, inputs_intermediates, training=None
629
+ ):
630
+ x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2 = (
631
+ inputs_intermediates
632
+ )
633
+ gate_msa = ops.expand_dims(gate_msa, axis=1)
634
+ shift_mlp = ops.expand_dims(shift_mlp, axis=1)
635
+ scale_mlp = ops.expand_dims(scale_mlp, axis=1)
636
+ gate_mlp = ops.expand_dims(gate_mlp, axis=1)
637
+ gate_msa2 = ops.expand_dims(gate_msa2, axis=1)
638
+ attn = self.attention_proj(inputs, training=training)
639
+ x = ops.add(x, ops.multiply(gate_msa, attn))
640
+ attn2 = self.attention_proj2(inputs2, training=training)
641
+ x = ops.add(x, ops.multiply(gate_msa2, attn2))
642
+ x = ops.add(
643
+ x,
644
+ ops.multiply(
645
+ gate_mlp,
292
646
  self.mlp(
293
647
  self._modulate(self.norm2(x), shift_mlp, scale_mlp),
294
648
  training=training,
@@ -302,17 +656,28 @@ class DismantledBlock(layers.Layer):
302
656
  inputs,
303
657
  timestep_embedding=None,
304
658
  inputs_intermediates=None,
659
+ inputs2=None, # For the dual attention.
305
660
  pre_attention=True,
306
661
  training=None,
307
662
  ):
308
663
  if pre_attention:
309
- return self._compute_pre_attention(
310
- inputs, timestep_embedding, training=training
311
- )
664
+ if self.use_dual_attention:
665
+ return self._compute_pre_attention_with_dual_attention(
666
+ inputs, timestep_embedding, training=training
667
+ )
668
+ else:
669
+ return self._compute_pre_attention(
670
+ inputs, timestep_embedding, training=training
671
+ )
312
672
  else:
313
- return self._compute_post_attention(
314
- inputs, inputs_intermediates, training=training
315
- )
673
+ if self.use_dual_attention:
674
+ return self._compute_post_attention_with_dual_attention(
675
+ inputs, inputs2, inputs_intermediates, training=training
676
+ )
677
+ else:
678
+ return self._compute_post_attention(
679
+ inputs, inputs_intermediates, training=training
680
+ )
316
681
 
317
682
  def get_config(self):
318
683
  config = super().get_config()
@@ -322,18 +687,47 @@ class DismantledBlock(layers.Layer):
322
687
  "hidden_dim": self.hidden_dim,
323
688
  "mlp_ratio": self.mlp_ratio,
324
689
  "use_projection": self.use_projection,
690
+ "qk_norm": self.qk_norm,
691
+ "use_dual_attention": self.use_dual_attention,
325
692
  }
326
693
  )
327
694
  return config
328
695
 
329
696
 
330
697
  class MMDiTBlock(layers.Layer):
698
+ """A MMDiT block consisting of two `DismantledBlock` layers.
699
+
700
+ One `DismantledBlock` processes the input latents, and the other processes
701
+ the context embedding. This block integrates two modalities within the
702
+ attention operation, allowing each representation to operate in its own
703
+ space while considering the other.
704
+
705
+ Args:
706
+ num_heads: int. Number of attention heads.
707
+ hidden_dim: int. The number of units in the hidden layers.
708
+ mlp_ratio: float. The expansion ratio of `MLP`.
709
+ use_context_projection: bool. Whether to use an attention projection
710
+ layer at the end of the context block.
711
+ qk_norm: Optional str. Whether to normalize the query and key tensors.
712
+ Available options are `None` and `"rms_norm"`. Defaults to `None`.
713
+ use_dual_attention: bool. Whether to use a dual attention in the
714
+ block. Defaults to `False`.
715
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
716
+ including `name`, `dtype` etc.
717
+
718
+ Reference:
719
+ - [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](
720
+ https://arxiv.org/abs/2403.03206)
721
+ """
722
+
331
723
  def __init__(
332
724
  self,
333
725
  num_heads,
334
726
  hidden_dim,
335
727
  mlp_ratio=4.0,
336
728
  use_context_projection=True,
729
+ qk_norm=None,
730
+ use_dual_attention=False,
337
731
  **kwargs,
338
732
  ):
339
733
  super().__init__(**kwargs)
@@ -341,18 +735,20 @@ class MMDiTBlock(layers.Layer):
341
735
  self.hidden_dim = hidden_dim
342
736
  self.mlp_ratio = mlp_ratio
343
737
  self.use_context_projection = use_context_projection
738
+ self.qk_norm = qk_norm
739
+ self.use_dual_attention = use_dual_attention
344
740
 
345
741
  head_dim = hidden_dim // num_heads
346
742
  self.head_dim = head_dim
347
743
  self._inverse_sqrt_key_dim = 1.0 / math.sqrt(head_dim)
348
- self._dot_product_equation = "aecd,abcd->acbe"
349
- self._combine_equation = "acbe,aecd->abcd"
350
744
 
351
745
  self.x_block = DismantledBlock(
352
746
  num_heads=num_heads,
353
747
  hidden_dim=hidden_dim,
354
748
  mlp_ratio=mlp_ratio,
355
749
  use_projection=True,
750
+ qk_norm=qk_norm,
751
+ use_dual_attention=use_dual_attention,
356
752
  dtype=self.dtype_policy,
357
753
  name="x_block",
358
754
  )
@@ -361,6 +757,7 @@ class MMDiTBlock(layers.Layer):
361
757
  hidden_dim=hidden_dim,
362
758
  mlp_ratio=mlp_ratio,
363
759
  use_projection=use_context_projection,
760
+ qk_norm=qk_norm,
364
761
  dtype=self.dtype_policy,
365
762
  name="context_block",
366
763
  )
@@ -371,20 +768,35 @@ class MMDiTBlock(layers.Layer):
371
768
  self.context_block.build(context_shape, timestep_embedding_shape)
372
769
 
373
770
  def _compute_attention(self, query, key, value):
374
- query = ops.multiply(
375
- query, ops.cast(self._inverse_sqrt_key_dim, query.dtype)
376
- )
377
- attention_scores = ops.einsum(self._dot_product_equation, key, query)
378
- attention_scores = self.softmax(attention_scores)
379
- attention_scores = ops.cast(attention_scores, self.compute_dtype)
380
- attention_output = ops.einsum(
381
- self._combine_equation, attention_scores, value
382
- )
383
- batch_size = ops.shape(attention_output)[0]
384
- attention_output = ops.reshape(
385
- attention_output, (batch_size, -1, self.num_heads * self.head_dim)
771
+ batch_size = ops.shape(query)[0]
772
+
773
+ # Use the fast path when `ops.dot_product_attention` and flash attention
774
+ # are available.
775
+ if hasattr(ops, "dot_product_attention") and hasattr(
776
+ keras.config, "is_flash_attention_enabled"
777
+ ):
778
+ encoded = ops.dot_product_attention(
779
+ query,
780
+ key,
781
+ value,
782
+ scale=self._inverse_sqrt_key_dim,
783
+ flash_attention=keras.config.is_flash_attention_enabled(),
784
+ )
785
+ return ops.reshape(
786
+ encoded, (batch_size, -1, self.num_heads * self.head_dim)
787
+ )
788
+
789
+ # Ref: jax.nn.dot_product_attention
790
+ # https://github.com/jax-ml/jax/blob/db89c245ac66911c98f265a05956fdfa4bc79d83/jax/_src/nn/functions.py#L846
791
+ logits = ops.einsum("BTNH,BSNH->BNTS", query, key)
792
+ logits = ops.multiply(logits, self._inverse_sqrt_key_dim)
793
+ probs = self.softmax(logits)
794
+ probs = ops.cast(probs, self.compute_dtype)
795
+ encoded = ops.einsum("BNTS,BSNH->BTNH", probs, value)
796
+ encoded = ops.reshape(
797
+ encoded, (batch_size, -1, self.num_heads * self.head_dim)
386
798
  )
387
- return attention_output
799
+ return encoded
388
800
 
389
801
  def call(self, inputs, context, timestep_embedding, training=None):
390
802
  # Compute pre-attention.
@@ -402,9 +814,14 @@ class MMDiTBlock(layers.Layer):
402
814
  training=training,
403
815
  )
404
816
  context_len = ops.shape(context_qkv[0])[1]
405
- x_qkv, x_intermediates = self.x_block(
406
- x, timestep_embedding=timestep_embedding, training=training
407
- )
817
+ if self.x_block.use_dual_attention:
818
+ x_qkv, x_qkv2, x_intermediates = self.x_block(
819
+ x, timestep_embedding=timestep_embedding, training=training
820
+ )
821
+ else:
822
+ x_qkv, x_intermediates = self.x_block(
823
+ x, timestep_embedding=timestep_embedding, training=training
824
+ )
408
825
  q = ops.concatenate([context_qkv[0], x_qkv[0]], axis=1)
409
826
  k = ops.concatenate([context_qkv[1], x_qkv[1]], axis=1)
410
827
  v = ops.concatenate([context_qkv[2], x_qkv[2]], axis=1)
@@ -415,12 +832,23 @@ class MMDiTBlock(layers.Layer):
415
832
  x_attention = attention[:, context_len:]
416
833
 
417
834
  # Compute post-attention.
418
- x = self.x_block(
419
- x_attention,
420
- inputs_intermediates=x_intermediates,
421
- pre_attention=False,
422
- training=training,
423
- )
835
+ if self.x_block.use_dual_attention:
836
+ q2, k2, v2 = x_qkv2
837
+ x_attention2 = self._compute_attention(q2, k2, v2)
838
+ x = self.x_block(
839
+ x_attention,
840
+ inputs_intermediates=x_intermediates,
841
+ inputs2=x_attention2,
842
+ pre_attention=False,
843
+ training=training,
844
+ )
845
+ else:
846
+ x = self.x_block(
847
+ x_attention,
848
+ inputs_intermediates=x_intermediates,
849
+ pre_attention=False,
850
+ training=training,
851
+ )
424
852
  if self.use_context_projection:
425
853
  context = self.context_block(
426
854
  context_attention,
@@ -440,6 +868,8 @@ class MMDiTBlock(layers.Layer):
440
868
  "hidden_dim": self.hidden_dim,
441
869
  "mlp_ratio": self.mlp_ratio,
442
870
  "use_context_projection": self.use_context_projection,
871
+ "qk_norm": self.qk_norm,
872
+ "use_dual_attention": self.use_dual_attention,
443
873
  }
444
874
  )
445
875
  return config
@@ -453,74 +883,16 @@ class MMDiTBlock(layers.Layer):
453
883
  return inputs_shape
454
884
 
455
885
 
456
- class OutputLayer(layers.Layer):
457
- def __init__(self, hidden_dim, output_dim, **kwargs):
458
- super().__init__(**kwargs)
459
- self.hidden_dim = hidden_dim
460
- self.output_dim = output_dim
461
- num_modulation = 2
462
-
463
- self.adaptive_norm_modulation = models.Sequential(
464
- [
465
- layers.Activation("silu", dtype=self.dtype_policy),
466
- layers.Dense(
467
- num_modulation * hidden_dim, dtype=self.dtype_policy
468
- ),
469
- ],
470
- name="adaptive_norm_modulation",
471
- )
472
- self.norm = layers.LayerNormalization(
473
- epsilon=1e-6,
474
- center=False,
475
- scale=False,
476
- dtype="float32",
477
- name="norm",
478
- )
479
- self.output_dense = layers.Dense(
480
- output_dim,
481
- use_bias=True,
482
- dtype=self.dtype_policy,
483
- name="output_dense",
484
- )
485
-
486
- def build(self, inputs_shape, timestep_embedding_shape):
487
- self.adaptive_norm_modulation.build(timestep_embedding_shape)
488
- self.norm.build(inputs_shape)
489
- self.output_dense.build(inputs_shape)
490
-
491
- def _modulate(self, inputs, shift, scale):
492
- shift = ops.expand_dims(shift, axis=1)
493
- scale = ops.expand_dims(scale, axis=1)
494
- return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
495
-
496
- def call(self, inputs, timestep_embedding, training=None):
497
- x = inputs
498
- modulation = self.adaptive_norm_modulation(
499
- timestep_embedding, training=training
500
- )
501
- modulation = ops.reshape(modulation, (-1, 2, self.hidden_dim))
502
- shift, scale = ops.unstack(modulation, 2, axis=1)
503
- x = self._modulate(self.norm(x), shift, scale)
504
- x = self.output_dense(x, training=training)
505
- return x
506
-
507
- def get_config(self):
508
- config = super().get_config()
509
- config.update(
510
- {
511
- "hidden_dim": self.hidden_dim,
512
- "output_dim": self.output_dim,
513
- }
514
- )
515
- return config
516
-
517
- def compute_output_shape(self, inputs_shape):
518
- outputs_shape = list(inputs_shape)
519
- outputs_shape[-1] = self.output_dim
520
- return outputs_shape
886
+ class Unpatch(layers.Layer):
887
+ """A layer that reconstructs the image from hidden patches.
521
888
 
889
+ Args:
890
+ patch_size: int. The size of each square patch in the input image.
891
+ output_dim: int. The number of units in the output layer.
892
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
893
+ including `name`, `dtype` etc.
894
+ """
522
895
 
523
- class Unpatch(layers.Layer):
524
896
  def __init__(self, patch_size, output_dim, **kwargs):
525
897
  super().__init__(**kwargs)
526
898
  self.patch_size = int(patch_size)
@@ -556,7 +928,7 @@ class Unpatch(layers.Layer):
556
928
 
557
929
 
558
930
  class MMDiT(Backbone):
559
- """Multimodal Diffusion Transformer (MMDiT) model for Stable Diffusion 3.
931
+ """A Multimodal Diffusion Transformer (MMDiT) model.
560
932
 
561
933
  MMDiT is introduced in [
562
934
  Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](
@@ -574,6 +946,12 @@ class MMDiT(Backbone):
574
946
  latent_shape: tuple. The shape of the latent image.
575
947
  context_shape: tuple. The shape of the context.
576
948
  pooled_projection_shape: tuple. The shape of the pooled projection.
949
+ qk_norm: Optional str. Whether to normalize the query and key tensors in
950
+ the intermediate blocks. Available options are `None` and
951
+ `"rms_norm"`. Defaults to `None`.
952
+ dual_attention_indices: Optional tuple. Specifies the indices of
953
+ the blocks that serve as dual attention blocks. Typically, this is
954
+ for 3.5 version. Defaults to `None`.
577
955
  data_format: `None` or str. If specified, either `"channels_last"` or
578
956
  `"channels_first"`. The ordering of the dimensions in the
579
957
  inputs. `"channels_last"` corresponds to inputs with shape
@@ -598,6 +976,8 @@ class MMDiT(Backbone):
598
976
  latent_shape=(64, 64, 16),
599
977
  context_shape=(None, 4096),
600
978
  pooled_projection_shape=(2048,),
979
+ qk_norm=None,
980
+ dual_attention_indices=None,
601
981
  data_format=None,
602
982
  dtype=None,
603
983
  **kwargs,
@@ -611,6 +991,7 @@ class MMDiT(Backbone):
611
991
  image_width = latent_shape[1] // patch_size
612
992
  output_dim = latent_shape[-1]
613
993
  output_dim_in_final = patch_size**2 * output_dim
994
+ dual_attention_indices = dual_attention_indices or ()
614
995
  data_format = standardize_data_format(data_format)
615
996
  if data_format != "channels_last":
616
997
  raise NotImplementedError(
@@ -636,12 +1017,8 @@ class MMDiT(Backbone):
636
1017
  dtype=dtype,
637
1018
  name="context_embedding",
638
1019
  )
639
- self.vector_embedding = models.Sequential(
640
- [
641
- layers.Dense(hidden_dim, activation="silu", dtype=dtype),
642
- layers.Dense(hidden_dim, activation=None, dtype=dtype),
643
- ],
644
- name="vector_embedding",
1020
+ self.vector_embedding = MLP(
1021
+ hidden_dim, hidden_dim, "silu", dtype=dtype, name="vector_embedding"
645
1022
  )
646
1023
  self.vector_embedding_add = layers.Add(
647
1024
  dtype=dtype, name="vector_embedding_add"
@@ -655,13 +1032,18 @@ class MMDiT(Backbone):
655
1032
  hidden_dim,
656
1033
  mlp_ratio,
657
1034
  use_context_projection=not (i == num_layers - 1),
1035
+ qk_norm=qk_norm,
1036
+ use_dual_attention=i in dual_attention_indices,
658
1037
  dtype=dtype,
659
1038
  name=f"joint_block_{i}",
660
1039
  )
661
1040
  for i in range(num_layers)
662
1041
  ]
663
- self.output_layer = OutputLayer(
664
- hidden_dim, output_dim_in_final, dtype=dtype, name="output_layer"
1042
+ self.output_ada_layer_norm = AdaptiveLayerNormalization(
1043
+ hidden_dim, dtype=dtype, name="output_ada_layer_norm"
1044
+ )
1045
+ self.output_dense = layers.Dense(
1046
+ output_dim_in_final, dtype=dtype, name="output_dense"
665
1047
  )
666
1048
  self.unpatch = Unpatch(
667
1049
  patch_size, output_dim, dtype=dtype, name="unpatch"
@@ -696,7 +1078,8 @@ class MMDiT(Backbone):
696
1078
  x = block(x, context, timestep_embedding)
697
1079
 
698
1080
  # Output layer.
699
- x = self.output_layer(x, timestep_embedding)
1081
+ x = self.output_ada_layer_norm(x, timestep_embedding)
1082
+ x = self.output_dense(x)
700
1083
  outputs = self.unpatch(x, height=image_height, width=image_width)
701
1084
 
702
1085
  super().__init__(
@@ -720,6 +1103,8 @@ class MMDiT(Backbone):
720
1103
  self.latent_shape = latent_shape
721
1104
  self.context_shape = context_shape
722
1105
  self.pooled_projection_shape = pooled_projection_shape
1106
+ self.qk_norm = qk_norm
1107
+ self.dual_attention_indices = dual_attention_indices
723
1108
 
724
1109
  def get_config(self):
725
1110
  config = super().get_config()
@@ -734,6 +1119,8 @@ class MMDiT(Backbone):
734
1119
  "latent_shape": self.latent_shape,
735
1120
  "context_shape": self.context_shape,
736
1121
  "pooled_projection_shape": self.pooled_projection_shape,
1122
+ "qk_norm": self.qk_norm,
1123
+ "dual_attention_indices": self.dual_attention_indices,
737
1124
  }
738
1125
  )
739
1126
  return config