keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,178 @@
1
+ from keras import ops
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.image_to_image import ImageToImage
5
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( # noqa: E501
6
+ StableDiffusion3Backbone,
7
+ )
8
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( # noqa: E501
9
+ StableDiffusion3TextToImagePreprocessor,
10
+ )
11
+
12
+
13
+ @keras_hub_export("keras_hub.models.StableDiffusion3ImageToImage")
14
+ class StableDiffusion3ImageToImage(ImageToImage):
15
+ """An end-to-end Stable Diffusion 3 model for image-to-image generation.
16
+
17
+ This model has a `generate()` method, which generates images based
18
+ on a combination of a reference image and a text prompt.
19
+
20
+ Args:
21
+ backbone: A `keras_hub.models.StableDiffusion3Backbone` instance.
22
+ preprocessor: A
23
+ `keras_hub.models.StableDiffusion3TextToImagePreprocessor` instance.
24
+
25
+ Examples:
26
+
27
+ Use `generate()` to do image generation.
28
+ ```python
29
+ prompt = (
30
+ "Astronaut in a jungle, cold color palette, muted colors, "
31
+ "detailed, 8k"
32
+ )
33
+ image_to_image = keras_hub.models.StableDiffusion3ImageToImage.from_preset(
34
+ "stable_diffusion_3_medium", image_shape=(512, 512, 3)
35
+ )
36
+ image_to_image.generate(
37
+ {
38
+ "images": np.ones((512, 512, 3), dtype="float32"),
39
+ "prompts": prompt,
40
+ }
41
+ )
42
+
43
+ # Generate with batched prompts.
44
+ image_to_image.generate(
45
+ {
46
+ "images": np.ones((2, 512, 512, 3), dtype="float32"),
47
+ "prompts": [
48
+ "cute wallpaper art of a cat",
49
+ "cute wallpaper art of a dog",
50
+ ],
51
+ }
52
+ )
53
+
54
+ # Generate with different `num_steps`, `guidance_scale` and `strength`.
55
+ image_to_image.generate(
56
+ {
57
+ "images": np.ones((512, 512, 3), dtype="float32"),
58
+ "prompts": prompt,
59
+ }
60
+ num_steps=50,
61
+ guidance_scale=5.0,
62
+ strength=0.6,
63
+ )
64
+
65
+ # Generate with `negative_prompts`.
66
+ text_to_image.generate(
67
+ {
68
+ "images": np.ones((512, 512, 3), dtype="float32"),
69
+ "prompts": prompt,
70
+ "negative_prompts": "green color",
71
+ }
72
+ )
73
+ ```
74
+ """
75
+
76
+ backbone_cls = StableDiffusion3Backbone
77
+ preprocessor_cls = StableDiffusion3TextToImagePreprocessor
78
+
79
+ def __init__(
80
+ self,
81
+ backbone,
82
+ preprocessor,
83
+ **kwargs,
84
+ ):
85
+ # === Layers ===
86
+ self.backbone = backbone
87
+ self.preprocessor = preprocessor
88
+
89
+ # === Functional Model ===
90
+ inputs = backbone.input
91
+ outputs = backbone.output
92
+ super().__init__(
93
+ inputs=inputs,
94
+ outputs=outputs,
95
+ **kwargs,
96
+ )
97
+
98
+ def fit(self, *args, **kwargs):
99
+ raise NotImplementedError(
100
+ "Currently, `fit` is not supported for "
101
+ "`StableDiffusion3ImageToImage`."
102
+ )
103
+
104
+ def generate_step(
105
+ self,
106
+ images,
107
+ noises,
108
+ token_ids,
109
+ starting_step,
110
+ num_steps,
111
+ guidance_scale,
112
+ ):
113
+ """A compilable generation function for batched of inputs.
114
+
115
+ This function represents the inner, XLA-compilable, generation function
116
+ for batched inputs.
117
+
118
+ Args:
119
+ images: A (batch_size, image_height, image_width, 3) tensor
120
+ containing the reference images.
121
+ noises: A (batch_size, latent_height, latent_width, channels) tensor
122
+ containing the noises to be added to the latents. Typically,
123
+ this tensor is sampled from the Gaussian distribution.
124
+ token_ids: A pair of (batch_size, num_tokens) tensor containing the
125
+ tokens based on the input prompts and negative prompts.
126
+ starting_step: int. The number of the starting diffusion step.
127
+ num_steps: int. The number of diffusion steps to take.
128
+ guidance_scale: float. The classifier free guidance scale defined in
129
+ [Classifier-Free Diffusion Guidance](
130
+ https://arxiv.org/abs/2207.12598). Higher scale encourages to
131
+ generate images that are closely linked to prompts, usually at
132
+ the expense of lower image quality.
133
+ """
134
+ token_ids, negative_token_ids = token_ids
135
+
136
+ # Encode images.
137
+ latents = self.backbone.encode_image_step(images)
138
+
139
+ # Add noises to latents.
140
+ latents = self.backbone.add_noise_step(
141
+ latents, noises, starting_step, num_steps
142
+ )
143
+
144
+ # Encode inputs.
145
+ embeddings = self.backbone.encode_text_step(
146
+ token_ids, negative_token_ids
147
+ )
148
+
149
+ # Denoise.
150
+ def body_fun(step, latents):
151
+ return self.backbone.denoise_step(
152
+ latents,
153
+ embeddings,
154
+ step,
155
+ num_steps,
156
+ guidance_scale,
157
+ )
158
+
159
+ latents = ops.fori_loop(starting_step, num_steps, body_fun, latents)
160
+
161
+ # Decode.
162
+ return self.backbone.decode_step(latents)
163
+
164
+ def generate(
165
+ self,
166
+ inputs,
167
+ num_steps=50,
168
+ strength=0.8,
169
+ guidance_scale=7.0,
170
+ seed=None,
171
+ ):
172
+ return super().generate(
173
+ inputs,
174
+ num_steps=num_steps,
175
+ strength=strength,
176
+ guidance_scale=guidance_scale,
177
+ seed=seed,
178
+ )
@@ -0,0 +1,193 @@
1
+ from keras import ops
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.inpaint import Inpaint
5
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( # noqa: E501
6
+ StableDiffusion3Backbone,
7
+ )
8
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( # noqa: E501
9
+ StableDiffusion3TextToImagePreprocessor,
10
+ )
11
+
12
+
13
+ @keras_hub_export("keras_hub.models.StableDiffusion3Inpaint")
14
+ class StableDiffusion3Inpaint(Inpaint):
15
+ """An end-to-end Stable Diffusion 3 model for inpaint generation.
16
+
17
+ This model has a `generate()` method, which generates images based
18
+ on a combination of a reference image, mask and a text prompt.
19
+
20
+ Args:
21
+ backbone: A `keras_hub.models.StableDiffusion3Backbone` instance.
22
+ preprocessor: A
23
+ `keras_hub.models.StableDiffusion3TextToImagePreprocessor` instance.
24
+
25
+ Examples:
26
+
27
+ Use `generate()` to do image generation.
28
+ ```python
29
+ reference_image = np.ones((1024, 1024, 3), dtype="float32")
30
+ reference_mask = np.ones((1024, 1024), dtype="float32")
31
+ inpaint = keras_hub.models.StableDiffusion3Inpaint.from_preset(
32
+ "stable_diffusion_3_medium", image_shape=(512, 512, 3)
33
+ )
34
+ inpaint.generate(
35
+ reference_image,
36
+ reference_mask,
37
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
38
+ )
39
+
40
+ # Generate with batched prompts.
41
+ reference_images = np.ones((2, 512, 512, 3), dtype="float32")
42
+ reference_mask = np.ones((2, 1024, 1024), dtype="float32")
43
+ inpaint.generate(
44
+ reference_images,
45
+ reference_mask,
46
+ ["cute wallpaper art of a cat", "cute wallpaper art of a dog"]
47
+ )
48
+
49
+ # Generate with different `num_steps`, `guidance_scale` and `strength`.
50
+ inpaint.generate(
51
+ reference_image,
52
+ reference_mask,
53
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
54
+ num_steps=50,
55
+ guidance_scale=5.0,
56
+ strength=0.6,
57
+ )
58
+ ```
59
+ """
60
+
61
+ backbone_cls = StableDiffusion3Backbone
62
+ preprocessor_cls = StableDiffusion3TextToImagePreprocessor
63
+
64
+ def __init__(
65
+ self,
66
+ backbone,
67
+ preprocessor,
68
+ **kwargs,
69
+ ):
70
+ # === Layers ===
71
+ self.backbone = backbone
72
+ self.preprocessor = preprocessor
73
+
74
+ # === Functional Model ===
75
+ inputs = backbone.input
76
+ outputs = backbone.output
77
+ super().__init__(
78
+ inputs=inputs,
79
+ outputs=outputs,
80
+ **kwargs,
81
+ )
82
+
83
+ def fit(self, *args, **kwargs):
84
+ raise NotImplementedError(
85
+ "Currently, `fit` is not supported for `StableDiffusion3Inpaint`."
86
+ )
87
+
88
+ def generate_step(
89
+ self,
90
+ images,
91
+ masks,
92
+ noises,
93
+ token_ids,
94
+ starting_step,
95
+ num_steps,
96
+ guidance_scale,
97
+ ):
98
+ """A compilable generation function for batched of inputs.
99
+
100
+ This function represents the inner, XLA-compilable, generation function
101
+ for batched inputs.
102
+
103
+ Args:
104
+ images: A (batch_size, image_height, image_width, 3) tensor
105
+ containing the reference images.
106
+ masks: A (batch_size, image_height, image_width) tensor
107
+ containing the reference masks.
108
+ noises: A (batch_size, latent_height, latent_width, channels) tensor
109
+ containing the noises to be added to the latents. Typically,
110
+ this tensor is sampled from the Gaussian distribution.
111
+ token_ids: A pair of (batch_size, num_tokens) tensor containing the
112
+ tokens based on the input prompts and negative prompts.
113
+ starting_step: int. The number of the starting diffusion step.
114
+ num_steps: int. The number of diffusion steps to take.
115
+ guidance_scale: float. The classifier free guidance scale defined in
116
+ [Classifier-Free Diffusion Guidance](
117
+ https://arxiv.org/abs/2207.12598). Higher scale encourages to
118
+ generate images that are closely linked to prompts, usually at
119
+ the expense of lower image quality.
120
+ """
121
+ token_ids, negative_token_ids = token_ids
122
+
123
+ # Get masked images.
124
+ masks = ops.cast(ops.expand_dims(masks, axis=-1) > 0.5, images.dtype)
125
+ masks_latent_size = ops.image.resize(
126
+ masks,
127
+ (self.backbone.latent_shape[1], self.backbone.latent_shape[2]),
128
+ interpolation="nearest",
129
+ )
130
+
131
+ # Encode images.
132
+ image_latents = self.backbone.encode_image_step(images)
133
+
134
+ # Add noises to latents.
135
+ latents = self.backbone.add_noise_step(
136
+ image_latents, noises, starting_step, num_steps
137
+ )
138
+
139
+ # Encode inputs.
140
+ embeddings = self.backbone.encode_text_step(
141
+ token_ids, negative_token_ids
142
+ )
143
+
144
+ # Denoise.
145
+ def body_fun(step, latents):
146
+ latents = self.backbone.denoise_step(
147
+ latents,
148
+ embeddings,
149
+ step,
150
+ num_steps,
151
+ guidance_scale,
152
+ )
153
+
154
+ # Compute the previous latents x_t -> x_t-1.
155
+ def true_fn():
156
+ next_step = ops.add(step, 1)
157
+ return self.backbone.add_noise_step(
158
+ image_latents, noises, next_step, num_steps
159
+ )
160
+
161
+ init_latents = ops.cond(
162
+ step < ops.subtract(num_steps, 1),
163
+ true_fn,
164
+ lambda: ops.cast(image_latents, noises.dtype),
165
+ )
166
+ latents = ops.add(
167
+ ops.multiply(
168
+ ops.subtract(1.0, masks_latent_size), init_latents
169
+ ),
170
+ ops.multiply(masks_latent_size, latents),
171
+ )
172
+ return latents
173
+
174
+ latents = ops.fori_loop(starting_step, num_steps, body_fun, latents)
175
+
176
+ # Decode.
177
+ return self.backbone.decode_step(latents)
178
+
179
+ def generate(
180
+ self,
181
+ inputs,
182
+ num_steps=50,
183
+ strength=0.6,
184
+ guidance_scale=7.0,
185
+ seed=None,
186
+ ):
187
+ return super().generate(
188
+ inputs,
189
+ num_steps=num_steps,
190
+ strength=strength,
191
+ guidance_scale=guidance_scale,
192
+ seed=seed,
193
+ )
@@ -5,14 +5,50 @@ backbone_presets = {
5
5
  "metadata": {
6
6
  "description": (
7
7
  "3 billion parameter, including CLIP L and CLIP G text "
8
- "encoders, MMDiT generative model, and VAE decoder. "
8
+ "encoders, MMDiT generative model, and VAE autoencoder. "
9
9
  "Developed by Stability AI."
10
10
  ),
11
- "params": 2952806723,
12
- "official_name": "StableDiffusion3",
13
- "path": "stablediffusion3",
14
- "model_card": "https://arxiv.org/abs/2110.00476",
11
+ "params": 2987080931,
12
+ "path": "stable_diffusion_3",
15
13
  },
16
- "kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/1",
17
- }
14
+ "kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/4",
15
+ },
16
+ "stable_diffusion_3.5_medium": {
17
+ "metadata": {
18
+ "description": (
19
+ "3 billion parameter, including CLIP L and CLIP G text "
20
+ "encoders, MMDiT-X generative model, and VAE autoencoder. "
21
+ "Developed by Stability AI."
22
+ ),
23
+ "params": 3371793763,
24
+ "path": "stable_diffusion_3",
25
+ },
26
+ "kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3.5_medium/1",
27
+ },
28
+ "stable_diffusion_3.5_large": {
29
+ "metadata": {
30
+ "description": (
31
+ "9 billion parameter, including CLIP L and CLIP G text "
32
+ "encoders, MMDiT generative model, and VAE autoencoder. "
33
+ "Developed by Stability AI."
34
+ ),
35
+ "params": 9048410595,
36
+ "path": "stable_diffusion_3",
37
+ },
38
+ "kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large/2",
39
+ },
40
+ "stable_diffusion_3.5_large_turbo": {
41
+ "metadata": {
42
+ "description": (
43
+ "9 billion parameter, including CLIP L and CLIP G text "
44
+ "encoders, MMDiT generative model, and VAE autoencoder. "
45
+ "A timestep-distilled version that eliminates classifier-free "
46
+ "guidance and uses fewer steps for generation. "
47
+ "Developed by Stability AI."
48
+ ),
49
+ "params": 9048410595,
50
+ "path": "stable_diffusion_3",
51
+ },
52
+ "kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large_turbo/2",
53
+ },
18
54
  }
