keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,520 @@
1
+ import itertools
2
+ from functools import partial
3
+
4
+ import keras
5
+ from keras import ops
6
+ from keras import random
7
+
8
+ from keras_hub.src.api_export import keras_hub_export
9
+ from keras_hub.src.models.task import Task
10
+ from keras_hub.src.utils.keras_utils import standardize_data_format
11
+
12
+ try:
13
+ import tensorflow as tf
14
+ except ImportError:
15
+ tf = None
16
+
17
+
18
+ @keras_hub_export("keras_hub.models.Inpaint")
19
+ class Inpaint(Task):
20
+ """Base class for image-to-image tasks.
21
+
22
+ `Inpaint` tasks wrap a `keras_hub.models.Backbone` and
23
+ a `keras_hub.models.Preprocessor` to create a model that can be used for
24
+ generation and generative fine-tuning.
25
+
26
+ `Inpaint` tasks provide an additional, high-level `generate()` function
27
+ which can be used to generate image by token with a (image, mask, string)
28
+ in, image out signature.
29
+
30
+ All `Inpaint` tasks include a `from_preset()` constructor which can be
31
+ used to load a pre-trained config and weights.
32
+
33
+ Example:
34
+
35
+ ```python
36
+ # Load a Stable Diffusion 3 backbone with pre-trained weights.
37
+ reference_image = np.ones((1024, 1024, 3), dtype="float32")
38
+ reference_mask = np.ones((1024, 1024), dtype="float32")
39
+ inpaint = keras_hub.models.Inpaint.from_preset(
40
+ "stable_diffusion_3_medium",
41
+ )
42
+ inpaint.generate(
43
+ reference_image,
44
+ reference_mask,
45
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
46
+ )
47
+
48
+ # Load a Stable Diffusion 3 backbone at bfloat16 precision.
49
+ inpaint = keras_hub.models.Inpaint.from_preset(
50
+ "stable_diffusion_3_medium",
51
+ dtype="bfloat16",
52
+ )
53
+ inpaint.generate(
54
+ reference_image,
55
+ reference_mask,
56
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
57
+ )
58
+ ```
59
+ """
60
+
61
+ def __init__(self, *args, **kwargs):
62
+ super().__init__(*args, **kwargs)
63
+ # Default compilation.
64
+ self.compile()
65
+
66
+ @property
67
+ def support_negative_prompts(self):
68
+ """Whether the model supports `negative_prompts` key in `generate()`."""
69
+ return bool(True)
70
+
71
+ @property
72
+ def image_shape(self):
73
+ return tuple(self.backbone.image_shape)
74
+
75
+ @property
76
+ def latent_shape(self):
77
+ return tuple(self.backbone.latent_shape)
78
+
79
+ def compile(
80
+ self,
81
+ optimizer="auto",
82
+ loss="auto",
83
+ *,
84
+ metrics="auto",
85
+ **kwargs,
86
+ ):
87
+ """Configures the `Inpaint` task for training.
88
+
89
+ The `Inpaint` task extends the default compilation signature of
90
+ `keras.Model.compile` with defaults for `optimizer`, `loss`, and
91
+ `metrics`. To override these defaults, pass any value
92
+ to these arguments during compilation.
93
+
94
+ Args:
95
+ optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
96
+ instance. Defaults to `"auto"`, which uses the default optimizer
97
+ for the given model and task. See `keras.Model.compile` and
98
+ `keras.optimizers` for more info on possible `optimizer` values.
99
+ loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
100
+ Defaults to `"auto"`, where a
101
+ `keras.losses.MeanSquaredError` loss will be applied. See
102
+ `keras.Model.compile` and `keras.losses` for more info on
103
+ possible `loss` values.
104
+ metrics: `"auto"`, or a list of metrics to be evaluated by
105
+ the model during training and testing. Defaults to `"auto"`,
106
+ where a `keras.metrics.MeanSquaredError` will be applied to
107
+ track the loss of the model during training. See
108
+ `keras.Model.compile` and `keras.metrics` for more info on
109
+ possible `metrics` values.
110
+ **kwargs: See `keras.Model.compile` for a full list of arguments
111
+ supported by the compile method.
112
+ """
113
+ # Ref: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L410-L414
114
+ if optimizer == "auto":
115
+ optimizer = keras.optimizers.AdamW(
116
+ 1e-4, weight_decay=1e-2, epsilon=1e-8, clipnorm=1.0
117
+ )
118
+ if loss == "auto":
119
+ loss = keras.losses.MeanSquaredError()
120
+ if metrics == "auto":
121
+ metrics = [keras.metrics.MeanSquaredError()]
122
+ super().compile(
123
+ optimizer=optimizer,
124
+ loss=loss,
125
+ metrics=metrics,
126
+ **kwargs,
127
+ )
128
+ self.generate_function = None
129
+
130
+ def generate_step(self, *args, **kwargs):
131
+ """Run generation on batches of input."""
132
+ raise NotImplementedError
133
+
134
+ def make_generate_function(self):
135
+ """Create or return the compiled generation function."""
136
+ if self.generate_function is not None:
137
+ return self.generate_function
138
+
139
+ self.generate_function = self.generate_step
140
+ if keras.config.backend() == "torch":
141
+ import torch
142
+
143
+ def wrapped_function(*args, **kwargs):
144
+ with torch.no_grad():
145
+ return self.generate_step(*args, **kwargs)
146
+
147
+ self.generate_function = wrapped_function
148
+ elif keras.config.backend() == "tensorflow" and not self.run_eagerly:
149
+ self.generate_function = tf.function(
150
+ self.generate_step, jit_compile=self.jit_compile
151
+ )
152
+ elif keras.config.backend() == "jax" and not self.run_eagerly:
153
+ import jax
154
+
155
+ @partial(jax.jit)
156
+ def compiled_function(state, *args, **kwargs):
157
+ (
158
+ trainable_variables,
159
+ non_trainable_variables,
160
+ ) = state
161
+ mapping = itertools.chain(
162
+ zip(self.trainable_variables, trainable_variables),
163
+ zip(self.non_trainable_variables, non_trainable_variables),
164
+ )
165
+
166
+ with keras.StatelessScope(state_mapping=mapping):
167
+ outputs = self.generate_step(*args, **kwargs)
168
+ return outputs
169
+
170
+ def wrapped_function(*args, **kwargs):
171
+ # Create an explicit tuple of all variable state.
172
+ state = (
173
+ # Use the explicit variable.value to preserve the
174
+ # sharding spec of distribution.
175
+ [v.value for v in self.trainable_variables],
176
+ [v.value for v in self.non_trainable_variables],
177
+ )
178
+ outputs = compiled_function(state, *args, **kwargs)
179
+ return outputs
180
+
181
+ self.generate_function = wrapped_function
182
+ return self.generate_function
183
+
184
+ def _normalize_generate_images(self, inputs):
185
+ """Normalize user image to the generate function.
186
+
187
+ This function converts all inputs to tensors, adds a batch dimension if
188
+ necessary, and returns a iterable "dataset like" object (either an
189
+ actual `tf.data.Dataset` or a list with a single batch element).
190
+ """
191
+ if tf and isinstance(inputs, tf.data.Dataset):
192
+ return inputs.as_numpy_iterator(), False
193
+
194
+ def normalize(x):
195
+ data_format = getattr(
196
+ self.backbone, "data_format", standardize_data_format(None)
197
+ )
198
+ input_is_scalar = False
199
+ x = ops.convert_to_tensor(x)
200
+ if len(ops.shape(x)) < 4:
201
+ x = ops.expand_dims(x, axis=0)
202
+ input_is_scalar = True
203
+ x = ops.image.resize(
204
+ x,
205
+ (self.backbone.image_shape[0], self.backbone.image_shape[1]),
206
+ interpolation="nearest",
207
+ data_format=data_format,
208
+ )
209
+ return x, input_is_scalar
210
+
211
+ if isinstance(inputs, dict):
212
+ for key in inputs:
213
+ inputs[key], input_is_scalar = normalize(inputs[key])
214
+ else:
215
+ inputs, input_is_scalar = normalize(inputs)
216
+
217
+ return inputs, input_is_scalar
218
+
219
+ def _normalize_generate_masks(self, inputs):
220
+ """Normalize user masks to the generate function.
221
+
222
+ This function converts all inputs to tensors, adds a batch dimension if
223
+ necessary, and returns a iterable "dataset like" object (either an
224
+ actual `tf.data.Dataset` or a list with a single batch element).
225
+ """
226
+ if tf and isinstance(inputs, tf.data.Dataset):
227
+ return inputs.as_numpy_iterator(), False
228
+
229
+ def normalize(x):
230
+ data_format = getattr(
231
+ self.backbone, "data_format", standardize_data_format(None)
232
+ )
233
+ input_is_scalar = False
234
+ x = ops.convert_to_tensor(x)
235
+ if len(ops.shape(x)) < 3:
236
+ x = ops.expand_dims(x, axis=0)
237
+ input_is_scalar = True
238
+ x = ops.expand_dims(x, axis=-1)
239
+ if keras.backend.standardize_dtype(x.dtype) == "bool":
240
+ x = ops.cast(x, "float32")
241
+ x = ops.image.resize(
242
+ x,
243
+ (self.backbone.image_shape[0], self.backbone.image_shape[1]),
244
+ interpolation="nearest",
245
+ data_format=data_format,
246
+ )
247
+ x = ops.squeeze(x, axis=-1)
248
+ return x, input_is_scalar
249
+
250
+ if isinstance(inputs, dict):
251
+ for key in inputs:
252
+ inputs[key], input_is_scalar = normalize(inputs[key])
253
+ else:
254
+ inputs, input_is_scalar = normalize(inputs)
255
+
256
+ return inputs, input_is_scalar
257
+
258
+ def _normalize_generate_inputs(self, inputs):
259
+ """Normalize user input to the generate function.
260
+
261
+ This function converts all inputs to tensors, adds a batch dimension if
262
+ necessary, and returns a iterable "dataset like" object (either an
263
+ actual `tf.data.Dataset` or a list with a single batch element).
264
+
265
+ The input format must be one of the following:
266
+ - A dict with "images", "masks", "prompts" and/or "negative_prompts"
267
+ keys
268
+ - A tf.data.Dataset with "images", "masks", "prompts" and/or
269
+ "negative_prompts" keys
270
+
271
+ The output will be a dict with "images", "masks", "prompts" and/or
272
+ "negative_prompts" keys.
273
+ """
274
+ if tf and isinstance(inputs, tf.data.Dataset):
275
+ _inputs = {
276
+ "images": inputs.map(lambda x: x["images"]).as_numpy_iterator(),
277
+ "masks": inputs.map(lambda x: x["masks"]).as_numpy_iterator(),
278
+ "prompts": inputs.map(
279
+ lambda x: x["prompts"]
280
+ ).as_numpy_iterator(),
281
+ }
282
+ if self.support_negative_prompts:
283
+ _inputs["negative_prompts"] = inputs.map(
284
+ lambda x: x["negative_prompts"]
285
+ ).as_numpy_iterator()
286
+ return _inputs, False
287
+
288
+ def normalize(x):
289
+ if isinstance(x, str):
290
+ return [x], True
291
+ if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0:
292
+ return x[tf.newaxis], True
293
+ return x, False
294
+
295
+ def normalize_images(x):
296
+ data_format = getattr(
297
+ self.backbone, "data_format", standardize_data_format(None)
298
+ )
299
+ input_is_scalar = False
300
+ x = ops.convert_to_tensor(x)
301
+ if len(ops.shape(x)) < 4:
302
+ x = ops.expand_dims(x, axis=0)
303
+ input_is_scalar = True
304
+ x = ops.image.resize(
305
+ x,
306
+ (self.backbone.image_shape[0], self.backbone.image_shape[1]),
307
+ interpolation="nearest",
308
+ data_format=data_format,
309
+ )
310
+ return x, input_is_scalar
311
+
312
+ def normalize_masks(x):
313
+ data_format = getattr(
314
+ self.backbone, "data_format", standardize_data_format(None)
315
+ )
316
+ input_is_scalar = False
317
+ x = ops.convert_to_tensor(x)
318
+ if len(ops.shape(x)) < 3:
319
+ x = ops.expand_dims(x, axis=0)
320
+ input_is_scalar = True
321
+ x = ops.expand_dims(x, axis=-1)
322
+ if keras.backend.standardize_dtype(x.dtype) == "bool":
323
+ x = ops.cast(x, "float32")
324
+ x = ops.image.resize(
325
+ x,
326
+ (self.backbone.image_shape[0], self.backbone.image_shape[1]),
327
+ interpolation="nearest",
328
+ data_format=data_format,
329
+ )
330
+ x = ops.squeeze(x, axis=-1)
331
+ return x, input_is_scalar
332
+
333
+ def get_dummy_prompts(x):
334
+ dummy_prompts = [""] * len(x)
335
+ if tf and isinstance(x, tf.Tensor):
336
+ return tf.convert_to_tensor(dummy_prompts)
337
+ else:
338
+ return dummy_prompts
339
+
340
+ for key in inputs:
341
+ if key == "images":
342
+ inputs[key], input_is_scalar = normalize_images(inputs[key])
343
+ elif key == "masks":
344
+ inputs[key], input_is_scalar = normalize_masks(inputs[key])
345
+ else:
346
+ inputs[key], input_is_scalar = normalize(inputs[key])
347
+
348
+ if self.support_negative_prompts and "negative_prompts" not in inputs:
349
+ inputs["negative_prompts"] = get_dummy_prompts(inputs["prompts"])
350
+
351
+ return [inputs], input_is_scalar
352
+
353
+ def _normalize_generate_outputs(self, outputs, input_is_scalar):
354
+ """Normalize user output from the generate function.
355
+
356
+ This function converts all output to numpy with a value range of
357
+ `[0, 255]`. If a batch dimension was added to the input, it is removed
358
+ from the output.
359
+ """
360
+
361
+ def normalize(x):
362
+ outputs = ops.concatenate(x, axis=0)
363
+ outputs = ops.clip(ops.divide(ops.add(outputs, 1.0), 2.0), 0.0, 1.0)
364
+ outputs = ops.cast(ops.round(ops.multiply(outputs, 255.0)), "uint8")
365
+ outputs = ops.squeeze(outputs, 0) if input_is_scalar else outputs
366
+ return ops.convert_to_numpy(outputs)
367
+
368
+ if isinstance(outputs[0], dict):
369
+ normalized = {}
370
+ for key in outputs[0]:
371
+ normalized[key] = normalize([x[key] for x in outputs])
372
+ return normalized
373
+ return normalize([x for x in outputs])
374
+
375
+ def generate(
376
+ self,
377
+ inputs,
378
+ num_steps,
379
+ strength,
380
+ guidance_scale=None,
381
+ seed=None,
382
+ ):
383
+ """Generate image based on the provided `inputs`.
384
+
385
+ Typically, `inputs` is a dict with `"images"` `"masks"` and `"prompts"`
386
+ keys. `"images"` are reference images within a value range of
387
+ `[-1.0, 1.0]`, which will be resized to height and width from
388
+ `self.backbone.image_shape`, then encoded into latent space by the VAE
389
+ encoder. `"masks"` are mask images with a boolean dtype, where white
390
+ pixels are repainted while black pixels are preserved. `"prompts"` are
391
+ strings that will be tokenized and encoded by the text encoder.
392
+
393
+ Some models support a `"negative_prompts"` key, which helps steer the
394
+ model away from generating certain styles and elements. To enable this,
395
+ add `"negative_prompts"` to the input dict.
396
+
397
+ If `inputs` are a `tf.data.Dataset`, outputs will be generated
398
+ "batch-by-batch" and concatenated. Otherwise, all inputs will be
399
+ processed as batches.
400
+
401
+ Args:
402
+ inputs: python data, tensor data, or a `tf.data.Dataset`. The format
403
+ must be one of the following:
404
+ - A dict with `"images"`, `"masks"`, `"prompts"` and/or
405
+ `"negative_prompts"` keys.
406
+ - A `tf.data.Dataset` with `"images"`, `"masks"`, `"prompts"`
407
+ and/or `"negative_prompts"` keys.
408
+ num_steps: int. The number of diffusion steps to take.
409
+ strength: float. Indicates the extent to which the reference
410
+ `images` are transformed. Must be between `0.0` and `1.0`. When
411
+ `strength=1.0`, `images` is essentially ignore and added noise
412
+ is maximum and the denoising process runs for the full number of
413
+ iterations specified in `num_steps`.
414
+ guidance_scale: Optional float. The classifier free guidance scale
415
+ defined in [Classifier-Free Diffusion Guidance](
416
+ https://arxiv.org/abs/2207.12598). A higher scale encourages
417
+ generating images more closely related to the prompts, typically
418
+ at the cost of lower image quality. Note that some models don't
419
+ utilize classifier-free guidance.
420
+ seed: optional int. Used as a random seed.
421
+ """
422
+ num_steps = int(num_steps)
423
+ strength = float(strength)
424
+ guidance_scale = (
425
+ float(guidance_scale) if guidance_scale is not None else None
426
+ )
427
+ if strength < 0.0 or strength > 1.0:
428
+ raise ValueError(
429
+ "`strength` must be between `0.0` and `1.0`. "
430
+ f"Received strength={strength}."
431
+ )
432
+ if guidance_scale is not None and guidance_scale > 1.0:
433
+ guidance_scale = ops.convert_to_tensor(guidance_scale)
434
+ else:
435
+ guidance_scale = None
436
+ starting_step = int(num_steps * (1.0 - strength))
437
+ starting_step = ops.convert_to_tensor(starting_step, "int32")
438
+ num_steps = ops.convert_to_tensor(num_steps, "int32")
439
+ guidance_scale = ops.convert_to_tensor(guidance_scale)
440
+
441
+ # Check `inputs` format.
442
+ required_keys = ["images", "masks", "prompts"]
443
+ if tf and isinstance(inputs, tf.data.Dataset):
444
+ spec = inputs.element_spec
445
+ if not all(key in spec for key in required_keys):
446
+ raise ValueError(
447
+ "Expected a `tf.data.Dataset` with the following keys:"
448
+ f"{required_keys}. Received: inputs.element_spec={spec}"
449
+ )
450
+ else:
451
+ if not isinstance(inputs, dict):
452
+ raise ValueError(
453
+ "Expected a `dict` or `tf.data.Dataset`. "
454
+ f"Received: inputs={inputs} of type {type(inputs)}."
455
+ )
456
+ if not all(key in inputs for key in required_keys):
457
+ raise ValueError(
458
+ "Expected a `dict` with the following keys:"
459
+ f"{required_keys}. "
460
+ f"Received: inputs.keys={list(inputs.keys())}"
461
+ )
462
+
463
+ # Setup our three main passes.
464
+ # 1. Preprocessing strings to dense integer tensors.
465
+ # 2. Generate outputs via a compiled function on dense tensors.
466
+ # 3. Postprocess dense tensors to a value range of `[0, 255]`.
467
+ generate_function = self.make_generate_function()
468
+
469
+ def preprocess(x):
470
+ if self.preprocessor is not None:
471
+ return self.preprocessor.generate_preprocess(x)
472
+ else:
473
+ return x
474
+
475
+ def generate(images, masks, x):
476
+ token_ids = x[0] if self.support_negative_prompts else x
477
+
478
+ # Initialize noises.
479
+ if isinstance(token_ids, dict):
480
+ arbitrary_key = list(token_ids.keys())[0]
481
+ batch_size = ops.shape(token_ids[arbitrary_key])[0]
482
+ else:
483
+ batch_size = ops.shape(token_ids)[0]
484
+ noise_shape = (batch_size,) + self.latent_shape[1:]
485
+ noises = random.normal(noise_shape, dtype="float32", seed=seed)
486
+
487
+ return generate_function(
488
+ images,
489
+ masks,
490
+ noises,
491
+ x,
492
+ starting_step,
493
+ num_steps,
494
+ guidance_scale,
495
+ )
496
+
497
+ # Normalize and preprocess inputs.
498
+ inputs, input_is_scalar = self._normalize_generate_inputs(inputs)
499
+ if self.support_negative_prompts:
500
+ images = [x["images"] for x in inputs]
501
+ masks = [x["masks"] for x in inputs]
502
+ token_ids = [preprocess(x["prompts"]) for x in inputs]
503
+ negative_token_ids = [
504
+ preprocess(x["negative_prompts"]) for x in inputs
505
+ ]
506
+ # Tuple format: (images, masks, (token_ids, negative_token_ids)).
507
+ inputs = [
508
+ x
509
+ for x in zip(images, masks, zip(token_ids, negative_token_ids))
510
+ ]
511
+ else:
512
+ images = [x["images"] for x in inputs]
513
+ masks = [x["masks"] for x in inputs]
514
+ token_ids = [preprocess(x["prompts"]) for x in inputs]
515
+ # Tuple format: (images, masks, token_ids).
516
+ inputs = [x for x in zip(images, masks, token_ids)]
517
+
518
+ # Inpaint.
519
+ outputs = [generate(*x) for x in inputs]
520
+ return self._normalize_generate_outputs(outputs, input_is_scalar)
@@ -34,17 +34,18 @@ class LlamaBackbone(Backbone):
34
34
  num_layers (int): The number of transformer layers.
