keras-hub-nightly 0.16.1.dev202409250340__py3-none-any.whl → 0.16.1.dev202409270338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (357) hide show
  1. keras_hub/__init__.py +0 -13
  2. keras_hub/api/__init__.py +0 -13
  3. keras_hub/api/bounding_box/__init__.py +0 -13
  4. keras_hub/api/layers/__init__.py +3 -13
  5. keras_hub/api/metrics/__init__.py +0 -13
  6. keras_hub/api/models/__init__.py +16 -13
  7. keras_hub/api/samplers/__init__.py +0 -13
  8. keras_hub/api/tokenizers/__init__.py +1 -13
  9. keras_hub/api/utils/__init__.py +0 -13
  10. keras_hub/src/__init__.py +0 -13
  11. keras_hub/src/api_export.py +0 -14
  12. keras_hub/src/bounding_box/__init__.py +0 -13
  13. keras_hub/src/bounding_box/converters.py +0 -13
  14. keras_hub/src/bounding_box/formats.py +0 -13
  15. keras_hub/src/bounding_box/iou.py +1 -13
  16. keras_hub/src/bounding_box/to_dense.py +0 -14
  17. keras_hub/src/bounding_box/to_ragged.py +0 -13
  18. keras_hub/src/bounding_box/utils.py +0 -13
  19. keras_hub/src/bounding_box/validate_format.py +0 -14
  20. keras_hub/src/layers/__init__.py +0 -13
  21. keras_hub/src/layers/modeling/__init__.py +0 -13
  22. keras_hub/src/layers/modeling/alibi_bias.py +0 -13
  23. keras_hub/src/layers/modeling/cached_multi_head_attention.py +0 -14
  24. keras_hub/src/layers/modeling/f_net_encoder.py +0 -14
  25. keras_hub/src/layers/modeling/masked_lm_head.py +0 -14
  26. keras_hub/src/layers/modeling/position_embedding.py +0 -14
  27. keras_hub/src/layers/modeling/reversible_embedding.py +0 -14
  28. keras_hub/src/layers/modeling/rotary_embedding.py +0 -14
  29. keras_hub/src/layers/modeling/sine_position_encoding.py +0 -14
  30. keras_hub/src/layers/modeling/token_and_position_embedding.py +0 -14
  31. keras_hub/src/layers/modeling/transformer_decoder.py +0 -14
  32. keras_hub/src/layers/modeling/transformer_encoder.py +0 -14
  33. keras_hub/src/layers/modeling/transformer_layer_utils.py +0 -14
  34. keras_hub/src/layers/preprocessing/__init__.py +0 -13
  35. keras_hub/src/layers/preprocessing/audio_converter.py +0 -13
  36. keras_hub/src/layers/preprocessing/image_converter.py +0 -13
  37. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +0 -15
  38. keras_hub/src/layers/preprocessing/multi_segment_packer.py +0 -14
  39. keras_hub/src/layers/preprocessing/preprocessing_layer.py +0 -14
  40. keras_hub/src/layers/preprocessing/random_deletion.py +0 -14
  41. keras_hub/src/layers/preprocessing/random_swap.py +0 -14
  42. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -13
  43. keras_hub/src/layers/preprocessing/start_end_packer.py +0 -15
  44. keras_hub/src/metrics/__init__.py +0 -13
  45. keras_hub/src/metrics/bleu.py +0 -14
  46. keras_hub/src/metrics/edit_distance.py +0 -14
  47. keras_hub/src/metrics/perplexity.py +0 -14
  48. keras_hub/src/metrics/rouge_base.py +0 -14
  49. keras_hub/src/metrics/rouge_l.py +0 -14
  50. keras_hub/src/metrics/rouge_n.py +0 -14
  51. keras_hub/src/models/__init__.py +0 -13
  52. keras_hub/src/models/albert/__init__.py +0 -14
  53. keras_hub/src/models/albert/albert_backbone.py +0 -14
  54. keras_hub/src/models/albert/albert_masked_lm.py +0 -14
  55. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +0 -14
  56. keras_hub/src/models/albert/albert_presets.py +0 -14
  57. keras_hub/src/models/albert/albert_text_classifier.py +0 -14
  58. keras_hub/src/models/albert/albert_text_classifier_preprocessor.py +0 -14
  59. keras_hub/src/models/albert/albert_tokenizer.py +0 -14
  60. keras_hub/src/models/backbone.py +0 -14
  61. keras_hub/src/models/bart/__init__.py +0 -14
  62. keras_hub/src/models/bart/bart_backbone.py +0 -14
  63. keras_hub/src/models/bart/bart_presets.py +0 -13
  64. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +0 -15
  65. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +0 -15
  66. keras_hub/src/models/bart/bart_tokenizer.py +0 -15
  67. keras_hub/src/models/bert/__init__.py +0 -14
  68. keras_hub/src/models/bert/bert_backbone.py +0 -14
  69. keras_hub/src/models/bert/bert_masked_lm.py +0 -14
  70. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +0 -14
  71. keras_hub/src/models/bert/bert_presets.py +0 -13
  72. keras_hub/src/models/bert/bert_text_classifier.py +0 -14
  73. keras_hub/src/models/bert/bert_text_classifier_preprocessor.py +0 -14
  74. keras_hub/src/models/bert/bert_tokenizer.py +0 -14
  75. keras_hub/src/models/bloom/__init__.py +0 -14
  76. keras_hub/src/models/bloom/bloom_attention.py +0 -13
  77. keras_hub/src/models/bloom/bloom_backbone.py +0 -14
  78. keras_hub/src/models/bloom/bloom_causal_lm.py +0 -15
  79. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +0 -15
  80. keras_hub/src/models/bloom/bloom_decoder.py +0 -13
  81. keras_hub/src/models/bloom/bloom_presets.py +0 -13
  82. keras_hub/src/models/bloom/bloom_tokenizer.py +0 -15
  83. keras_hub/src/models/causal_lm.py +0 -14
  84. keras_hub/src/models/causal_lm_preprocessor.py +0 -13
  85. keras_hub/src/models/clip/__init__.py +0 -0
  86. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_encoder_block.py +8 -15
  87. keras_hub/src/models/clip/clip_preprocessor.py +134 -0
  88. keras_hub/src/models/clip/clip_text_encoder.py +139 -0
  89. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_tokenizer.py +65 -41
  90. keras_hub/src/models/csp_darknet/__init__.py +0 -13
  91. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +0 -13
  92. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -13
  93. keras_hub/src/models/deberta_v3/__init__.py +0 -14
  94. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +0 -15
  95. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +0 -15
  96. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +0 -14
  97. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -13
  98. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +0 -15
  99. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier_preprocessor.py +0 -14
  100. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +0 -15
  101. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +0 -14
  102. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +0 -14
  103. keras_hub/src/models/deberta_v3/relative_embedding.py +0 -14
  104. keras_hub/src/models/densenet/__init__.py +5 -13
  105. keras_hub/src/models/densenet/densenet_backbone.py +11 -21
  106. keras_hub/src/models/densenet/densenet_image_classifier.py +27 -17
  107. keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py +27 -0
  108. keras_hub/src/models/{stable_diffusion_v3/__init__.py → densenet/densenet_image_converter.py} +10 -0
  109. keras_hub/src/models/densenet/densenet_presets.py +56 -0
  110. keras_hub/src/models/distil_bert/__init__.py +0 -14
  111. keras_hub/src/models/distil_bert/distil_bert_backbone.py +0 -15
  112. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +0 -15
  113. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +0 -14
  114. keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -13
  115. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +0 -15
  116. keras_hub/src/models/distil_bert/distil_bert_text_classifier_preprocessor.py +0 -15
  117. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +0 -15
  118. keras_hub/src/models/efficientnet/__init__.py +0 -13
  119. keras_hub/src/models/efficientnet/efficientnet_backbone.py +0 -13
  120. keras_hub/src/models/efficientnet/fusedmbconv.py +0 -14
  121. keras_hub/src/models/efficientnet/mbconv.py +0 -14
  122. keras_hub/src/models/electra/__init__.py +0 -14
  123. keras_hub/src/models/electra/electra_backbone.py +0 -14
  124. keras_hub/src/models/electra/electra_presets.py +0 -13
  125. keras_hub/src/models/electra/electra_tokenizer.py +0 -14
  126. keras_hub/src/models/f_net/__init__.py +0 -14
  127. keras_hub/src/models/f_net/f_net_backbone.py +0 -15
  128. keras_hub/src/models/f_net/f_net_masked_lm.py +0 -15
  129. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +0 -14
  130. keras_hub/src/models/f_net/f_net_presets.py +0 -13
  131. keras_hub/src/models/f_net/f_net_text_classifier.py +0 -15
  132. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +0 -15
  133. keras_hub/src/models/f_net/f_net_tokenizer.py +0 -15
  134. keras_hub/src/models/falcon/__init__.py +0 -14
  135. keras_hub/src/models/falcon/falcon_attention.py +0 -13
  136. keras_hub/src/models/falcon/falcon_backbone.py +0 -13
  137. keras_hub/src/models/falcon/falcon_causal_lm.py +0 -14
  138. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +0 -14
  139. keras_hub/src/models/falcon/falcon_presets.py +0 -13
  140. keras_hub/src/models/falcon/falcon_tokenizer.py +0 -15
  141. keras_hub/src/models/falcon/falcon_transformer_decoder.py +0 -13
  142. keras_hub/src/models/feature_pyramid_backbone.py +0 -13
  143. keras_hub/src/models/gemma/__init__.py +0 -14
  144. keras_hub/src/models/gemma/gemma_attention.py +0 -13
  145. keras_hub/src/models/gemma/gemma_backbone.py +0 -15
  146. keras_hub/src/models/gemma/gemma_causal_lm.py +0 -15
  147. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +0 -14
  148. keras_hub/src/models/gemma/gemma_decoder_block.py +0 -13
  149. keras_hub/src/models/gemma/gemma_presets.py +0 -13
  150. keras_hub/src/models/gemma/gemma_tokenizer.py +0 -14
  151. keras_hub/src/models/gemma/rms_normalization.py +0 -14
  152. keras_hub/src/models/gpt2/__init__.py +0 -14
  153. keras_hub/src/models/gpt2/gpt2_backbone.py +0 -15
  154. keras_hub/src/models/gpt2/gpt2_causal_lm.py +0 -15
  155. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +0 -14
  156. keras_hub/src/models/gpt2/gpt2_preprocessor.py +0 -15
  157. keras_hub/src/models/gpt2/gpt2_presets.py +0 -13
  158. keras_hub/src/models/gpt2/gpt2_tokenizer.py +0 -15
  159. keras_hub/src/models/gpt_neo_x/__init__.py +0 -13
  160. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +0 -14
  161. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +0 -14
  162. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +0 -14
  163. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +0 -14
  164. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +0 -14
  165. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +0 -14
  166. keras_hub/src/models/image_classifier.py +0 -13
  167. keras_hub/src/models/image_classifier_preprocessor.py +0 -13
  168. keras_hub/src/models/image_segmenter.py +0 -13
  169. keras_hub/src/models/llama/__init__.py +0 -14
  170. keras_hub/src/models/llama/llama_attention.py +0 -13
  171. keras_hub/src/models/llama/llama_backbone.py +0 -13
  172. keras_hub/src/models/llama/llama_causal_lm.py +0 -13
  173. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +0 -15
  174. keras_hub/src/models/llama/llama_decoder.py +0 -13
  175. keras_hub/src/models/llama/llama_layernorm.py +0 -13
  176. keras_hub/src/models/llama/llama_presets.py +0 -13
  177. keras_hub/src/models/llama/llama_tokenizer.py +0 -14
  178. keras_hub/src/models/llama3/__init__.py +0 -14
  179. keras_hub/src/models/llama3/llama3_backbone.py +0 -14
  180. keras_hub/src/models/llama3/llama3_causal_lm.py +0 -13
  181. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +0 -14
  182. keras_hub/src/models/llama3/llama3_presets.py +0 -13
  183. keras_hub/src/models/llama3/llama3_tokenizer.py +0 -14
  184. keras_hub/src/models/masked_lm.py +0 -13
  185. keras_hub/src/models/masked_lm_preprocessor.py +0 -13
  186. keras_hub/src/models/mistral/__init__.py +0 -14
  187. keras_hub/src/models/mistral/mistral_attention.py +0 -13
  188. keras_hub/src/models/mistral/mistral_backbone.py +0 -14
  189. keras_hub/src/models/mistral/mistral_causal_lm.py +0 -14
  190. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +0 -14
  191. keras_hub/src/models/mistral/mistral_layer_norm.py +0 -13
  192. keras_hub/src/models/mistral/mistral_presets.py +0 -13
  193. keras_hub/src/models/mistral/mistral_tokenizer.py +0 -14
  194. keras_hub/src/models/mistral/mistral_transformer_decoder.py +0 -13
  195. keras_hub/src/models/mix_transformer/__init__.py +0 -13
  196. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +0 -13
  197. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -13
  198. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +0 -13
  199. keras_hub/src/models/mobilenet/__init__.py +0 -13
  200. keras_hub/src/models/mobilenet/mobilenet_backbone.py +0 -13
  201. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -13
  202. keras_hub/src/models/opt/__init__.py +0 -14
  203. keras_hub/src/models/opt/opt_backbone.py +0 -15
  204. keras_hub/src/models/opt/opt_causal_lm.py +0 -15
  205. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +0 -13
  206. keras_hub/src/models/opt/opt_presets.py +0 -13
  207. keras_hub/src/models/opt/opt_tokenizer.py +0 -15
  208. keras_hub/src/models/pali_gemma/__init__.py +0 -13
  209. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +0 -13
  210. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +0 -13
  211. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +0 -13
  212. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +0 -14
  213. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +0 -13
  214. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +0 -13
  215. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +0 -13
  216. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +0 -13
  217. keras_hub/src/models/phi3/__init__.py +0 -14
  218. keras_hub/src/models/phi3/phi3_attention.py +0 -13
  219. keras_hub/src/models/phi3/phi3_backbone.py +0 -13
  220. keras_hub/src/models/phi3/phi3_causal_lm.py +0 -13
  221. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +0 -14
  222. keras_hub/src/models/phi3/phi3_decoder.py +0 -13
  223. keras_hub/src/models/phi3/phi3_layernorm.py +0 -13
  224. keras_hub/src/models/phi3/phi3_presets.py +0 -13
  225. keras_hub/src/models/phi3/phi3_rotary_embedding.py +0 -13
  226. keras_hub/src/models/phi3/phi3_tokenizer.py +0 -13
  227. keras_hub/src/models/preprocessor.py +51 -32
  228. keras_hub/src/models/resnet/__init__.py +0 -14
  229. keras_hub/src/models/resnet/resnet_backbone.py +0 -13
  230. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -13
  231. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +0 -14
  232. keras_hub/src/models/resnet/resnet_image_converter.py +0 -13
  233. keras_hub/src/models/resnet/resnet_presets.py +0 -13
  234. keras_hub/src/models/retinanet/__init__.py +0 -13
  235. keras_hub/src/models/retinanet/anchor_generator.py +0 -14
  236. keras_hub/src/models/retinanet/box_matcher.py +0 -14
  237. keras_hub/src/models/retinanet/non_max_supression.py +0 -14
  238. keras_hub/src/models/roberta/__init__.py +0 -14
  239. keras_hub/src/models/roberta/roberta_backbone.py +0 -15
  240. keras_hub/src/models/roberta/roberta_masked_lm.py +0 -15
  241. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +0 -14
  242. keras_hub/src/models/roberta/roberta_presets.py +0 -13
  243. keras_hub/src/models/roberta/roberta_text_classifier.py +0 -15
  244. keras_hub/src/models/roberta/roberta_text_classifier_preprocessor.py +0 -14
  245. keras_hub/src/models/roberta/roberta_tokenizer.py +0 -15
  246. keras_hub/src/models/sam/__init__.py +0 -13
  247. keras_hub/src/models/sam/sam_backbone.py +0 -14
  248. keras_hub/src/models/sam/sam_image_segmenter.py +0 -14
  249. keras_hub/src/models/sam/sam_layers.py +0 -14
  250. keras_hub/src/models/sam/sam_mask_decoder.py +0 -14
  251. keras_hub/src/models/sam/sam_prompt_encoder.py +0 -14
  252. keras_hub/src/models/sam/sam_transformer.py +0 -14
  253. keras_hub/src/models/seq_2_seq_lm.py +0 -13
  254. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +0 -13
  255. keras_hub/src/models/stable_diffusion_3/__init__.py +9 -0
  256. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +80 -0
  257. keras_hub/src/models/{stable_diffusion_v3 → stable_diffusion_3}/mmdit.py +351 -39
  258. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +631 -0
  259. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +31 -0
  260. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +138 -0
  261. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +83 -0
  262. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py} +7 -20
  263. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +320 -0
  264. keras_hub/src/models/t5/__init__.py +0 -14
  265. keras_hub/src/models/t5/t5_backbone.py +0 -14
  266. keras_hub/src/models/t5/t5_layer_norm.py +0 -14
  267. keras_hub/src/models/t5/t5_multi_head_attention.py +0 -14
  268. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_preprocessor.py → t5/t5_preprocessor.py} +12 -16
  269. keras_hub/src/models/t5/t5_presets.py +0 -13
  270. keras_hub/src/models/t5/t5_tokenizer.py +0 -14
  271. keras_hub/src/models/t5/t5_transformer_layer.py +0 -14
  272. keras_hub/src/models/task.py +0 -14
  273. keras_hub/src/models/text_classifier.py +0 -13
  274. keras_hub/src/models/text_classifier_preprocessor.py +0 -13
  275. keras_hub/src/models/text_to_image.py +282 -0
  276. keras_hub/src/models/vgg/__init__.py +0 -13
  277. keras_hub/src/models/vgg/vgg_backbone.py +0 -13
  278. keras_hub/src/models/vgg/vgg_image_classifier.py +0 -13
  279. keras_hub/src/models/vit_det/__init__.py +0 -13
  280. keras_hub/src/models/vit_det/vit_det_backbone.py +0 -14
  281. keras_hub/src/models/vit_det/vit_layers.py +0 -15
  282. keras_hub/src/models/whisper/__init__.py +0 -14
  283. keras_hub/src/models/whisper/whisper_audio_converter.py +0 -15
  284. keras_hub/src/models/whisper/whisper_backbone.py +0 -15
  285. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +0 -13
  286. keras_hub/src/models/whisper/whisper_decoder.py +0 -14
  287. keras_hub/src/models/whisper/whisper_encoder.py +0 -14
  288. keras_hub/src/models/whisper/whisper_presets.py +0 -14
  289. keras_hub/src/models/whisper/whisper_tokenizer.py +0 -14
  290. keras_hub/src/models/xlm_roberta/__init__.py +0 -14
  291. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +0 -15
  292. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +0 -15
  293. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +0 -14
  294. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -13
  295. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +0 -15
  296. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier_preprocessor.py +0 -15
  297. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +0 -15
  298. keras_hub/src/models/xlnet/__init__.py +0 -13
  299. keras_hub/src/models/xlnet/relative_attention.py +0 -14
  300. keras_hub/src/models/xlnet/xlnet_backbone.py +0 -14
  301. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +0 -14
  302. keras_hub/src/models/xlnet/xlnet_encoder.py +0 -14
  303. keras_hub/src/samplers/__init__.py +0 -13
  304. keras_hub/src/samplers/beam_sampler.py +0 -14
  305. keras_hub/src/samplers/contrastive_sampler.py +0 -14
  306. keras_hub/src/samplers/greedy_sampler.py +0 -14
  307. keras_hub/src/samplers/random_sampler.py +0 -14
  308. keras_hub/src/samplers/sampler.py +0 -14
  309. keras_hub/src/samplers/serialization.py +0 -14
  310. keras_hub/src/samplers/top_k_sampler.py +0 -14
  311. keras_hub/src/samplers/top_p_sampler.py +0 -14
  312. keras_hub/src/tests/__init__.py +0 -13
  313. keras_hub/src/tests/test_case.py +0 -14
  314. keras_hub/src/tokenizers/__init__.py +0 -13
  315. keras_hub/src/tokenizers/byte_pair_tokenizer.py +0 -14
  316. keras_hub/src/tokenizers/byte_tokenizer.py +0 -14
  317. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +0 -14
  318. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +0 -14
  319. keras_hub/src/tokenizers/tokenizer.py +23 -27
  320. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +0 -15
  321. keras_hub/src/tokenizers/word_piece_tokenizer.py +0 -14
  322. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +0 -15
  323. keras_hub/src/utils/__init__.py +0 -13
  324. keras_hub/src/utils/imagenet/__init__.py +0 -13
  325. keras_hub/src/utils/imagenet/imagenet_utils.py +0 -13
  326. keras_hub/src/utils/keras_utils.py +0 -14
  327. keras_hub/src/utils/pipeline_model.py +0 -14
  328. keras_hub/src/utils/preset_utils.py +32 -76
  329. keras_hub/src/utils/python_utils.py +0 -13
  330. keras_hub/src/utils/tensor_utils.py +0 -14
  331. keras_hub/src/utils/timm/__init__.py +0 -13
  332. keras_hub/src/utils/timm/convert_densenet.py +107 -0
  333. keras_hub/src/utils/timm/convert_resnet.py +0 -13
  334. keras_hub/src/utils/timm/preset_loader.py +3 -13
  335. keras_hub/src/utils/transformers/__init__.py +0 -13
  336. keras_hub/src/utils/transformers/convert_albert.py +0 -13
  337. keras_hub/src/utils/transformers/convert_bart.py +0 -13
  338. keras_hub/src/utils/transformers/convert_bert.py +0 -13
  339. keras_hub/src/utils/transformers/convert_distilbert.py +0 -13
  340. keras_hub/src/utils/transformers/convert_gemma.py +0 -13
  341. keras_hub/src/utils/transformers/convert_gpt2.py +0 -13
  342. keras_hub/src/utils/transformers/convert_llama3.py +0 -13
  343. keras_hub/src/utils/transformers/convert_mistral.py +0 -13
  344. keras_hub/src/utils/transformers/convert_pali_gemma.py +0 -13
  345. keras_hub/src/utils/transformers/preset_loader.py +1 -15
  346. keras_hub/src/utils/transformers/safetensor_utils.py +9 -15
  347. keras_hub/src/version_utils.py +1 -15
  348. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409270338.dist-info}/METADATA +30 -27
  349. keras_hub_nightly-0.16.1.dev202409270338.dist-info/RECORD +351 -0
  350. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +0 -93
  351. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +0 -149
  352. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +0 -317
  353. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +0 -126
  354. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +0 -186
  355. keras_hub_nightly-0.16.1.dev202409250340.dist-info/RECORD +0 -342
  356. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409270338.dist-info}/WHEEL +0 -0
  357. {keras_hub_nightly-0.16.1.dev202409250340.dist-info → keras_hub_nightly-0.16.1.dev202409270338.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,138 @@
1
+ from keras import ops
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
5
+ StableDiffusion3Backbone,
6
+ )
7
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import (
8
+ StableDiffusion3TextToImagePreprocessor,
9
+ )
10
+ from keras_hub.src.models.text_to_image import TextToImage
11
+
12
+
13
+ @keras_hub_export("keras_hub.models.StableDiffusion3TextToImage")
14
+ class StableDiffusion3TextToImage(TextToImage):
15
+ """An end-to-end Stable Diffusion 3 model for text-to-image generation.
16
+
17
+ This model has a `generate()` method, which generates image based on a
18
+ prompt.
19
+
20
+ Args:
21
+ backbone: A `keras_hub.models.StableDiffusion3Backbone` instance.
22
+ preprocessor: A
23
+ `keras_hub.models.StableDiffusion3TextToImagePreprocessor` instance.
24
+
25
+ Examples:
26
+
27
+ Use `generate()` to do image generation.
28
+ ```python
29
+ text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset(
30
+ "stable_diffusion_3_medium", height=512, width=512
31
+ )
32
+ text_to_image.generate(
33
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
34
+ )
35
+
36
+ # Generate with batched prompts.
37
+ text_to_image.generate(
38
+ ["cute wallpaper art of a cat", "cute wallpaper art of a dog"]
39
+ )
40
+
41
+ # Generate with different `num_steps` and `classifier_free_guidance_scale`.
42
+ text_to_image.generate(
43
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
44
+ num_steps=50,
45
+ classifier_free_guidance_scale=5.0,
46
+ )
47
+ ```
48
+ """
49
+
50
+ backbone_cls = StableDiffusion3Backbone
51
+ preprocessor_cls = StableDiffusion3TextToImagePreprocessor
52
+
53
+ def __init__(
54
+ self,
55
+ backbone,
56
+ preprocessor,
57
+ **kwargs,
58
+ ):
59
+ # === Layers ===
60
+ self.backbone = backbone
61
+ self.preprocessor = preprocessor
62
+
63
+ # === Functional Model ===
64
+ inputs = backbone.input
65
+ outputs = backbone.output
66
+ super().__init__(
67
+ inputs=inputs,
68
+ outputs=outputs,
69
+ **kwargs,
70
+ )
71
+
72
+ def fit(self, *args, **kwargs):
73
+ raise NotImplementedError(
74
+ "Currently, `fit` is not supported for "
75
+ "`StableDiffusion3TextToImage`."
76
+ )
77
+
78
+ def generate_step(
79
+ self,
80
+ latents,
81
+ token_ids,
82
+ negative_token_ids,
83
+ num_steps,
84
+ guidance_scale,
85
+ ):
86
+ """A compilable generation function for batched of inputs.
87
+
88
+ This function represents the inner, XLA-compilable, generation function
89
+ for batched inputs.
90
+
91
+ Args:
92
+ latents: A (batch_size, height, width, channels) tensor
93
+ containing the latents to start generation from. Typically, this
94
+ tensor is sampled from the Gaussian distribution.
95
+ token_ids: A (batch_size, num_tokens) tensor containing the
96
+ tokens based on the input prompts.
97
+ negative_token_ids: A (batch_size, num_tokens) tensor
98
+ containing the negative tokens based on the input prompts.
99
+ num_steps: int. The number of diffusion steps to take.
100
+ guidance_scale: float. The classifier free guidance scale defined in
101
+ [Classifier-Free Diffusion Guidance](
102
+ https://arxiv.org/abs/2207.12598). Higher scale encourages to
103
+ generate images that are closely linked to prompts, usually at
104
+ the expense of lower image quality.
105
+ """
106
+ # Encode inputs.
107
+ embeddings = self.backbone.encode_step(token_ids, negative_token_ids)
108
+
109
+ # Denoise.
110
+ def body_fun(step, latents):
111
+ return self.backbone.denoise_step(
112
+ latents,
113
+ embeddings,
114
+ step,
115
+ num_steps,
116
+ guidance_scale,
117
+ )
118
+
119
+ latents = ops.fori_loop(0, num_steps, body_fun, latents)
120
+
121
+ # Decode.
122
+ return self.backbone.decode_step(latents)
123
+
124
+ def generate(
125
+ self,
126
+ inputs,
127
+ negative_inputs=None,
128
+ num_steps=28,
129
+ guidance_scale=7.0,
130
+ seed=None,
131
+ ):
132
+ return super().generate(
133
+ inputs,
134
+ negative_inputs=negative_inputs,
135
+ num_steps=num_steps,
136
+ guidance_scale=guidance_scale,
137
+ seed=seed,
138
+ )
@@ -0,0 +1,83 @@
1
+ import keras
2
+ from keras import layers
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.models.preprocessor import Preprocessor
6
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
7
+ StableDiffusion3Backbone,
8
+ )
9
+
10
+
11
+ @keras_hub_export("keras_hub.models.StableDiffusion3TextToImagePreprocessor")
12
+ class StableDiffusion3TextToImagePreprocessor(Preprocessor):
13
+ """Stable Diffusion 3 text-to-image model preprocessor.
14
+
15
+ This preprocessing layer is meant for use with
16
+ `keras_hub.models.StableDiffusion3TextToImage`.
17
+
18
+ For use with generation, the layer exposes one methods
19
+ `generate_preprocess()`.
20
+
21
+ Args:
22
+ clip_l_preprocessor: A `keras_hub.models.CLIPPreprocessor` instance.
23
+ clip_g_preprocessor: A `keras_hub.models.CLIPPreprocessor` instance.
24
+ t5_preprocessor: A optional `keras_hub.models.T5Preprocessor` instance.
25
+ """
26
+
27
+ backbone_cls = StableDiffusion3Backbone
28
+
29
+ def __init__(
30
+ self,
31
+ clip_l_preprocessor,
32
+ clip_g_preprocessor,
33
+ t5_preprocessor=None,
34
+ **kwargs,
35
+ ):
36
+ super().__init__(**kwargs)
37
+ self.clip_l_preprocessor = clip_l_preprocessor
38
+ self.clip_g_preprocessor = clip_g_preprocessor
39
+ self.t5_preprocessor = t5_preprocessor
40
+
41
+ @property
42
+ def sequence_length(self):
43
+ """The padded length of model input sequences."""
44
+ return self.clip_l_preprocessor.sequence_length
45
+
46
+ def build(self, input_shape):
47
+ self.built = True
48
+
49
+ def generate_preprocess(self, x):
50
+ token_ids = {}
51
+ token_ids["clip_l"] = self.clip_l_preprocessor(x)["token_ids"]
52
+ token_ids["clip_g"] = self.clip_g_preprocessor(x)["token_ids"]
53
+ if self.t5_preprocessor is not None:
54
+ token_ids["t5"] = self.t5_preprocessor(x)["token_ids"]
55
+ return token_ids
56
+
57
+ def get_config(self):
58
+ config = super().get_config()
59
+ config.update(
60
+ {
61
+ "clip_l_preprocessor": layers.serialize(
62
+ self.clip_l_preprocessor
63
+ ),
64
+ "clip_g_preprocessor": layers.serialize(
65
+ self.clip_g_preprocessor
66
+ ),
67
+ "t5_preprocessor": layers.serialize(self.t5_preprocessor),
68
+ }
69
+ )
70
+ return config
71
+
72
+ @classmethod
73
+ def from_config(cls, config):
74
+ for layer_name in (
75
+ "clip_l_preprocessor",
76
+ "clip_g_preprocessor",
77
+ "t5_preprocessor",
78
+ ):
79
+ if layer_name in config and isinstance(config[layer_name], dict):
80
+ config[layer_name] = keras.layers.deserialize(
81
+ config[layer_name]
82
+ )
83
+ return cls(**config)
@@ -1,16 +1,3 @@
1
- # Copyright 2024 The KerasHub Authors
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
1
  import keras
