keras-hub-nightly 0.16.1.dev202409250340__py3-none-any.whl → 0.16.1.dev202409270338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (357) hide show
  1. keras_hub/__init__.py +0 -13
  2. keras_hub/api/__init__.py +0 -13
  3. keras_hub/api/bounding_box/__init__.py +0 -13
  4. keras_hub/api/layers/__init__.py +3 -13
  5. keras_hub/api/metrics/__init__.py +0 -13
  6. keras_hub/api/models/__init__.py +16 -13
  7. keras_hub/api/samplers/__init__.py +0 -13
  8. keras_hub/api/tokenizers/__init__.py +1 -13
  9. keras_hub/api/utils/__init__.py +0 -13
  10. keras_hub/src/__init__.py +0 -13
  11. keras_hub/src/api_export.py +0 -14
  12. keras_hub/src/bounding_box/__init__.py +0 -13
  13. keras_hub/src/bounding_box/converters.py +0 -13
  14. keras_hub/src/bounding_box/formats.py +0 -13
  15. keras_hub/src/bounding_box/iou.py +1 -13
  16. keras_hub/src/bounding_box/to_dense.py +0 -14
  17. keras_hub/src/bounding_box/to_ragged.py +0 -13
  18. keras_hub/src/bounding_box/utils.py +0 -13
  19. keras_hub/src/bounding_box/validate_format.py +0 -14
  20. keras_hub/src/layers/__init__.py +0 -13
  21. keras_hub/src/layers/modeling/__init__.py +0 -13
  22. keras_hub/src/layers/modeling/alibi_bias.py +0 -13
  23. keras_hub/src/layers/modeling/cached_multi_head_attention.py +0 -14
  24. keras_hub/src/layers/modeling/f_net_encoder.py +0 -14
  25. keras_hub/src/layers/modeling/masked_lm_head.py +0 -14
  26. keras_hub/src/layers/modeling/position_embedding.py +0 -14
  27. keras_hub/src/layers/modeling/reversible_embedding.py +0 -14
  28. keras_hub/src/layers/modeling/rotary_embedding.py +0 -14
  29. keras_hub/src/layers/modeling/sine_position_encoding.py +0 -14
  30. keras_hub/src/layers/modeling/token_and_position_embedding.py +0 -14
  31. keras_hub/src/layers/modeling/transformer_decoder.py +0 -14
  32. keras_hub/src/layers/modeling/transformer_encoder.py +0 -14
  33. keras_hub/src/layers/modeling/transformer_layer_utils.py +0 -14
  34. keras_hub/src/layers/preprocessing/__init__.py +0 -13
  35. keras_hub/src/layers/preprocessing/audio_converter.py +0 -13
  36. keras_hub/src/layers/preprocessing/image_converter.py +0 -13
  37. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +0 -15
  38. keras_hub/src/layers/preprocessing/multi_segment_packer.py +0 -14
  39. keras_hub/src/layers/preprocessing/preprocessing_layer.py +0 -14
  40. keras_hub/src/layers/preprocessing/random_deletion.py +0 -14
  41. keras_hub/src/layers/preprocessing/random_swap.py +0 -14
  42. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -13
  43. keras_hub/src/layers/preprocessing/start_end_packer.py +0 -15
  44. keras_hub/src/metrics/__init__.py +0 -13
  45. keras_hub/src/metrics/bleu.py +0 -14
  46. keras_hub/src/metrics/edit_distance.py +0 -14
  47. keras_hub/src/metrics/perplexity.py +0 -14
  48. keras_hub/src/metrics/rouge_base.py +0 -14
  49. keras_hub/src/metrics/rouge_l.py +0 -14
  50. keras_hub/src/metrics/rouge_n.py +0 -14
  51. keras_hub/src/models/__init__.py +0 -13
  52. keras_hub/src/models/albert/__init__.py +0 -14
  53. keras_hub/src/models/albert/albert_backbone.py +0 -14
  54. keras_hub/src/models/albert/albert_masked_lm.py +0 -14
  55. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +0 -14
  56. keras_hub/src/models/albert/albert_presets.py +0 -14
  57. keras_hub/src/models/albert/albert_text_classifier.py +0 -14
  58. keras_hub/src/models/albert/albert_text_classifier_preprocessor.py +0 -14
  59. keras_hub/src/models/albert/albert_tokenizer.py +0 -14
  60. keras_hub/src/models/backbone.py +0 -14
  61. keras_hub/src/models/bart/__init__.py +0 -14
  62. keras_hub/src/models/bart/bart_backbone.py +0 -14
  63. keras_hub/src/models/bart/bart_presets.py +0 -13
  64. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +0 -15
  65. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +0 -15
  66. keras_hub/src/models/bart/bart_tokenizer.py +0 -15
  67. keras_hub/src/models/bert/__init__.py +0 -14
  68. keras_hub/src/models/bert/bert_backbone.py +0 -14
  69. keras_hub/src/models/bert/bert_masked_lm.py +0 -14
  70. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +0 -14
  71. keras_hub/src/models/bert/bert_presets.py +0 -13
  72. keras_hub/src/models/bert/bert_text_classifier.py +0 -14
  73. keras_hub/src/models/bert/bert_text_classifier_preprocessor.py +0 -14
  74. keras_hub/src/models/bert/bert_tokenizer.py +0 -14
  75. keras_hub/src/models/bloom/__init__.py +0 -14
  76. keras_hub/src/models/bloom/bloom_attention.py +0 -13
  77. keras_hub/src/models/bloom/bloom_backbone.py +0 -14
  78. keras_hub/src/models/bloom/bloom_causal_lm.py +0 -15
  79. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +0 -15
  80. keras_hub/src/models/bloom/bloom_decoder.py +0 -13
  81. keras_hub/src/models/bloom/bloom_presets.py +0 -13
  82. keras_hub/src/models/bloom/bloom_tokenizer.py +0 -15
  83. keras_hub/src/models/causal_lm.py +0 -14
  84. keras_hub/src/models/causal_lm_preprocessor.py +0 -13
  85. keras_hub/src/models/clip/__init__.py +0 -0
  86. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_encoder_block.py +8 -15
  87. keras_hub/src/models/clip/clip_preprocessor.py +134 -0
  88. keras_hub/src/models/clip/clip_text_encoder.py +139 -0
  89. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_tokenizer.py +65 -41
  90. keras_hub/src/models/csp_darknet/__init__.py +0 -13
  91. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +0 -13
  92. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -13
  93. keras_hub/src/models/deberta_v3/__init__.py +0 -14
  94. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +0 -15
  95. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +0 -15
  96. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +0 -14
  97. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -13
  98. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +0 -15
  99. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier_preprocessor.py +0 -14
  100. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +0 -15
  101. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +0 -14
  102. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +0 -14
  103. keras_hub/src/models/deberta_v3/relative_embedding.py +0 -14
  104. keras_hub/src/models/densenet/__init__.py +5 -13
  105. keras_hub/src/models/densenet/densenet_backbone.py +11 -21
  106. keras_hub/src/models/densenet/densenet_image_classifier.py +27 -17
  107. keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py +27 -0
  108. keras_hub/src/models/{stable_diffusion_v3/__init__.py → densenet/densenet_image_converter.py} +10 -0
  109. keras_hub/src/models/densenet/densenet_presets.py +56 -0
  110. keras_hub/src/models/distil_bert/__init__.py +0 -14
  111. keras_hub/src/models/distil_bert/distil_bert_backbone.py +0 -15
  112. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +0 -15
  113. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +0 -14
  114. keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -13
  115. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +0 -15
  116. keras_hub/src/models/distil_bert/distil_bert_text_classifier_preprocessor.py +0 -15
  117. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +0 -15
  118. keras_hub/src/models/efficientnet/__init__.py +0 -13
  119. keras_hub/src/models/efficientnet/efficientnet_backbone.py +0 -13
  120. keras_hub/src/models/efficientnet/fusedmbconv.py +0 -14
  121. keras_hub/src/models/efficientnet/mbconv.py +0 -14
  122. keras_hub/src/models/electra/__init__.py +0 -14
  123. keras_hub/src/models/electra/electra_backbone.py +0 -14
  124. keras_hub/src/models/electra/electra_presets.py +0 -13
  125. keras_hub/src/models/electra/electra_tokenizer.py +0 -14
  126. keras_hub/src/models/f_net/__init__.py +0 -14
  127. keras_hub/src/models/f_net/f_net_backbone.py +0 -15
  128. keras_hub/src/models/f_net/f_net_masked_lm.py +0 -15
  129. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +0 -14
  130. keras_hub/src/models/f_net/f_net_presets.py +0 -13
  131. keras_hub/src/models/f_net/f_net_text_classifier.py +0 -15
  132. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +0 -15
  133. keras_hub/src/models/f_net/f_net_tokenizer.py +0 -15
  134. keras_hub/src/models/falcon/__init__.py +0 -14
  135. keras_hub/src/models/falcon/falcon_attention.py +0 -13
  136. keras_hub/src/models/falcon/falcon_backbone.py +0 -13
  137. keras_hub/src/models/falcon/falcon_causal_lm.py +0 -14
  138. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +0 -14
  139. keras_hub/src/models/falcon/falcon_presets.py +0 -13
  140. keras_hub/src/models/falcon/falcon_tokenizer.py +0 -15
  141. keras_hub/src/models/falcon/falcon_transformer_decoder.py +0 -13
  142. keras_hub/src/models/feature_pyramid_backbone.py +0 -13
  143. keras_hub/src/models/gemma/__init__.py +0 -14
  144. keras_hub/src/models/gemma/gemma_attention.py +0 -13
  145. keras_hub/src/models/gemma/gemma_backbone.py +0 -15
  146. keras_hub/src/models/gemma/gemma_causal_lm.py +0 -15
  147. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +0 -14
  148. keras_hub/src/models/gemma/gemma_decoder_block.py +0 -13
  149. keras_hub/src/models/gemma/gemma_presets.py +0 -13
  150. keras_hub/src/models/gemma/gemma_tokenizer.py +0 -14
  151. keras_hub/src/models/gemma/rms_normalization.py +0 -14
  152. keras_hub/src/models/gpt2/__init__.py +0 -14
  153. keras_hub/src/models/gpt2/gpt2_backbone.py +0 -15
  154. keras_hub/src/models/gpt2/gpt2_causal_lm.py +0 -15
  155. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +0 -14
  156. keras_hub/src/models/gpt2/gpt2_preprocessor.py +0 -15
  157. keras_hub/src/models/gpt2/gpt2_presets.py +0 -13
  158. keras_hub/src/models/gpt2/gpt2_tokenizer.py +0 -15
  159. keras_hub/src/models/gpt_neo_x/__init__.py +0 -13
  160. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +0 -14
  161. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +0 -14
  162. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +0 -14
  163. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +0 -14
  164. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +0 -14
  165. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +0 -14
  166. keras_hub/src/models/image_classifier.py +0 -13
  167. keras_hub/src/models/image_classifier_preprocessor.py +0 -13
  168. keras_hub/src/models/image_segmenter.py +0 -13
  169. keras_hub/src/models/llama/__init__.py +0 -14
  170. keras_hub/src/models/llama/llama_attention.py +0 -13
  171. keras_hub/src/models/llama/llama_backbone.py +0 -13
  172. keras_hub/src/models/llama/llama_causal_lm.py +0 -13
  173. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +0 -15
  174. keras_hub/src/models/llama/llama_decoder.py +0 -13
  175. keras_hub/src/models/llama/llama_layernorm.py +0 -13
  176. keras_hub/src/models/llama/llama_presets.py +0 -13
  177. keras_hub/src/models/llama/llama_tokenizer.py +0 -14
  178. keras_hub/src/models/llama3/__init__.py +0 -14
  179. keras_hub/src/models/llama3/llama3_backbone.py +0 -14
  180. keras_hub/src/models/llama3/llama3_causal_lm.py +0 -13
  181. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +0 -14
  182. keras_hub/src/models/llama3/llama3_presets.py +0 -13
  183. keras_hub/src/models/llama3/llama3_tokenizer.py +0 -14
  184. keras_hub/src/models/masked_lm.py +0 -13
  185. keras_hub/src/models/masked_lm_preprocessor.py +0 -13
  186. keras_hub/src/models/mistral/__init__.py +0 -14
  187. keras_hub/src/models/mistral/mistral_attention.py +0 -13
  188. keras_hub/src/models/mistral/mistral_backbone.py +0 -14
  189. keras_hub/src/models/mistral/mistral_causal_lm.py +0 -14
  190. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +0 -14
  191. keras_hub/src/models/mistral/mistral_layer_norm.py +0 -13
  192. keras_hub/src/models/mistral/mistral_presets.py +0 -13
  193. keras_hub/src/models/mistral/mistral_tokenizer.py +0 -14
  194. keras_hub/src/models/mistral/mistral_transformer_decoder.py +0 -13
  195. keras_hub/src/models/mix_transformer/__init__.py +0 -13
  196. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +0 -13
  197. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -13
  198. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +0 -13
  199. keras_hub/src/models/mobilenet/__init__.py +0 -13
  200. keras_hub/src/models/mobilenet/mobilenet_backbone.py +0 -13
  201. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -13
  202. keras_hub/src/models/opt/__init__.py +0 -14
  203. keras_hub/src/models/opt/opt_backbone.py +0 -15
  204. keras_hub/src/models/opt/opt_causal_lm.py +0 -15
  205. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +0 -13
  206. keras_hub/src/models/opt/opt_presets.py +0 -13
  207. keras_hub/src/models/opt/opt_tokenizer.py +0 -15
  208. keras_hub/src/models/pali_gemma/__init__.py +0 -13
  209. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +0 -13
  210. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +0 -13
  211. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +0 -13
  212. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +0 -14
  213. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +0 -13
  214. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +0 -13
  215. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +0 -13
  216. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +0 -13
  217. keras_hub/src/models/phi3/__init__.py +0 -14
  218. keras_hub/src/models/phi3/phi3_attention.py +0 -13
  219. keras_hub/src/models/phi3/phi3_backbone.py +0 -13
  220. keras_hub/src/models/phi3/phi3_causal_lm.py +0 -13
  221. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +0 -14
  222. keras_hub/src/models/phi3/phi3_decoder.py +0 -13
  223. keras_hub/src/models/phi3/phi3_layernorm.py +0 -13
  224. keras_hub/src/models/phi3/phi3_presets.py +0 -13
  225. keras_hub/src/models/phi3/phi3_rotary_embedding.py +0 -13
  226. keras_hub/src/models/phi3/phi3_tokenizer.py +0 -13
  227. keras_hub/src/models/preprocessor.py +51 -32
  228. keras_hub/src/models/resnet/__init__.py +0 -14
  229. keras_hub/src/models/resnet/resnet_backbone.py +0 -13
  230. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -13
  231. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +0 -14
  232. keras_hub/src/models/resnet/resnet_image_converter.py +0 -13
  233. keras_hub/src/models/resnet/resnet_presets.py +0 -13
  234. keras_hub/src/models/retinanet/__init__.py +0 -13
  235. keras_hub/src/models/retinanet/anchor_generator.py +0 -14
  236. keras_hub/src/models/retinanet/box_matcher.py +0 -14
  237. keras_hub/src/models/retinanet/non_max_supression.py +0 -14
  238. keras_hub/src/models/roberta/__init__.py +0 -14
  239. keras_hub/src/models/roberta/roberta_backbone.py +0 -15
  240. keras_hub/src/models/roberta/roberta_masked_lm.py +0 -15
  241. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +0 -14
  242. keras_hub/src/models/roberta/roberta_presets.py +0 -13
  243. keras_hub/src/models/roberta/roberta_text_classifier.py +0 -15
  244. keras_hub/src/models/roberta/roberta_text_classifier_preprocessor.py +0 -14
  245. keras_hub/src/models/roberta/roberta_tokenizer.py +0 -15
  246. keras_hub/src/models/sam/__init__.py +0 -13
  247. keras_hub/src/models/sam/sam_backbone.py +0 -14
  248. keras_hub/src/models/sam/sam_image_segmenter.py +0 -14
  249. keras_hub/src/models/sam/sam_layers.py +0 -14
  250. keras_hub/src/models/sam/sam_mask_decoder.py +0 -14
  251. keras_hub/src/models/sam/sam_prompt_encoder.py +0 -14
  252. keras_hub/src/models/sam/sam_transformer.py +0 -14
  253. keras_hub/src/models/seq_2_seq_lm.py +0 -13
  254. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +0 -13
  255. keras_hub/src/models/stable_diffusion_3/__init__.py +9 -0
  256. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +80 -0
  257. keras_hub/src/models/{stable_diffusion_v3 → stable_diffusion_3}/mmdit.py +351 -39
  258. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +631 -0
  259. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +31 -0
  260. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +138 -0
  261. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +83 -0
  262. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py} +7 -20
  263. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +320 -0
  264. keras_hub/src/models/t5/__init__.py +0 -14
  265. keras_hub/src/models/t5/t5_backbone.py +0 -14
  266. keras_hub/src/models/t5/t5_layer_norm.py +0 -14
  267. keras_hub/src/models/t5/t5_multi_head_attention.py +0 -14
  268. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_preprocessor.py → t5/t5_preprocessor.py} +12 -16
  269. keras_hub/src/models/t5/t5_presets.py +0 -13
  270. keras_hub/src/models/t5/t5_tokenizer.py +0 -14
  271. keras_hub/src/models/t5/t5_transformer_layer.py +0 -14
  272. keras_hub/src/models/task.py +0 -14
  273. keras_hub/src/models/text_classifier.py +0 -13
  274. keras_hub/src/models/text_classifier_preprocessor.py +0 -13
  275. keras_hub/src/models/text_to_image.py +282 -0
  276. keras_hub/src/models/vgg/__init__.py +0 -13
  277. keras_hub/src/models/vgg/vgg_backbone.py +0 -13
  278. keras_hub/src/models/vgg/vgg_image_classifier.py +0 -13
  279. keras_hub/src/models/vit_det/__init__.py +0 -13
  280. keras_hub/src/models/vit_det/vit_det_backbone.py +0 -14
  281. keras_hub/src/models/vit_det/vit_layers.py +0 -15
  282. keras_hub/src/models/whisper/__init__.py +0 -14
  283. keras_hub/src/models/whisper/whisper_audio_converter.py +0 -15
  284. keras_hub/src/models/whisper/whisper_backbone.py +0 -15
  285. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +0 -13
  286. keras_hub/src/models/whisper/whisper_decoder.py +0 -14
  287. keras_hub/src/models/whisper/whisper_encoder.py +0 -14
  288. keras_hub/src/models/whisper/whisper_presets.py +0 -14
  289. keras_hub/src/models/whisper/whisper_tokenizer.py +0 -14
  290. keras_hub/src/models/xlm_roberta/__init__.py +0 -14
  291. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +0 -15
  292. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +0 -15
  293. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +0 -14
  294. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -13
  295. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +0 -15
  296. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier_preprocessor.py +0 -15
  297. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +0 -15
  298. keras_hub/src/models/xlnet/__init__.py +0 -13
  299. keras_hub/src/models/xlnet/relative_attention.py +0 -14
  300. keras_hub/src/models/xlnet/xlnet_backbone.py +0 -14
  301. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +0 -14
  302. keras_hub/src/models/xlnet/xlnet_encoder.py +0 -14
  303. keras_hub/src/samplers/__init__.py +0 -13
  304. keras_hub/src/samplers/beam_sampler.py +0 -14
  305. keras_hub/src/samplers/contrastive_sampler.py +0 -14
  306. keras_hub/src/samplers/greedy_sampler.py +0 -14
  307. keras_hub/src/samplers/random_sampler.py +0 -14
  308. keras_hub/src/samplers/sampler.py +0 -14
  309. keras_hub/src/samplers/serialization.py +0 -14
  310. keras_hub/src/samplers/top_k_sampler.py +0 -14
  311. keras_hub/src/samplers/top_p_sampler.py +0 -14
  312. keras_hub/src/tests/__init__.py +0 -13
  313. keras_hub/src/tests/test_case.py +0 -14
  314. keras_hub/src/tokenizers/__init__.py +0 -13
  315. keras_hub/src/tokenizers/byte_pair_tokenizer.py +0 -14
  316. keras_hub/src/tokenizers/byte_tokenizer.py +0 -14
  317. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +0 -14
  318. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +0 -14
  319. keras_hub/src/tokenizers/tokenizer.py +23 -27
  320. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +0 -15
  321. keras_hub/src/tokenizers/word_piece_tokenizer.py +0 -14
  322. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +0 -15
  323. keras_hub/src/utils/__init__.py +0 -13
  324. keras_hub/src/utils/imagenet/__init__.py +0 -13
  325. keras_hub/src/utils/imagenet/imagenet_utils.py +0 -13
  326. keras_hub/src/utils/keras_utils.py +0 -14
  327. keras_hub/src/utils/pipeline_model.py +0 -14
  328. keras_hub/src/utils/preset_utils.py +32 -76
  329. keras_hub/src/utils/python_utils.py +0 -13
  330. keras_hub/src/utils/tensor_utils.py +0 -14
  331. keras_hub/src/utils/timm/__init__.py +0 -13
  332. keras_hub/src/utils/timm/convert_densenet.py +107 -0
  333. keras_hub/src/utils/timm/convert_resnet.py +0 -13
  334. keras_hub/src/utils/timm/preset_loader.py +3 -13
  335. keras_hub/src/utils/transformers/__init__.py +0 -13
  336. keras_hub/src/utils/transformers/convert_albert.py +0 -13
  337. keras_hub/src/utils/transformers/convert_bart.py +0 -13
  338. keras_hub/src/utils/transformers/convert_bert.py +0 -13
  339. keras_hub/src/utils/transformers/convert_distilbert.py +0 -13
  340. keras_hub/src/utils/transformers/convert_gemma.py +0 -13
  341. keras_hub/src/utils/transformers/convert_gpt2.py +0 -13
  342. keras_hub/src/utils/transformers/convert_llama3.py +0 -13
  343. keras_hub/src/utils/transformers/convert_mistral.py +0 -13
  344. keras_hub/src/utils/transformers/convert_pali_gemma.py +0 -13
  345. keras_hub/src/utils/transformers/preset_loader.py +1 -15
  346. keras_hub/src/utils/transformers/safetensor_utils.py +9 -15
  347. keras_hub/src/version_utils.py +1 -15
  348. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409270338.dist-info}/METADATA +30 -27
  349. keras_hub_nightly-0.16.1.dev202409270338.dist-info/RECORD +351 -0
  350. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +0 -93
  351. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +0 -149
  352. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +0 -317
  353. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +0 -126
  354. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +0 -186
  355. keras_hub_nightly-0.16.1.dev202409250340.dist-info/RECORD +0 -342
  356. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409270338.dist-info}/WHEEL +0 -0
  357. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409270338.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,3 @@