35
35
  num_query_heads (int): The number of query attention heads for
36
36
  each transformer.
37
- hidden_dim (int): The size of the transformer encoding and pooling layers.
38
- intermediate_dim (int): The output dimension of the first Dense layer in a
39
- three-layer feedforward network for each transformer.
40
- num_key_value_heads (int): The number of key and value attention heads for
41
- each transformer.
42
- rope_max_wavelength (int, optional): The maximum angular wavelength of the
43
- sine/cosine curves, for rotary embeddings. Defaults to `10000`.
44
- rope_scaling_factor (float, optional): The scaling factor for calculation
45
- of roatary embedding. Defaults to `1.0`.
46
- layer_norm_epsilon (float, optional): Epsilon for the layer normalization
47
- layers in the transformer decoder. Defaults to `1e-6`.
37
+ hidden_dim (int): The size of the transformer encoding and pooling
38
+ layers.
39
+ intermediate_dim (int): The output dimension of the first Dense layer in
40
+ a three-layer feedforward network for each transformer.
41
+ num_key_value_heads (int): The number of key and value attention heads
42
+ for each transformer.
43
+ rope_max_wavelength (int, optional): The maximum angular wavelength of
44
+ the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
45
+ rope_scaling_factor (float, optional): The scaling factor for
46
+ calculation of roatary embedding. Defaults to `1.0`.
47
+ layer_norm_epsilon (float, optional): Epsilon for the layer
48
+ normalization layers in the transformer decoder. Defaults to `1e-6`.
48
49
  dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
