keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,186 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+
16
+ from keras_hub.src.utils.preset_utils import HF_CONFIG_FILE
17
+ from keras_hub.src.utils.preset_utils import get_file
18
+ from keras_hub.src.utils.preset_utils import jax_memory_cleanup
19
+ from keras_hub.src.utils.preset_utils import load_config
20
+ from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
21
+
22
+
23
+ def convert_backbone_config(transformers_config):
24
+ return {
25
+ "vocabulary_size": transformers_config["vocab_size"],
26
+ "num_layers": transformers_config["n_layer"],
27
+ "num_heads": transformers_config["n_head"],
28
+ "hidden_dim": transformers_config["n_embd"],
29
+ "intermediate_dim": transformers_config["n_embd"] * 4,
30
+ "dropout": transformers_config["resid_pdrop"],
31
+ "max_sequence_length": transformers_config["n_positions"],
32
+ }
33
+
34
+
35
+ def convert_weights(backbone, loader, transformers_config):
36
+ # Embeddings
37
+ loader.port_weight(
38
+ keras_variable=backbone.token_embedding.embeddings,
39
+ hf_weight_key="wte.weight",
40
+ )
41
+ loader.port_weight(
42
+ keras_variable=backbone.position_embedding.position_embeddings,
43
+ hf_weight_key="wpe.weight",
44
+ )
45
+
46
+ # Attention blocks
47
+ for index in range(backbone.num_layers):
48
+ decoder_layer = backbone.transformer_layers[index]
49
+
50
+ # Norm layers
51
+ loader.port_weight(
52
+ keras_variable=decoder_layer._self_attention_layer_norm.gamma,
53
+ hf_weight_key=f"h.{index}.ln_1.weight",
54
+ )
55
+ loader.port_weight(
56
+ keras_variable=decoder_layer._self_attention_layer_norm.beta,
57
+ hf_weight_key=f"h.{index}.ln_1.bias",
58
+ )
59
+ loader.port_weight(
60
+ keras_variable=decoder_layer._feedforward_layer_norm.gamma,
61
+ hf_weight_key=f"h.{index}.ln_2.weight",
62
+ )
63
+ loader.port_weight(
64
+ keras_variable=decoder_layer._feedforward_layer_norm.beta,
65
+ hf_weight_key=f"h.{index}.ln_2.bias",
66
+ )
67
+
68
+ # Attention layers
69
+ n_embd = transformers_config["n_embd"]
70
+
71
+ # Query
72
+ loader.port_weight(
73
+ keras_variable=decoder_layer._self_attention_layer.query_dense.kernel,
74
+ hf_weight_key=f"h.{index}.attn.c_attn.weight",
75
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
76
+ hf_tensor[:, :n_embd], keras_shape
77
+ ),
78
+ )
79
+ loader.port_weight(
80
+ keras_variable=decoder_layer._self_attention_layer.query_dense.bias,
81
+ hf_weight_key=f"h.{index}.attn.c_attn.bias",
82
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
83
+ hf_tensor[:n_embd], keras_shape
84
+ ),
85
+ )
86
+
87
+ # Key
88
+ loader.port_weight(
89
+ keras_variable=decoder_layer._self_attention_layer.key_dense.kernel,
90
+ hf_weight_key=f"h.{index}.attn.c_attn.weight",
91
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
92
+ hf_tensor[:, n_embd : 2 * n_embd], keras_shape
93
+ ),
94
+ )
95
+ loader.port_weight(
96
+ keras_variable=decoder_layer._self_attention_layer.key_dense.bias,
97
+ hf_weight_key=f"h.{index}.attn.c_attn.bias",
98
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
99
+ hf_tensor[n_embd : 2 * n_embd], keras_shape
100
+ ),
101
+ )
102
+
103
+ # Value
104
+ loader.port_weight(
105
+ keras_variable=decoder_layer._self_attention_layer.value_dense.kernel,
106
+ hf_weight_key=f"h.{index}.attn.c_attn.weight",
107
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
108
+ hf_tensor[:, 2 * n_embd :], keras_shape
109
+ ),
110
+ )
111
+ loader.port_weight(
112
+ keras_variable=decoder_layer._self_attention_layer.value_dense.bias,
113
+ hf_weight_key=f"h.{index}.attn.c_attn.bias",
114
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
115
+ hf_tensor[2 * n_embd :], keras_shape
116
+ ),
117
+ )
118
+
119
+ # Output
120
+ loader.port_weight(
121
+ keras_variable=decoder_layer._self_attention_layer.output_dense.kernel,
122
+ hf_weight_key=f"h.{index}.attn.c_proj.weight",
123
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
124
+ hf_tensor, keras_shape
125
+ ),
126
+ )
127
+ loader.port_weight(
128
+ keras_variable=decoder_layer._self_attention_layer.output_dense.bias,
129
+ hf_weight_key=f"h.{index}.attn.c_proj.bias",
130
+ )
131
+
132
+ # MLP layers
133
+ loader.port_weight(
134
+ keras_variable=decoder_layer._feedforward_intermediate_dense.kernel,
135
+ hf_weight_key=f"h.{index}.mlp.c_fc.weight",
136
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
137
+ hf_tensor, keras_shape
138
+ ),
139
+ )
140
+ loader.port_weight(
141
+ keras_variable=decoder_layer._feedforward_intermediate_dense.bias,
142
+ hf_weight_key=f"h.{index}.mlp.c_fc.bias",
143
+ )
144
+ loader.port_weight(
145
+ keras_variable=decoder_layer._feedforward_output_dense.kernel,
146
+ hf_weight_key=f"h.{index}.mlp.c_proj.weight",
147
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
148
+ hf_tensor, keras_shape
149
+ ),
150
+ )
151
+ loader.port_weight(
152
+ keras_variable=decoder_layer._feedforward_output_dense.bias,
153
+ hf_weight_key=f"h.{index}.mlp.c_proj.bias",
154
+ )
155
+
156
+ # Normalization
157
+ loader.port_weight(
158
+ keras_variable=backbone.layer_norm.gamma,
159
+ hf_weight_key="ln_f.weight",
160
+ )
161
+ loader.port_weight(
162
+ keras_variable=backbone.layer_norm.beta,
163
+ hf_weight_key="ln_f.bias",
164
+ )
165
+
166
+ return backbone
167
+
168
+
169
+ def load_gpt2_backbone(cls, preset, load_weights):
170
+ transformers_config = load_config(preset, HF_CONFIG_FILE)
171
+ keras_config = convert_backbone_config(transformers_config)
172
+ backbone = cls(**keras_config)
173
+ if load_weights:
174
+ jax_memory_cleanup(backbone)
175
+ with SafetensorLoader(preset) as loader:
176
+ convert_weights(backbone, loader, transformers_config)
177
+ return backbone
178
+
179
+
180
+ def load_gpt2_tokenizer(cls, preset):
181
+ vocab_file = get_file(preset, "vocab.json")
182
+ merges_file = get_file(preset, "merges.txt")
183
+ return cls(
184
+ vocabulary=vocab_file,
185
+ merges=merges_file,
186
+ )
@@ -0,0 +1,136 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+
16
+ from keras_hub.src.utils.preset_utils import HF_CONFIG_FILE
17
+ from keras_hub.src.utils.preset_utils import jax_memory_cleanup
18
+ from keras_hub.src.utils.preset_utils import load_config
19
+ from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
20
+
21
+
22
+ def convert_backbone_config(transformers_config):
23
+ return {
24
+ "vocabulary_size": transformers_config["vocab_size"],
25
+ "num_layers": transformers_config["num_hidden_layers"],
26
+ "num_query_heads": transformers_config["num_attention_heads"],
27
+ "hidden_dim": transformers_config["hidden_size"],
28
+ "intermediate_dim": transformers_config["intermediate_size"],
29
+ "num_key_value_heads": transformers_config["num_key_value_heads"],
30
+ }
31
+
32
+
33
+ def convert_weights(backbone, loader, transformers_config):
34
+ loader.port_weight(
35
+ keras_variable=backbone.get_layer("token_embedding").embeddings,
36
+ hf_weight_key="model.embed_tokens.weight",
37
+ )
38
+ loader.port_weight(
39
+ keras_variable=backbone.get_layer("token_embedding").reverse_embeddings,
40
+ hf_weight_key="lm_head.weight",
41
+ # rearrange_pattern="b a -> a b",
42
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
43
+ )
44
+
45
+ def transpose_and_reshape(x, shape):
46
+ return np.reshape(np.transpose(x), shape)
47
+
48
+ # Attention blocks
49
+ for i in range(backbone.num_layers):
50
+ decoder_layer = backbone.get_layer(f"transformer_layer_{i}")
51
+ # Norm layers
52
+ loader.port_weight(
53
+ keras_variable=decoder_layer._self_attention_layernorm.scale,
54
+ hf_weight_key=f"model.layers.{i}.input_layernorm.weight",
55
+ )
56
+ loader.port_weight(
57
+ keras_variable=decoder_layer._feedforward_layernorm.scale,
58
+ hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight",
59
+ )
60
+
61
+ # Attention layers
62
+ loader.port_weight(
63
+ keras_variable=decoder_layer._self_attention_layer._query_dense.kernel,
64
+ hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight",
65
+ hook_fn=transpose_and_reshape,
66
+ )
67
+ loader.port_weight(
68
+ keras_variable=decoder_layer._self_attention_layer._key_dense.kernel,
69
+ hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight",
70
+ hook_fn=transpose_and_reshape,
71
+ )
72
+ loader.port_weight(
73
+ keras_variable=decoder_layer._self_attention_layer._value_dense.kernel,
74
+ hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight",
75
+ hook_fn=transpose_and_reshape,
76
+ )
77
+ loader.port_weight(
78
+ keras_variable=decoder_layer._self_attention_layer._output_dense.kernel,
79
+ hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight",
80
+ # rearrange_patterns="c (a b) -> a b c",
81
+ # rearrange_dims={"a": backbone.num_query_heads},
82
+ hook_fn=transpose_and_reshape,
83
+ )
84
+
85
+ # MLP layers
86
+ loader.port_weight(
87
+ keras_variable=decoder_layer._feedforward_gate_dense.kernel,
88
+ hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight",
89
+ # rearrange_patterns="b a -> a b",
90
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
91
+ )
92
+ loader.port_weight(
93
+ keras_variable=decoder_layer._feedforward_intermediate_dense.kernel,
94
+ hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight",
95
+ # rearrange_patterns="b a -> a b",
96
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
97
+ )
98
+ loader.port_weight(
99
+ keras_variable=decoder_layer._feedforward_output_dense.kernel,
100
+ hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight",
101
+ # rearrange_patterns="b a -> a b",
102
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
103
+ )
104
+
105
+ # Final normalization layer
106
+ loader.port_weight(
107
+ keras_variable=backbone.get_layer("sequence_output_layernorm").scale,
108
+ hf_weight_key="model.norm.weight",
109
+ )
110
+
111
+ return backbone
112
+
113
+
114
+ def load_llama3_backbone(cls, preset, load_weights):
115
+ transformers_config = load_config(preset, HF_CONFIG_FILE)
116
+ keras_config = convert_backbone_config(transformers_config)
117
+ backbone = cls(**keras_config)
118
+ if load_weights:
119
+ jax_memory_cleanup(backbone)
120
+ with SafetensorLoader(preset) as loader:
121
+ convert_weights(backbone, loader, transformers_config)
122
+ return backbone
123
+
124
+
125
+ def load_llama3_tokenizer(cls, preset):
126
+ tokenizer_config = load_config(preset, "tokenizer.json")
127
+ vocab = tokenizer_config["model"]["vocab"]
128
+ merges = tokenizer_config["model"]["merges"]
129
+
130
+ bot = tokenizer_config["added_tokens"][0] # begin of text
131
+ eot = tokenizer_config["added_tokens"][1] # end of text
132
+
133
+ vocab[bot["content"]] = bot["id"]
134
+ vocab[eot["content"]] = eot["id"]
135
+
136
+ return cls(vocabulary=vocab, merges=merges)
@@ -0,0 +1,303 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+
16
+ from keras_hub.src.utils.preset_utils import HF_CONFIG_FILE
17
+ from keras_hub.src.utils.preset_utils import get_file
18
+ from keras_hub.src.utils.preset_utils import jax_memory_cleanup
19
+ from keras_hub.src.utils.preset_utils import load_config
20
+ from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
21
+
22
+
23
+ def convert_backbone_config(transformers_config):
24
+ text_config = transformers_config["text_config"]
25
+ vision_config = transformers_config["vision_config"]
26
+ return {
27
+ "vocabulary_size": transformers_config["image_token_index"],
28
+ "image_size": (
29
+ vision_config["image_size"]
30
+ if "image_size" in vision_config.keys()
31
+ else 224
32
+ ),
33
+ "num_layers": text_config["num_hidden_layers"],
34
+ "num_query_heads": text_config["num_attention_heads"],
35
+ "num_key_value_heads": text_config["num_key_value_heads"],
36
+ "hidden_dim": text_config["hidden_size"],
37
+ "intermediate_dim": text_config["intermediate_size"] * 2,
38
+ "head_dim": text_config["num_image_tokens"],
39
+ "vit_patch_size": vision_config["patch_size"],
40
+ "vit_num_heads": vision_config["num_attention_heads"],
41
+ "vit_hidden_dim": vision_config["hidden_size"],
42
+ "vit_num_layers": vision_config["num_hidden_layers"],
43
+ "vit_intermediate_dim": vision_config["intermediate_size"],
44
+ }
45
+
46
+
47
+ def convert_weights(backbone, loader, transformers_config):
48
+ ############################################################################
49
+ # Image Tower
50
+ ############################################################################
51
+ image_encoder = backbone.vit_encoder.get_layer("image_encoder")
52
+
53
+ # Embedding
54
+ loader.port_weight(
55
+ keras_variable=image_encoder.vision_embeddings.patch_embedding.bias,
56
+ hf_weight_key="vision_tower.vision_model.embeddings.patch_embedding.bias",
57
+ )
58
+
59
+ loader.port_weight(
60
+ keras_variable=image_encoder.vision_embeddings.patch_embedding.kernel,
61
+ hf_weight_key="vision_tower.vision_model.embeddings.patch_embedding.weight",
62
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(2, 3, 1, 0)),
63
+ )
64
+
65
+ # Positional Embedding
66
+ loader.port_weight(
67
+ keras_variable=image_encoder.vision_embeddings.position_embedding.embeddings,
68
+ hf_weight_key="vision_tower.vision_model.embeddings.position_embedding.weight",
69
+ )
70
+
71
+ # Normalization
72
+ loader.port_weight(
73
+ keras_variable=image_encoder.encoder_layer_norm.gamma,
74
+ hf_weight_key="vision_tower.vision_model.post_layernorm.weight",
75
+ )
76
+
77
+ loader.port_weight(
78
+ keras_variable=image_encoder.encoder_layer_norm.beta,
79
+ hf_weight_key="vision_tower.vision_model.post_layernorm.bias",
80
+ )
81
+
82
+ # ResBlocks
83
+ for index in range(image_encoder.num_layers):
84
+ block = image_encoder.resblocks[index]
85
+
86
+ loader.port_weight(
87
+ keras_variable=block.layer_norm_1.beta,
88
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.layer_norm1.bias",
89
+ )
90
+
91
+ loader.port_weight(
92
+ keras_variable=block.layer_norm_1.gamma,
93
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.layer_norm1.weight",
94
+ )
95
+
96
+ loader.port_weight(
97
+ keras_variable=block.layer_norm_2.beta,
98
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.layer_norm2.bias",
99
+ )
100
+
101
+ loader.port_weight(
102
+ keras_variable=block.layer_norm_2.gamma,
103
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.layer_norm2.weight",
104
+ )
105
+
106
+ loader.port_weight(
107
+ keras_variable=block.mlp_dense_1.kernel,
108
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.mlp.fc1.weight",
109
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
110
+ )
111
+
112
+ loader.port_weight(
113
+ keras_variable=block.mlp_dense_1.bias,
114
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.mlp.fc1.bias",
115
+ )
116
+
117
+ loader.port_weight(
118
+ keras_variable=block.mlp_dense_2.kernel,
119
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.mlp.fc2.weight",
120
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
121
+ )
122
+
123
+ loader.port_weight(
124
+ keras_variable=block.mlp_dense_2.bias,
125
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.mlp.fc2.bias",
126
+ )
127
+
128
+ loader.port_weight(
129
+ keras_variable=block.attn.key_proj.bias,
130
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.k_proj.bias",
131
+ )
132
+
133
+ loader.port_weight(
134
+ keras_variable=block.attn.key_proj.kernel,
135
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.k_proj.weight",
136
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
137
+ )
138
+
139
+ loader.port_weight(
140
+ keras_variable=block.attn.out_proj.bias,
141
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.out_proj.bias",
142
+ )
143
+
144
+ loader.port_weight(
145
+ keras_variable=block.attn.out_proj.kernel,
146
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.out_proj.weight",
147
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
148
+ )
149
+
150
+ loader.port_weight(
151
+ keras_variable=block.attn.query_proj.bias,
152
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.q_proj.bias",
153
+ )
154
+
155
+ loader.port_weight(
156
+ keras_variable=block.attn.query_proj.kernel,
157
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.q_proj.weight",
158
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
159
+ )
160
+
161
+ loader.port_weight(
162
+ keras_variable=block.attn.value_proj.bias,
163
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.v_proj.bias",
164
+ )
165
+
166
+ loader.port_weight(
167
+ keras_variable=block.attn.value_proj.kernel,
168
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.v_proj.weight",
169
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
170
+ )
171
+
172
+ # Multi Modal Projection
173
+ loader.port_weight(
174
+ keras_variable=backbone.vit_encoder.get_layer(
175
+ "image_classifier"
176
+ ).kernel,
177
+ hf_weight_key="multi_modal_projector.linear.weight",
178
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
179
+ )
180
+
181
+ loader.port_weight(
182
+ keras_variable=backbone.vit_encoder.get_layer("image_classifier").bias,
183
+ hf_weight_key="multi_modal_projector.linear.bias",
184
+ )
185
+
186
+ ############################################################################
187
+ # Language Tower
188
+ ############################################################################
189
+ for index in range(backbone.num_layers):
190
+ decoder_layer = backbone.transformer_layers[index]
191
+
192
+ # Norm layers
193
+ loader.port_weight(
194
+ keras_variable=decoder_layer.pre_attention_norm.scale,
195
+ hf_weight_key=f"language_model.model.layers.{index}.input_layernorm.weight",
196
+ )
197
+ loader.port_weight(
198
+ keras_variable=decoder_layer.pre_ffw_norm.scale,
199
+ hf_weight_key=f"language_model.model.layers.{index}.post_attention_layernorm.weight",
200
+ )
201
+
202
+ # Attention layers
203
+ loader.port_weight(
204
+ keras_variable=decoder_layer.attention.query_dense.kernel,
205
+ hf_weight_key=f"language_model.model.layers.{index}.self_attn.q_proj.weight",
206
+ hook_fn=lambda hf_tensor, keras_shape: np.transpose(
207
+ np.reshape(
208
+ hf_tensor,
209
+ (keras_shape[0], keras_shape[2], keras_shape[1]),
210
+ ),
211
+ axes=(0, 2, 1),
212
+ ),
213
+ )
214
+ loader.port_weight(
215
+ keras_variable=decoder_layer.attention.key_dense.kernel,
216
+ hf_weight_key=f"language_model.model.layers.{index}.self_attn.k_proj.weight",
217
+ hook_fn=lambda hf_tensor, keras_shape: np.transpose(
218
+ np.reshape(
219
+ hf_tensor,
220
+ (keras_shape[0], keras_shape[2], keras_shape[1]),
221
+ ),
222
+ axes=(0, 2, 1),
223
+ ),
224
+ )
225
+ loader.port_weight(
226
+ keras_variable=decoder_layer.attention.value_dense.kernel,
227
+ hf_weight_key=f"language_model.model.layers.{index}.self_attn.v_proj.weight",
228
+ hook_fn=lambda hf_tensor, keras_shape: np.transpose(
229
+ np.reshape(
230
+ hf_tensor,
231
+ (keras_shape[0], keras_shape[2], keras_shape[1]),
232
+ ),
233
+ axes=(0, 2, 1),
234
+ ),
235
+ )
236
+ loader.port_weight(
237
+ keras_variable=decoder_layer.attention.output_dense.kernel,
238
+ hf_weight_key=f"language_model.model.layers.{index}.self_attn.o_proj.weight",
239
+ hook_fn=lambda hf_tensor, keras_shape: np.transpose(
240
+ np.reshape(
241
+ hf_tensor,
242
+ (keras_shape[2], keras_shape[0], keras_shape[1]),
243
+ ),
244
+ axes=(1, 2, 0),
245
+ ),
246
+ )
247
+
248
+ # MLP layers
249
+ loader.port_weight(
250
+ keras_variable=decoder_layer.gating_ffw.kernel,
251
+ hf_weight_key=f"language_model.model.layers.{index}.mlp.gate_proj.weight",
252
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
253
+ )
254
+ loader.port_weight(
255
+ keras_variable=decoder_layer.gating_ffw_2.kernel,
256
+ hf_weight_key=f"language_model.model.layers.{index}.mlp.up_proj.weight",
257
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
258
+ )
259
+ loader.port_weight(
260
+ keras_variable=decoder_layer.ffw_linear.kernel,
261
+ hf_weight_key=f"language_model.model.layers.{index}.mlp.down_proj.weight",
262
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
263
+ )
264
+
265
+ # Normalization
266
+ loader.port_weight(
267
+ keras_variable=backbone.layer_norm.scale,
268
+ hf_weight_key="language_model.model.norm.weight",
269
+ )
270
+
271
+ # Embedding
272
+ loader.port_weight(
273
+ keras_variable=backbone.token_embedding.embeddings,
274
+ hf_weight_key="language_model.model.embed_tokens.weight",
275
+ hook_fn=lambda hf_tensor, keras_shape: hf_tensor[: keras_shape[0]],
276
+ )
277
+
278
+ return backbone
279
+
280
+
281
+ def load_pali_gemma_backbone(cls, preset, load_weights):
282
+ transformers_config = load_config(preset, HF_CONFIG_FILE)
283
+ keras_config = convert_backbone_config(transformers_config)
284
+ backbone = cls(**keras_config)
285
+ if load_weights:
286
+ jax_memory_cleanup(backbone)
287
+ with SafetensorLoader(preset) as loader:
288
+ convert_weights(backbone, loader, transformers_config)
289
+ return backbone
290
+
291
+
292
+ def load_pali_gemma_tokenizer(cls, preset):
293
+ """
294
+ Load the Gemma tokenizer.
295
+
296
+ Args:
297
+ cls (class): Tokenizer class.
298
+ preset (str): Preset configuration name.
299
+
300
+ Returns:
301
+ tokenizer: Initialized tokenizer.
302
+ """
303
+ return cls(get_file(preset, "tokenizer.model"))