1
- # Copyright 2024 The KerasHub Authors
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
1
  import math
15
2
 
16
3
  import keras
@@ -19,7 +6,8 @@ from keras import models
19
6
  from keras import ops
20
7
 
21
8
  from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
22
- from keras_hub.src.models.stable_diffusion_v3.mmdit_block import MMDiTBlock
9
+ from keras_hub.src.models.backbone import Backbone
10
+ from keras_hub.src.utils.keras_utils import gelu_approximate
23
11
  from keras_hub.src.utils.keras_utils import standardize_data_format
24
12
 
25
13
 
@@ -79,8 +67,8 @@ class AdjustablePositionEmbedding(PositionEmbedding):
79
67
  width = width or self.width
80
68
  shape = ops.shape(inputs)
81
69
  feature_length = shape[-1]
82
- top = ops.floor_divide(self.height - height, 2)
83
- left = ops.floor_divide(self.width - width, 2)
70
+ top = ops.cast(ops.floor_divide(self.height - height, 2), "int32")
71
+ left = ops.cast(ops.floor_divide(self.width - width, 2), "int32")
84
72
  position_embedding = ops.convert_to_tensor(self.position_embeddings)
85
73
  position_embedding = ops.reshape(
86
74
  position_embedding, (self.height, self.width, feature_length)
@@ -166,6 +154,305 @@ class TimestepEmbedding(layers.Layer):
166
154
  return output_shape
167
155
 
168
156
 
157
+ class DismantledBlock(layers.Layer):
158
+ def __init__(
159
+ self,
160
+ num_heads,
161
+ hidden_dim,
162
+ mlp_ratio=4.0,
163
+ use_projection=True,
164
+ **kwargs,
165
+ ):
166
+ super().__init__(**kwargs)
167
+ self.num_heads = num_heads
168
+ self.hidden_dim = hidden_dim
169
+ self.mlp_ratio = mlp_ratio
170
+ self.use_projection = use_projection
171
+
172
+ head_dim = hidden_dim // num_heads
173
+ self.head_dim = head_dim
174
+ mlp_hidden_dim = int(hidden_dim * mlp_ratio)
175
+ self.mlp_hidden_dim = mlp_hidden_dim
176
+ num_modulations = 6 if use_projection else 2
177
+ self.num_modulations = num_modulations
178
+
179
+ self.adaptive_norm_modulation = models.Sequential(
180
+ [
181
+ layers.Activation("silu", dtype=self.dtype_policy),
182
+ layers.Dense(
183
+ num_modulations * hidden_dim, dtype=self.dtype_policy
184
+ ),
185
+ ],
186
+ name="adaptive_norm_modulation",
187
+ )
188
+ self.norm1 = layers.LayerNormalization(
189
+ epsilon=1e-6,
190
+ center=False,
191
+ scale=False,
192
+ dtype="float32",
193
+ name="norm1",
194
+ )
195
+ self.attention_qkv = layers.Dense(
196
+ hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv"
197
+ )
198
+ if use_projection:
199
+ self.attention_proj = layers.Dense(
200
+ hidden_dim, dtype=self.dtype_policy, name="attention_proj"
201
+ )
202
+ self.norm2 = layers.LayerNormalization(
203
+ epsilon=1e-6,
204
+ center=False,
205
+ scale=False,
206
+ dtype="float32",
207
+ name="norm2",
208
+ )
209
+ self.mlp = models.Sequential(
210
+ [
211
+ layers.Dense(
212
+ mlp_hidden_dim,
213
+ activation=gelu_approximate,
214
+ dtype=self.dtype_policy,
215
+ ),
216
+ layers.Dense(
217
+ hidden_dim,
218
+ dtype=self.dtype_policy,
219
+ ),
220
+ ],
221
+ name="mlp",
222
+ )
223
+
224
+ def build(self, inputs_shape, timestep_embedding):
225
+ self.adaptive_norm_modulation.build(timestep_embedding)
226
+ self.attention_qkv.build(inputs_shape)
227
+ self.norm1.build(inputs_shape)
228
+ if self.use_projection:
229
+ self.attention_proj.build(inputs_shape)
230
+ self.norm2.build(inputs_shape)
231
+ self.mlp.build(inputs_shape)
232
+
233
+ def _modulate(self, inputs, shift, scale):
234
+ shift = ops.expand_dims(shift, axis=1)
235
+ scale = ops.expand_dims(scale, axis=1)
236
+ return ops.add(ops.multiply(inputs, ops.add(scale, 1.0)), shift)
237
+
238
+ def _compute_pre_attention(self, inputs, timestep_embedding, training=None):
239
+ batch_size = ops.shape(inputs)[0]
240
+ if self.use_projection:
241
+ modulation = self.adaptive_norm_modulation(
242
+ timestep_embedding, training=training
243
+ )
244
+ modulation = ops.reshape(
245
+ modulation, (batch_size, 6, self.hidden_dim)
246
+ )
247
+ (
248
+ shift_msa,
249
+ scale_msa,
250
+ gate_msa,
251
+ shift_mlp,
252
+ scale_mlp,
253
+ gate_mlp,
254
+ ) = ops.unstack(modulation, 6, axis=1)
255
+ qkv = self.attention_qkv(
256
+ self._modulate(self.norm1(inputs), shift_msa, scale_msa),
257
+ training=training,
258
+ )
259
+ qkv = ops.reshape(
260
+ qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
261
+ )
262
+ q, k, v = ops.unstack(qkv, 3, axis=2)
263
+ return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp)
264
+ else:
265
+ modulation = self.adaptive_norm_modulation(
266
+ timestep_embedding, training=training
267
+ )
268
+ modulation = ops.reshape(
269
+ modulation, (batch_size, 2, self.hidden_dim)
270
+ )
271
+ shift_msa, scale_msa = ops.unstack(modulation, 2, axis=1)
272
+ qkv = self.attention_qkv(
273
+ self._modulate(self.norm1(inputs), shift_msa, scale_msa),
274
+ training=training,
275
+ )
276
+ qkv = ops.reshape(
277
+ qkv, (batch_size, -1, 3, self.num_heads, self.head_dim)
278
+ )
279
+ q, k, v = ops.unstack(qkv, 3, axis=2)
280
+ return (q, k, v)
281
+
282
+ def _compute_post_attention(
283
+ self, inputs, inputs_intermediates, training=None
284
+ ):
285
+ x, gate_msa, shift_mlp, scale_mlp, gate_mlp = inputs_intermediates
286
+ attn = self.attention_proj(inputs, training=training)
287
+ x = ops.add(x, ops.multiply(ops.expand_dims(gate_msa, axis=1), attn))
288
+ x = ops.add(
289
+ x,
290
+ ops.multiply(
291
+ ops.expand_dims(gate_mlp, axis=1),
292
+ self.mlp(
293
+ self._modulate(self.norm2(x), shift_mlp, scale_mlp),
294
+ training=training,
295
+ ),
296
+ ),
297
+ )
298
+ return x
299
+
300
+ def call(
301
+ self,
302
+ inputs,
303
+ timestep_embedding=None,
304
+ inputs_intermediates=None,
305
+ pre_attention=True,
306
+ training=None,
307
+ ):
308
+ if pre_attention:
309
+ return self._compute_pre_attention(
310
+ inputs, timestep_embedding, training=training
311
+ )
312
+ else:
313
+ return self._compute_post_attention(
314
+ inputs, inputs_intermediates, training=training
315
+ )
316
+
317
+ def get_config(self):
318
+ config = super().get_config()
319
+ config.update(
320
+ {
321
+ "num_heads": self.num_heads,
322
+ "hidden_dim": self.hidden_dim,
323
+ "mlp_ratio": self.mlp_ratio,
324
+ "use_projection": self.use_projection,
325
+ }
326
+ )
327
+ return config
328
+
329
+
330
+ class MMDiTBlock(layers.Layer):
331
+ def __init__(
332
+ self,
333
+ num_heads,
334
+ hidden_dim,
335
+ mlp_ratio=4.0,
336
+ use_context_projection=True,
337
+ **kwargs,
338
+ ):
339
+ super().__init__(**kwargs)
340
+ self.num_heads = num_heads
341
+ self.hidden_dim = hidden_dim
342
+ self.mlp_ratio = mlp_ratio
343
+ self.use_context_projection = use_context_projection
344
+
345
+ head_dim = hidden_dim // num_heads
346
+ self.head_dim = head_dim
347
+ self._inverse_sqrt_key_dim = 1.0 / math.sqrt(head_dim)
348
+ self._dot_product_equation = "aecd,abcd->acbe"
349
+ self._combine_equation = "acbe,aecd->abcd"
350
+
351
+ self.x_block = DismantledBlock(
352
+ num_heads=num_heads,
353
+ hidden_dim=hidden_dim,
354
+ mlp_ratio=mlp_ratio,
355
+ use_projection=True,
356
+ dtype=self.dtype_policy,
357
+ name="x_block",
358
+ )
359
+ self.context_block = DismantledBlock(
360
+ num_heads=num_heads,
361
+ hidden_dim=hidden_dim,
362
+ mlp_ratio=mlp_ratio,
363
+ use_projection=use_context_projection,
364
+ dtype=self.dtype_policy,
365
+ name="context_block",
366
+ )
367
+ self.softmax = layers.Softmax(dtype="float32")
368
+
369
+ def build(self, inputs_shape, context_shape, timestep_embedding_shape):
370
+ self.x_block.build(inputs_shape, timestep_embedding_shape)
371
+ self.context_block.build(context_shape, timestep_embedding_shape)
372
+
373
+ def _compute_attention(self, query, key, value):
374
+ query = ops.multiply(
375
+ query, ops.cast(self._inverse_sqrt_key_dim, query.dtype)
376
+ )
377
+ attention_scores = ops.einsum(self._dot_product_equation, key, query)
378
+ attention_scores = self.softmax(attention_scores)
379
+ attention_scores = ops.cast(attention_scores, self.compute_dtype)
380
+ attention_output = ops.einsum(
381
+ self._combine_equation, attention_scores, value
382
+ )
383
+ batch_size = ops.shape(attention_output)[0]
384
+ attention_output = ops.reshape(
385
+ attention_output, (batch_size, -1, self.num_heads * self.head_dim)
386
+ )
387
+ return attention_output
388
+
389
+ def call(self, inputs, context, timestep_embedding, training=None):
390
+ # Compute pre-attention.
391
+ x = inputs
392
+ if self.use_context_projection:
393
+ context_qkv, context_intermediates = self.context_block(
394
+ context,
395
+ timestep_embedding=timestep_embedding,
396
+ training=training,
397
+ )
398
+ else:
399
+ context_qkv = self.context_block(
400
+ context,
401
+ timestep_embedding=timestep_embedding,
402
+ training=training,
403
+ )
404
+ context_len = ops.shape(context_qkv[0])[1]
405
+ x_qkv, x_intermediates = self.x_block(
406
+ x, timestep_embedding=timestep_embedding, training=training
407
+ )
408
+ q = ops.concatenate([context_qkv[0], x_qkv[0]], axis=1)
409
+ k = ops.concatenate([context_qkv[1], x_qkv[1]], axis=1)
410
+ v = ops.concatenate([context_qkv[2], x_qkv[2]], axis=1)
411
+
412
+ # Compute attention.
413
+ attention = self._compute_attention(q, k, v)
414
+ context_attention = attention[:, :context_len]
415
+ x_attention = attention[:, context_len:]
416
+
417
+ # Compute post-attention.
418
+ x = self.x_block(
419
+ x_attention,
420
+ inputs_intermediates=x_intermediates,
421
+ pre_attention=False,
422
+ training=training,
423
+ )
424
+ if self.use_context_projection:
425
+ context = self.context_block(
426
+ context_attention,
427
+ inputs_intermediates=context_intermediates,
428
+ pre_attention=False,
429
+ training=training,
430
+ )
431
+ return x, context
432
+ else:
433
+ return x
434
+
435
+ def get_config(self):
436
+ config = super().get_config()
437
+ config.update(
438
+ {
439
+ "num_heads": self.num_heads,
440
+ "hidden_dim": self.hidden_dim,
441
+ "mlp_ratio": self.mlp_ratio,
442
+ "use_context_projection": self.use_context_projection,
443
+ }
444
+ )
445
+ return config
446
+
447
+ def compute_output_shape(
448
+ self, inputs_shape, context_shape, timestep_embedding_shape
449
+ ):
450
+ if self.use_context_projection:
451
+ return inputs_shape, context_shape
452
+ else:
453
+ return inputs_shape
454
+
455
+
169
456
  class OutputLayer(layers.Layer):
