keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,84 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.api_export import keras_hub_export
16
+ from keras_hub.src.models.llama.llama_backbone import LlamaBackbone
17
+
18
+
19
+ # LLaMA 3 shares the same architecture as its predecessors
20
+ # So, we simply create an alias for API consistency
21
+ @keras_hub_export("keras_hub.models.Llama3Backbone")
22
+ class Llama3Backbone(LlamaBackbone):
23
+ """
24
+ The Llama Transformer core architecture with hyperparameters.
25
+
26
+ This network implements a Transformer-based decoder network,
27
+ Llama, as described in
28
+ ["Llama 7B"](https://arxiv.org/pdf/2310.06825.pdf).
29
+ It includes the embedding lookups and transformer layers.
30
+
31
+ The default constructor gives a fully customizable, randomly initialized
32
+ Llama model with any number of layers, heads, and embedding
33
+ dimensions. To load preset architectures and weights, use the `from_preset`
34
+ constructor.
35
+
36
+ Args:
37
+ vocabulary_size (int): The size of the token vocabulary.
38
+ num_layers (int): The number of transformer layers.
39
+ num_query_heads (int): The number of query attention heads for
40
+ each transformer.
41
+ hidden_dim (int): The size of the transformer encoding and pooling layers.
42
+ intermediate_dim (int): The output dimension of the first Dense layer in a
43
+ three-layer feedforward network for each transformer.
44
+ num_key_value_heads (int): The number of key and value attention heads for
45
+ each transformer.
46
+ rope_max_wavelength (int, optional): The maximum angular wavelength of the
47
+ sine/cosine curves, for rotary embeddings. Defaults to `10000`.
48
+ rope_scaling_factor (float, optional): The scaling factor for calculation
49
+ of roatary embedding. Defaults to `1.0`.
50
+ layer_norm_epsilon (float, optional): Epsilon for the layer normalization
51
+ layers in the transformer decoder. Defaults to `1e-6`.
52
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
53
+ for model computations and weights. Note that some computations,
54
+ such as softmax and layer normalization, will always be done at
55
+ float32 precision regardless of dtype.
56
+
57
+ Examples:
58
+
59
+ ```python
60
+ input_data = {
61
+ "token_ids": np.ones(shape=(1, 12), dtype="int32"),
62
+ "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
63
+ }
64
+
65
+ # Pretrained Llama decoder.
66
+ model = keras_hub.models.Llama3Backbone.from_preset("llama3_8b_en")
67
+ model(input_data)
68
+
69
+ # Randomly initialized Llama decoder with custom config.
70
+ model = keras_hub.models.Llama3Backbone(
71
+ vocabulary_size=10,
72
+ hidden_dim=512,
73
+ num_layers=2,
74
+ num_query_heads=32,
75
+ num_key_value_heads=8,
76
+ intermediate_dim=1024,
77
+ layer_norm_epsilon=1e-6,
78
+ dtype="float32"
79
+ )
80
+ model(input_data)
81
+ ```
82
+ """
83
+
84
+ pass
@@ -0,0 +1,46 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from keras_hub.src.api_export import keras_hub_export
15
+ from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone
16
+ from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import (
17
+ Llama3CausalLMPreprocessor,
18
+ )
19
+ from keras_hub.src.models.llama.llama_causal_lm import LlamaCausalLM
20
+
21
+
22
+ @keras_hub_export("keras_hub.models.Llama3CausalLM")
23
+ class Llama3CausalLM(LlamaCausalLM):
24
+ """An end-to-end Llama 3 model for causal language modeling.
25
+
26
+ A causal language model (LM) predicts the next token based on previous
27
+ tokens. This task setup can be used to train the model unsupervised on
28
+ plain text input, or to autoregressively generate plain text similar to
29
+ the data used for training. This task can be used for pre-training or
30
+ fine-tuning a LLaMA 3 model, simply by calling `fit()`.
31
+
32
+ This model has a `generate()` method, which generates text based on a
33
+ prompt. The generation strategy used is controlled by an additional
34
+ `sampler` argument on `compile()`. You can recompile the model with
35
+ different `keras_hub.samplers` objects to control the generation. By
36
+ default, `"top_k"` sampling will be used.
37
+
38
+ Args:
39
+ backbone: A `keras_hub.models.Llama3Backbone` instance.
40
+ preprocessor: A `keras_hub.models.Llama3CausalLMPreprocessor` or `None`.
41
+ If `None`, this model will not apply preprocessing, and inputs
42
+ should be preprocessed before calling the model.
43
+ """
44
+
45
+ backbone_cls = Llama3Backbone
46
+ preprocessor_cls = Llama3CausalLMPreprocessor
@@ -0,0 +1,173 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from absl import logging
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.models.llama3.llama3_preprocessor import Llama3Preprocessor
20
+ from keras_hub.src.utils.keras_utils import (
21
+ convert_inputs_to_list_of_tensor_segments,
22
+ )
23
+ from keras_hub.src.utils.tensor_utils import strip_to_ragged
24
+
25
+
26
+ @keras_hub_export("keras_hub.models.Llama3CausalLMPreprocessor")
27
+ class Llama3CausalLMPreprocessor(Llama3Preprocessor):
28
+ """Llama 3 Causal LM preprocessor.
29
+
30
+ This preprocessing layer is meant for use with
31
+ `keras_hub.models.Llama3CausalLM`. By default, it will take in batches of
32
+ strings, and return outputs in a `(x, y, sample_weight)` format, where the
33
+ `y` label is the next token id in the `x` sequence.
34
+
35
+ For use with generation, the layer also exposes two methods
36
+ `generate_preprocess()` and `generate_postprocess()`. When this preprocessor
37
+ is attached to a `keras_hub.models.Llama3CausalLM` instance, these methods
38
+ will be called implicitly in `generate()`. They can also be called
39
+ standalone (e.g. to precompute preprocessing inputs for generation in a
40
+ separate process).
41
+
42
+ Args:
43
+ tokenizer: A `keras_hub.models.Llama3Tokenizer` instance.
44
+ sequence_length: The length of the packed inputs.
45
+ add_start_token: If `True`, the preprocessor will prepend the tokenizer
46
+ start token to each input sequence. Default is `False`.
47
+ add_end_token: If `True`, the preprocessor will append the tokenizer
48
+ end token to each input sequence. Default is `False`.
49
+
50
+ Call arguments:
51
+ x: A string, `tf.Tensor` or list of python strings.
52
+ y: Label data. Should always be `None` as the layer generates labels.
53
+ sample_weight: Label weights. Should always be `None` as the layer
54
+ generates label weights.
55
+ sequence_length: Pass to override the configured `sequence_length` of
56
+ the layer.
57
+
58
+ Examples:
59
+ ```python
60
+ # Load the preprocessor from a preset.
61
+ preprocessor = keras_hub.models.Llama3CausalLMPreprocessor.from_preset(
62
+ "llama_base_en"
63
+ )
64
+
65
+ # Tokenize and pack a single sentence.
66
+ sentence = tf.constant("League of legends")
67
+ preprocessor(sentence)
68
+ # Same output.
69
+ preprocessor("League of legends")
70
+
71
+ # Tokenize a batch of sentences.
72
+ sentences = tf.constant(["Taco tuesday", "Fish taco please!"])
73
+ preprocessor(sentences)
74
+ # Same output.
75
+ preprocessor(["Taco tuesday", "Fish taco please!"])
76
+
77
+ # Map a dataset to preprocess a single sentence.
78
+ features = tf.constant(
79
+ [
80
+ "Avatar 2 is amazing!",
81
+ "Well, I am not sure.",
82
+ ]
83
+ )
84
+ labels = tf.constant([1, 0])
85
+ ds = tf.data.Dataset.from_tensor_slices((features, labels))
86
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
87
+
88
+ # Map a dataset to preprocess unlabled sentences.
89
+ ds = tf.data.Dataset.from_tensor_slices(features)
90
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
91
+ ```
92
+ """
93
+
94
+ def call(
95
+ self,
96
+ x,
97
+ y=None,
98
+ sample_weight=None,
99
+ sequence_length=None,
100
+ ):
101
+ if y is not None or sample_weight is not None:
102
+ logging.warning(
103
+ "`Llama3CausalLMPreprocessor` generates `y` and "
104
+ "`sample_weight` based on your input data, but your data "
105
+ "already contains `y` or `sample_weight`. Your `y` and "
106
+ "`sample_weight` will be ignored."
107
+ )
108
+ sequence_length = sequence_length or self.sequence_length
109
+
110
+ x = convert_inputs_to_list_of_tensor_segments(x)[0]
111
+ x = self.tokenizer(x)
112
+ # Pad with one extra token to account for the truncation below.
113
+ token_ids, padding_mask = self.packer(
114
+ x,
115
+ sequence_length=sequence_length + 1,
116
+ add_start_value=self.add_start_token,
117
+ add_end_value=self.add_end_token,
118
+ )
119
+ # The last token does not have a next token, so we truncate it out.
120
+ x = {
121
+ "token_ids": token_ids[..., :-1],
122
+ "padding_mask": padding_mask[..., :-1],
123
+ }
124
+ # Target `y` will be the next token.
125
+ y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
126
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
127
+
128
+ def generate_preprocess(
129
+ self,
130
+ x,
131
+ sequence_length=None,
132
+ ):
133
+ """Convert strings to integer token input for generation.
134
+
135
+ Similar to calling the layer for training, this method takes in strings
136
+ or tensor strings, tokenizes and packs the input, and computes a padding
137
+ mask masking all inputs not filled in with a padded value.
138
+
139
+ Unlike calling the layer for training, this method does not compute
140
+ labels and will never append a `tokenizer.end_token_id` to the end of
141
+ the sequence (as generation is expected to continue at the end of the
142
+ inputted prompt).
143
+ """
144
+ if not self.built:
145
+ self.build(None)
146
+
147
+ x = convert_inputs_to_list_of_tensor_segments(x)[0]
148
+ x = self.tokenizer(x)
149
+ token_ids, padding_mask = self.packer(
150
+ x, sequence_length=sequence_length, add_end_value=False
151
+ )
152
+ return {
153
+ "token_ids": token_ids,
154
+ "padding_mask": padding_mask,
155
+ }
156
+
157
+ def generate_postprocess(
158
+ self,
159
+ x,
160
+ ):
161
+ """Convert integer token output to strings for generation.
162
+
163
+ This method reverses `generate_preprocess()`, by first removing all
164
+ padding and start/end tokens, and then converting the integer sequence
165
+ back to a string.
166
+ """
167
+ token_ids, padding_mask = x["token_ids"], x["padding_mask"]
168
+ ids_to_strip = (
169
+ self.tokenizer.end_token_id,
170
+ self.tokenizer.start_token_id,
171
+ )
172
+ token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip)
173
+ return self.tokenizer.detokenize(token_ids)
@@ -0,0 +1,21 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from keras_hub.src.api_export import keras_hub_export
15
+ from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
16
+ from keras_hub.src.models.llama.llama_preprocessor import LlamaPreprocessor
17
+
18
+
19
+ @keras_hub_export("keras_hub.models.Llama3Preprocessor")
20
+ class Llama3Preprocessor(LlamaPreprocessor):
21
+ tokenizer_cls = Llama3Tokenizer
@@ -0,0 +1,69 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Llama 3 model preset configurations."""
15
+
16
+ # Metadata for loading pretrained model weights.
17
+ backbone_presets = {
18
+ "llama3_8b_en": {
19
+ "metadata": {
20
+ "description": "8 billion parameter, 32-layer, base LLaMA 3 model.",
21
+ "params": 8030261248,
22
+ "official_name": "LLaMA 3",
23
+ "path": "llama3",
24
+ "model_card": "https://github.com/meta-llama/llama3",
25
+ },
26
+ "kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en/3",
27
+ },
28
+ "llama3_8b_en_int8": {
29
+ "metadata": {
30
+ "description": (
31
+ "8 billion parameter, 32-layer, base LLaMA 3 model with "
32
+ "activation and weights quantized to int8."
33
+ ),
34
+ "params": 8031894016,
35
+ "official_name": "LLaMA 3",
36
+ "path": "llama3",
37
+ "model_card": "https://github.com/meta-llama/llama3",
38
+ },
39
+ "kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en_int8/1",
40
+ },
41
+ "llama3_instruct_8b_en": {
42
+ "metadata": {
43
+ "description": (
44
+ "8 billion parameter, 32-layer, instruction tuned LLaMA 3 "
45
+ "model."
46
+ ),
47
+ "params": 8030261248,
48
+ "official_name": "LLaMA 3",
49
+ "path": "llama3",
50
+ "model_card": "https://github.com/meta-llama/llama3",
51
+ },
52
+ "kaggle_handle": "kaggle://keras/llama3/keras/llama3_instruct_8b_en/3",
53
+ },
54
+ "llama3_instruct_8b_en_int8": {
55
+ "metadata": {
56
+ "description": (
57
+ "8 billion parameter, 32-layer, instruction tuned LLaMA 3 "
58
+ "model with activation and weights quantized to int8."
59
+ ),
60
+ "params": 8031894016,
61
+ "official_name": "LLaMA 3",
62
+ "path": "llama3",
63
+ "model_card": "https://github.com/meta-llama/llama3",
64
+ },
65
+ "kaggle_handle": (
66
+ "kaggle://keras/llama3/keras/llama3_instruct_8b_en_int8/1"
67
+ ),
68
+ },
69
+ }
@@ -0,0 +1,63 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.api_export import keras_hub_export
16
+ from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
17
+
18
+
19
+ @keras_hub_export("keras_hub.models.Llama3Tokenizer")
20
+ class Llama3Tokenizer(BytePairTokenizer):
21
+ def __init__(
22
+ self,
23
+ vocabulary=None,
24
+ merges=None,
25
+ **kwargs,
26
+ ):
27
+ self.start_token = "<|begin_of_text|>"
28
+ self.end_token = "<|end_of_text|>"
29
+
30
+ super().__init__(
31
+ vocabulary=vocabulary,
32
+ merges=merges,
33
+ unsplittable_tokens=[self.start_token, self.end_token],
34
+ **kwargs,
35
+ )
36
+
37
+ def set_vocabulary_and_merges(self, vocabulary, merges):
38
+ super().set_vocabulary_and_merges(vocabulary, merges)
39
+
40
+ if vocabulary is not None:
41
+ # Check for necessary special tokens.
42
+ if self.end_token not in self.get_vocabulary():
43
+ raise ValueError(
44
+ f"Cannot find token `'{self.end_token}'` in the provided "
45
+ f"`vocabulary`. Please provide `'{self.end_token}'` in "
46
+ "your `vocabulary` or use a pretrained `vocabulary` name."
47
+ )
48
+
49
+ self.start_token_id = self.token_to_id(self.start_token)
50
+ self.end_token_id = self.token_to_id(self.end_token)
51
+ self.pad_token_id = 0
52
+ else:
53
+ self.end_token_id = None
54
+ self.start_token_id = None
55
+ self.pad_token_id = None
56
+
57
+ def get_config(self):
58
+ config = super().get_config()
59
+ # In the constructor, we pass the list of special tokens to the
60
+ # `unsplittable_tokens` arg of the superclass' constructor. Hence, we
61
+ # delete it from the config here.
62
+ del config["unsplittable_tokens"]
63
+ return config
@@ -0,0 +1,101 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+
16
+ from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.models.task import Task
18
+
19
+
20
+ @keras_hub_export("keras_hub.models.MaskedLM")
21
+ class MaskedLM(Task):
22
+ """Base class for masked language modeling tasks.
23
+
24
+ `MaskedLM` tasks wrap a `keras_hub.models.Backbone` and
25
+ a `keras_hub.models.Preprocessor` to create a model that can be used for
26
+ unsupervised fine-tuning with a masked language modeling loss.
27
+
28
+ When calling `fit()`, all input will be tokenized, and random tokens in
29
+ the input sequence will be masked. These positions of these masked tokens
30
+ will be fed as an additional model input, and the original value of the
31
+ tokens predicted by the model outputs.
32
+
33
+ All `MaskedLM` tasks include a `from_preset()` constructor which can be used
34
+ to load a pre-trained config and weights.
35
+
36
+ Example:
37
+ ```python
38
+ # Load a Bert MaskedLM with pre-trained weights.
39
+ masked_lm = keras_hub.models.MaskedLM.from_preset(
40
+ "bert_base_en",
41
+ )
42
+ masked_lm.fit(train_ds)
43
+ ```
44
+ """
45
+
46
+ def __init__(self, *args, **kwargs):
47
+ super().__init__(*args, **kwargs)
48
+ # Default compilation.
49
+ self.compile()
50
+
51
+ def compile(
52
+ self,
53
+ optimizer="auto",
54
+ loss="auto",
55
+ *,
56
+ weighted_metrics="auto",
57
+ **kwargs,
58
+ ):
59
+ """Configures the `MaskedLM` task for training.
60
+
61
+ The `MaskedLM` task extends the default compilation signature of
62
+ `keras.Model.compile` with defaults for `optimizer`, `loss`, and
63
+ `weighted_metrics`. To override these defaults, pass any value
64
+ to these arguments during compilation.
65
+
66
+ Note that because training inputs include padded tokens which are
67
+ excluded from the loss, it is almost always a good idea to compile with
68
+ `weighted_metrics` and not `metrics`.
69
+
70
+ Args:
71
+ optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
72
+ instance. Defaults to `"auto"`, which uses the default optimizer
73
+ for the given model and task. See `keras.Model.compile` and
74
+ `keras.optimizers` for more info on possible `optimizer` values.
75
+ loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
76
+ Defaults to `"auto"`, where a
77
+ `keras.losses.SparseCategoricalCrossentropy` loss will be
78
+ applied for the token classification `MaskedLM` task. See
79
+ `keras.Model.compile` and `keras.losses` for more info on
80
+ possible `loss` values.
81
+ weighted_metrics: `"auto"`, or a list of metrics to be evaluated by
82
+ the model during training and testing. Defaults to `"auto"`,
83
+ where a `keras.metrics.SparseCategoricalAccuracy` will be
84
+ applied to track the accuracy of the model at guessing masked
85
+ token values. See `keras.Model.compile` and `keras.metrics` for
86
+ more info on possible `weighted_metrics` values.
87
+ **kwargs: See `keras.Model.compile` for a full list of arguments
88
+ supported by the compile method.
89
+ """
90
+ if optimizer == "auto":
91
+ optimizer = keras.optimizers.Adam(5e-5)
92
+ if loss == "auto":
93
+ loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
94
+ if weighted_metrics == "auto":
95
+ weighted_metrics = [keras.metrics.SparseCategoricalAccuracy()]
96
+ super().compile(
97
+ optimizer=optimizer,
98
+ loss=loss,
99
+ weighted_metrics=weighted_metrics,
100
+ **kwargs,
101
+ )
@@ -0,0 +1,20 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.models.mistral.mistral_backbone import MistralBackbone
16
+ from keras_hub.src.models.mistral.mistral_presets import backbone_presets
17
+ from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer
18
+ from keras_hub.src.utils.preset_utils import register_presets
19
+
20
+ register_presets(backbone_presets, (MistralBackbone, MistralTokenizer))