15
2
 
16
3
  from keras_hub.src.layers.modeling.reversible_embedding import (
@@ -20,7 +7,7 @@ from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm
20
7
  from keras_hub.src.models.t5.t5_transformer_layer import T5TransformerLayer
21
8
 
22
9
 
23
- class T5XXLTextEncoder(keras.Model):
10
+ class T5Encoder(keras.Model):
24
11
  def __init__(
25
12
  self,
26
13
  vocabulary_size,
@@ -81,10 +68,10 @@ class T5XXLTextEncoder(keras.Model):
81
68
 
82
69
  # === Functional Model ===
83
70
  encoder_token_id_input = keras.Input(
84
- shape=(None,), dtype="int32", name="encoder_token_ids"
71
+ shape=(None,), dtype="int32", name="token_ids"
85
72
  )
86
73
  encoder_padding_mask_input = keras.Input(
87
- shape=(None,), dtype="int32", name="encoder_padding_mask"
74
+ shape=(None,), dtype="int32", name="padding_mask"
88
75
  )
89
76
  # Encoder.
90
77
  x = self.token_embedding(encoder_token_id_input)
@@ -102,14 +89,14 @@ class T5XXLTextEncoder(keras.Model):
102
89
  x, position_bias = output
103
90
  x = self.encoder_layer_norm(x)
104
91
  x = self.encoder_dropout(x)
105
- encoder_output = x
92
+ sequence_output = x
106
93
 
107
94
  super().__init__(
108
95
  {
109
- "encoder_token_ids": encoder_token_id_input,
110
- "encoder_padding_mask": encoder_padding_mask_input,
96
+ "token_ids": encoder_token_id_input,
97
+ "padding_mask": encoder_padding_mask_input,
111
98
  },
112
- outputs=encoder_output,
99
+ outputs=sequence_output,
113
100
  **kwargs,
114
101
  )
115
102
 
@@ -0,0 +1,320 @@
1
+ import math
2
+
3
+ from keras import layers
4
+ from keras import ops
5
+
6
+ from keras_hub.src.models.backbone import Backbone
7
+ from keras_hub.src.utils.keras_utils import standardize_data_format
8
+
9
+
10
+ class VAEAttention(layers.Layer):
11
+ def __init__(self, filters, groups=32, data_format=None, **kwargs):
12
+ super().__init__(**kwargs)
13
+ self.filters = filters
14
+ self.data_format = standardize_data_format(data_format)
15
+ gn_axis = -1 if self.data_format == "channels_last" else 1
16
+
17
+ self.group_norm = layers.GroupNormalization(
18
+ groups=groups,
19
+ axis=gn_axis,
20
+ epsilon=1e-6,
21
+ dtype="float32",
22
+ name="group_norm",
23
+ )
24
+ self.query_conv2d = layers.Conv2D(
25
+ filters,
26
+ 1,
27
+ 1,
28
+ data_format=self.data_format,
29
+ dtype=self.dtype_policy,
30
+ name="query_conv2d",
31
+ )
32
+ self.key_conv2d = layers.Conv2D(
33
+ filters,
34
+ 1,
35
+ 1,
36
+ data_format=self.data_format,
37
+ dtype=self.dtype_policy,
38
+ name="key_conv2d",
39
+ )
40
+ self.value_conv2d = layers.Conv2D(
41
+ filters,
42
+ 1,
43
+ 1,
44
+ data_format=self.data_format,
45
+ dtype=self.dtype_policy,
46
+ name="value_conv2d",
47
+ )
48
+ self.softmax = layers.Softmax(dtype="float32")
49
+ self.output_conv2d = layers.Conv2D(
50
+ filters,
51
+ 1,
52
+ 1,
53
+ data_format=self.data_format,
54
+ dtype=self.dtype_policy,
55
+ name="output_conv2d",
56
+ )
57
+
58
+ self.groups = groups
59
+ self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))
60
+
61
+ def build(self, input_shape):
62
+ self.group_norm.build(input_shape)
63
+ self.query_conv2d.build(input_shape)
64
+ self.key_conv2d.build(input_shape)
65
+ self.value_conv2d.build(input_shape)
66
+ self.output_conv2d.build(input_shape)
67
+
68
+ def call(self, inputs, training=None):
69
+ x = self.group_norm(inputs)
70
+ query = self.query_conv2d(x)
71
+ key = self.key_conv2d(x)
72
+ value = self.value_conv2d(x)
73
+
74
+ if self.data_format == "channels_first":
75
+ query = ops.transpose(query, (0, 2, 3, 1))
76
+ key = ops.transpose(key, (0, 2, 3, 1))
77
+ value = ops.transpose(value, (0, 2, 3, 1))
78
+ shape = ops.shape(inputs)
79
+ b = shape[0]
80
+ query = ops.reshape(query, (b, -1, self.filters))
81
+ key = ops.reshape(key, (b, -1, self.filters))
82
+ value = ops.reshape(value, (b, -1, self.filters))
83
+
84
+ # Compute attention.
85
+ query = ops.multiply(
86
+ query, ops.cast(self._inverse_sqrt_filters, query.dtype)
87
+ )
88
+ # [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
89
+ attention_scores = ops.einsum("abc,adc->abd", query, key)
90
+ attention_scores = ops.cast(
91
+ self.softmax(attention_scores), self.compute_dtype
92
+ )
93
+ # [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
94
+ attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
95
+ x = ops.reshape(attention_output, shape)
96
+
97
+ x = self.output_conv2d(x)
98
+ if self.data_format == "channels_first":
99
+ x = ops.transpose(x, (0, 3, 1, 2))
100
+ x = ops.add(x, inputs)
101
+ return x
102
+
103
+ def get_config(self):
104
+ config = super().get_config()
105
+ config.update(
106
+ {
107
+ "filters": self.filters,
108
+ "groups": self.groups,
109
+ }
110
+ )
111
+ return config
112
+
113
+ def compute_output_shape(self, input_shape):
114
+ return input_shape
115
+
116
+
117
+ def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None):
118
+ data_format = standardize_data_format(data_format)
119
+ gn_axis = -1 if data_format == "channels_last" else 1
120
+ input_filters = x.shape[gn_axis]
121
+
122
+ residual = x
123
+ x = layers.GroupNormalization(
124
+ groups=32,
125
+ axis=gn_axis,
126
+ epsilon=1e-6,
127
+ dtype="float32",
128
+ name=f"{name}_norm1",
129
+ )(x)
130
+ x = layers.Activation("swish", dtype=dtype)(x)
131
+ x = layers.Conv2D(
132
+ filters,
133
+ 3,
134
+ 1,
135
+ padding="same",
136
+ data_format=data_format,
137
+ dtype=dtype,
138
+ name=f"{name}_conv1",
139
+ )(x)
140
+ x = layers.GroupNormalization(
141
+ groups=32,
142
+ axis=gn_axis,
143
+ epsilon=1e-6,
144
+ dtype="float32",
145
+ name=f"{name}_norm2",
146
+ )(x)
147
+ x = layers.Activation("swish", dtype=dtype)(x)
148
+ x = layers.Conv2D(
149
+ filters,
150
+ 3,
151
+ 1,
152
+ padding="same",
153
+ data_format=data_format,
154
+ dtype=dtype,
155
+ name=f"{name}_conv2",
156
+ )(x)
157
+ if input_filters != filters:
158
+ residual = layers.Conv2D(
159
+ filters,
160
+ 1,
161
+ 1,
162
+ data_format=data_format,
163
+ dtype=dtype,
164
+ name=f"{name}_residual_projection",
165
+ )(residual)
166
+ x = layers.Add(dtype=dtype)([residual, x])
167
+ return x
168
+
169
+
170
+ class VAEImageDecoder(Backbone):
171
+ """Decoder for the VAE model used in Stable Diffusion 3.
172
+
173
+ Args:
174
+ stackwise_num_filters: list of ints. The number of filters for each
175
+ stack.
176
+ stackwise_num_blocks: list of ints. The number of blocks for each stack.
177
+ output_channels: int. The number of channels in the output.
178
+ latent_shape: tuple. The shape of the latent image.
179
+ data_format: `None` or str. If specified, either `"channels_last"` or
180
+ `"channels_first"`. The ordering of the dimensions in the
181
+ inputs. `"channels_last"` corresponds to inputs with shape
182
+ `(batch_size, height, width, channels)`
183
+ while `"channels_first"` corresponds to inputs with shape
184
+ `(batch_size, channels, height, width)`. It defaults to the
185
+ `image_data_format` value found in your Keras config file at
186
+ `~/.keras/keras.json`. If you never set it, then it will be
187
+ `"channels_last"`.
188
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
189
+ to use for the model's computations and weights.
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ stackwise_num_filters,
195
+ stackwise_num_blocks,
196
+ output_channels=3,
197
+ latent_shape=(None, None, 16),
198
+ data_format=None,
199
+ dtype=None,
200
+ **kwargs,
201
+ ):
202
+ data_format = standardize_data_format(data_format)
203
+ gn_axis = -1 if data_format == "channels_last" else 1
204
+
205
+ # === Functional Model ===
206
+ latent_inputs = layers.Input(shape=latent_shape)
207
+
208
+ x = layers.Conv2D(
209
+ stackwise_num_filters[0],
210
+ 3,
211
+ 1,
212
+ padding="same",
213
+ data_format=data_format,
214
+ dtype=dtype,
215
+ name="input_projection",
216
+ )(latent_inputs)
217
+ x = apply_resnet_block(
218
+ x,
219
+ stackwise_num_filters[0],
220
+ data_format=data_format,
221
+ dtype=dtype,
222
+ name="input_block0",
223
+ )
224
+ x = VAEAttention(
225
+ stackwise_num_filters[0],
226
+ data_format=data_format,
227
+ dtype=dtype,
228
+ name="input_attention",
229
+ )(x)
230
+ x = apply_resnet_block(
231
+ x,
232
+ stackwise_num_filters[0],
233
+ data_format=data_format,
234
+ dtype=dtype,
235
+ name="input_block1",
236
+ )
237
+
238
+ # Stacks.
239
+ for i, filters in enumerate(stackwise_num_filters):
240
+ for j in range(stackwise_num_blocks[i]):
241
+ x = apply_resnet_block(
242
+ x,
243
+ filters,
244
+ data_format=data_format,
245
+ dtype=dtype,
246
+ name=f"block{i}_{j}",
247
+ )
248
+ if i != len(stackwise_num_filters) - 1:
249
+ # No upsamling in the last blcok.
250
+ x = layers.UpSampling2D(
251
+ 2,
252
+ data_format=data_format,
253
+ dtype=dtype,
254
+ name=f"upsample_{i}",
255
+ )(x)
256
+ x = layers.Conv2D(
257
+ filters,
258
+ 3,
259
+ 1,
260
+ padding="same",
261
+ data_format=data_format,
262
+ dtype=dtype,
263
+ name=f"upsample_{i}_conv",
264
+ )(x)
265
+
266
+ # Ouput block.
267
+ x = layers.GroupNormalization(
268
+ groups=32,
269
+ axis=gn_axis,
270
+ epsilon=1e-6,
271
+ dtype="float32",
272
+ name="output_norm",
273
+ )(x)
274
+ x = layers.Activation("swish", dtype=dtype, name="output_activation")(x)
275
+ image_outputs = layers.Conv2D(
276
+ output_channels,
277
+ 3,
278
+ 1,
279
+ padding="same",
280
+ data_format=data_format,
281
+ dtype=dtype,
282
+ name="output_projection",
283
+ )(x)
284
+ super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs)
285
+
286
+ # === Config ===
287
+ self.stackwise_num_filters = stackwise_num_filters
288
+ self.stackwise_num_blocks = stackwise_num_blocks
289
+ self.output_channels = output_channels
290
+ self.latent_shape = latent_shape
291
+
292
+ @property
293
+ def scaling_factor(self):
294
+ """The scaling factor for the latent space.
295
+
296
+ This is used to scale the latent space to have unit variance when
297
+ training the diffusion model.
298
+ """
299
+ return 1.5305
300
+
301
+ @property
302
+ def shift_factor(self):
303
+ """The shift factor for the latent space.
304
+
305
+ This is used to shift the latent space to have zero mean when
306
+ training the diffusion model.
307
+ """
308
+ return 0.0609
309
+
310
+ def get_config(self):
311
+ config = super().get_config()
312
+ config.update(
313
+ {
314
+ "stackwise_num_filters": self.stackwise_num_filters,
315
+ "stackwise_num_blocks": self.stackwise_num_blocks,
316
+ "output_channels": self.output_channels,
317
+ "image_shape": self.latent_shape,
318
+ }
319
+ )
320
+ return config
@@ -1,17 +1,3 @@
1
- # Copyright 2024 The KerasHub Authors
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
1
  from keras_hub.src.models.t5.t5_backbone import T5Backbone
16
2
  from keras_hub.src.models.t5.t5_presets import backbone_presets
17
3
  from keras_hub.src.utils.preset_utils import register_presets
@@ -1,17 +1,3 @@
1
- # Copyright 2024 The KerasHub Authors
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
1
  import keras
16
2
 
17
3
  from keras_hub.src.api_export import keras_hub_export
@@ -1,17 +1,3 @@
1
- # Copyright 2024 The KerasHub Authors
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
1
  import keras
16
2
  from keras import ops
17
3
 
@@ -1,17 +1,3 @@
1
- # Copyright 2024 The KerasHub Authors
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # https://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
1
  import keras
16
2
  import numpy as np
17
3
  from keras import ops