@@ -1,10 +1,10 @@
1
1
  from keras import ops
2
2
 
3
3
  from keras_hub.src.api_export import keras_hub_export
4
- from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
4
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( # noqa: E501
5
5
  StableDiffusion3Backbone,
6
6
  )
7
- from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import (
7
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( # noqa: E501
8
8
  StableDiffusion3TextToImagePreprocessor,
9
9
  )
10
10
  from keras_hub.src.models.text_to_image import TextToImage
@@ -27,7 +27,7 @@ class StableDiffusion3TextToImage(TextToImage):
27
27
  Use `generate()` to do image generation.
28
28
  ```python
29
29
  text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset(
30
- "stable_diffusion_3_medium", height=512, width=512
30
+ "stable_diffusion_3_medium", image_shape=(512, 512, 3)
31
31
  )
32
32
  text_to_image.generate(
33
33
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
@@ -38,11 +38,23 @@ class StableDiffusion3TextToImage(TextToImage):
38
38
  ["cute wallpaper art of a cat", "cute wallpaper art of a dog"]
39
39
  )
40
40
 
41
- # Generate with different `num_steps` and `classifier_free_guidance_scale`.
41
+ # Generate with different `num_steps` and `guidance_scale`.
42
42
  text_to_image.generate(
43
43
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
44
44
  num_steps=50,
45
- classifier_free_guidance_scale=5.0,
45
+ guidance_scale=5.0,
46
+ )
47
+
48
+ # Generate with `negative_prompts`.
49
+ prompt = (
50
+ "Astronaut in a jungle, cold color palette, muted colors, "
51
+ "detailed, 8k"
52
+ )
53
+ text_to_image.generate(
54
+ {
55
+ "prompts": prompt,
56
+ "negative_prompts": "green color",
57
+ }
46
58
  )
47
59
  ```
48
60
  """
@@ -79,7 +91,6 @@ class StableDiffusion3TextToImage(TextToImage):
79
91
  self,
80
92
  latents,
81
93
  token_ids,
82
- negative_token_ids,
83
94
  num_steps,
84
95
  guidance_scale,
85
96
  ):
@@ -92,10 +103,8 @@ class StableDiffusion3TextToImage(TextToImage):
92
103
  latents: A (batch_size, height, width, channels) tensor
93
104
  containing the latents to start generation from. Typically, this
94
105
  tensor is sampled from the Gaussian distribution.
95
- token_ids: A (batch_size, num_tokens) tensor containing the
96
- tokens based on the input prompts.
97
- negative_token_ids: A (batch_size, num_tokens) tensor
98
- containing the negative tokens based on the input prompts.
106
+ token_ids: A pair of (batch_size, num_tokens) tensor containing the
107
+ tokens based on the input prompts and negative prompts.
99
108
  num_steps: int. The number of diffusion steps to take.
100
109
  guidance_scale: float. The classifier free guidance scale defined in
101
110
  [Classifier-Free Diffusion Guidance](
@@ -103,8 +112,12 @@ class StableDiffusion3TextToImage(TextToImage):
103
112
  generate images that are closely linked to prompts, usually at
104
113
  the expense of lower image quality.
105
114
  """
106
- # Encode inputs.
107
- embeddings = self.backbone.encode_step(token_ids, negative_token_ids)
115
+ token_ids, negative_token_ids = token_ids
116
+
117
+ # Encode prompts.
118
+ embeddings = self.backbone.encode_text_step(
119
+ token_ids, negative_token_ids
120
+ )
108
121
 
109
122
  # Denoise.
110
123
  def body_fun(step, latents):
@@ -124,14 +137,12 @@ class StableDiffusion3TextToImage(TextToImage):
124
137
  def generate(
125
138
  self,
126
139
  inputs,
127
- negative_inputs=None,
128
140
  num_steps=28,
129
141
  guidance_scale=7.0,
130
142
  seed=None,
131
143
  ):
132
144
  return super().generate(
133
145
  inputs,
134
- negative_inputs=negative_inputs,
135
146
  num_steps=num_steps,
136
147
  guidance_scale=guidance_scale,
137
148
  seed=seed,
@@ -3,7 +3,7 @@ from keras import layers
3
3
 
4
4
  from keras_hub.src.api_export import keras_hub_export
5
5
  from keras_hub.src.models.preprocessor import Preprocessor
6
- from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
6
+ from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( # noqa: E501
7
7
  StableDiffusion3Backbone,
8
8
  )
9
9
 
@@ -42,11 +42,12 @@ class T5Backbone(Backbone):
42
42
  projections in the multi-head attention layers. Defaults to
43
43
  hidden_dim / num_heads.
44
44
  dropout: float. Dropout probability for the Transformer layers.
45
- activation: activation function (or activation string name). The
46
- activation to be used in the inner dense blocks of the
47
- Transformer layers. Defaults to `"relu"`.
45
+ activation: string. The activation function to use in the dense blocks
46
+ of the Transformer Layers.
48
47
  use_gated_activation: boolean. Whether to use activation gating in
49
- the inner dense blocks of the Transformer layers.
48
+ the inner dense blocks of the Transformer layers. When used with
49
+ the GELU activation function, this is referred to as GEGLU
50
+ (gated GLU) from https://arxiv.org/pdf/2002.05202.
50
51
  The original T5 architecture didn't use gating, but more
51
52
  recent versions do. Defaults to `True`.
52
53
  layer_norm_epsilon: float. Epsilon factor to be used in the
@@ -1,4 +1,4 @@
1
- """XLM-RoBERTa model preset configurations."""
1
+ """T5 model preset configurations."""
2
2
 
3
3
  backbone_presets = {
4
4
  "t5_small_multi": {
@@ -8,11 +8,17 @@ backbone_presets = {
8
8
  "Corpus (C4)."
9
9
  ),
10
10
  "params": 0,
11
- "official_name": "T5",
12
11
  "path": "t5",
13
- "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
14
12
  },
15
- "kaggle_handle": "kaggle://keras/t5/keras/t5_small_multi/2",
13
+ "kaggle_handle": "kaggle://keras/t5/keras/t5_small_multi/3",
14
+ },
15
+ "t5_1.1_small": {
16
+ "metadata": {
17
+ "description": (""),
18
+ "params": 60511616,
19
+ "path": "t5",
20
+ },
21
+ "kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_small/2",
16
22
  },
17
23
  "t5_base_multi": {
18
24
  "metadata": {
@@ -21,11 +27,17 @@ backbone_presets = {
21
27
  "Corpus (C4)."
22
28
  ),
23
29
  "params": 0,
24
- "official_name": "T5",
25
30
  "path": "t5",
26
- "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
27
31
  },
28
- "kaggle_handle": "kaggle://keras/t5/keras/t5_base_multi/2",
32
+ "kaggle_handle": "kaggle://keras/t5/keras/t5_base_multi/3",
33
+ },
34
+ "t5_1.1_base": {
35
+ "metadata": {
36
+ "description": (""),
37
+ "params": 247577856,
38
+ "path": "t5",
39
+ },
40
+ "kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_base/2",
29
41
  },
30
42
  "t5_large_multi": {
31
43
  "metadata": {
@@ -34,11 +46,33 @@ backbone_presets = {
34
46
  "Corpus (C4)."
35
47
  ),
36
48
  "params": 0,
37
- "official_name": "T5",
38
49
  "path": "t5",
39
- "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
40
50
  },
41
- "kaggle_handle": "kaggle://keras/t5/keras/t5_large_multi/2",
51
+ "kaggle_handle": "kaggle://keras/t5/keras/t5_large_multi/3",
52
+ },
53
+ "t5_1.1_large": {
54
+ "metadata": {
55
+ "description": (""),
56
+ "params": 750251008,
57
+ "path": "t5",
58
+ },
59
+ "kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_large/2",
60
+ },
61
+ "t5_1.1_xl": {
62
+ "metadata": {
63
+ "description": (""),
64
+ "params": 2849757184,
65
+ "path": "t5",
66
+ },
67
+ "kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_xl/2",
68
+ },
69
+ "t5_1.1_xxl": {
70
+ "metadata": {
71
+ "description": (""),
72
+ "params": 11135332352,
73
+ "path": "t5",
74
+ },
75
+ "kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_xxl/2",
42
76
  },
43
77
  "flan_small_multi": {
44
78
  "metadata": {
@@ -47,11 +81,9 @@ backbone_presets = {
47
81
  "Corpus (C4)."
48
82
  ),
49
83
  "params": 0,
50
- "official_name": "T5",
51
84
  "path": "t5",
52
- "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
53
85
  },
54
- "kaggle_handle": "kaggle://keras/t5/keras/flan_small_multi/2",
86
+ "kaggle_handle": "kaggle://keras/t5/keras/flan_small_multi/3",
55
87
  },
56
88
  "flan_base_multi": {
57
89
  "metadata": {
@@ -60,11 +92,9 @@ backbone_presets = {
60
92
  "Corpus (C4)."
61
93
  ),
62
94
  "params": 0,
63
- "official_name": "T5",
64
95
  "path": "t5",
65
- "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
66
96
  },
67
- "kaggle_handle": "kaggle://keras/t5/keras/flan_base_multi/2",
97
+ "kaggle_handle": "kaggle://keras/t5/keras/flan_base_multi/3",
68
98
  },
69
99
  "flan_large_multi": {
70
100
  "metadata": {
@@ -73,10 +103,8 @@ backbone_presets = {
73
103
  "Corpus (C4)."
74
104
  ),
75
105
  "params": 0,
76
- "official_name": "T5",
77
106
  "path": "t5",
78
- "model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
79
107
  },
80
- "kaggle_handle": "kaggle://keras/t5/keras/flan_large_multi/2",
108
+ "kaggle_handle": "kaggle://keras/t5/keras/flan_large_multi/3",
81
109
  },
82
110
  }