170
457
  def __init__(self, hidden_dim, output_dim, **kwargs):
171
458
  super().__init__(**kwargs)
@@ -186,11 +473,11 @@ class OutputLayer(layers.Layer):
186
473
  epsilon=1e-6,
187
474
  center=False,
188
475
  scale=False,
189
- dtype=self.dtype_policy,
476
+ dtype="float32",
190
477
  name="norm",
191
478
  )
192
479
  self.output_dense = layers.Dense(
193
- output_dim, # patch_size ** 2 * input_channels
480
+ output_dim,
194
481
  use_bias=True,
195
482
  dtype=self.dtype_policy,
196
483
  name="output_dense",
@@ -227,6 +514,11 @@ class OutputLayer(layers.Layer):
227
514
  )
228
515
  return config
229
516
 
517
+ def compute_output_shape(self, inputs_shape):
518
+ outputs_shape = list(inputs_shape)
519
+ outputs_shape[-1] = self.output_dim
520
+ return outputs_shape
521
+
230
522
 
231
523
  class Unpatch(layers.Layer):
232
524
  def __init__(self, patch_size, output_dim, **kwargs):
@@ -263,18 +555,48 @@ class Unpatch(layers.Layer):
263
555
  return [inputs_shape[0], None, None, self.output_dim]
