keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -4,13 +4,10 @@ from keras import ops
4
4
 
5
5
  from keras_hub.src.api_export import keras_hub_export
6
6
  from keras_hub.src.models.backbone import Backbone
7
- from keras_hub.src.models.stable_diffusion_3.flow_match_euler_discrete_scheduler import (
7
+ from keras_hub.src.models.stable_diffusion_3.flow_match_euler_discrete_scheduler import ( # noqa: E501
8
8
  FlowMatchEulerDiscreteScheduler,
9
9
  )
10
10
  from keras_hub.src.models.stable_diffusion_3.mmdit import MMDiT
11
- from keras_hub.src.models.stable_diffusion_3.vae_image_decoder import (
12
- VAEImageDecoder,
13
- )
14
11
  from keras_hub.src.utils.keras_utils import standardize_data_format
15
12
 
16
13
 
@@ -54,11 +51,52 @@ class CLIPProjection(layers.Layer):
54
51
  return (inputs_shape[0], self.hidden_dim)
55
52
 
56
53
 
57
- class ClassifierFreeGuidanceConcatenate(layers.Layer):
58
- def __init__(self, axis=0, **kwargs):
59
- super().__init__(**kwargs)
60
- self.axis = axis
54
+ class CLIPConcatenate(layers.Layer):
55
+ def call(
56
+ self,
57
+ clip_l_projection,
58
+ clip_g_projection,
59
+ clip_l_intermediate_output,
60
+ clip_g_intermediate_output,
61
+ padding,
62
+ ):
63
+ pooled_embeddings = ops.concatenate(
64
+ [clip_l_projection, clip_g_projection], axis=-1
65
+ )
66
+ embeddings = ops.concatenate(
67
+ [clip_l_intermediate_output, clip_g_intermediate_output], axis=-1
68
+ )
69
+ embeddings = ops.pad(embeddings, [[0, 0], [0, 0], [0, padding]])
70
+ return pooled_embeddings, embeddings
71
+
72
+
73
+ class ImageRescaling(layers.Rescaling):
74
+ """Rescales inputs from image space to latent space.
75
+
76
+ The rescaling is performed using the formula: `(inputs - offset) * scale`.
77
+ """
61
78
 
79
+ def call(self, inputs):
80
+ dtype = self.compute_dtype
81
+ scale = self.backend.cast(self.scale, dtype)
82
+ offset = self.backend.cast(self.offset, dtype)
83
+ return (self.backend.cast(inputs, dtype) - offset) * scale
84
+
85
+
86
+ class LatentRescaling(layers.Rescaling):
87
+ """Rescales inputs from latent space to image space.
88
+
89
+ The rescaling is performed using the formula: `inputs / scale + offset`.
90
+ """
91
+
92
+ def call(self, inputs):
93
+ dtype = self.compute_dtype
94
+ scale = self.backend.cast(self.scale, dtype)
95
+ offset = self.backend.cast(self.offset, dtype)
96
+ return (self.backend.cast(inputs, dtype) / scale) + offset
97
+
98
+
99
+ class ClassifierFreeGuidanceConcatenate(layers.Layer):
62
100
  def call(
63
101
  self,
64
102
  latents,
@@ -69,20 +107,16 @@ class ClassifierFreeGuidanceConcatenate(layers.Layer):
69
107
  timestep,
70
108
  ):
71
109
  timestep = ops.broadcast_to(timestep, ops.shape(latents)[:1])
72
- latents = ops.concatenate([latents, latents], axis=self.axis)
110
+ latents = ops.concatenate([latents, latents], axis=0)
73
111
  contexts = ops.concatenate(
74
- [positive_contexts, negative_contexts], axis=self.axis
112
+ [positive_contexts, negative_contexts], axis=0
75
113
  )
76
114
  pooled_projections = ops.concatenate(
77
- [positive_pooled_projections, negative_pooled_projections],
78
- axis=self.axis,
115
+ [positive_pooled_projections, negative_pooled_projections], axis=0
79
116
  )
80
- timesteps = ops.concatenate([timestep, timestep], axis=self.axis)
117
+ timesteps = ops.concatenate([timestep, timestep], axis=0)
81
118
  return latents, contexts, pooled_projections, timesteps
82
119
 
83
- def get_config(self):
84
- return super().get_config()
85
-
86
120
 
87
121
  class ClassifierFreeGuidance(layers.Layer):
88
122
  """Perform classifier free guidance.
@@ -103,9 +137,6 @@ class ClassifierFreeGuidance(layers.Layer):
103
137
  - [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
104
138
  """
105
139
 
106
- def __init__(self, **kwargs):
107
- super().__init__(**kwargs)
108
-
109
140
  def call(self, inputs, guidance_scale):
110
141
  positive_noise, negative_noise = ops.split(inputs, 2, axis=0)
111
142
  return ops.add(
@@ -115,9 +146,6 @@ class ClassifierFreeGuidance(layers.Layer):
115
146
  ),
116
147
  )
