keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,173 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+
16
+ from keras_hub.src.utils.preset_utils import HF_CONFIG_FILE
17
+ from keras_hub.src.utils.preset_utils import HF_TOKENIZER_CONFIG_FILE
18
+ from keras_hub.src.utils.preset_utils import get_file
19
+ from keras_hub.src.utils.preset_utils import jax_memory_cleanup
20
+ from keras_hub.src.utils.preset_utils import load_config
21
+ from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
22
+
23
+
24
+ def convert_backbone_config(transformers_config):
25
+ return {
26
+ "vocabulary_size": transformers_config["vocab_size"],
27
+ "num_layers": transformers_config["num_hidden_layers"],
28
+ "num_heads": transformers_config["num_attention_heads"],
29
+ "hidden_dim": transformers_config["hidden_size"],
30
+ "intermediate_dim": transformers_config["intermediate_size"],
31
+ }
32
+
33
+
34
+ def convert_weights(backbone, loader, transformers_config):
35
+ # Embedding layer
36
+ loader.port_weight(
37
+ keras_variable=backbone.get_layer("token_embedding").embeddings,
38
+ hf_weight_key="bert.embeddings.word_embeddings.weight",
39
+ )
40
+ loader.port_weight(
41
+ keras_variable=backbone.get_layer(
42
+ "position_embedding"
43
+ ).position_embeddings,
44
+ hf_weight_key="bert.embeddings.position_embeddings.weight",
45
+ )
46
+ loader.port_weight(
47
+ keras_variable=backbone.get_layer("segment_embedding").embeddings,
48
+ hf_weight_key="bert.embeddings.token_type_embeddings.weight",
49
+ )
50
+ loader.port_weight(
51
+ keras_variable=backbone.get_layer("embeddings_layer_norm").beta,
52
+ hf_weight_key="bert.embeddings.LayerNorm.beta",
53
+ )
54
+ loader.port_weight(
55
+ keras_variable=backbone.get_layer("embeddings_layer_norm").gamma,
56
+ hf_weight_key="bert.embeddings.LayerNorm.gamma",
57
+ )
58
+
59
+ def transpose_and_reshape(x, shape):
60
+ return np.reshape(np.transpose(x), shape)
61
+
62
+ # Attention blocks
63
+ for i in range(backbone.num_layers):
64
+ block = backbone.get_layer(f"transformer_layer_{i}")
65
+ attn = block._self_attention_layer
66
+ hf_prefix = "bert.encoder.layer."
67
+ # Attention layers
68
+ loader.port_weight(
69
+ keras_variable=attn.query_dense.kernel,
70
+ hf_weight_key=f"{hf_prefix}{i}.attention.self.query.weight",
71
+ hook_fn=transpose_and_reshape,
72
+ )
73
+ loader.port_weight(
74
+ keras_variable=attn.query_dense.bias,
75
+ hf_weight_key=f"{hf_prefix}{i}.attention.self.query.bias",
76
+ hook_fn=lambda hf_tensor, shape: np.reshape(hf_tensor, shape),
77
+ )
78
+ loader.port_weight(
79
+ keras_variable=attn.key_dense.kernel,
80
+ hf_weight_key=f"{hf_prefix}{i}.attention.self.key.weight",
81
+ hook_fn=transpose_and_reshape,
82
+ )
83
+ loader.port_weight(
84
+ keras_variable=attn.key_dense.bias,
85
+ hf_weight_key=f"{hf_prefix}{i}.attention.self.key.bias",
86
+ hook_fn=lambda hf_tensor, shape: np.reshape(hf_tensor, shape),
87
+ )
88
+ loader.port_weight(
89
+ keras_variable=attn.value_dense.kernel,
90
+ hf_weight_key=f"{hf_prefix}{i}.attention.self.value.weight",
91
+ hook_fn=transpose_and_reshape,
92
+ )
93
+ loader.port_weight(
94
+ keras_variable=attn.value_dense.bias,
95
+ hf_weight_key=f"{hf_prefix}{i}.attention.self.value.bias",
96
+ hook_fn=lambda hf_tensor, shape: np.reshape(hf_tensor, shape),
97
+ )
98
+ loader.port_weight(
99
+ keras_variable=attn.output_dense.kernel,
100
+ hf_weight_key=f"{hf_prefix}{i}.attention.output.dense.weight",
101
+ hook_fn=transpose_and_reshape,
102
+ )
103
+ loader.port_weight(
104
+ keras_variable=attn.output_dense.bias,
105
+ hf_weight_key=f"{hf_prefix}{i}.attention.output.dense.bias",
106
+ hook_fn=lambda hf_tensor, shape: np.reshape(hf_tensor, shape),
107
+ )
108
+ # Attention layer norm.
109
+ loader.port_weight(
110
+ keras_variable=block._self_attention_layer_norm.beta,
111
+ hf_weight_key=f"{hf_prefix}{i}.attention.output.LayerNorm.beta",
112
+ )
113
+ loader.port_weight(
114
+ keras_variable=block._self_attention_layer_norm.gamma,
115
+ hf_weight_key=f"{hf_prefix}{i}.attention.output.LayerNorm.gamma",
116
+ )
117
+ # MLP layers
118
+ loader.port_weight(
119
+ keras_variable=block._feedforward_intermediate_dense.kernel,
120
+ hf_weight_key=f"{hf_prefix}{i}.intermediate.dense.weight",
121
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
122
+ )
123
+ loader.port_weight(
124
+ keras_variable=block._feedforward_intermediate_dense.bias,
125
+ hf_weight_key=f"{hf_prefix}{i}.intermediate.dense.bias",
126
+ )
127
+ loader.port_weight(
128
+ keras_variable=block._feedforward_output_dense.kernel,
129
+ hf_weight_key=f"{hf_prefix}{i}.output.dense.weight",
130
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
131
+ )
132
+ loader.port_weight(
133
+ keras_variable=block._feedforward_output_dense.bias,
134
+ hf_weight_key=f"{hf_prefix}{i}.output.dense.bias",
135
+ )
136
+ # Output layer norm.
137
+ loader.port_weight(
138
+ keras_variable=block._feedforward_layer_norm.beta,
139
+ hf_weight_key=f"{hf_prefix}{i}.output.LayerNorm.beta",
140
+ )
141
+ loader.port_weight(
142
+ keras_variable=block._feedforward_layer_norm.gamma,
143
+ hf_weight_key=f"{hf_prefix}{i}.output.LayerNorm.gamma",
144
+ )
145
+
146
+ loader.port_weight(
147
+ keras_variable=backbone.get_layer("pooled_dense").kernel,
148
+ hf_weight_key="bert.pooler.dense.weight",
149
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
150
+ )
151
+ loader.port_weight(
152
+ keras_variable=backbone.get_layer("pooled_dense").bias,
153
+ hf_weight_key="bert.pooler.dense.bias",
154
+ )
155
+
156
+
157
+ def load_bert_backbone(cls, preset, load_weights):
158
+ transformers_config = load_config(preset, HF_CONFIG_FILE)
159
+ keras_config = convert_backbone_config(transformers_config)
160
+ backbone = cls(**keras_config)
161
+ if load_weights:
162
+ jax_memory_cleanup(backbone)
163
+ with SafetensorLoader(preset) as loader:
164
+ convert_weights(backbone, loader, transformers_config)
165
+ return backbone
166
+
167
+
168
+ def load_bert_tokenizer(cls, preset):
169
+ transformers_config = load_config(preset, HF_TOKENIZER_CONFIG_FILE)
170
+ return cls(
171
+ get_file(preset, "vocab.txt"),
172
+ lowercase=transformers_config["do_lower_case"],
173
+ )
@@ -0,0 +1,184 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+
16
+ from keras_hub.src.utils.preset_utils import HF_CONFIG_FILE
17
+ from keras_hub.src.utils.preset_utils import HF_TOKENIZER_CONFIG_FILE
18
+ from keras_hub.src.utils.preset_utils import get_file
19
+ from keras_hub.src.utils.preset_utils import jax_memory_cleanup
20
+ from keras_hub.src.utils.preset_utils import load_config
21
+ from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
22
+
23
+
24
+ def convert_backbone_config(transformers_config):
25
+ return {
26
+ "vocabulary_size": transformers_config["vocab_size"],
27
+ "num_layers": transformers_config["n_layers"],
28
+ "num_heads": transformers_config["n_heads"],
29
+ "hidden_dim": transformers_config["dim"],
30
+ "intermediate_dim": transformers_config["hidden_dim"],
31
+ "dropout": transformers_config["dropout"],
32
+ "max_sequence_length": transformers_config["max_position_embeddings"],
33
+ }
34
+
35
+
36
+ def convert_weights(backbone, loader):
37
+ # Embeddings
38
+ loader.port_weight(
39
+ keras_variable=backbone.get_layer(
40
+ "token_and_position_embedding"
41
+ ).token_embedding.embeddings,
42
+ hf_weight_key="distilbert.embeddings.word_embeddings.weight",
43
+ )
44
+ loader.port_weight(
45
+ keras_variable=backbone.get_layer(
46
+ "token_and_position_embedding"
47
+ ).position_embedding.position_embeddings,
48
+ hf_weight_key="distilbert.embeddings.position_embeddings.weight",
49
+ )
50
+
51
+ # Attention blocks
52
+ for index in range(backbone.num_layers):
53
+ decoder_layer = backbone.transformer_layers[index]
54
+
55
+ # Norm layers
56
+ loader.port_weight(
57
+ keras_variable=decoder_layer._self_attention_layer_norm.gamma,
58
+ hf_weight_key=f"distilbert.transformer.layer.{index}.sa_layer_norm.weight",
59
+ )
60
+ loader.port_weight(
61
+ keras_variable=decoder_layer._self_attention_layer_norm.beta,
62
+ hf_weight_key=f"distilbert.transformer.layer.{index}.sa_layer_norm.bias",
63
+ )
64
+ loader.port_weight(
65
+ keras_variable=decoder_layer._feedforward_layer_norm.gamma,
66
+ hf_weight_key=f"distilbert.transformer.layer.{index}.output_layer_norm.weight",
67
+ )
68
+ loader.port_weight(
69
+ keras_variable=decoder_layer._feedforward_layer_norm.beta,
70
+ hf_weight_key=f"distilbert.transformer.layer.{index}.output_layer_norm.bias",
71
+ )
72
+
73
+ # Attention layers
74
+ # Query
75
+ loader.port_weight(
76
+ keras_variable=decoder_layer._self_attention_layer.query_dense.kernel,
77
+ hf_weight_key=f"distilbert.transformer.layer.{index}.attention.q_lin.weight",
78
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
79
+ np.transpose(hf_tensor), keras_shape
80
+ ),
81
+ )
82
+ loader.port_weight(
83
+ keras_variable=decoder_layer._self_attention_layer.query_dense.bias,
84
+ hf_weight_key=f"distilbert.transformer.layer.{index}.attention.q_lin.bias",
85
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
86
+ hf_tensor, keras_shape
87
+ ),
88
+ )
89
+
90
+ # Key
91
+ loader.port_weight(
92
+ keras_variable=decoder_layer._self_attention_layer.key_dense.kernel,
93
+ hf_weight_key=f"distilbert.transformer.layer.{index}.attention.k_lin.weight",
94
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
95
+ np.transpose(hf_tensor), keras_shape
96
+ ),
97
+ )
98
+ loader.port_weight(
99
+ keras_variable=decoder_layer._self_attention_layer.key_dense.bias,
100
+ hf_weight_key=f"distilbert.transformer.layer.{index}.attention.k_lin.bias",
101
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
102
+ hf_tensor, keras_shape
103
+ ),
104
+ )
105
+
106
+ # Value
107
+ loader.port_weight(
108
+ keras_variable=decoder_layer._self_attention_layer.value_dense.kernel,
109
+ hf_weight_key=f"distilbert.transformer.layer.{index}.attention.v_lin.weight",
110
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
111
+ np.transpose(hf_tensor), keras_shape
112
+ ),
113
+ )
114
+ loader.port_weight(
115
+ keras_variable=decoder_layer._self_attention_layer.value_dense.bias,
116
+ hf_weight_key=f"distilbert.transformer.layer.{index}.attention.v_lin.bias",
117
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
118
+ hf_tensor, keras_shape
119
+ ),
120
+ )
121
+
122
+ # Output
123
+ loader.port_weight(
124
+ keras_variable=decoder_layer._self_attention_layer.output_dense.kernel,
125
+ hf_weight_key=f"distilbert.transformer.layer.{index}.attention.out_lin.weight",
126
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
127
+ np.transpose(hf_tensor), keras_shape
128
+ ),
129
+ )
130
+ loader.port_weight(
131
+ keras_variable=decoder_layer._self_attention_layer.output_dense.bias,
132
+ hf_weight_key=f"distilbert.transformer.layer.{index}.attention.out_lin.bias",
133
+ )
134
+
135
+ # MLP layers
136
+ loader.port_weight(
137
+ keras_variable=decoder_layer._feedforward_intermediate_dense.kernel,
138
+ hf_weight_key=f"distilbert.transformer.layer.{index}.ffn.lin1.weight",
139
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
140
+ )
141
+ loader.port_weight(
142
+ keras_variable=decoder_layer._feedforward_intermediate_dense.bias,
143
+ hf_weight_key=f"distilbert.transformer.layer.{index}.ffn.lin1.bias",
144
+ )
145
+ loader.port_weight(
146
+ keras_variable=decoder_layer._feedforward_output_dense.kernel,
147
+ hf_weight_key=f"distilbert.transformer.layer.{index}.ffn.lin2.weight",
148
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
149
+ )
150
+ loader.port_weight(
151
+ keras_variable=decoder_layer._feedforward_output_dense.bias,
152
+ hf_weight_key=f"distilbert.transformer.layer.{index}.ffn.lin2.bias",
153
+ )
154
+
155
+ # Normalization
156
+ loader.port_weight(
157
+ keras_variable=backbone.embeddings_layer_norm.gamma,
158
+ hf_weight_key="distilbert.embeddings.LayerNorm.weight",
159
+ )
160
+ loader.port_weight(
161
+ keras_variable=backbone.embeddings_layer_norm.beta,
162
+ hf_weight_key="distilbert.embeddings.LayerNorm.bias",
163
+ )
164
+
165
+ return backbone
166
+
167
+
168
+ def load_distilbert_backbone(cls, preset, load_weights):
169
+ transformers_config = load_config(preset, HF_CONFIG_FILE)
170
+ keras_config = convert_backbone_config(transformers_config)
171
+ backbone = cls(**keras_config)
172
+ if load_weights:
173
+ jax_memory_cleanup(backbone)
174
+ with SafetensorLoader(preset) as loader:
175
+ convert_weights(backbone, loader)
176
+ return backbone
177
+
178
+
179
+ def load_distilbert_tokenizer(cls, preset):
180
+ transformers_config = load_config(preset, HF_TOKENIZER_CONFIG_FILE)
181
+ return cls(
182
+ get_file(preset, "vocab.txt"),
183
+ lowercase=transformers_config["do_lower_case"],
184
+ )
@@ -0,0 +1,187 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+
16
+ from keras_hub.src.utils.preset_utils import HF_CONFIG_FILE
17
+ from keras_hub.src.utils.preset_utils import get_file
18
+ from keras_hub.src.utils.preset_utils import jax_memory_cleanup
19
+ from keras_hub.src.utils.preset_utils import load_config
20
+ from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
21
+
22
+
23
+ def convert_backbone_config(transformers_config):
24
+ backbone_config = dict()
25
+ if transformers_config["model_type"] == "gemma":
26
+ # Build Gemma backbone configuration
27
+ backbone_config = {
28
+ "vocabulary_size": transformers_config["vocab_size"],
29
+ "num_layers": transformers_config["num_hidden_layers"],
30
+ "num_query_heads": transformers_config["num_attention_heads"],
31
+ "num_key_value_heads": transformers_config["num_key_value_heads"],
32
+ "hidden_dim": transformers_config["hidden_size"],
33
+ "intermediate_dim": transformers_config["intermediate_size"] * 2,
34
+ "head_dim": transformers_config["head_dim"],
35
+ }
36
+ elif transformers_config["model_type"] == "gemma2":
37
+ # Build Gemma 2 backbone configuration
38
+ backbone_config = {
39
+ "vocabulary_size": transformers_config["vocab_size"],
40
+ "num_layers": transformers_config["num_hidden_layers"],
41
+ "num_query_heads": transformers_config["num_attention_heads"],
42
+ "num_key_value_heads": transformers_config["num_key_value_heads"],
43
+ "hidden_dim": transformers_config["hidden_size"],
44
+ "intermediate_dim": transformers_config["intermediate_size"] * 2,
45
+ "head_dim": transformers_config["head_dim"],
46
+ "query_head_dim_normalize": (
47
+ transformers_config["head_dim"]
48
+ == transformers_config["query_pre_attn_scalar"]
49
+ ),
50
+ "use_post_ffw_norm": True,
51
+ "use_post_attention_norm": True,
52
+ "final_logit_soft_cap": transformers_config[
53
+ "final_logit_softcapping"
54
+ ],
55
+ "attention_logit_soft_cap": transformers_config[
56
+ "attn_logit_softcapping"
57
+ ],
58
+ "sliding_window_size": transformers_config["sliding_window"],
59
+ "use_sliding_window_attention": True,
60
+ }
61
+ return backbone_config
62
+
63
+
64
+ def convert_weights(backbone, loader, transformers_config):
65
+ # Embedding layer
66
+ loader.port_weight(
67
+ keras_variable=backbone.get_layer("token_embedding").embeddings,
68
+ hf_weight_key="model.embed_tokens.weight",
69
+ )
70
+
71
+ # Attention blocks
72
+ for i in range(backbone.num_layers):
73
+ decoder_layer = backbone.get_layer(f"decoder_block_{i}")
74
+ # Norm layers
75
+ loader.port_weight(
76
+ keras_variable=decoder_layer.pre_attention_norm.scale,
77
+ hf_weight_key=f"model.layers.{i}.input_layernorm.weight",
78
+ )
79
+
80
+ if decoder_layer.use_post_attention_norm:
81
+ loader.port_weight(
82
+ keras_variable=decoder_layer.post_attention_norm.scale,
83
+ hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight",
84
+ )
85
+
86
+ if transformers_config["model_type"] == "gemma":
87
+ loader.port_weight(
88
+ keras_variable=decoder_layer.pre_ffw_norm.scale,
89
+ hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight",
90
+ )
91
+ elif transformers_config["model_type"] == "gemma2":
92
+ loader.port_weight(
93
+ keras_variable=decoder_layer.pre_ffw_norm.scale,
94
+ hf_weight_key=f"model.layers.{i}.pre_feedforward_layernorm.weight",
95
+ )
96
+
97
+ if decoder_layer.use_post_ffw_norm:
98
+ loader.port_weight(
99
+ keras_variable=decoder_layer.post_ffw_norm.scale,
100
+ hf_weight_key=f"model.layers.{i}.post_feedforward_layernorm.weight",
101
+ )
102
+
103
+ # Attention layers
104
+ loader.port_weight(
105
+ keras_variable=decoder_layer.attention.query_dense.kernel,
106
+ hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight",
107
+ hook_fn=lambda hf_tensor, keras_shape: np.transpose(
108
+ np.reshape(
109
+ hf_tensor,
110
+ (keras_shape[0], keras_shape[2], keras_shape[1]),
111
+ ),
112
+ axes=(0, 2, 1),
113
+ ),
114
+ )
115
+ loader.port_weight(
116
+ keras_variable=decoder_layer.attention.key_dense.kernel,
117
+ hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight",
118
+ hook_fn=lambda hf_tensor, keras_shape: np.transpose(
119
+ np.reshape(
120
+ hf_tensor,
121
+ (keras_shape[0], keras_shape[2], keras_shape[1]),
122
+ ),
123
+ axes=(0, 2, 1),
124
+ ),
125
+ )
126
+ loader.port_weight(
127
+ keras_variable=decoder_layer.attention.value_dense.kernel,
128
+ hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight",
129
+ hook_fn=lambda hf_tensor, keras_shape: np.transpose(
130
+ np.reshape(
131
+ hf_tensor,
132
+ (keras_shape[0], keras_shape[2], keras_shape[1]),
133
+ ),
134
+ axes=(0, 2, 1),
135
+ ),
136
+ )
137
+ loader.port_weight(
138
+ keras_variable=decoder_layer.attention.output_dense.kernel,
139
+ hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight",
140
+ hook_fn=lambda hf_tensor, keras_shape: np.transpose(
141
+ np.reshape(
142
+ hf_tensor,
143
+ (keras_shape[2], keras_shape[0], keras_shape[1]),
144
+ ),
145
+ axes=(1, 2, 0),
146
+ ),
147
+ )
148
+
149
+ # MLP layers
150
+ loader.port_weight(
151
+ keras_variable=decoder_layer.gating_ffw.kernel,
152
+ hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight",
153
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
154
+ )
155
+ loader.port_weight(
156
+ keras_variable=decoder_layer.gating_ffw_2.kernel,
157
+ hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight",
158
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
159
+ )
160
+ loader.port_weight(
161
+ keras_variable=decoder_layer.ffw_linear.kernel,
162
+ hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight",
163
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
164
+ )
165
+
166
+ # Final normalization layer
167
+ loader.port_weight(
168
+ keras_variable=backbone.get_layer("final_normalization").scale,
169
+ hf_weight_key="model.norm.weight",
170
+ )
171
+
172
+ return backbone
173
+
174
+
175
+ def load_gemma_backbone(cls, preset, load_weights):
176
+ transformers_config = load_config(preset, HF_CONFIG_FILE)
177
+ keras_config = convert_backbone_config(transformers_config)
178
+ backbone = cls(**keras_config)
179
+ if load_weights:
180
+ jax_memory_cleanup(backbone)
181
+ with SafetensorLoader(preset) as loader:
182
+ convert_weights(backbone, loader, transformers_config)
183
+ return backbone
184
+
185
+
186
+ def load_gemma_tokenizer(cls, preset):
187
+ return cls(get_file(preset, "tokenizer.model"))