264
556
 
265
557
 
266
- class MMDiT(keras.Model):
558
+ class MMDiT(Backbone):
559
+ """Multimodal Diffusion Transformer (MMDiT) model for Stable Diffusion 3.
560
+
561
+ MMDiT is introduced in [
562
+ Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](
563
+ https://arxiv.org/abs/2403.03206).
564
+
565
+ Args:
566
+ patch_size: int. The size of each square patch in the input image.
567
+ hidden_dim: int. The size of the transformer hidden state at the end
568
+ of each transformer layer.
569
+ num_layers: int. The number of transformer layers.
570
+ num_heads: int. The number of attention heads for each transformer.
571
+ position_size: int. The size of the height and width for the position
572
+ embedding.
573
+ mlp_ratio: float. The ratio of the mlp hidden dim to the transformer
574
+ latent_shape: tuple. The shape of the latent image.
575
+ context_shape: tuple. The shape of the context.
576
+ pooled_projection_shape: tuple. The shape of the pooled projection.
577
+ data_format: `None` or str. If specified, either `"channels_last"` or
578
+ `"channels_first"`. The ordering of the dimensions in the
579
+ inputs. `"channels_last"` corresponds to inputs with shape
580
+ `(batch_size, height, width, channels)`
581
+ while `"channels_first"` corresponds to inputs with shape
582
+ `(batch_size, channels, height, width)`. It defaults to the
583
+ `image_data_format` value found in your Keras config file at
584
+ `~/.keras/keras.json`. If you never set it, then it will be
585
+ `"channels_last"`.
586
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
587
+ to use for the model's computations and weights.
588
+ """
589
+
267
590
  def __init__(
268
591
  self,
269
592
  patch_size,
270
- num_heads,
271
593
  hidden_dim,
272
- depth,
594
+ num_layers,
595
+ num_heads,
273
596
  position_size,
274
- output_dim,
275
597
  mlp_ratio=4.0,
276
598
  latent_shape=(64, 64, 16),
277
- context_shape=(1024, 4096),
599
+ context_shape=(None, 4096),
278
600
  pooled_projection_shape=(2048,),
279
601
  data_format=None,
280
602
  dtype=None,
@@ -287,6 +609,7 @@ class MMDiT(keras.Model):
287
609
  )