117
148
 
118
- def get_config(self):
119
- return super().get_config()
120
-
121
149
  def compute_output_shape(self, inputs_shape):
122
150
  outputs_shape = list(inputs_shape)
123
151
  if outputs_shape[0] is not None:
@@ -145,58 +173,10 @@ class EulerStep(layers.Layer):
145
173
  https://arxiv.org/abs/2206.00364).
146
174
  """
147
175
 
148
- def __init__(self, **kwargs):
149
- super().__init__(**kwargs)
150
-
151
176
  def call(self, latents, noise_residual, sigma, sigma_next):
152
177
  sigma_diff = ops.subtract(sigma_next, sigma)
153
178
  return ops.add(latents, ops.multiply(sigma_diff, noise_residual))
154
179
 
155
- def get_config(self):
156
- return super().get_config()
157
-
158
- def compute_output_shape(self, latents_shape):
159
- return latents_shape
160
-
161
-
162
- class LatentSpaceDecoder(layers.Layer):
163
- """Decoder to transform the latent space back to the original image space.
164
-
165
- During decoding, the latents are transformed back to the original image
166
- space using the equation: `latents / scale + shift`.
167
-
168
- Args:
169
- scale: float. The scaling factor.
170
- shift: float. The shift factor.
171
- **kwargs: other keyword arguments passed to `keras.layers.Layer`,
172
- including `name`, `dtype` etc.
173
-
174
- Call arguments:
175
- latents: The latent tensor to be transformed.
176
-
177
- Reference:
178
- - [High-Resolution Image Synthesis with Latent Diffusion Models](
179
- https://arxiv.org/abs/2112.10752).
180
- """
181
-
182
- def __init__(self, scale, shift, **kwargs):
183
- super().__init__(**kwargs)
184
- self.scale = scale
185
- self.shift = shift
186
-
187
- def call(self, latents):
188
- return ops.add(ops.divide(latents, self.scale), self.shift)
189
-
190
- def get_config(self):
191
- config = super().get_config()
192
- config.update(
193
- {
194
- "scale": self.scale,
195
- "shift": self.shift,
196
- }
197
- )
198
- return config
199
-
200
180
  def compute_output_shape(self, latents_shape):
201
181
  return latents_shape
202
182
 
@@ -222,16 +202,18 @@ class StableDiffusion3Backbone(Backbone):
222
202
  transformer in MMDiT.
223
203
  mmdit_position_size: int. The size of the height and width for the
224
204
  position embedding in MMDiT.
225
- vae_stackwise_num_filters: list of ints. The number of filters for each
226
- stack in VAE.
227
- vae_stackwise_num_blocks: list of ints. The number of blocks for each
228
- stack in VAE.
229
- clip_l: `keras_hub.models.CLIPTextEncoder`. The text encoder for
230
- encoding the inputs.
231
- clip_g: `keras_hub.models.CLIPTextEncoder`. The text encoder for
232
- encoding the inputs.
233
- t5: optional `keras_hub.models.T5Encoder`. The text encoder for
234
- encoding the inputs.
205
+ mmdit_qk_norm: Optional str. Whether to normalize the query and key
206
+ tensors for each transformer in MMDiT. Available options are `None`
207
+ and `"rms_norm"`. Typically, this is set to `None` for 3.0 version
208
+ and to `"rms_norm"` for 3.5 version.
209
+ mmdit_dual_attention_indices: Optional tuple. Specifies the indices of
210
+ the blocks that serve as dual attention blocks. Typically, this is
211
+ for 3.5 version. Defaults to `None`.
212
+ vae: The VAE used for transformations between pixel space and latent
213
+ space.
214
+ clip_l: The CLIP text encoder for encoding the inputs.
215
+ clip_g: The CLIP text encoder for encoding the inputs.
216
+ t5: optional The T5 text encoder for encoding the inputs.
235
217
  latent_channels: int. The number of channels in the latent. Defaults to
236
218
  `16`.
237
219
  output_channels: int. The number of channels in the output. Defaults to
@@ -239,9 +221,9 @@ class StableDiffusion3Backbone(Backbone):
239
221
  num_train_timesteps: int. The number of diffusion steps to train the
240
222
  model. Defaults to `1000`.
241
223
  shift: float. The shift value for the timestep schedule. Defaults to
242
- `1.0`.
243
- height: optional int. The output height of the image.
244
- width: optional int. The output width of the image.
224
+ `3.0`.
225
+ image_shape: tuple. The input shape without the batch size. Defaults to
226
+ `(1024, 1024, 3)`.
245
227
  data_format: `None` or str. If specified, either `"channels_last"` or
246
228
  `"channels_first"`. The ordering of the dimensions in the
247
229
  inputs. `"channels_last"` corresponds to inputs with shape
@@ -264,6 +246,7 @@ class StableDiffusion3Backbone(Backbone):
264
246
  )
265
247
 
266
248
  # Randomly initialized Stable Diffusion 3 model with custom config.
249
+ vae = keras_hub.models.VAEBackbone(...)
267
250
  clip_l = keras_hub.models.CLIPTextEncoder(...)
268
251
  clip_g = keras_hub.models.CLIPTextEncoder(...)
269
252
  model = keras_hub.models.StableDiffusion3Backbone(
@@ -272,8 +255,9 @@ class StableDiffusion3Backbone(Backbone):
272
255
  mmdit_hidden_dim=256,
273
256
  mmdit_depth=4,
274
257
  mmdit_position_size=192,
275
- vae_stackwise_num_filters=[128, 128, 64, 32],
276
- vae_stackwise_num_blocks=[1, 1, 1, 1],
258
+ mmdit_qk_norm=None,
259
+ mmdit_dual_attention_indices=None,
260
+ vae=vae,
277
261
  clip_l=clip_l,
278
262
  clip_g=clip_g,
279
263
  )
@@ -287,46 +271,48 @@ class StableDiffusion3Backbone(Backbone):
287
271
  mmdit_num_layers,
288
272
  mmdit_num_heads,
289
273
  mmdit_position_size,
290
- vae_stackwise_num_filters,
291
- vae_stackwise_num_blocks,
274
+ mmdit_qk_norm,
275
+ mmdit_dual_attention_indices,
276
+ vae,
292
277
  clip_l,
293
278
  clip_g,
294
279
  t5=None,
295
280
  latent_channels=16,
296
281
  output_channels=3,
297
282
  num_train_timesteps=1000,
298
- shift=1.0,
299
- height=None,
300
- width=None,
283
+ shift=3.0,
284
+ image_shape=(1024, 1024, 3),
301
285
  data_format=None,
302
286
  dtype=None,
303
287
  **kwargs,
304
288
  ):
305
- height = int(height or 1024)
306
- width = int(width or 1024)
307
- if height % 8 != 0 or width % 8 != 0:
308
- raise ValueError(
309
- "`height` and `width` must be divisible by 8. "
310
- f"Received: height={height}, width={width}"
311
- )
312
289
  data_format = standardize_data_format(data_format)
313
290
  if data_format != "channels_last":
314
291
  raise NotImplementedError
315
- latent_shape = (height // 8, width // 8, latent_channels)
292
+ height = image_shape[0]
293
+ width = image_shape[1]
294
+ if height % 8 != 0 or width % 8 != 0:
295
+ raise ValueError(
296
+ "height and width in `image_shape` must be divisible by 8. "
297
+ f"Received: image_shape={image_shape}"
298
+ )
299
+ latent_shape = (height // 8, width // 8, int(latent_channels))
316
300
  context_shape = (None, 4096 if t5 is None else t5.hidden_dim)
317
301
  pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,)
302
+ self._latent_shape = latent_shape
318
303
 
319
304
  # === Layers ===
320
305
  self.clip_l = clip_l
321
306
  self.clip_l_projection = CLIPProjection(
322
307
  clip_l.hidden_dim, dtype=dtype, name="clip_l_projection"
323
308
  )
324
- self.clip_l_projection.build([None, clip_l.hidden_dim], None)
325
309
  self.clip_g = clip_g
326
310
  self.clip_g_projection = CLIPProjection(
327
311
  clip_g.hidden_dim, dtype=dtype, name="clip_g_projection"
328
312
  )
329
- self.clip_g_projection.build([None, clip_g.hidden_dim], None)
313
+ self.clip_concatenate = CLIPConcatenate(
314
+ dtype=dtype, name="clip_concatenate"
315
+ )
330
316
  self.t5 = t5
331
317
  self.diffuser = MMDiT(
332
318
  mmdit_patch_size,
@@ -337,18 +323,18 @@ class StableDiffusion3Backbone(Backbone):
337
323
  latent_shape=latent_shape,
338
324
  context_shape=context_shape,
339
325
  pooled_projection_shape=pooled_projection_shape,
326
+ qk_norm=mmdit_qk_norm,
327
+ dual_attention_indices=mmdit_dual_attention_indices,
340
328
  data_format=data_format,
341
329
  dtype=dtype,
342
330
  name="diffuser",
343
331
  )
344
- self.decoder = VAEImageDecoder(
345
- vae_stackwise_num_filters,
346
- vae_stackwise_num_blocks,
347
- output_channels,
348
- latent_shape=latent_shape,
349
- data_format=data_format,
350
- dtype=dtype,
351
- name="decoder",
332
+ self.vae = vae
333
+ self.cfg_concat = ClassifierFreeGuidanceConcatenate(
334
+ dtype=dtype, name="classifier_free_guidance_concat"
335
+ )
336
+ self.cfg = ClassifierFreeGuidance(
337
+ dtype=dtype, name="classifier_free_guidance"
352
338
  )
353
339
  # Set `dtype="float32"` to ensure the high precision for the noise
354
340
  # residual.
@@ -358,21 +344,25 @@ class StableDiffusion3Backbone(Backbone):
358
344
  dtype="float32",
359
345
  name="scheduler",
360
346
  )
361
- self.cfg_concat = ClassifierFreeGuidanceConcatenate(
362
- dtype="float32", name="classifier_free_guidance_concat"
363
- )
364
- self.cfg = ClassifierFreeGuidance(
365
- dtype="float32", name="classifier_free_guidance"
366
- )
367
347
  self.euler_step = EulerStep(dtype="float32", name="euler_step")
368
- self.latent_space_decoder = LatentSpaceDecoder(
369
- scale=self.decoder.scaling_factor,
370
- shift=self.decoder.shift_factor,
371
- dtype="float32",
372
- name="latent_space_decoder",
348
+ self.image_rescaling = ImageRescaling(
349
+ scale=self.vae.scale,
350
+ offset=self.vae.shift,
351
+ dtype=dtype,
352
+ name="image_rescaling",
353
+ )
354
+ self.latent_rescaling = LatentRescaling(
355
+ scale=self.vae.scale,
356
+ offset=self.vae.shift,
357
+ dtype=dtype,
358
+ name="latent_rescaling",
373
359
  )
374
360
 
375
361
  # === Functional Model ===
362
+ image_input = keras.Input(
363
+ shape=image_shape,
364
+ name="images",
365
+ )
376
366
  latent_input = keras.Input(
377
367
  shape=latent_shape,
378
368
  name="latents",
@@ -428,17 +418,19 @@ class StableDiffusion3Backbone(Backbone):
428
418
  dtype="float32",
429
419
  name="guidance_scale",
430
420
  )
431
- embeddings = self.encode_step(token_ids, negative_token_ids)
421
+ embeddings = self.encode_text_step(token_ids, negative_token_ids)
422
+ latents = self.encode_image_step(image_input)
432
423
  # Use `steps=0` to define the functional model.
433
- latents = self.denoise_step(
424
+ denoised_latents = self.denoise_step(
434
425
  latent_input,
435
426
  embeddings,
436
427
  0,
437
428
  num_step_input[0],
438
429
  guidance_scale_input[0],
439
430
  )
440
- outputs = self.decode_step(latents)
431
+ images = self.decode_step(denoised_latents)
441
432
  inputs = {
433
+ "images": image_input,
442
434
  "latents": latent_input,
443
435
  "clip_l_token_ids": clip_l_token_id_input,
444
436
  "clip_l_negative_token_ids": clip_l_negative_token_id_input,
@@ -447,6 +439,10 @@ class StableDiffusion3Backbone(Backbone):
447
439
  "num_steps": num_step_input,
448
440
  "guidance_scale": guidance_scale_input,
449
441
  }
442
+ outputs = {
443
+ "latents": latents,
444
+ "images": images,
445
+ }
450
446
  if self.t5 is not None:
451
447
  inputs["t5_token_ids"] = t5_token_id_input
452
448
  inputs["t5_negative_token_ids"] = t5_negative_token_id_input
@@ -463,18 +459,17 @@ class StableDiffusion3Backbone(Backbone):
463
459
  self.mmdit_num_layers = mmdit_num_layers
464
460
  self.mmdit_num_heads = mmdit_num_heads
465
461
  self.mmdit_position_size = mmdit_position_size
466
- self.vae_stackwise_num_filters = vae_stackwise_num_filters
467
- self.vae_stackwise_num_blocks = vae_stackwise_num_blocks
462
+ self.mmdit_qk_norm = mmdit_qk_norm
463
+ self.mmdit_dual_attention_indices = mmdit_dual_attention_indices
468
464
  self.latent_channels = latent_channels
469
465
  self.output_channels = output_channels
470
466
  self.num_train_timesteps = num_train_timesteps
471
467
  self.shift = shift
472
- self.height = height
473
- self.width = width
468
+ self.image_shape = image_shape
474
469
 
475
470
  @property
476
471
  def latent_shape(self):
477
- return (None,) + tuple(self.diffuser.latent_shape)
472
+ return (None,) + self._latent_shape
478
473
 
479
474
  @property
480
475
  def clip_hidden_dim(self):
@@ -484,13 +479,17 @@ class StableDiffusion3Backbone(Backbone):
484
479
  def t5_hidden_dim(self):
485
480
  return 4096 if self.t5 is None else self.t5.hidden_dim
486
481
 
487
- def encode_step(self, token_ids, negative_token_ids):
482
+ def encode_text_step(self, token_ids, negative_token_ids):
488
483
  clip_hidden_dim = self.clip_hidden_dim
489
484
  t5_hidden_dim = self.t5_hidden_dim
490
485
 
491
486
  def encode(token_ids):
492
- clip_l_outputs = self.clip_l(token_ids["clip_l"], training=False)
493
- clip_g_outputs = self.clip_g(token_ids["clip_g"], training=False)
487
+ clip_l_outputs = self.clip_l(
488
+ {"token_ids": token_ids["clip_l"]}, training=False
489
+ )
490
+ clip_g_outputs = self.clip_g(
491
+ {"token_ids": token_ids["clip_g"]}, training=False
492
+ )
494
493
  clip_l_projection = self.clip_l_projection(
495
494
  clip_l_outputs["sequence_output"],
496
495
  token_ids["clip_l"],
@@ -501,23 +500,21 @@ class StableDiffusion3Backbone(Backbone):
501
500
  token_ids["clip_g"],
502
501
  training=False,
503
502
  )
504
- pooled_embeddings = ops.concatenate(
505
- [clip_l_projection, clip_g_projection],
506
- axis=-1,
507
- )
508
- embeddings = ops.concatenate(
509
- [
510
- clip_l_outputs["intermediate_output"],
511
- clip_g_outputs["intermediate_output"],
512
- ],
513
- axis=-1,
514
- )
515
- embeddings = ops.pad(
516
- embeddings,
517
- [[0, 0], [0, 0], [0, t5_hidden_dim - clip_hidden_dim]],
503
+ pooled_embeddings, embeddings = self.clip_concatenate(
504
+ clip_l_projection,
505
+ clip_g_projection,
506
+ clip_l_outputs["intermediate_output"],
507
+ clip_g_outputs["intermediate_output"],
508
+ padding=t5_hidden_dim - clip_hidden_dim,
518
509
  )
519
510
  if self.t5 is not None:
520
- t5_outputs = self.t5(token_ids["t5"], training=False)
511
+ t5_outputs = self.t5(
512
+ {
513
+ "token_ids": token_ids["t5"],
514
+ "padding_mask": ops.ones_like(token_ids["t5"]),
515
+ },
516
+ training=False,
517
+ )
521
518
  embeddings = ops.concatenate([embeddings, t5_outputs], axis=-2)
522
519
  else:
523
520
  padded_size = self.clip_l.max_sequence_length
@@ -537,23 +534,36 @@ class StableDiffusion3Backbone(Backbone):
537
534
  negative_pooled_embeddings,
538
535
  )
539
536
 
537
+ def encode_image_step(self, images):
538
+ latents = self.vae.encode(images)
539
+ return self.image_rescaling(latents)
540
+
541
+ def add_noise_step(self, latents, noises, step, num_steps):
542
+ return self.scheduler.add_noise(latents, noises, step, num_steps)
543
+
540
544
  def denoise_step(
541
545
  self,
542
546
  latents,
543
547
  embeddings,
544
- steps,
548
+ step,
545
549
  num_steps,
546
- guidance_scale,
550
+ guidance_scale=None,
547
551
  ):
548
- steps = ops.convert_to_tensor(steps)
549
- steps_next = ops.add(steps, 1)
550
- sigma, timestep = self.scheduler(steps, num_steps)
551
- sigma_next, _ = self.scheduler(steps_next, num_steps)
552
+ step = ops.convert_to_tensor(step)
553
+ next_step = ops.add(step, 1)
554
+ sigma, timestep = self.scheduler(step, num_steps)
555
+ next_sigma, _ = self.scheduler(next_step, num_steps)
552
556
 
553
557
  # Concatenation for classifier-free guidance.
554
- concated_latents, contexts, pooled_projs, timesteps = self.cfg_concat(
555
- latents, *embeddings, timestep
556
- )
558
+ if guidance_scale is not None:
559
+ concated_latents, contexts, pooled_projs, timesteps = (
560
+ self.cfg_concat(latents, *embeddings, timestep)
561
+ )
562
+ else:
563
+ timesteps = ops.broadcast_to(timestep, ops.shape(latents)[:1])
564
+ concated_latents = latents
565
+ contexts = embeddings[0]
566
+ pooled_projs = embeddings[2]
557
567
 
558
568
  # Diffusion.
559
569
  predicted_noise = self.diffuser(
@@ -567,14 +577,15 @@ class StableDiffusion3Backbone(Backbone):
567
577
  )
568
578
 
569
579
  # Classifier-free guidance.
570
- predicted_noise = self.cfg(predicted_noise, guidance_scale)
580
+ if guidance_scale is not None:
581
+ predicted_noise = self.cfg(predicted_noise, guidance_scale)
571
582
 
572
583
  # Euler step.
573
- return self.euler_step(latents, predicted_noise, sigma, sigma_next)
584
+ return self.euler_step(latents, predicted_noise, sigma, next_sigma)
574
585
 
575
586
  def decode_step(self, latents):
576
- latents = self.latent_space_decoder(latents)
577
- return self.decoder(latents, training=False)
587
+ latents = self.latent_rescaling(latents)
588
+ return self.vae.decode(latents, training=False)
578
589
 
579
590
  def get_config(self):
580
591
  config = super().get_config()
@@ -585,8 +596,11 @@ class StableDiffusion3Backbone(Backbone):
585
596
  "mmdit_num_layers": self.mmdit_num_layers,
586
597
  "mmdit_num_heads": self.mmdit_num_heads,
587
598
  "mmdit_position_size": self.mmdit_position_size,
588
- "vae_stackwise_num_filters": self.vae_stackwise_num_filters,
589
- "vae_stackwise_num_blocks": self.vae_stackwise_num_blocks,
599
+ "mmdit_qk_norm": self.mmdit_qk_norm,
600
+ "mmdit_dual_attention_indices": (
601
+ self.mmdit_dual_attention_indices
602
+ ),
603
+ "vae": layers.serialize(self.vae),
590
604
  "clip_l": layers.serialize(self.clip_l),
591
605
  "clip_g": layers.serialize(self.clip_g),
592
606
  "t5": layers.serialize(self.t5),
@@ -594,8 +608,7 @@ class StableDiffusion3Backbone(Backbone):
594
608
  "output_channels": self.output_channels,
595
609
  "num_train_timesteps": self.num_train_timesteps,
596
610
  "shift": self.shift,
597
- "height": self.height,
598
- "width": self.width,
611
+ "image_shape": self.image_shape,
599
612
  }
600
613
  )
601
614
  return config
@@ -607,6 +620,8 @@ class StableDiffusion3Backbone(Backbone):
607
620
  # Propagate `dtype` to text encoders if needed.
608
621
  if "dtype" in config and config["dtype"] is not None:
609
622
  dtype_config = config["dtype"]
623
+ if "dtype" not in config["vae"]["config"]:
624
+ config["vae"]["config"]["dtype"] = dtype_config
610
625
  if "dtype" not in config["clip_l"]["config"]:
611
626
  config["clip_l"]["config"]["dtype"] = dtype_config
612
627
  if "dtype" not in config["clip_g"]["config"]:
@@ -617,7 +632,10 @@ class StableDiffusion3Backbone(Backbone):
617
632
  ):
618
633
  config["t5"]["config"]["dtype"] = dtype_config
619
634
 
620
- # We expect `clip_l`, `clip_g` and/or `t5` to be instantiated.
635
+ # We expect `vae`, `clip_l`, `clip_g` and/or `t5` to be instantiated.
636
+ config["vae"] = layers.deserialize(
637
+ config["vae"], custom_objects=custom_objects
638
+ )
621
639
  config["clip_l"] = layers.deserialize(
622
640
  config["clip_l"], custom_objects=custom_objects
623
641
  )
@@ -628,4 +646,12 @@ class StableDiffusion3Backbone(Backbone):
628
646
  config["t5"] = layers.deserialize(
629
647
  config["t5"], custom_objects=custom_objects
630
648
  )
649
+
650
+ # To maintain backward compatibility, we need to ensure that
651
+ # `mmdit_qk_norm` and `mmdit_dual_attention_indices` is included in the
652
+ # config.
653
+ if "mmdit_qk_norm" not in config:
654
+ config["mmdit_qk_norm"] = None
655
+ if "mmdit_dual_attention_indices" not in config:
656
+ config["mmdit_dual_attention_indices"] = None
631
657
  return cls(**config)