keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,225 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+
5
+ class TimestepEmbedding(keras.layers.Layer):
6
+ """Creates sinusoidal timestep embeddings.
7
+
8
+ Call arguments:
9
+ t: Tensor of shape (N,), representing N indices, one per batch element.
10
+ These values may be fractional.
11
+ dim: int. The dimension of the output.
12
+ max_period: int, optional. Controls the minimum frequency of the
13
+ embeddings. Defaults to 10000.
14
+ time_factor: float, optional. A scaling factor applied to `t`. Defaults
15
+ to 1000.0.
16
+
17
+ Returns:
18
+ A tensor of shape (N, D) representing the positional embeddings,
19
+ where N is the number of batch elements and D is the specified
20
+ dimension `dim`.
21
+ """
22
+
23
+ def call(self, t, dim, max_period=10000, time_factor=1000.0):
24
+ t = time_factor * t
25
+ half_dim = dim // 2
26
+ freqs = ops.exp(
27
+ ops.cast(-ops.log(max_period), dtype=t.dtype)
28
+ * ops.arange(half_dim, dtype=t.dtype)
29
+ / half_dim
30
+ )
31
+ args = t[:, None] * freqs[None]
32
+ embedding = ops.concatenate([ops.cos(args), ops.sin(args)], axis=-1)
33
+
34
+ if dim % 2 != 0:
35
+ embedding = ops.concatenate(
36
+ [embedding, ops.zeros_like(embedding[:, :1])], axis=-1
37
+ )
38
+
39
+ return embedding
40
+
41
+
42
+ class RotaryPositionalEmbedding(keras.layers.Layer):
43
+ """
44
+ Applies Rotary Positional Embedding (RoPE) to the input tensor.
45
+
46
+ Call arguments:
47
+ pos: KerasTensor. The positional tensor with shape (..., n, d).
48
+ dim: int. The embedding dimension, should be even.
49
+ theta: int. The base frequency.
50
+
51
+ Returns:
52
+ KerasTensor: The tensor with applied RoPE transformation.
53
+ """
54
+
55
+ def call(self, pos, dim, theta):
56
+ scale = ops.arange(0, dim, 2, dtype="float32") / dim
57
+ omega = 1.0 / (theta**scale)
58
+ out = ops.einsum("...n,d->...nd", pos, omega)
59
+ out = ops.stack(
60
+ [ops.cos(out), -ops.sin(out), ops.sin(out), ops.cos(out)], axis=-1
61
+ )
62
+ out = ops.reshape(out, ops.shape(out)[:-1] + (2, 2))
63
+ return ops.cast(out, dtype="float32")
64
+
65
+
66
+ class ApplyRoPE(keras.layers.Layer):
67
+ """
68
+ Applies the RoPE transformation to the query and key tensors.
69
+
70
+ Call arguments:
71
+ xq: KerasTensor. The query tensor of shape (..., L, D).
72
+ xk: KerasTensor. The key tensor of shape (..., L, D).
73
+ freqs_cis: KerasTensor. The frequency complex numbers tensor with shape
74
+ `(..., 2)`.
75
+
76
+ Returns:
77
+ tuple[KerasTensor, KerasTensor]: The transformed query and key tensors.
78
+ """
79
+
80
+ def call(self, xq, xk, freqs_cis):
81
+ xq_ = ops.reshape(xq, (*ops.shape(xq)[:-1], -1, 1, 2))
82
+ xk_ = ops.reshape(xk, (*ops.shape(xk)[:-1], -1, 1, 2))
83
+
84
+ xq_out = (
85
+ freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
86
+ )
87
+ xk_out = (
88
+ freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
89
+ )
90
+
91
+ return ops.reshape(xq_out, ops.shape(xq)), ops.reshape(
92
+ xk_out, ops.shape(xk)
93
+ )
94
+
95
+
96
+ class FluxRoPEAttention(keras.layers.Layer):
97
+ """Computes the attention mechanism with RoPE.
98
+
99
+ Args:
100
+ dropout_p: float, optional. Dropout probability. Defaults to 0.0.
101
+ is_causal: bool, optional. If True, applies causal masking. Defaults to
102
+ False.
103
+
104
+ Call arguments:
105
+ q: KerasTensor. Query tensor of shape (..., L, D).
106
+ k: KerasTensor. Key tensor of shape (..., S, D).
107
+ v: KerasTensor. Value tensor of shape (..., S, D).
108
+ positional_encoding: KerasTensor. Positional encoding tensor.
109
+
110
+ Returns:
111
+ KerasTensor: The resulting tensor from the attention mechanism.
112
+ """
113
+
114
+ def __init__(self, dropout_p=0.0, is_causal=False):
115
+ super(FluxRoPEAttention, self).__init__()
116
+ self.dropout_p = dropout_p
117
+ self.is_causal = is_causal
118
+
119
+ def call(self, q, k, v, positional_encoding):
120
+ # Apply the RoPE transformation
121
+ q, k = ApplyRoPE()(q, k, positional_encoding)
122
+
123
+ # Scaled dot-product attention
124
+ x = scaled_dot_product_attention(
125
+ q, k, v, dropout_p=self.dropout_p, is_causal=self.is_causal
126
+ )
127
+ x = ops.transpose(x, (0, 2, 1, 3))
128
+ b, s, h, d = ops.shape(x)
129
+ return ops.reshape(x, (b, s, h * d))
130
+
131
+
132
+ # TODO: This is probably already implemented in several places, but is needed to
133
+ # ensure numeric equivalence to the original implementation. It uses
134
+ # torch.functional.scaled_dot_product_attention() - do we have an equivalent
135
+ # already in Keras?
136
+ def scaled_dot_product_attention(
137
+ query,
138
+ key,
139
+ value,
140
+ attn_mask=None,
141
+ dropout_p=0.0,
142
+ is_causal=False,
143
+ scale=None,
144
+ ):
145
+ """
146
+ Computes the scaled dot-product attention.
147
+
148
+ Args:
149
+ query: KerasTensor. Query tensor of shape (..., L, D).
150
+ key: KerasTensor. Key tensor of shape (..., S, D).
151
+ value: KerasTensor. Value tensor of shape (..., S, D).
152
+ attn_mask: KerasTensor, optional. Attention mask tensor. Defaults to
153
+ None.
154
+ dropout_p: float, optional. Dropout probability. Defaults to 0.0.
155
+ is_causal: bool, optional. If True, applies causal masking. Defaults to
156
+ False.
157
+ scale: float, optional. Scale factor for attention. Defaults to None.
158
+
159
+ Returns:
160
+ KerasTensor: The output tensor from the attention mechanism.
161
+ """
162
+ L, S = ops.shape(query)[-2], ops.shape(key)[-2]
163
+ scale_factor = (
164
+ 1 / ops.sqrt(ops.cast(ops.shape(query)[-1], dtype=query.dtype))
165
+ if scale is None
166
+ else scale
167
+ )
168
+ attn_bias = ops.zeros((L, S), dtype=query.dtype)
169
+
170
+ if is_causal:
171
+ assert attn_mask is None
172
+ temp_mask = ops.ones((L, S), dtype=ops.bool)
173
+ temp_mask = ops.tril(temp_mask, diagonal=0)
174
+ attn_bias = ops.where(temp_mask, attn_bias, float("-inf"))
175
+
176
+ if attn_mask is not None:
177
+ if ops.shape(attn_mask)[-1] == 1: # If the mask is 3D
178
+ attn_bias += attn_mask
179
+ else:
180
+ attn_bias = ops.where(attn_mask, attn_bias, float("-inf"))
181
+
182
+ # Compute attention weights
183
+ attn_weight = (
184
+ ops.matmul(query, ops.transpose(key, axes=[0, 1, 3, 2])) * scale_factor
185
+ )
186
+ attn_weight += attn_bias
187
+ attn_weight = keras.activations.softmax(attn_weight, axis=-1)
188
+
189
+ if dropout_p > 0.0:
190
+ attn_weight = keras.layers.Dropout(dropout_p)(
191
+ attn_weight, training=True
192
+ )
193
+
194
+ return ops.matmul(attn_weight, value)
195
+
196
+
197
+ def rearrange_symbolic_tensors(qkv, K, H):
198
+ """
199
+ Splits the qkv tensor into query (q), key (k), and value (v) components.
200
+
201
+ Mimics rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=num_heads),
202
+ for graph-mode TensorFlow support when doing functional subclassing
203
+ models.
204
+
205
+ Arguments:
206
+ qkv: np.ndarray. Input tensor of shape (B, L, K*H*D).
207
+ K: int. Number of components (q, k, v).
208
+ H: int. Number of attention heads.
209
+
210
+ Returns:
211
+ tuple: q, k, v tensors of shape (B, H, L, D).
212
+ """
213
+ # Get the shape of qkv and calculate L and D
214
+ B, L, dim = ops.shape(qkv)
215
+ D = dim // (K * H)
216
+
217
+ # Reshape and transpose the qkv tensor
218
+ qkv_reshaped = ops.reshape(qkv, (B, L, K, H, D))
219
+ qkv_transposed = ops.transpose(qkv_reshaped, (2, 0, 3, 1, 4))
220
+
221
+ # Split q, k, v along the first dimension (K)
222
+ qkv_splits = ops.split(qkv_transposed, K, axis=0)
223
+ q, k, v = [ops.squeeze(split, 0) for split in qkv_splits]
224
+
225
+ return q, k, v
@@ -0,0 +1,236 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.backbone import Backbone
5
+ from keras_hub.src.models.flux.flux_layers import DoubleStreamBlock
6
+ from keras_hub.src.models.flux.flux_layers import EmbedND
7
+ from keras_hub.src.models.flux.flux_layers import LastLayer
8
+ from keras_hub.src.models.flux.flux_layers import MLPEmbedder
9
+ from keras_hub.src.models.flux.flux_layers import SingleStreamBlock
10
+ from keras_hub.src.models.flux.flux_maths import TimestepEmbedding
11
+
12
+
13
+ @keras_hub_export("keras_hub.models.FluxBackbone")
14
+ class FluxBackbone(Backbone):
15
+ """Transformer model for flow matching on sequences.
16
+
17
+ The model processes image and text data with associated positional and
18
+ timestep embeddings, and optionally applies guidance embedding.
19
+ Double-stream blocks handle separate image and text streams, while
20
+ single-stream blocks combine these streams. Ported from:
21
+ https://github.com/black-forest-labs/flux
22
+
23
+
24
+ Args:
25
+ input_channels: int. The number of input channels.
26
+ hidden_size: int. The hidden size of the transformer, must be divisible
27
+ by `num_heads`.
28
+ mlp_ratio: float. The ratio of the MLP dimension to the hidden size.
29
+ num_heads: int. The number of attention heads.
30
+ depth: int. The number of double-stream blocks.
31
+ depth_single_blocks: int. The number of single-stream blocks.
32
+ axes_dim: list[int]. A list of dimensions for the positional embedding
33
+ axes.
34
+ theta: int. The base frequency for positional embeddings.
35
+ use_bias: bool. Whether to apply bias to the query, key, and value
36
+ projections.
37
+ guidance_embed: bool. If True, applies guidance embedding in the model.
38
+
39
+ Call arguments:
40
+ image: KerasTensor. Image input tensor of shape (N, L, D) where N is the
41
+ batch size, L is the sequence length, and D is the feature
42
+ dimension.
43
+ image_ids: KerasTensor. Image ID input tensor of shape (N, L, D)
44
+ corresponding to the image sequences.
45
+ text: KerasTensor. Text input tensor of shape (N, L, D).
46
+ text_ids: KerasTensor. Text ID input tensor of shape (N, L, D)
47
+ corresponding to the text sequences.
48
+ timesteps: KerasTensor. Timestep tensor used to compute positional
49
+ embeddings.
50
+ y: KerasTensor. Additional vector input, such as target values.
51
+ guidance: KerasTensor, optional. Guidance input tensor used
52
+ in guidance-embedded models.
53
+ Raises:
54
+ ValueError: If `hidden_size` is not divisible by `num_heads`, or if
55
+ `sum(axes_dim)` is not equal to the positional embedding dimension.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ input_channels,
61
+ hidden_size,
62
+ mlp_ratio,
63
+ num_heads,
64
+ depth,
65
+ depth_single_blocks,
66
+ axes_dim,
67
+ theta,
68
+ use_bias,
69
+ guidance_embed=False,
70
+ # These will be inferred from the CLIP/T5 encoders later
71
+ image_shape=(None, 768, 3072),
72
+ text_shape=(None, 768, 3072),
73
+ image_ids_shape=(None, 768, 3072),
74
+ text_ids_shape=(None, 768, 3072),
75
+ y_shape=(None, 128),
76
+ **kwargs,
77
+ ):
78
+ # === Layers ===
79
+ self.positional_embedder = EmbedND(theta=theta, axes_dim=axes_dim)
80
+ self.image_input_embedder = keras.layers.Dense(
81
+ hidden_size, use_bias=True
82
+ )
83
+ self.time_input_embedder = MLPEmbedder(hidden_dim=hidden_size)
84
+ self.vector_embedder = MLPEmbedder(hidden_dim=hidden_size)
85
+ self.guidance_input_embedder = (
86
+ MLPEmbedder(hidden_dim=hidden_size)
87
+ if guidance_embed
88
+ else keras.layers.Identity()
89
+ )
90
+ self.text_input_embedder = keras.layers.Dense(hidden_size)
91
+
92
+ self.double_blocks = [
93
+ DoubleStreamBlock(
94
+ hidden_size,
95
+ num_heads,
96
+ mlp_ratio=mlp_ratio,
97
+ use_bias=use_bias,
98
+ )
99
+ for _ in range(depth)
100
+ ]
101
+
102
+ self.single_blocks = [
103
+ SingleStreamBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
104
+ for _ in range(depth_single_blocks)
105
+ ]
106
+
107
+ self.final_layer = LastLayer(hidden_size, 1, input_channels)
108
+ self.timestep_embedding = TimestepEmbedding()
109
+ self.guidance_embed = guidance_embed
110
+
111
+ # === Functional Model ===
112
+ image_input = keras.Input(shape=image_shape, name="image")
113
+ image_ids = keras.Input(shape=image_ids_shape, name="image_ids")
114
+ text_input = keras.Input(shape=text_shape, name="text")
115
+ text_ids = keras.Input(shape=text_ids_shape, name="text_ids")
116
+ y = keras.Input(shape=y_shape, name="y")
117
+ timesteps_input = keras.Input(shape=(), name="timesteps")
118
+ guidance_input = keras.Input(shape=(), name="guidance")
119
+
120
+ # running on sequences image
121
+ image = self.image_input_embedder(image_input)
122
+ modulation_encoding = self.time_input_embedder(
123
+ self.timestep_embedding(timesteps_input, dim=256)
124
+ )
125
+ if self.guidance_embed:
126
+ if guidance_input is None:
127
+ raise ValueError(
128
+ "Didn't get guidance strength for guidance distilled model."
129
+ )
130
+ modulation_encoding = (
131
+ modulation_encoding
132
+ + self.guidance_input_embedder(
133
+ self.timestep_embedding(guidance_input, dim=256)
134
+ )
135
+ )
136
+
137
+ modulation_encoding = modulation_encoding + self.vector_embedder(y)
138
+ text = self.text_input_embedder(text_input)
139
+
140
+ ids = keras.ops.concatenate((text_ids, image_ids), axis=1)
141
+ positional_encoding = self.positional_embedder(ids)
142
+
143
+ for block in self.double_blocks:
144
+ image, text = block(
145
+ image=image,
146
+ text=text,
147
+ modulation_encoding=modulation_encoding,
148
+ positional_encoding=positional_encoding,
149
+ )
150
+
151
+ image = keras.ops.concatenate((text, image), axis=1)
152
+ for block in self.single_blocks:
153
+ image = block(
154
+ image,
155
+ modulation_encoding=modulation_encoding,
156
+ positional_encoding=positional_encoding,
157
+ )
158
+ image = image[:, text.shape[1] :, ...]
159
+
160
+ image = self.final_layer(
161
+ image, modulation_encoding
162
+ ) # (N, T, patch_size ** 2 * output_channels)
163
+
164
+ super().__init__(
165
+ inputs={
166
+ "image": image_input,
167
+ "image_ids": image_ids,
168
+ "text": text_input,
169
+ "text_ids": text_ids,
170
+ "y": y,
171
+ "timesteps": timesteps_input,
172
+ "guidance": guidance_input,
173
+ },
174
+ outputs=image,
175
+ **kwargs,
176
+ )
177
+
178
+ # === Config ===
179
+ self.input_channels = input_channels
180
+ self.output_channels = self.input_channels
181
+ self.hidden_size = hidden_size
182
+ self.num_heads = num_heads
183
+ self.image_shape = image_shape
184
+ self.text_shape = text_shape
185
+ self.image_ids_shape = image_ids_shape
186
+ self.text_ids_shape = text_ids_shape
187
+ self.y_shape = y_shape
188
+ self.mlp_ratio = mlp_ratio
189
+ self.depth = depth
190
+ self.depth_single_blocks = depth_single_blocks
191
+ self.axes_dim = axes_dim
192
+ self.theta = theta
193
+ self.use_bias = use_bias
194
+
195
+ def get_config(self):
196
+ config = super().get_config()
197
+ config.update(
198
+ {
199
+ "input_channels": self.input_channels,
200
+ "hidden_size": self.hidden_size,
201
+ "mlp_ratio": self.mlp_ratio,
202
+ "num_heads": self.num_heads,
203
+ "depth": self.depth,
204
+ "depth_single_blocks": self.depth_single_blocks,
205
+ "axes_dim": self.axes_dim,
206
+ "theta": self.theta,
207
+ "use_bias": self.use_bias,
208
+ "guidance_embed": self.guidance_embed,
209
+ "image_shape": self.image_shape,
210
+ "text_shape": self.text_shape,
211
+ "image_ids_shape": self.image_ids_shape,
212
+ "text_ids_shape": self.text_ids_shape,
213
+ "y_shape": self.y_shape,
214
+ }
215
+ )
216
+ return config
217
+
218
+ def encode_text_step(self, token_ids, negative_token_ids):
219
+ raise NotImplementedError("Not implemented yet")
220
+
221
+ def encode(token_ids):
222
+ raise NotImplementedError("Not implemented yet")
223
+
224
+ def encode_image_step(self, images):
225
+ raise NotImplementedError("Not implemented yet")
226
+
227
+ def add_noise_step(self, latents, noises, step, num_steps):
228
+ raise NotImplementedError("Not implemented yet")
229
+
230
+ def denoise_step(
231
+ self,
232
+ ):
233
+ raise NotImplementedError("Not implemented yet")
234
+
235
+ def decode_step(self, latents):
236
+ raise NotImplementedError("Not implemented yet")
@@ -0,0 +1,3 @@
1
+ """FLUX model preset configurations."""
2
+
3
+ presets = {}
@@ -0,0 +1,146 @@
1
+ from keras import ops
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.flux.flux_model import FluxBackbone
5
+ from keras_hub.src.models.flux.flux_text_to_image_preprocessor import (
6
+ FluxTextToImagePreprocessor,
7
+ )
8
+ from keras_hub.src.models.text_to_image import TextToImage
9
+
10
+
11
+ @keras_hub_export("keras_hub.models.FluxTextToImage")
12
+ class FluxTextToImage(TextToImage):
13
+ """An end-to-end Flux model for text-to-image generation.
14
+
15
+ This model has a `generate()` method, which generates image based on a
16
+ prompt.
17
+
18
+ Args:
19
+ backbone: A `keras_hub.models.FluxBackbone` instance.
20
+ preprocessor: A
21
+ `keras_hub.models.FluxTextToImagePreprocessor` instance.
22
+
23
+ Examples:
24
+
25
+ Use `generate()` to do image generation.
26
+ ```python
27
+ prompt = (
28
+ "Astronaut in a jungle, cold color palette, muted colors, "
29
+ "detailed, 8k"
30
+ )
31
+ text_to_image = keras_hub.models.FluxTextToImage.from_preset(
32
+ "TBA", height=512, width=512
33
+ )
34
+ text_to_image.generate(
35
+ prompt
36
+ )
37
+
38
+ # Generate with batched prompts.
39
+ text_to_image.generate(
40
+ ["cute wallpaper art of a cat", "cute wallpaper art of a dog"]
41
+ )
42
+
43
+ # Generate with different `num_steps` and `guidance_scale`.
44
+ text_to_image.generate(
45
+ prompt,
46
+ num_steps=50,
47
+ guidance_scale=5.0,
48
+ )
49
+
50
+ # Generate with `negative_prompts`.
51
+ text_to_image.generate(
52
+ {
53
+ "prompts": prompt,
54
+ "negative_prompts": "green color",
55
+ }
56
+ )
57
+ ```
58
+ """
59
+
60
+ backbone_cls = FluxBackbone
61
+ preprocessor_cls = FluxTextToImagePreprocessor
62
+
63
+ def __init__(
64
+ self,
65
+ backbone,
66
+ preprocessor,
67
+ **kwargs,
68
+ ):
69
+ # === Layers ===
70
+ self.backbone = backbone
71
+ self.preprocessor = preprocessor
72
+
73
+ # === Functional Model ===
74
+ inputs = backbone.input
75
+ outputs = backbone.output
76
+ super().__init__(
77
+ inputs=inputs,
78
+ outputs=outputs,
79
+ **kwargs,
80
+ )
81
+
82
+ def fit(self, *args, **kwargs):
83
+ raise NotImplementedError(
84
+ "Currently, `fit` is not supported for `FluxTextToImage`."
85
+ )
86
+
87
+ def generate_step(
88
+ self,
89
+ latents,
90
+ token_ids,
91
+ num_steps,
92
+ guidance_scale,
93
+ ):
94
+ """A compilable generation function for batched of inputs.
95
+
96
+ This function represents the inner, XLA-compilable, generation function
97
+ for batched inputs.
98
+
99
+ Args:
100
+ latents: A (batch_size, height, width, channels) tensor
101
+ containing the latents to start generation from. Typically, this
102
+ tensor is sampled from the Gaussian distribution.
103
+ token_ids: A pair of (batch_size, num_tokens) tensor containing the
104
+ tokens based on the input prompts and negative prompts.
105
+ num_steps: int. The number of diffusion steps to take.
106
+ guidance_scale: float. The classifier free guidance scale defined in
107
+ [Classifier-Free Diffusion Guidance](
108
+ https://arxiv.org/abs/2207.12598). Higher scale encourages to
109
+ generate images that are closely linked to prompts, usually at
110
+ the expense of lower image quality.
111
+ """
112
+ token_ids, negative_token_ids = token_ids
113
+
114
+ # Encode prompts.
115
+ embeddings = self.backbone.encode_text_step(
116
+ token_ids, negative_token_ids
117
+ )
118
+
119
+ # Denoise.
120
+ def body_fun(step, latents):
121
+ return self.backbone.denoise_step(
122
+ latents,
123
+ embeddings,
124
+ step,
125
+ num_steps,
126
+ guidance_scale,
127
+ )
128
+
129
+ latents = ops.fori_loop(0, num_steps, body_fun, latents)
130
+
131
+ # Decode.
132
+ return self.backbone.decode_step(latents)
133
+
134
+ def generate(
135
+ self,
136
+ inputs,
137
+ num_steps=28,
138
+ guidance_scale=7.0,
139
+ seed=None,
140
+ ):
141
+ return super().generate(
142
+ inputs,
143
+ num_steps=num_steps,
144
+ guidance_scale=guidance_scale,
145
+ seed=seed,
146
+ )
@@ -0,0 +1,73 @@
1
+ import keras
2
+ from keras import layers
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.models.flux.flux_model import FluxBackbone
6
+ from keras_hub.src.models.preprocessor import Preprocessor
7
+
8
+
9
+ @keras_hub_export("keras_hub.models.FluxTextToImagePreprocessor")
10
+ class FluxTextToImagePreprocessor(Preprocessor):
11
+ """Flux text-to-image model preprocessor.
12
+
13
+ This preprocessing layer is meant for use with
14
+ `keras_hub.models.FluxTextToImagePreprocessor`.
15
+
16
+ For use with generation, the layer exposes one methods
17
+ `generate_preprocess()`.
18
+
19
+ Args:
20
+ clip_l_preprocessor: A `keras_hub.models.CLIPPreprocessor` instance.
21
+ t5_preprocessor: A optional `keras_hub.models.T5Preprocessor` instance.
22
+ """
23
+
24
+ backbone_cls = FluxBackbone
25
+
26
+ def __init__(
27
+ self,
28
+ clip_l_preprocessor,
29
+ t5_preprocessor=None,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(**kwargs)
33
+ self.clip_l_preprocessor = clip_l_preprocessor
34
+ self.t5_preprocessor = t5_preprocessor
35
+
36
+ @property
37
+ def sequence_length(self):
38
+ """The padded length of model input sequences."""
39
+ return self.clip_l_preprocessor.sequence_length
40
+
41
+ def build(self, input_shape):
42
+ self.built = True
43
+
44
+ def generate_preprocess(self, x):
45
+ token_ids = {}
46
+ token_ids["clip_l"] = self.clip_l_preprocessor(x)["token_ids"]
47
+ if self.t5_preprocessor is not None:
48
+ token_ids["t5"] = self.t5_preprocessor(x)["token_ids"]
49
+ return token_ids
50
+
51
+ def get_config(self):
52
+ config = super().get_config()
53
+ config.update(
54
+ {
55
+ "clip_l_preprocessor": layers.serialize(
56
+ self.clip_l_preprocessor
57
+ ),
58
+ "t5_preprocessor": layers.serialize(self.t5_preprocessor),
59
+ }
60
+ )
61
+ return config
62
+
63
+ @classmethod
64
+ def from_config(cls, config):
65
+ for layer_name in (
66
+ "clip_l_preprocessor",
67
+ "t5_preprocessor",
68
+ ):
69
+ if layer_name in config and isinstance(config[layer_name], dict):
70
+ config[layer_name] = keras.layers.deserialize(
71
+ config[layer_name]
72
+ )
73
+ return cls(**config)