keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,496 @@
1
+ import keras
2
+ from keras import layers
3
+ from keras import ops
4
+
5
+ from keras_hub.src.layers.modeling.rms_normalization import RMSNormalization
6
+ from keras_hub.src.models.flux.flux_maths import FluxRoPEAttention
7
+ from keras_hub.src.models.flux.flux_maths import RotaryPositionalEmbedding
8
+ from keras_hub.src.models.flux.flux_maths import rearrange_symbolic_tensors
9
+
10
+
11
+ class EmbedND(keras.Model):
12
+ """Embedding layer for N-dimensional inputs using RoPE.
13
+
14
+ This layer applies RoPE embeddings across multiple axes of the input tensor
15
+ and concatenates the embeddings along a specified axis.
16
+
17
+ Args:
18
+ theta. Rotational angle parameter for RoPE.
19
+ axes_dim. Dimensionality for each axis of the input tensor.
20
+ """
21
+
22
+ def __init__(self, theta, axes_dim):
23
+ super().__init__()
24
+ self.theta = theta
25
+ self.axes_dim = axes_dim
26
+ self.rope = RotaryPositionalEmbedding()
27
+
28
+ def build(self, input_shape):
29
+ n_axes = input_shape[-1]
30
+ for i in range(n_axes):
31
+ self.rope.build((input_shape[:-1] + (self.axes_dim[i],)))
32
+
33
+ def call(self, ids):
34
+ """Computes the positional embeddings for each axis and concatenates.
35
+
36
+ Args:
37
+ ids: KerasTensor. Input tensor of shape (..., num_axes).
38
+
39
+ Returns:
40
+ KerasTensor: Positional embeddings of shape
41
+ (..., concatenated_dim, 1, ...).
42
+ """
43
+ n_axes = ids.shape[-1]
44
+ emb = ops.concatenate(
45
+ [
46
+ self.rope(ids[..., i], dim=self.axes_dim[i], theta=self.theta)
47
+ for i in range(n_axes)
48
+ ],
49
+ axis=-3,
50
+ )
51
+
52
+ return ops.expand_dims(emb, axis=1)
53
+
54
+
55
+ class MLPEmbedder(keras.Model):
56
+ """A simple multi-layer perceptron (MLP) embedder model.
57
+
58
+ This model applies a linear transformation followed by the SiLU activation
59
+ function and another linear transformation to the input tensor.
60
+
61
+ Args:
62
+ hidden_dim. The dimensionality of the hidden layer.
63
+ """
64
+
65
+ def __init__(self, hidden_dim):
66
+ super().__init__()
67
+ self.hidden_dim = hidden_dim
68
+ self.input_layer = layers.Dense(hidden_dim, use_bias=True)
69
+ self.silu = layers.Activation("silu")
70
+ self.output_layer = layers.Dense(hidden_dim, use_bias=True)
71
+
72
+ def build(self, input_shape):
73
+ self.input_layer.build(input_shape)
74
+ self.output_layer.build((input_shape[0], self.input_layer.units))
75
+
76
+ def call(self, x):
77
+ """Applies the MLP embedding to the input tensor.
78
+
79
+ Args:
80
+ x: Input tensor of shape (batch_size, in_dim).
81
+
82
+ Returns:
83
+ Output tensor of shape (batch_size, hidden_dim) after applying the
84
+ MLP transformations.
85
+ """
86
+ x = self.input_layer(x)
87
+ x = self.silu(x)
88
+ return self.output_layer(x)
89
+
90
+
91
+ class QKNorm(keras.layers.Layer):
92
+ """A layer that applies RMS normalization to query and key tensors.
93
+
94
+ This layer normalizes the input query and key tensors using separate
95
+ RMSNormalization layers for each.
96
+
97
+ Args:
98
+ input_dim. The dimensionality of the input query and key tensors.
99
+ """
100
+
101
+ def __init__(self, input_dim):
102
+ super().__init__()
103
+ self.query_norm = RMSNormalization(input_dim)
104
+ self.key_norm = RMSNormalization(input_dim)
105
+
106
+ def build(self, input_shape):
107
+ self.query_norm.build(input_shape)
108
+ self.key_norm.build(input_shape)
109
+
110
+ def call(self, q, k):
111
+ """
112
+ Applies RMS normalization to the query and key tensors.
113
+
114
+ Args:
115
+ q: KerasTensor. The query tensor of shape (batch_size, input_dim).
116
+ k: KerasTensor. The key tensor of shape (batch_size, input_dim).
117
+
118
+ Returns:
119
+ tuple[KerasTensor, KerasTensor]: A tuple containing the normalized
120
+ query and key tensors.
121
+ """
122
+ q = self.query_norm(q)
123
+ k = self.key_norm(k)
124
+ return q, k
125
+
126
+
127
+ class SelfAttention(keras.Model):
128
+ """Multi-head self-attention layer with RoPE and RMS normalization.
129
+
130
+ This layer performs self-attention over the input sequence and applies RMS
131
+ normalization to the query and key tensors before computing the attention
132
+ scores.
133
+
134
+ Args:
135
+ dim: int. Dimensionality of the input tensor.
136
+ num_heads: int. Number of attention heads. Default is 8.
137
+ use_bias: bool. Whether to use bias in the query, key, value projection
138
+ layers. Default is False.
139
+ """
140
+
141
+ def __init__(self, dim, num_heads=8, use_bias=False):
142
+ super().__init__()
143
+ self.num_heads = num_heads
144
+ head_dim = dim // num_heads
145
+ self.dim = dim
146
+
147
+ self.qkv = layers.Dense(dim * 3, use_bias=use_bias)
148
+ self.norm = QKNorm(head_dim)
149
+ self.proj = layers.Dense(dim)
150
+ self.attention = FluxRoPEAttention()
151
+
152
+ def build(self, input_shape):
153
+ self.qkv.build(input_shape)
154
+ head_dim = input_shape[-1] // self.num_heads
155
+ self.norm.build((None, input_shape[1], head_dim))
156
+ self.proj.build((None, input_shape[1], input_shape[-1]))
157
+
158
+ def call(self, x, positional_encoding):
159
+ """Applies self-attention with RoPE embeddings.
160
+
161
+ Args:
162
+ x: KerasTensor. Input tensor of shape (batch_size, seq_len, dim).
163
+ positional_encoding: KerasTensor. Positional encoding tensor for
164
+ RoPE.
165
+
166
+ Returns:
167
+ KerasTensor: Output tensor after self-attention and projection.
168
+ """
169
+ qkv = self.qkv(x)
170
+ q, k, v = rearrange_symbolic_tensors(qkv, K=3, H=self.num_heads)
171
+ q, k = self.norm(q, k)
172
+ x = self.attention(
173
+ q=q, k=k, v=v, positional_encoding=positional_encoding
174
+ )
175
+ x = self.proj(x)
176
+ return x
177
+
178
+
179
+ class Modulation(keras.Model):
180
+ """Modulation layer that produces shift, scale, and gate tensors.
181
+
182
+ This layer applies a SiLU activation to the input tensor followed by a
183
+ linear transformation to generate modulation parameters. It can optionally
184
+ generate two sets of modulation parameters.
185
+
186
+ Args:
187
+ dim: int. Dimensionality of the modulation output.
188
+ double: bool. Whether to generate two sets of modulation parameters.
189
+ """
190
+
191
+ def __init__(self, dim, double):
192
+ super().__init__()
193
+ self.dim = dim
194
+ self.is_double = double
195
+ self.multiplier = 6 if double else 3
196
+ self.linear_projection = keras.layers.Dense(
197
+ self.multiplier * dim, use_bias=True
198
+ )
199
+
200
+ def build(self, input_shape):
201
+ self.linear_projection.build(input_shape)
202
+
203
+ def call(self, x):
204
+ """
205
+ Generates modulation parameters from the input tensor.
206
+
207
+ Args:
208
+ x: KerasTensor. Input tensor.
209
+
210
+ Returns:
211
+ tuple[ModulationOut, ModulationOut | None]: A tuple containing th
212
+ shift, scale, and gate tensors. If `double` is True, returns two
213
+ sets of modulation parameters.
214
+ """
215
+ x = keras.layers.Activation("silu")(x)
216
+ out = self.linear_projection(x)
217
+ out = ops.split(
218
+ out[:, None, :], indices_or_sections=self.multiplier, axis=-1
219
+ )
220
+
221
+ first_output = {"shift": out[0], "scale": out[1], "gate": out[2]}
222
+ second_output = (
223
+ {"shift": out[3], "scale": out[4], "gate": out[5]}
224
+ if self.is_double
225
+ else None
226
+ )
227
+
228
+ return first_output, second_output
229
+
230
+
231
+ class DoubleStreamBlock(keras.Model):
232
+ """
233
+ A block that processes image and text inputs in parallel using
234
+ self-attention and MLP layers, with modulation.
235
+
236
+ Args:
237
+ hidden_size: int. The hidden dimension size for the model.
238
+ num_heads: int. The number of attention heads.
239
+ mlp_ratio: float. The ratio of the MLP hidden dimension to the hidde
240
+ size.
241
+ use_bias: bool, optional. Whether to include bias in QKV projection.
242
+ Default is False.
243
+ """
244
+
245
+ def __init__(
246
+ self,
247
+ hidden_size,
248
+ num_heads,
249
+ mlp_ratio,
250
+ use_bias=False,
251
+ ):
252
+ super().__init__()
253
+
254
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
255
+ self.num_heads = num_heads
256
+ self.hidden_size = hidden_size
257
+
258
+ self.image_mod = Modulation(hidden_size, double=True)
259
+ self.image_norm1 = keras.layers.LayerNormalization(epsilon=1e-6)
260
+ self.image_attn = SelfAttention(
261
+ dim=hidden_size, num_heads=num_heads, use_bias=use_bias
262
+ )
263
+
264
+ self.image_norm2 = keras.layers.LayerNormalization(epsilon=1e-6)
265
+ self.image_mlp = keras.Sequential(
266
+ [
267
+ keras.layers.Dense(mlp_hidden_dim, use_bias=True),
268
+ keras.layers.Activation("gelu"),
269
+ keras.layers.Dense(hidden_size, use_bias=True),
270
+ ]
271
+ )
272
+
273
+ self.text_mod = Modulation(hidden_size, double=True)
274
+ self.text_norm1 = keras.layers.LayerNormalization(epsilon=1e-6)
275
+ self.text_attn = SelfAttention(
276
+ dim=hidden_size, num_heads=num_heads, use_bias=use_bias
277
+ )
278
+
279
+ self.text_norm2 = keras.layers.LayerNormalization(epsilon=1e-6)
280
+ self.text_mlp = keras.Sequential(
281
+ [
282
+ keras.layers.Dense(mlp_hidden_dim, use_bias=True),
283
+ keras.layers.Activation("gelu"),
284
+ keras.layers.Dense(hidden_size, use_bias=True),
285
+ ]
286
+ )
287
+ self.attention = FluxRoPEAttention()
288
+
289
+ def call(self, image, text, modulation_encoding, positional_encoding):
290
+ """
291
+ Forward pass for the DoubleStreamBlock.
292
+
293
+ Args:
294
+ image: Input image tensor.
295
+ text: Input text tensor.
296
+ modulation_encoding: Modulation vector.
297
+ positional_encoding: Positional encoding tensor.
298
+
299
+ Returns:
300
+ A `(image, text)` tuple of modified image and text tensors.
301
+ """
302
+ image_mod1, image_mod2 = self.image_mod(modulation_encoding)
303
+ text_mod1, text_mod2 = self.text_mod(modulation_encoding)
304
+
305
+ # prepare image for attention
306
+ image_modulated = self.image_norm1(image)
307
+ image_modulated = (
308
+ 1 + image_mod1["scale"]
309
+ ) * image_modulated + image_mod1["shift"]
310
+ image_qkv = self.image_attn.qkv(image_modulated)
311
+
312
+ image_q, image_k, image_v = rearrange_symbolic_tensors(
313
+ image_qkv, K=3, H=self.num_heads
314
+ )
315
+ image_q, image_k = self.image_attn.norm(image_q, image_k)
316
+
317
+ # prepare text for attention
318
+ text_modulated = self.text_norm1(text)
319
+ text_modulated = (1 + text_mod1["scale"]) * text_modulated + text_mod1[
320
+ "shift"
321
+ ]
322
+ text_qkv = self.text_attn.qkv(text_modulated)
323
+
324
+ text_q, text_k, text_v = rearrange_symbolic_tensors(
325
+ text_qkv, K=3, H=self.num_heads
326
+ )
327
+
328
+ text_q, text_k = self.text_attn.norm(text_q, text_k)
329
+
330
+ # run actual attention
331
+ q = ops.concatenate((text_q, image_q), axis=2)
332
+ k = ops.concatenate((text_k, image_k), axis=2)
333
+ v = ops.concatenate((text_v, image_v), axis=2)
334
+
335
+ attn = self.attention(
336
+ q=q, k=k, v=v, positional_encoding=positional_encoding
337
+ )
338
+ text_attn, image_attn = (
339
+ attn[:, : text.shape[1]],
340
+ attn[:, text.shape[1] :],
341
+ )
342
+
343
+ # calculate the image blocks
344
+ image = image + image_mod1["gate"] * self.image_attn.proj(image_attn)
345
+ image = image + image_mod2["gate"] * self.image_mlp(
346
+ (1 + image_mod2["scale"]) * self.image_norm2(image)
347
+ + image_mod2["shift"]
348
+ )
349
+
350
+ # calculate the text blocks
351
+ text = text + text_mod1["gate"] * self.text_attn.proj(text_attn)
352
+ text = text + text_mod2["gate"] * self.text_mlp(
353
+ (1 + text_mod2["scale"]) * self.text_norm2(text)
354
+ + text_mod2["shift"]
355
+ )
356
+ return image, text
357
+
358
+
359
+ class SingleStreamBlock(keras.Model):
360
+ """
361
+ A DiT block with parallel linear layers.
362
+
363
+ As described in https://arxiv.org/abs/2302.05442 and
364
+ adapted for the modulation interface.
365
+
366
+ Args:
367
+ hidden_size: int. The hidden dimension size for the model.
368
+ num_heads: int. The number of attention heads.
369
+ mlp_ratio: float, optional. The ratio of the MLP hidden dimension to the
370
+ hidden size. Default is 4.0.
371
+ qk_scale: float, optional. Scaling factor for the query-key product.
372
+ Default is None.
373
+ """
374
+
375
+ def __init__(
376
+ self,
377
+ hidden_size,
378
+ num_heads,
379
+ mlp_ratio=4.0,
380
+ qk_scale=None,
381
+ ):
382
+ super().__init__()
383
+ self.hidden_dim = hidden_size
384
+ self.num_heads = num_heads
385
+ head_dim = hidden_size // num_heads
386
+ self.scale = qk_scale or head_dim**-0.5
387
+
388
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
389
+ # qkv and mlp_in
390
+ self.linear1 = keras.layers.Dense(hidden_size * 3 + self.mlp_hidden_dim)
391
+ # proj and mlp_out
392
+ self.linear2 = keras.layers.Dense(hidden_size)
393
+
394
+ self.norm = QKNorm(head_dim)
395
+
396
+ self.hidden_size = hidden_size
397
+ self.pre_norm = keras.layers.LayerNormalization(epsilon=1e-6)
398
+ self.modulation = Modulation(hidden_size, double=False)
399
+ self.attention = FluxRoPEAttention()
400
+
401
+ def build(
402
+ self, x_shape, modulation_encoding_shape, positional_encoding_shape
403
+ ):
404
+ self.linear1.build(x_shape)
405
+ self.linear2.build(
406
+ (x_shape[0], x_shape[1], self.hidden_size + self.mlp_hidden_dim)
407
+ )
408
+
409
+ self.modulation.build(
410
+ modulation_encoding_shape
411
+ ) # Build the modulation layer
412
+
413
+ self.norm.build(
414
+ (
415
+ x_shape[0],
416
+ self.num_heads,
417
+ x_shape[1],
418
+ x_shape[-1] // self.num_heads,
419
+ )
420
+ )
421
+
422
+ def call(self, x, modulation_encoding, positional_encoding):
423
+ """
424
+ Forward pass for the SingleStreamBlock.
425
+
426
+ Args:
427
+ x: KerasTensor. Input tensor.
428
+ modulation_encoding: KerasTensor. Modulation vector.
429
+ positional_encoding: KerasTensor. Positional encoding tensor.
430
+
431
+ Returns:
432
+ KerasTensor: The modified input tensor after processing.
433
+ """
434
+ mod, _ = self.modulation(modulation_encoding)
435
+ x_mod = (1 + mod["scale"]) * self.pre_norm(x) + mod["shift"]
436
+ qkv, mlp = ops.split(
437
+ self.linear1(x_mod), [3 * self.hidden_size], axis=-1
438
+ )
439
+
440
+ q, k, v = rearrange_symbolic_tensors(qkv, K=3, H=self.num_heads)
441
+ q, k = self.norm(q, k)
442
+
443
+ # compute attention
444
+ attn = self.attention(
445
+ q, k=k, v=v, positional_encoding=positional_encoding
446
+ )
447
+ # compute activation in mlp stream, cat again and run second linear
448
+ # layer
449
+ output = self.linear2(
450
+ ops.concatenate(
451
+ (attn, keras.activations.gelu(mlp, approximate=True)), 2
452
+ )
453
+ )
454
+ return x + mod["gate"] * output
455
+
456
+
457
+ class LastLayer(keras.Model):
458
+ """
459
+ Final layer for processing output tensors with adaptive normalization.
460
+
461
+ Args:
462
+ hidden_size: int. The hidden dimension size for the model.
463
+ patch_size: int. The size of each patch.
464
+ output_channels: int. The number of output channels.
465
+ """
466
+
467
+ def __init__(self, hidden_size, patch_size, output_channels):
468
+ super().__init__()
469
+ self.norm_final = keras.layers.LayerNormalization(epsilon=1e-6)
470
+ self.linear = keras.layers.Dense(
471
+ patch_size * patch_size * output_channels, use_bias=True
472
+ )
473
+ self.adaLN_modulation = keras.Sequential(
474
+ [
475
+ keras.layers.Activation("silu"),
476
+ keras.layers.Dense(2 * hidden_size, use_bias=True),
477
+ ]
478
+ )
479
+
480
+ def call(self, x, modulation_encoding):
481
+ """
482
+ Forward pass for the LastLayer.
483
+
484
+ Args:
485
+ x: KerasTensor. Input tensor.
486
+ modulation_encoding: KerasTensor. Modulation vector.
487
+
488
+ Returns:
489
+ KerasTensor: The output tensor after final processing.
490
+ """
491
+ shift, scale = ops.split(
492
+ self.adaLN_modulation(modulation_encoding), 2, axis=1
493
+ )
494
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
495
+ x = self.linear(x)
496
+ return x