49
50
  for model computations and weights. Note that some computations,
50
51
  such as softmax and layer normalization, will always be done at
@@ -59,7 +60,7 @@ class LlamaBackbone(Backbone):
59
60
  }
60
61
 
61
62
  # Pretrained Llama decoder.
62
- model = keras_hub.models.LlamaBackbone.from_preset("llama7b_base_en")
63
+ model = keras_hub.models.LlamaBackbone.from_preset("llama2_7b_en")
63
64
  model(input_data)
64
65
 
65
66
  # Randomly initialized Llama decoder with custom config.
@@ -175,3 +176,128 @@ class LlamaBackbone(Backbone):
175
176
  }
176
177
  )
177
178
  return config
179
+
180
+ @staticmethod
181
+ def get_layout_map(
182
+ device_mesh,
183
+ model_parallel_dim_name="model",
184
+ data_parallel_dim_name="batch",
185
+ ):
186
+ """Get a `keras.distribution.LayoutMap` for model parallel distribution.
187
+
188
+ The returned `LayoutMap` contains the sharding spec for the Llama
189
+ backbone weights, so that you can use it to distribute weights across
190
+ the accelerators.
191
+
192
+ Example:
193
+ ```
194
+ # Feel free to change the mesh shape to balance data and model
195
+ # parallelism
196
+ mesh = keras.distribution.DeviceMesh(
197
+ shape=(1, 8),
198
+ axis_names=('batch', 'model'),
199
+ devices=keras.distribution.list_devices(),
200
+ )
201
+ layout_map = LlamaBackbone.get_layout_map(
202
+ mesh,
203
+ model_parallel_dim_name="model",
204
+ )
205
+
206
+ distribution = keras.distribution.ModelParallel(
207
+ layout_map=layout_map,
208
+ batch_dim_name='batch',
209
+ )
210
+
211
+ with distribution.scope():
212
+ llama_model = keras_hub.models.LlamaCausalLM.from_preset()
213
+ ```
214
+
215
+ To see how the layout map was applied, load the model then run
216
+ (for one decoder block):
217
+ ```
218
+ embedding_layer = llama_model.backbone.get_layer("token_embedding")
219
+ decoder_block_1 = llama_model.backbone.get_layer('transformer_layer_0')
220
+ for variable in embedding_layer.weights + decoder_block_1.weights:
221
+ print(
222
+ f'{variable.path:<58} {str(variable.shape):<16} '
223
+ f'{str(variable.value.sharding.spec)}'
224
+ )
225
+ ```
226
+
227
+ Args:
228
+ device_mesh: The `keras.distribution.DeviceMesh` instance for
229
+ distribution.
230
+ model_parallel_dim_name: The axis name of the device mesh, where
231
+ the weights should be partition on.
232
+ data_parallel_dim_name: The axis name of the device mesh, where
233
+ the data should be partition on.
234
+ Return:
235
+ `keras.distribution.LayoutMap` that contains the sharding spec
236
+ for all the model weights.
237
+ """
238
+ # The weight path and shape of the Llama backbone is like below
239
+ # token_embedding/embeddings (128256, 2048)
240
+ # repeat block for decoder
241
+ # transformer_layer_0/self_attention/query/kernel (2048, 32, 64)
242
+ # transformer_layer_0/self_attention/key/kernel (2048, 8, 64)
243
+ # transformer_layer_0/self_attention/value/kernel (2048, 8, 64)
244
+ # transformer_layer_0/self_attention/attention_output/kernel
245
+ # (32, 64, 2048)
246
+ # transformer_layer_0/self_attention_layernorm/scale (2048,)
247
+ # transformer_layer_0/feedforward_intermediate_dense/kernel
248
+ # (2048, 8192)
249
+ # transformer_layer_0/feedforward_gate_dense/kernel (2048, 8192)
250
+ # transformer_layer_0/feedforward_output_dense/kerne (8192, 2048)
251
+ # transformer_layer_0/feedforward_layernorm/scale (2048,)
252
+
253
+ if not isinstance(device_mesh, keras.distribution.DeviceMesh):
254
+ raise ValueError(
255
+ "Invalid device_mesh type. Expected "
256
+ f"`keras.distribution.Device`, got {type(device_mesh)}"
257
+ )
258
+ if model_parallel_dim_name not in device_mesh.axis_names:
259
+ raise ValueError(
260
+ f"{model_parallel_dim_name} is not found in the "
261
+ f"device_mesh.axis_names. {device_mesh.axis_name=}"
262
+ )
263
+ if data_parallel_dim_name not in device_mesh.axis_names:
264
+ raise ValueError(
265
+ f"{data_parallel_dim_name} is not found in the "
266
+ f"device_mesh.axis_names. {device_mesh.axis_name=}"
267
+ )
268
+ # Note that it is possible to further config the mesh to be 3D, eg
269
+ # (data, seq, model). We leave it as 2D for now for simplicity.
270
+ data_dim = data_parallel_dim_name
271
+ model_dim = model_parallel_dim_name
272
+ # The sharding config is based on the Gemma team training config.
273
+ # See https://arxiv.org/abs/2403.08295
274
+ layout_map = keras.distribution.LayoutMap(device_mesh)
275
+ layout_map["token_embedding/embeddings"] = (model_dim, data_dim)
276
+ layout_map[
277
+ "transformer_layer.*self_attention.*(query|key|value).kernel"
278
+ ] = (
279
+ model_dim,
280
+ data_dim,
281
+ None,
282
+ )
283
+ layout_map["transformer_layer.*attention_output.kernel"] = (
284
+ model_dim,
285
+ None,
286
+ data_dim,
287
+ )
288
+ layout_map[
289
+ "transformer_layer.*feedforward_intermediate_dense.kernel"
290
+ ] = (
291
+ data_dim,
292
+ model_dim,
293
+ )
294
+ layout_map["transformer_layer.*feedforward_gate_dense.kernel"] = (
295
+ data_dim,
296
+ model_dim,
297
+ )
298
+ layout_map["transformer_layer.*feedforward_output_dense.kernel"] = (
299
+ model_dim,
300
+ data_dim,
301
+ )
302
+
303
+ return layout_map