288
610
  image_height = latent_shape[0] // patch_size
289
611
  image_width = latent_shape[1] // patch_size
612
+ output_dim = latent_shape[-1]
290
613
  output_dim_in_final = patch_size**2 * output_dim
291
614
  data_format = standardize_data_format(data_format)
292
615
  if data_format != "channels_last":
@@ -331,11 +654,11 @@ class MMDiT(keras.Model):
331
654
  num_heads,
332
655
  hidden_dim,
333
656
  mlp_ratio,
334
- use_context_projection=not (i == depth - 1),
657
+ use_context_projection=not (i == num_layers - 1),
335
658
  dtype=dtype,
336
659
  name=f"joint_block_{i}",
337
660
  )
338
- for i in range(depth)
661
+ for i in range(num_layers)
339
662
  ]
340
663
  self.output_layer = OutputLayer(
341
664
  hidden_dim, output_dim_in_final, dtype=dtype, name="output_layer"
@@ -391,33 +714,22 @@ class MMDiT(keras.Model):
391
714
  self.patch_size = patch_size
392
715
  self.num_heads = num_heads
393
716
  self.hidden_dim = hidden_dim
394
- self.depth = depth
717
+ self.num_layers = num_layers
395
718
  self.position_size = position_size
396
- self.output_dim = output_dim
397
719
  self.mlp_ratio = mlp_ratio
398
720
  self.latent_shape = latent_shape
399
721
  self.context_shape = context_shape
400
722
  self.pooled_projection_shape = pooled_projection_shape
401
723
 
402
- if dtype is not None:
403
- try:
404
- self.dtype_policy = keras.dtype_policies.get(dtype)
405
- # Before Keras 3.2, there is no `keras.dtype_policies.get`.
406
- except AttributeError:
407
- if isinstance(dtype, keras.DTypePolicy):
408
- dtype = dtype.name
409
- self.dtype_policy = keras.DTypePolicy(dtype)
410
-
411
724
  def get_config(self):
412
725
  config = super().get_config()
413
726
  config.update(
414
727
  {
415
728
  "patch_size": self.patch_size,
416
- "num_heads": self.num_heads,
417
729
  "hidden_dim": self.hidden_dim,
418
- "depth": self.depth,
730
+ "num_layers": self.num_layers,
731
+ "num_heads": self.num_heads,
419
732
  "position_size": self.position_size,
420
- "output_dim": self.output_dim,
421
733
  "mlp_ratio": self.mlp_ratio,
422
734
  "latent_shape": self.latent_shape,
423
735
  "context_shape": self.context_shape,