keras-hub-nightly 0.16.1.dev202409250340__py3-none-any.whl → 0.16.1.dev202409270338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (357) hide show
  1. keras_hub/__init__.py +0 -13
  2. keras_hub/api/__init__.py +0 -13
  3. keras_hub/api/bounding_box/__init__.py +0 -13
  4. keras_hub/api/layers/__init__.py +3 -13
  5. keras_hub/api/metrics/__init__.py +0 -13
  6. keras_hub/api/models/__init__.py +16 -13
  7. keras_hub/api/samplers/__init__.py +0 -13
  8. keras_hub/api/tokenizers/__init__.py +1 -13
  9. keras_hub/api/utils/__init__.py +0 -13
  10. keras_hub/src/__init__.py +0 -13
  11. keras_hub/src/api_export.py +0 -14
  12. keras_hub/src/bounding_box/__init__.py +0 -13
  13. keras_hub/src/bounding_box/converters.py +0 -13
  14. keras_hub/src/bounding_box/formats.py +0 -13
  15. keras_hub/src/bounding_box/iou.py +1 -13
  16. keras_hub/src/bounding_box/to_dense.py +0 -14
  17. keras_hub/src/bounding_box/to_ragged.py +0 -13
  18. keras_hub/src/bounding_box/utils.py +0 -13
  19. keras_hub/src/bounding_box/validate_format.py +0 -14
  20. keras_hub/src/layers/__init__.py +0 -13
  21. keras_hub/src/layers/modeling/__init__.py +0 -13
  22. keras_hub/src/layers/modeling/alibi_bias.py +0 -13
  23. keras_hub/src/layers/modeling/cached_multi_head_attention.py +0 -14
  24. keras_hub/src/layers/modeling/f_net_encoder.py +0 -14
  25. keras_hub/src/layers/modeling/masked_lm_head.py +0 -14
  26. keras_hub/src/layers/modeling/position_embedding.py +0 -14
  27. keras_hub/src/layers/modeling/reversible_embedding.py +0 -14
  28. keras_hub/src/layers/modeling/rotary_embedding.py +0 -14
  29. keras_hub/src/layers/modeling/sine_position_encoding.py +0 -14
  30. keras_hub/src/layers/modeling/token_and_position_embedding.py +0 -14
  31. keras_hub/src/layers/modeling/transformer_decoder.py +0 -14
  32. keras_hub/src/layers/modeling/transformer_encoder.py +0 -14
  33. keras_hub/src/layers/modeling/transformer_layer_utils.py +0 -14
  34. keras_hub/src/layers/preprocessing/__init__.py +0 -13
  35. keras_hub/src/layers/preprocessing/audio_converter.py +0 -13
  36. keras_hub/src/layers/preprocessing/image_converter.py +0 -13
  37. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +0 -15
  38. keras_hub/src/layers/preprocessing/multi_segment_packer.py +0 -14
  39. keras_hub/src/layers/preprocessing/preprocessing_layer.py +0 -14
  40. keras_hub/src/layers/preprocessing/random_deletion.py +0 -14
  41. keras_hub/src/layers/preprocessing/random_swap.py +0 -14
  42. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -13
  43. keras_hub/src/layers/preprocessing/start_end_packer.py +0 -15
  44. keras_hub/src/metrics/__init__.py +0 -13
  45. keras_hub/src/metrics/bleu.py +0 -14
  46. keras_hub/src/metrics/edit_distance.py +0 -14
  47. keras_hub/src/metrics/perplexity.py +0 -14
  48. keras_hub/src/metrics/rouge_base.py +0 -14
  49. keras_hub/src/metrics/rouge_l.py +0 -14
  50. keras_hub/src/metrics/rouge_n.py +0 -14
  51. keras_hub/src/models/__init__.py +0 -13
  52. keras_hub/src/models/albert/__init__.py +0 -14
  53. keras_hub/src/models/albert/albert_backbone.py +0 -14
  54. keras_hub/src/models/albert/albert_masked_lm.py +0 -14
  55. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +0 -14
  56. keras_hub/src/models/albert/albert_presets.py +0 -14
  57. keras_hub/src/models/albert/albert_text_classifier.py +0 -14
  58. keras_hub/src/models/albert/albert_text_classifier_preprocessor.py +0 -14
  59. keras_hub/src/models/albert/albert_tokenizer.py +0 -14
  60. keras_hub/src/models/backbone.py +0 -14
  61. keras_hub/src/models/bart/__init__.py +0 -14
  62. keras_hub/src/models/bart/bart_backbone.py +0 -14
  63. keras_hub/src/models/bart/bart_presets.py +0 -13
  64. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +0 -15
  65. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +0 -15
  66. keras_hub/src/models/bart/bart_tokenizer.py +0 -15
  67. keras_hub/src/models/bert/__init__.py +0 -14
  68. keras_hub/src/models/bert/bert_backbone.py +0 -14
  69. keras_hub/src/models/bert/bert_masked_lm.py +0 -14
  70. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +0 -14
  71. keras_hub/src/models/bert/bert_presets.py +0 -13
  72. keras_hub/src/models/bert/bert_text_classifier.py +0 -14
  73. keras_hub/src/models/bert/bert_text_classifier_preprocessor.py +0 -14
  74. keras_hub/src/models/bert/bert_tokenizer.py +0 -14
  75. keras_hub/src/models/bloom/__init__.py +0 -14
  76. keras_hub/src/models/bloom/bloom_attention.py +0 -13
  77. keras_hub/src/models/bloom/bloom_backbone.py +0 -14
  78. keras_hub/src/models/bloom/bloom_causal_lm.py +0 -15
  79. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +0 -15
  80. keras_hub/src/models/bloom/bloom_decoder.py +0 -13
  81. keras_hub/src/models/bloom/bloom_presets.py +0 -13
  82. keras_hub/src/models/bloom/bloom_tokenizer.py +0 -15
  83. keras_hub/src/models/causal_lm.py +0 -14
  84. keras_hub/src/models/causal_lm_preprocessor.py +0 -13
  85. keras_hub/src/models/clip/__init__.py +0 -0
  86. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_encoder_block.py +8 -15
  87. keras_hub/src/models/clip/clip_preprocessor.py +134 -0
  88. keras_hub/src/models/clip/clip_text_encoder.py +139 -0
  89. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_tokenizer.py +65 -41
  90. keras_hub/src/models/csp_darknet/__init__.py +0 -13
  91. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +0 -13
  92. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -13
  93. keras_hub/src/models/deberta_v3/__init__.py +0 -14
  94. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +0 -15
  95. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +0 -15
  96. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +0 -14
  97. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -13
  98. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +0 -15
  99. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier_preprocessor.py +0 -14
  100. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +0 -15
  101. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +0 -14
  102. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +0 -14
  103. keras_hub/src/models/deberta_v3/relative_embedding.py +0 -14
  104. keras_hub/src/models/densenet/__init__.py +5 -13
  105. keras_hub/src/models/densenet/densenet_backbone.py +11 -21
  106. keras_hub/src/models/densenet/densenet_image_classifier.py +27 -17
  107. keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py +27 -0
  108. keras_hub/src/models/{stable_diffusion_v3/__init__.py → densenet/densenet_image_converter.py} +10 -0
  109. keras_hub/src/models/densenet/densenet_presets.py +56 -0
  110. keras_hub/src/models/distil_bert/__init__.py +0 -14
  111. keras_hub/src/models/distil_bert/distil_bert_backbone.py +0 -15
  112. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +0 -15
  113. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +0 -14
  114. keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -13
  115. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +0 -15
  116. keras_hub/src/models/distil_bert/distil_bert_text_classifier_preprocessor.py +0 -15
  117. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +0 -15
  118. keras_hub/src/models/efficientnet/__init__.py +0 -13
  119. keras_hub/src/models/efficientnet/efficientnet_backbone.py +0 -13
  120. keras_hub/src/models/efficientnet/fusedmbconv.py +0 -14
  121. keras_hub/src/models/efficientnet/mbconv.py +0 -14
  122. keras_hub/src/models/electra/__init__.py +0 -14
  123. keras_hub/src/models/electra/electra_backbone.py +0 -14
  124. keras_hub/src/models/electra/electra_presets.py +0 -13
  125. keras_hub/src/models/electra/electra_tokenizer.py +0 -14
  126. keras_hub/src/models/f_net/__init__.py +0 -14
  127. keras_hub/src/models/f_net/f_net_backbone.py +0 -15
  128. keras_hub/src/models/f_net/f_net_masked_lm.py +0 -15
  129. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +0 -14
  130. keras_hub/src/models/f_net/f_net_presets.py +0 -13
  131. keras_hub/src/models/f_net/f_net_text_classifier.py +0 -15
  132. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +0 -15
  133. keras_hub/src/models/f_net/f_net_tokenizer.py +0 -15
  134. keras_hub/src/models/falcon/__init__.py +0 -14
  135. keras_hub/src/models/falcon/falcon_attention.py +0 -13
  136. keras_hub/src/models/falcon/falcon_backbone.py +0 -13
  137. keras_hub/src/models/falcon/falcon_causal_lm.py +0 -14
  138. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +0 -14
  139. keras_hub/src/models/falcon/falcon_presets.py +0 -13
  140. keras_hub/src/models/falcon/falcon_tokenizer.py +0 -15
  141. keras_hub/src/models/falcon/falcon_transformer_decoder.py +0 -13
  142. keras_hub/src/models/feature_pyramid_backbone.py +0 -13
  143. keras_hub/src/models/gemma/__init__.py +0 -14
  144. keras_hub/src/models/gemma/gemma_attention.py +0 -13
  145. keras_hub/src/models/gemma/gemma_backbone.py +0 -15
  146. keras_hub/src/models/gemma/gemma_causal_lm.py +0 -15
  147. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +0 -14
  148. keras_hub/src/models/gemma/gemma_decoder_block.py +0 -13
  149. keras_hub/src/models/gemma/gemma_presets.py +0 -13
  150. keras_hub/src/models/gemma/gemma_tokenizer.py +0 -14
  151. keras_hub/src/models/gemma/rms_normalization.py +0 -14
  152. keras_hub/src/models/gpt2/__init__.py +0 -14
  153. keras_hub/src/models/gpt2/gpt2_backbone.py +0 -15
  154. keras_hub/src/models/gpt2/gpt2_causal_lm.py +0 -15
  155. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +0 -14
  156. keras_hub/src/models/gpt2/gpt2_preprocessor.py +0 -15
  157. keras_hub/src/models/gpt2/gpt2_presets.py +0 -13
  158. keras_hub/src/models/gpt2/gpt2_tokenizer.py +0 -15
  159. keras_hub/src/models/gpt_neo_x/__init__.py +0 -13
  160. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +0 -14
  161. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +0 -14
  162. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +0 -14
  163. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +0 -14
  164. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +0 -14
  165. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +0 -14
  166. keras_hub/src/models/image_classifier.py +0 -13
  167. keras_hub/src/models/image_classifier_preprocessor.py +0 -13
  168. keras_hub/src/models/image_segmenter.py +0 -13
  169. keras_hub/src/models/llama/__init__.py +0 -14
  170. keras_hub/src/models/llama/llama_attention.py +0 -13
  171. keras_hub/src/models/llama/llama_backbone.py +0 -13
  172. keras_hub/src/models/llama/llama_causal_lm.py +0 -13
  173. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +0 -15
  174. keras_hub/src/models/llama/llama_decoder.py +0 -13
  175. keras_hub/src/models/llama/llama_layernorm.py +0 -13
  176. keras_hub/src/models/llama/llama_presets.py +0 -13
  177. keras_hub/src/models/llama/llama_tokenizer.py +0 -14
  178. keras_hub/src/models/llama3/__init__.py +0 -14
  179. keras_hub/src/models/llama3/llama3_backbone.py +0 -14
  180. keras_hub/src/models/llama3/llama3_causal_lm.py +0 -13
  181. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +0 -14
  182. keras_hub/src/models/llama3/llama3_presets.py +0 -13
  183. keras_hub/src/models/llama3/llama3_tokenizer.py +0 -14
  184. keras_hub/src/models/masked_lm.py +0 -13
  185. keras_hub/src/models/masked_lm_preprocessor.py +0 -13
  186. keras_hub/src/models/mistral/__init__.py +0 -14
  187. keras_hub/src/models/mistral/mistral_attention.py +0 -13
  188. keras_hub/src/models/mistral/mistral_backbone.py +0 -14
  189. keras_hub/src/models/mistral/mistral_causal_lm.py +0 -14
  190. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +0 -14
  191. keras_hub/src/models/mistral/mistral_layer_norm.py +0 -13
  192. keras_hub/src/models/mistral/mistral_presets.py +0 -13
  193. keras_hub/src/models/mistral/mistral_tokenizer.py +0 -14
  194. keras_hub/src/models/mistral/mistral_transformer_decoder.py +0 -13
  195. keras_hub/src/models/mix_transformer/__init__.py +0 -13
  196. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +0 -13
  197. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -13
  198. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +0 -13
  199. keras_hub/src/models/mobilenet/__init__.py +0 -13
  200. keras_hub/src/models/mobilenet/mobilenet_backbone.py +0 -13
  201. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -13
  202. keras_hub/src/models/opt/__init__.py +0 -14
  203. keras_hub/src/models/opt/opt_backbone.py +0 -15
  204. keras_hub/src/models/opt/opt_causal_lm.py +0 -15
  205. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +0 -13
  206. keras_hub/src/models/opt/opt_presets.py +0 -13
  207. keras_hub/src/models/opt/opt_tokenizer.py +0 -15
  208. keras_hub/src/models/pali_gemma/__init__.py +0 -13
  209. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +0 -13
  210. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +0 -13
  211. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +0 -13
  212. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +0 -14
  213. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +0 -13
  214. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +0 -13
  215. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +0 -13
  216. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +0 -13
  217. keras_hub/src/models/phi3/__init__.py +0 -14
  218. keras_hub/src/models/phi3/phi3_attention.py +0 -13
  219. keras_hub/src/models/phi3/phi3_backbone.py +0 -13
  220. keras_hub/src/models/phi3/phi3_causal_lm.py +0 -13
  221. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +0 -14
  222. keras_hub/src/models/phi3/phi3_decoder.py +0 -13
  223. keras_hub/src/models/phi3/phi3_layernorm.py +0 -13
  224. keras_hub/src/models/phi3/phi3_presets.py +0 -13
  225. keras_hub/src/models/phi3/phi3_rotary_embedding.py +0 -13
  226. keras_hub/src/models/phi3/phi3_tokenizer.py +0 -13
  227. keras_hub/src/models/preprocessor.py +51 -32
  228. keras_hub/src/models/resnet/__init__.py +0 -14
  229. keras_hub/src/models/resnet/resnet_backbone.py +0 -13
  230. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -13
  231. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +0 -14
  232. keras_hub/src/models/resnet/resnet_image_converter.py +0 -13
  233. keras_hub/src/models/resnet/resnet_presets.py +0 -13
  234. keras_hub/src/models/retinanet/__init__.py +0 -13
  235. keras_hub/src/models/retinanet/anchor_generator.py +0 -14
  236. keras_hub/src/models/retinanet/box_matcher.py +0 -14
  237. keras_hub/src/models/retinanet/non_max_supression.py +0 -14
  238. keras_hub/src/models/roberta/__init__.py +0 -14
  239. keras_hub/src/models/roberta/roberta_backbone.py +0 -15
  240. keras_hub/src/models/roberta/roberta_masked_lm.py +0 -15
  241. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +0 -14
  242. keras_hub/src/models/roberta/roberta_presets.py +0 -13
  243. keras_hub/src/models/roberta/roberta_text_classifier.py +0 -15
  244. keras_hub/src/models/roberta/roberta_text_classifier_preprocessor.py +0 -14
  245. keras_hub/src/models/roberta/roberta_tokenizer.py +0 -15
  246. keras_hub/src/models/sam/__init__.py +0 -13
  247. keras_hub/src/models/sam/sam_backbone.py +0 -14
  248. keras_hub/src/models/sam/sam_image_segmenter.py +0 -14
  249. keras_hub/src/models/sam/sam_layers.py +0 -14
  250. keras_hub/src/models/sam/sam_mask_decoder.py +0 -14
  251. keras_hub/src/models/sam/sam_prompt_encoder.py +0 -14
  252. keras_hub/src/models/sam/sam_transformer.py +0 -14
  253. keras_hub/src/models/seq_2_seq_lm.py +0 -13
  254. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +0 -13
  255. keras_hub/src/models/stable_diffusion_3/__init__.py +9 -0
  256. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +80 -0
  257. keras_hub/src/models/{stable_diffusion_v3 → stable_diffusion_3}/mmdit.py +351 -39
  258. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +631 -0
  259. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +31 -0
  260. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +138 -0
  261. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +83 -0
  262. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py} +7 -20
  263. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +320 -0
  264. keras_hub/src/models/t5/__init__.py +0 -14
  265. keras_hub/src/models/t5/t5_backbone.py +0 -14
  266. keras_hub/src/models/t5/t5_layer_norm.py +0 -14
  267. keras_hub/src/models/t5/t5_multi_head_attention.py +0 -14
  268. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_preprocessor.py → t5/t5_preprocessor.py} +12 -16
  269. keras_hub/src/models/t5/t5_presets.py +0 -13
  270. keras_hub/src/models/t5/t5_tokenizer.py +0 -14
  271. keras_hub/src/models/t5/t5_transformer_layer.py +0 -14
  272. keras_hub/src/models/task.py +0 -14
  273. keras_hub/src/models/text_classifier.py +0 -13
  274. keras_hub/src/models/text_classifier_preprocessor.py +0 -13
  275. keras_hub/src/models/text_to_image.py +282 -0
  276. keras_hub/src/models/vgg/__init__.py +0 -13
  277. keras_hub/src/models/vgg/vgg_backbone.py +0 -13
  278. keras_hub/src/models/vgg/vgg_image_classifier.py +0 -13
  279. keras_hub/src/models/vit_det/__init__.py +0 -13
  280. keras_hub/src/models/vit_det/vit_det_backbone.py +0 -14
  281. keras_hub/src/models/vit_det/vit_layers.py +0 -15
  282. keras_hub/src/models/whisper/__init__.py +0 -14
  283. keras_hub/src/models/whisper/whisper_audio_converter.py +0 -15
  284. keras_hub/src/models/whisper/whisper_backbone.py +0 -15
  285. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +0 -13
  286. keras_hub/src/models/whisper/whisper_decoder.py +0 -14
  287. keras_hub/src/models/whisper/whisper_encoder.py +0 -14
  288. keras_hub/src/models/whisper/whisper_presets.py +0 -14
  289. keras_hub/src/models/whisper/whisper_tokenizer.py +0 -14
  290. keras_hub/src/models/xlm_roberta/__init__.py +0 -14
  291. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +0 -15
  292. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +0 -15
  293. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +0 -14
  294. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -13
  295. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +0 -15
  296. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier_preprocessor.py +0 -15
  297. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +0 -15
  298. keras_hub/src/models/xlnet/__init__.py +0 -13
  299. keras_hub/src/models/xlnet/relative_attention.py +0 -14
  300. keras_hub/src/models/xlnet/xlnet_backbone.py +0 -14
  301. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +0 -14
  302. keras_hub/src/models/xlnet/xlnet_encoder.py +0 -14
  303. keras_hub/src/samplers/__init__.py +0 -13
  304. keras_hub/src/samplers/beam_sampler.py +0 -14
  305. keras_hub/src/samplers/contrastive_sampler.py +0 -14
  306. keras_hub/src/samplers/greedy_sampler.py +0 -14
  307. keras_hub/src/samplers/random_sampler.py +0 -14
  308. keras_hub/src/samplers/sampler.py +0 -14
  309. keras_hub/src/samplers/serialization.py +0 -14
  310. keras_hub/src/samplers/top_k_sampler.py +0 -14
  311. keras_hub/src/samplers/top_p_sampler.py +0 -14
  312. keras_hub/src/tests/__init__.py +0 -13
  313. keras_hub/src/tests/test_case.py +0 -14
  314. keras_hub/src/tokenizers/__init__.py +0 -13
  315. keras_hub/src/tokenizers/byte_pair_tokenizer.py +0 -14
  316. keras_hub/src/tokenizers/byte_tokenizer.py +0 -14
  317. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +0 -14
  318. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +0 -14
  319. keras_hub/src/tokenizers/tokenizer.py +23 -27
  320. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +0 -15
  321. keras_hub/src/tokenizers/word_piece_tokenizer.py +0 -14
  322. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +0 -15
  323. keras_hub/src/utils/__init__.py +0 -13
  324. keras_hub/src/utils/imagenet/__init__.py +0 -13
  325. keras_hub/src/utils/imagenet/imagenet_utils.py +0 -13
  326. keras_hub/src/utils/keras_utils.py +0 -14
  327. keras_hub/src/utils/pipeline_model.py +0 -14
  328. keras_hub/src/utils/preset_utils.py +32 -76
  329. keras_hub/src/utils/python_utils.py +0 -13
  330. keras_hub/src/utils/tensor_utils.py +0 -14
  331. keras_hub/src/utils/timm/__init__.py +0 -13
  332. keras_hub/src/utils/timm/convert_densenet.py +107 -0
  333. keras_hub/src/utils/timm/convert_resnet.py +0 -13
  334. keras_hub/src/utils/timm/preset_loader.py +3 -13
  335. keras_hub/src/utils/transformers/__init__.py +0 -13
  336. keras_hub/src/utils/transformers/convert_albert.py +0 -13
  337. keras_hub/src/utils/transformers/convert_bart.py +0 -13
  338. keras_hub/src/utils/transformers/convert_bert.py +0 -13
  339. keras_hub/src/utils/transformers/convert_distilbert.py +0 -13
  340. keras_hub/src/utils/transformers/convert_gemma.py +0 -13
  341. keras_hub/src/utils/transformers/convert_gpt2.py +0 -13
  342. keras_hub/src/utils/transformers/convert_llama3.py +0 -13
  343. keras_hub/src/utils/transformers/convert_mistral.py +0 -13
  344. keras_hub/src/utils/transformers/convert_pali_gemma.py +0 -13
  345. keras_hub/src/utils/transformers/preset_loader.py +1 -15
  346. keras_hub/src/utils/transformers/safetensor_utils.py +9 -15
  347. keras_hub/src/version_utils.py +1 -15
  348. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409270338.dist-info}/METADATA +30 -27
  349. keras_hub_nightly-0.16.1.dev202409270338.dist-info/RECORD +351 -0
  350. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +0 -93
  351. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +0 -149
  352. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +0 -317
  353. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +0 -126
  354. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +0 -186
  355. keras_hub_nightly-0.16.1.dev202409250340.dist-info/RECORD +0 -342
  356. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409270338.dist-info}/WHEEL +0 -0
  357. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409270338.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,631 @@
1
+ import keras
2
+ from keras import layers
3
+ from keras import ops
4
+
5
+ from keras_hub.src.api_export import keras_hub_export
6
+ from keras_hub.src.models.backbone import Backbone
7
+ from keras_hub.src.models.stable_diffusion_3.flow_match_euler_discrete_scheduler import (
8
+ FlowMatchEulerDiscreteScheduler,
9
+ )
10
+ from keras_hub.src.models.stable_diffusion_3.mmdit import MMDiT
11
+ from keras_hub.src.models.stable_diffusion_3.vae_image_decoder import (
12
+ VAEImageDecoder,
13
+ )
14
+ from keras_hub.src.utils.keras_utils import standardize_data_format
15
+
16
+
17
+ class CLIPProjection(layers.Layer):
18
+ def __init__(self, hidden_dim, **kwargs):
19
+ super().__init__(**kwargs)
20
+ self.hidden_dim = int(hidden_dim)
21
+
22
+ self.dense = layers.Dense(
23
+ hidden_dim,
24
+ use_bias=False,
25
+ dtype=self.dtype_policy,
26
+ name="dense",
27
+ )
28
+
29
+ def build(self, inputs_shape, token_ids_shape):
30
+ inputs_shape = list(inputs_shape)
31
+ self.dense.build([None, inputs_shape[-1]])
32
+
33
+ # Assign identity matrix to the kernel as default.
34
+ self.dense._kernel.assign(ops.eye(self.hidden_dim))
35
+
36
+ def call(self, inputs, token_ids):
37
+ indices = ops.expand_dims(
38
+ ops.cast(ops.argmax(token_ids, axis=-1), "int32"), axis=-1
39
+ )
40
+ pooled_output = ops.take_along_axis(inputs, indices[:, :, None], axis=1)
41
+ pooled_output = ops.squeeze(pooled_output, axis=1)
42
+ return self.dense(pooled_output)
43
+
44
+ def get_config(self):
45
+ config = super().get_config()
46
+ config.update(
47
+ {
48
+ "hidden_dim": self.hidden_dim,
49
+ }
50
+ )
51
+ return config
52
+
53
+ def compute_output_shape(self, inputs_shape):
54
+ return (inputs_shape[0], self.hidden_dim)
55
+
56
+
57
+ class ClassifierFreeGuidanceConcatenate(layers.Layer):
58
+ def __init__(self, axis=0, **kwargs):
59
+ super().__init__(**kwargs)
60
+ self.axis = axis
61
+
62
+ def call(
63
+ self,
64
+ latents,
65
+ positive_contexts,
66
+ negative_contexts,
67
+ positive_pooled_projections,
68
+ negative_pooled_projections,
69
+ timestep,
70
+ ):
71
+ timestep = ops.broadcast_to(timestep, ops.shape(latents)[:1])
72
+ latents = ops.concatenate([latents, latents], axis=self.axis)
73
+ contexts = ops.concatenate(
74
+ [positive_contexts, negative_contexts], axis=self.axis
75
+ )
76
+ pooled_projections = ops.concatenate(
77
+ [positive_pooled_projections, negative_pooled_projections],
78
+ axis=self.axis,
79
+ )
80
+ timesteps = ops.concatenate([timestep, timestep], axis=self.axis)
81
+ return latents, contexts, pooled_projections, timesteps
82
+
83
+ def get_config(self):
84
+ return super().get_config()
85
+
86
+
87
+ class ClassifierFreeGuidance(layers.Layer):
88
+ """Perform classifier free guidance.
89
+
90
+ This layer expects the inputs to be a concatenation of positive and negative
91
+ (or empty) noise. The computation applies the classifier-free guidance
92
+ scale.
93
+
94
+ Args:
95
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
96
+ including `name`, `dtype` etc.
97
+
98
+ Call arguments:
99
+ inputs: A concatenation of positive and negative (or empty) noises.
100
+ guidance_scale: The scale factor for classifier-free guidance.
101
+
102
+ Reference:
103
+ - [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
104
+ """
105
+
106
+ def __init__(self, **kwargs):
107
+ super().__init__(**kwargs)
108
+
109
+ def call(self, inputs, guidance_scale):
110
+ positive_noise, negative_noise = ops.split(inputs, 2, axis=0)
111
+ return ops.add(
112
+ negative_noise,
113
+ ops.multiply(
114
+ guidance_scale, ops.subtract(positive_noise, negative_noise)
115
+ ),
116
+ )
117
+
118
+ def get_config(self):
119
+ return super().get_config()
120
+
121
+ def compute_output_shape(self, inputs_shape):
122
+ outputs_shape = list(inputs_shape)
123
+ if outputs_shape[0] is not None:
124
+ outputs_shape[0] = outputs_shape[0] // 2
125
+ return outputs_shape
126
+
127
+
128
+ class EulerStep(layers.Layer):
129
+ """A layer predicts the sample with the timestep and the predicted noise.
130
+
131
+ Args:
132
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
133
+ including `name`, `dtype` etc.
134
+
135
+ Call arguments:
136
+ latents: A current sample created by the diffusion process.
137
+ noise_residual: The direct output from the diffusion model.
138
+ sigma: The amount of noise added at the current timestep.
139
+ sigma_next: The amount of noise added at the next timestep.
140
+
141
+ References:
142
+ - [Common Diffusion Noise Schedules and Sample Steps are Flawed](
143
+ https://arxiv.org/abs/2305.08891).
144
+ - [Elucidating the Design Space of Diffusion-Based Generative Models](
145
+ https://arxiv.org/abs/2206.00364).
146
+ """
147
+
148
+ def __init__(self, **kwargs):
149
+ super().__init__(**kwargs)
150
+
151
+ def call(self, latents, noise_residual, sigma, sigma_next):
152
+ sigma_diff = ops.subtract(sigma_next, sigma)
153
+ return ops.add(latents, ops.multiply(sigma_diff, noise_residual))
154
+
155
+ def get_config(self):
156
+ return super().get_config()
157
+
158
+ def compute_output_shape(self, latents_shape):
159
+ return latents_shape
160
+
161
+
162
+ class LatentSpaceDecoder(layers.Layer):
163
+ """Decoder to transform the latent space back to the original image space.
164
+
165
+ During decoding, the latents are transformed back to the original image
166
+ space using the equation: `latents / scale + shift`.
167
+
168
+ Args:
169
+ scale: float. The scaling factor.
170
+ shift: float. The shift factor.
171
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
172
+ including `name`, `dtype` etc.
173
+
174
+ Call arguments:
175
+ latents: The latent tensor to be transformed.
176
+
177
+ Reference:
178
+ - [High-Resolution Image Synthesis with Latent Diffusion Models](
179
+ https://arxiv.org/abs/2112.10752).
180
+ """
181
+
182
+ def __init__(self, scale, shift, **kwargs):
183
+ super().__init__(**kwargs)
184
+ self.scale = scale
185
+ self.shift = shift
186
+
187
+ def call(self, latents):
188
+ return ops.add(ops.divide(latents, self.scale), self.shift)
189
+
190
+ def get_config(self):
191
+ config = super().get_config()
192
+ config.update(
193
+ {
194
+ "scale": self.scale,
195
+ "shift": self.shift,
196
+ }
197
+ )
198
+ return config
199
+
200
+ def compute_output_shape(self, latents_shape):
201
+ return latents_shape
202
+
203
+
204
+ @keras_hub_export("keras_hub.models.StableDiffusion3Backbone")
205
+ class StableDiffusion3Backbone(Backbone):
206
+ """Stable Diffusion 3 core network with hyperparameters.
207
+
208
+ This backbone imports CLIP and T5 models as text encoders and implements the
209
+ base MMDiT and VAE networks for the Stable Diffusion 3 model.
210
+
211
+ The default constructor gives a fully customizable, randomly initialized
212
+ MMDiT and VAE models with any hyperparameters. To load preset architectures
213
+ and weights, use the `from_preset` constructor.
214
+
215
+ Args:
216
+ mmdit_patch_size: int. The size of each square patch in the input image
217
+ in MMDiT.
218
+ mmdit_hidden_dim: int. The size of the transformer hidden state at the
219
+ end of each transformer layer in MMDiT.
220
+ mmdit_num_layers: int. The number of transformer layers in MMDiT.
221
+ mmdit_num_heads: int. The number of attention heads for each
222
+ transformer in MMDiT.
223
+ mmdit_position_size: int. The size of the height and width for the
224
+ position embedding in MMDiT.
225
+ vae_stackwise_num_filters: list of ints. The number of filters for each
226
+ stack in VAE.
227
+ vae_stackwise_num_blocks: list of ints. The number of blocks for each
228
+ stack in VAE.
229
+ clip_l: `keras_hub.models.CLIPTextEncoder`. The text encoder for
230
+ encoding the inputs.
231
+ clip_g: `keras_hub.models.CLIPTextEncoder`. The text encoder for
232
+ encoding the inputs.
233
+ t5: optional `keras_hub.models.T5Encoder`. The text encoder for
234
+ encoding the inputs.
235
+ latent_channels: int. The number of channels in the latent. Defaults to
236
+ `16`.
237
+ output_channels: int. The number of channels in the output. Defaults to
238
+ `3`.
239
+ num_train_timesteps: int. The number of diffusion steps to train the
240
+ model. Defaults to `1000`.
241
+ shift: float. The shift value for the timestep schedule. Defaults to
242
+ `1.0`.
243
+ height: optional int. The output height of the image.
244
+ width: optional int. The output width of the image.
245
+ data_format: `None` or str. If specified, either `"channels_last"` or
246
+ `"channels_first"`. The ordering of the dimensions in the
247
+ inputs. `"channels_last"` corresponds to inputs with shape
248
+ `(batch_size, height, width, channels)`
249
+ while `"channels_first"` corresponds to inputs with shape
250
+ `(batch_size, channels, height, width)`. It defaults to the
251
+ `image_data_format` value found in your Keras config file at
252
+ `~/.keras/keras.json`. If you never set it, then it will be
253
+ `"channels_last"`.
254
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
255
+ for the models computations and weights. Note that some
256
+ computations, such as softmax and layer normalization will always
257
+ be done a float32 precision regardless of dtype.
258
+
259
+ Example:
260
+ ```python
261
+ # Pretrained Stable Diffusion 3 model.
262
+ model = keras_hub.models.StableDiffusion3Backbone.from_preset(
263
+ "stable_diffusion_3_medium"
264
+ )
265
+
266
+ # Randomly initialized Stable Diffusion 3 model with custom config.
267
+ clip_l = keras_hub.models.CLIPTextEncoder(...)
268
+ clip_g = keras_hub.models.CLIPTextEncoder(...)
269
+ model = keras_hub.models.StableDiffusion3Backbone(
270
+ mmdit_patch_size=2,
271
+ mmdit_num_heads=4,
272
+ mmdit_hidden_dim=256,
273
+ mmdit_depth=4,
274
+ mmdit_position_size=192,
275
+ vae_stackwise_num_filters=[128, 128, 64, 32],
276
+ vae_stackwise_num_blocks=[1, 1, 1, 1],
277
+ clip_l=clip_l,
278
+ clip_g=clip_g,
279
+ )
280
+ ```
281
+ """
282
+
283
+ def __init__(
284
+ self,
285
+ mmdit_patch_size,
286
+ mmdit_hidden_dim,
287
+ mmdit_num_layers,
288
+ mmdit_num_heads,
289
+ mmdit_position_size,
290
+ vae_stackwise_num_filters,
291
+ vae_stackwise_num_blocks,
292
+ clip_l,
293
+ clip_g,
294
+ t5=None,
295
+ latent_channels=16,
296
+ output_channels=3,
297
+ num_train_timesteps=1000,
298
+ shift=1.0,
299
+ height=None,
300
+ width=None,
301
+ data_format=None,
302
+ dtype=None,
303
+ **kwargs,
304
+ ):
305
+ height = int(height or 1024)
306
+ width = int(width or 1024)
307
+ if height % 8 != 0 or width % 8 != 0:
308
+ raise ValueError(
309
+ "`height` and `width` must be divisible by 8. "
310
+ f"Received: height={height}, width={width}"
311
+ )
312
+ data_format = standardize_data_format(data_format)
313
+ if data_format != "channels_last":
314
+ raise NotImplementedError
315
+ latent_shape = (height // 8, width // 8, latent_channels)
316
+ context_shape = (None, 4096 if t5 is None else t5.hidden_dim)
317
+ pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,)
318
+
319
+ # === Layers ===
320
+ self.clip_l = clip_l
321
+ self.clip_l_projection = CLIPProjection(
322
+ clip_l.hidden_dim, dtype=dtype, name="clip_l_projection"
323
+ )
324
+ self.clip_l_projection.build([None, clip_l.hidden_dim], None)
325
+ self.clip_g = clip_g
326
+ self.clip_g_projection = CLIPProjection(
327
+ clip_g.hidden_dim, dtype=dtype, name="clip_g_projection"
328
+ )
329
+ self.clip_g_projection.build([None, clip_g.hidden_dim], None)
330
+ self.t5 = t5
331
+ self.diffuser = MMDiT(
332
+ mmdit_patch_size,
333
+ mmdit_hidden_dim,
334
+ mmdit_num_layers,
335
+ mmdit_num_heads,
336
+ mmdit_position_size,
337
+ latent_shape=latent_shape,
338
+ context_shape=context_shape,
339
+ pooled_projection_shape=pooled_projection_shape,
340
+ data_format=data_format,
341
+ dtype=dtype,
342
+ name="diffuser",
343
+ )
344
+ self.decoder = VAEImageDecoder(
345
+ vae_stackwise_num_filters,
346
+ vae_stackwise_num_blocks,
347
+ output_channels,
348
+ latent_shape=latent_shape,
349
+ data_format=data_format,
350
+ dtype=dtype,
351
+ name="decoder",
352
+ )
353
+ # Set `dtype="float32"` to ensure the high precision for the noise
354
+ # residual.
355
+ self.scheduler = FlowMatchEulerDiscreteScheduler(
356
+ num_train_timesteps=num_train_timesteps,
357
+ shift=shift,
358
+ dtype="float32",
359
+ name="scheduler",
360
+ )
361
+ self.cfg_concat = ClassifierFreeGuidanceConcatenate(
362
+ dtype="float32", name="classifier_free_guidance_concat"
363
+ )
364
+ self.cfg = ClassifierFreeGuidance(
365
+ dtype="float32", name="classifier_free_guidance"
366
+ )
367
+ self.euler_step = EulerStep(dtype="float32", name="euler_step")
368
+ self.latent_space_decoder = LatentSpaceDecoder(
369
+ scale=self.decoder.scaling_factor,
370
+ shift=self.decoder.shift_factor,
371
+ dtype="float32",
372
+ name="latent_space_decoder",
373
+ )
374
+
375
+ # === Functional Model ===
376
+ latent_input = keras.Input(
377
+ shape=latent_shape,
378
+ name="latents",
379
+ )
380
+ clip_l_token_id_input = keras.Input(
381
+ shape=(None,),
382
+ dtype="int32",
383
+ name="clip_l_token_ids",
384
+ )
385
+ clip_l_negative_token_id_input = keras.Input(
386
+ shape=(None,),
387
+ dtype="int32",
388
+ name="clip_l_negative_token_ids",
389
+ )
390
+ clip_g_token_id_input = keras.Input(
391
+ shape=(None,),
392
+ dtype="int32",
393
+ name="clip_g_token_ids",
394
+ )
395
+ clip_g_negative_token_id_input = keras.Input(
396
+ shape=(None,),
397
+ dtype="int32",
398
+ name="clip_g_negative_token_ids",
399
+ )
400
+ token_ids = {
401
+ "clip_l": clip_l_token_id_input,
402
+ "clip_g": clip_g_token_id_input,
403
+ }
404
+ negative_token_ids = {
405
+ "clip_l": clip_l_negative_token_id_input,
406
+ "clip_g": clip_g_negative_token_id_input,
407
+ }
408
+ if self.t5 is not None:
409
+ t5_token_id_input = keras.Input(
410
+ shape=(None,),
411
+ dtype="int32",
412
+ name="t5_token_ids",
413
+ )
414
+ t5_negative_token_id_input = keras.Input(
415
+ shape=(None,),
416
+ dtype="int32",
417
+ name="t5_negative_token_ids",
418
+ )
419
+ token_ids["t5"] = t5_token_id_input
420
+ negative_token_ids["t5"] = t5_negative_token_id_input
421
+ num_step_input = keras.Input(
422
+ shape=(),
423
+ dtype="int32",
424
+ name="num_steps",
425
+ )
426
+ guidance_scale_input = keras.Input(
427
+ shape=(),
428
+ dtype="float32",
429
+ name="guidance_scale",
430
+ )
431
+ embeddings = self.encode_step(token_ids, negative_token_ids)
432
+ # Use `steps=0` to define the functional model.
433
+ latents = self.denoise_step(
434
+ latent_input,
435
+ embeddings,
436
+ 0,
437
+ num_step_input[0],
438
+ guidance_scale_input[0],
439
+ )
440
+ outputs = self.decode_step(latents)
441
+ inputs = {
442
+ "latents": latent_input,
443
+ "clip_l_token_ids": clip_l_token_id_input,
444
+ "clip_l_negative_token_ids": clip_l_negative_token_id_input,
445
+ "clip_g_token_ids": clip_g_token_id_input,
446
+ "clip_g_negative_token_ids": clip_g_negative_token_id_input,
447
+ "num_steps": num_step_input,
448
+ "guidance_scale": guidance_scale_input,
449
+ }
450
+ if self.t5 is not None:
451
+ inputs["t5_token_ids"] = t5_token_id_input
452
+ inputs["t5_negative_token_ids"] = t5_negative_token_id_input
453
+ super().__init__(
454
+ inputs=inputs,
455
+ outputs=outputs,
456
+ dtype=dtype,
457
+ **kwargs,
458
+ )
459
+
460
+ # === Config ===
461
+ self.mmdit_patch_size = mmdit_patch_size
462
+ self.mmdit_hidden_dim = mmdit_hidden_dim
463
+ self.mmdit_num_layers = mmdit_num_layers
464
+ self.mmdit_num_heads = mmdit_num_heads
465
+ self.mmdit_position_size = mmdit_position_size
466
+ self.vae_stackwise_num_filters = vae_stackwise_num_filters
467
+ self.vae_stackwise_num_blocks = vae_stackwise_num_blocks
468
+ self.latent_channels = latent_channels
469
+ self.output_channels = output_channels
470
+ self.num_train_timesteps = num_train_timesteps
471
+ self.shift = shift
472
+ self.height = height
473
+ self.width = width
474
+
475
+ @property
476
+ def latent_shape(self):
477
+ return (None,) + tuple(self.diffuser.latent_shape)
478
+
479
+ @property
480
+ def clip_hidden_dim(self):
481
+ return self.clip_l.hidden_dim + self.clip_g.hidden_dim
482
+
483
+ @property
484
+ def t5_hidden_dim(self):
485
+ return 4096 if self.t5 is None else self.t5.hidden_dim
486
+
487
+ def encode_step(self, token_ids, negative_token_ids):
488
+ clip_hidden_dim = self.clip_hidden_dim
489
+ t5_hidden_dim = self.t5_hidden_dim
490
+
491
+ def encode(token_ids):
492
+ clip_l_outputs = self.clip_l(token_ids["clip_l"], training=False)
493
+ clip_g_outputs = self.clip_g(token_ids["clip_g"], training=False)
494
+ clip_l_projection = self.clip_l_projection(
495
+ clip_l_outputs["sequence_output"],
496
+ token_ids["clip_l"],
497
+ training=False,
498
+ )
499
+ clip_g_projection = self.clip_g_projection(
500
+ clip_g_outputs["sequence_output"],
501
+ token_ids["clip_g"],
502
+ training=False,
503
+ )
504
+ pooled_embeddings = ops.concatenate(
505
+ [clip_l_projection, clip_g_projection],
506
+ axis=-1,
507
+ )
508
+ embeddings = ops.concatenate(
509
+ [
510
+ clip_l_outputs["intermediate_output"],
511
+ clip_g_outputs["intermediate_output"],
512
+ ],
513
+ axis=-1,
514
+ )
515
+ embeddings = ops.pad(
516
+ embeddings,
517
+ [[0, 0], [0, 0], [0, t5_hidden_dim - clip_hidden_dim]],
518
+ )
519
+ if self.t5 is not None:
520
+ t5_outputs = self.t5(token_ids["t5"], training=False)
521
+ embeddings = ops.concatenate([embeddings, t5_outputs], axis=-2)
522
+ else:
523
+ padded_size = self.clip_l.max_sequence_length
524
+ embeddings = ops.pad(
525
+ embeddings, [[0, 0], [0, padded_size], [0, 0]]
526
+ )
527
+ return embeddings, pooled_embeddings
528
+
529
+ positive_embeddings, positive_pooled_embeddings = encode(token_ids)
530
+ negative_embeddings, negative_pooled_embeddings = encode(
531
+ negative_token_ids
532
+ )
533
+ return (
534
+ positive_embeddings,
535
+ negative_embeddings,
536
+ positive_pooled_embeddings,
537
+ negative_pooled_embeddings,
538
+ )
539
+
540
+ def denoise_step(
541
+ self,
542
+ latents,
543
+ embeddings,
544
+ steps,
545
+ num_steps,
546
+ guidance_scale,
547
+ ):
548
+ steps = ops.convert_to_tensor(steps)
549
+ steps_next = ops.add(steps, 1)
550
+ sigma, timestep = self.scheduler(steps, num_steps)
551
+ sigma_next, _ = self.scheduler(steps_next, num_steps)
552
+
553
+ # Concatenation for classifier-free guidance.
554
+ concated_latents, contexts, pooled_projs, timesteps = self.cfg_concat(
555
+ latents, *embeddings, timestep
556
+ )
557
+
558
+ # Diffusion.
559
+ predicted_noise = self.diffuser(
560
+ {
561
+ "latent": concated_latents,
562
+ "context": contexts,
563
+ "pooled_projection": pooled_projs,
564
+ "timestep": timesteps,
565
+ },
566
+ training=False,
567
+ )
568
+
569
+ # Classifier-free guidance.
570
+ predicted_noise = self.cfg(predicted_noise, guidance_scale)
571
+
572
+ # Euler step.
573
+ return self.euler_step(latents, predicted_noise, sigma, sigma_next)
574
+
575
+ def decode_step(self, latents):
576
+ latents = self.latent_space_decoder(latents)
577
+ return self.decoder(latents, training=False)
578
+
579
+ def get_config(self):
580
+ config = super().get_config()
581
+ config.update(
582
+ {
583
+ "mmdit_patch_size": self.mmdit_patch_size,
584
+ "mmdit_hidden_dim": self.mmdit_hidden_dim,
585
+ "mmdit_num_layers": self.mmdit_num_layers,
586
+ "mmdit_num_heads": self.mmdit_num_heads,
587
+ "mmdit_position_size": self.mmdit_position_size,
588
+ "vae_stackwise_num_filters": self.vae_stackwise_num_filters,
589
+ "vae_stackwise_num_blocks": self.vae_stackwise_num_blocks,
590
+ "clip_l": layers.serialize(self.clip_l),
591
+ "clip_g": layers.serialize(self.clip_g),
592
+ "t5": layers.serialize(self.t5),
593
+ "latent_channels": self.latent_channels,
594
+ "output_channels": self.output_channels,
595
+ "num_train_timesteps": self.num_train_timesteps,
596
+ "shift": self.shift,
597
+ "height": self.height,
598
+ "width": self.width,
599
+ }
600
+ )
601
+ return config
602
+
603
+ @classmethod
604
+ def from_config(cls, config, custom_objects=None):
605
+ config = config.copy()
606
+
607
+ # Propagate `dtype` to text encoders if needed.
608
+ if "dtype" in config and config["dtype"] is not None:
609
+ dtype_config = config["dtype"]
610
+ if "dtype" not in config["clip_l"]["config"]:
611
+ config["clip_l"]["config"]["dtype"] = dtype_config
612
+ if "dtype" not in config["clip_g"]["config"]:
613
+ config["clip_g"]["config"]["dtype"] = dtype_config
614
+ if (
615
+ config["t5"] is not None
616
+ and "dtype" not in config["t5"]["config"]
617
+ ):
618
+ config["t5"]["config"]["dtype"] = dtype_config
619
+
620
+ # We expect `clip_l`, `clip_g` and/or `t5` to be instantiated.
621
+ config["clip_l"] = layers.deserialize(
622
+ config["clip_l"], custom_objects=custom_objects
623
+ )
624
+ config["clip_g"] = layers.deserialize(
625
+ config["clip_g"], custom_objects=custom_objects
626
+ )
627
+ if config["t5"] is not None:
628
+ config["t5"] = layers.deserialize(
629
+ config["t5"], custom_objects=custom_objects
630
+ )
631
+ return cls(**config)
@@ -0,0 +1,31 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """StableDiffusion3 preset configurations."""
15
+
16
+ backbone_presets = {
17
+ "stable_diffusion_3_medium": {
18
+ "metadata": {
19
+ "description": (
20
+ "3 billion parameter, including CLIP L and CLIP G text "
21
+ "encoders, MMDiT generative model, and VAE decoder. "
22
+ "Developed by Stability AI."
23
+ ),
24
+ "params": 2952806723,
25
+ "official_name": "StableDiffusion3",
26
+ "path": "stablediffusion3",
27
+ "model_card": "https://arxiv.org/abs/2110.00476",
28
+ },
29
+ "kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/1",
30
+ }
31
+ }