keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,258 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
19
+ compute_causal_mask,
20
+ )
21
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
22
+ merge_padding_and_attention_mask,
23
+ )
24
+ from keras_hub.src.models.gpt_neo_x.gpt_neo_x_attention import GPTNeoXAttention
25
+ from keras_hub.src.utils.keras_utils import clone_initializer
26
+
27
+
28
+ class GPTNeoXDecoder(keras.layers.Layer):
29
+ """GPTNeoX decoder.
30
+
31
+ This class follows the architecture of the GPT-NeoX decoder layer in the
32
+ paper [GPT-NeoX-20B: An Open-Source Autoregressive Language Model](https://arxiv.org/abs/2204.06745).
33
+ Users can instantiate multiple instances of this class to stack up a decoder.
34
+
35
+ This layer will always apply a causal mask to the decoder attention layer.
36
+
37
+ Args:
38
+ intermediate_dim: int, the hidden size of feedforward network.
39
+ num_heads: int, the number of heads for multi-head attention.
40
+ dropout: float. the dropout value, shared by
41
+ the multi-head attention and feedforward layers.
42
+ activation: string or `keras.activations`. the activation function of
43
+ feedforward network.
44
+ layer_norm_epsilon: float. The epsilon value in layer
45
+ normalization components.
46
+ kernel_initializer: string or `keras.initializers` initializer. The
47
+ kernel initializer for the dense and multi-head attention layers.
48
+ bias_initializer: string or `keras.initializers` initializer. The bias
49
+ initializer for the dense and multi-head attention layers.
50
+ rotary_max_wavelength: int. The maximum angular wavelength of the
51
+ sine/cosine curves, for rotary embeddings.
52
+ rotary_percentage: float. The percentage by which query, key, value
53
+ matrices are to be rotated.
54
+ max_sequence_length: int. The maximum sequence length that this encoder
55
+ can consume. If `None`, `max_sequence_length` uses the value from
56
+ sequence length. This determines the variable shape for positional
57
+ embeddings.
58
+ name: string. The name of the layer.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ intermediate_dim,
64
+ num_heads,
65
+ dropout=0.0,
66
+ activation="relu",
67
+ layer_norm_epsilon=1e-5,
68
+ kernel_initializer="glorot_uniform",
69
+ bias_initializer="zeros",
70
+ rotary_percentage=0.25,
71
+ rotary_max_wavelength=10000,
72
+ max_sequence_length=512,
73
+ **kwargs,
74
+ ):
75
+ super().__init__(**kwargs)
76
+ self.intermediate_dim = intermediate_dim
77
+ self.num_heads = num_heads
78
+ self.dropout = dropout
79
+ self.rotary_percentage = rotary_percentage
80
+ self.rotary_max_wavelength = rotary_max_wavelength
81
+ self.max_sequence_length = max_sequence_length
82
+ self.activation = keras.activations.get(activation)
83
+ self.layer_norm_epsilon = layer_norm_epsilon
84
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
85
+ self.bias_initializer = keras.initializers.get(bias_initializer)
86
+ self.supports_masking = True
87
+ self.rotary_percentage = rotary_percentage
88
+ self._decoder_sequence_shape = None
89
+
90
+ def build(self, decoder_sequence_shape):
91
+ self._decoder_sequence_shape = decoder_sequence_shape
92
+ hidden_dim = decoder_sequence_shape[-1]
93
+ # Self attention layers.
94
+ self._self_attention_layer = GPTNeoXAttention(
95
+ num_heads=self.num_heads,
96
+ hidden_dim=hidden_dim,
97
+ dropout=self.dropout,
98
+ rotary_percentage=self.rotary_percentage,
99
+ rotary_max_wavelength=self.rotary_max_wavelength,
100
+ max_sequence_length=self.max_sequence_length,
101
+ kernel_initializer=clone_initializer(self.kernel_initializer),
102
+ bias_initializer=clone_initializer(self.bias_initializer),
103
+ dtype=self.dtype_policy,
104
+ name="self_attention",
105
+ )
106
+ self._self_attention_layer.build(decoder_sequence_shape)
107
+
108
+ self._self_attention_layer_norm = keras.layers.LayerNormalization(
109
+ epsilon=self.layer_norm_epsilon,
110
+ dtype=self.dtype_policy,
111
+ name="self_attention_layer_norm",
112
+ )
113
+ self._self_attention_layer_norm.build(decoder_sequence_shape)
114
+
115
+ self._self_attention_dropout = keras.layers.Dropout(
116
+ rate=self.dropout,
117
+ dtype=self.dtype_policy,
118
+ name="self_attention_dropout",
119
+ )
120
+
121
+ # Feedforward layers.
122
+ self._feedforward_intermediate_dense = keras.layers.Dense(
123
+ self.intermediate_dim,
124
+ activation=self.activation,
125
+ kernel_initializer=clone_initializer(self.kernel_initializer),
126
+ bias_initializer=clone_initializer(self.bias_initializer),
127
+ dtype=self.dtype_policy,
128
+ name="feedforward_intermediate_dense",
129
+ )
130
+ self._feedforward_intermediate_dense.build(decoder_sequence_shape)
131
+
132
+ self._feedforward_output_dense = keras.layers.Dense(
133
+ hidden_dim,
134
+ kernel_initializer=clone_initializer(self.kernel_initializer),
135
+ bias_initializer=clone_initializer(self.bias_initializer),
136
+ dtype=self.dtype_policy,
137
+ name="feedforward_output_dense",
138
+ )
139
+
140
+ intermediate_shape = list(decoder_sequence_shape)
141
+ intermediate_shape[-1] = self.intermediate_dim
142
+ self._feedforward_output_dense.build(tuple(intermediate_shape))
143
+
144
+ self._feedforward_layer_norm = keras.layers.LayerNormalization(
145
+ epsilon=self.layer_norm_epsilon,
146
+ dtype=self.dtype_policy,
147
+ name="feedforward_layer_norm",
148
+ )
149
+ self._feedforward_layer_norm.build(decoder_sequence_shape)
150
+
151
+ self._feedforward_dropout = keras.layers.Dropout(
152
+ rate=self.dropout,
153
+ dtype=self.dtype_policy,
154
+ name="feedforward_dropout",
155
+ )
156
+ self.built = True
157
+
158
+ def call(
159
+ self,
160
+ decoder_sequence,
161
+ decoder_padding_mask=None,
162
+ decoder_attention_mask=None,
163
+ self_attention_cache=None,
164
+ self_attention_cache_update_index=None,
165
+ ):
166
+ self_attention_mask = self._compute_self_attention_mask(
167
+ decoder_sequence=decoder_sequence,
168
+ decoder_padding_mask=decoder_padding_mask,
169
+ decoder_attention_mask=decoder_attention_mask,
170
+ self_attention_cache=self_attention_cache,
171
+ self_attention_cache_update_index=self_attention_cache_update_index,
172
+ )
173
+
174
+ residual = decoder_sequence
175
+
176
+ x = self._self_attention_layer_norm(decoder_sequence)
177
+
178
+ # Self attention block.
179
+ x, self_attention_cache = self._self_attention_layer(
180
+ hidden_states=x,
181
+ attention_mask=self_attention_mask,
182
+ cache=self_attention_cache,
183
+ cache_update_index=self_attention_cache_update_index,
184
+ )
185
+ x = self._self_attention_dropout(x)
186
+ attention_output = x
187
+
188
+ x = self._feedforward_layer_norm(decoder_sequence)
189
+ x = self._feedforward_intermediate_dense(x)
190
+ x = self._feedforward_output_dense(x)
191
+ feedforward_output = x
192
+ x = feedforward_output + attention_output + residual
193
+
194
+ if self_attention_cache is not None:
195
+ return (x, self_attention_cache)
196
+ else:
197
+ return x
198
+
199
+ def _compute_self_attention_mask(
200
+ self,
201
+ decoder_sequence,
202
+ decoder_padding_mask,
203
+ decoder_attention_mask,
204
+ self_attention_cache=None,
205
+ self_attention_cache_update_index=None,
206
+ ):
207
+ decoder_mask = merge_padding_and_attention_mask(
208
+ decoder_sequence, decoder_padding_mask, decoder_attention_mask
209
+ )
210
+ batch_size = ops.shape(decoder_sequence)[0]
211
+ input_length = output_length = ops.shape(decoder_sequence)[1]
212
+ # We need to handle a rectangular causal mask when doing cached
213
+ # decoding. For generative inference, `decoder_sequence` will
214
+ # generally be length 1, and `cache` will be the full generation length.
215
+ if self_attention_cache is not None:
216
+ input_length = ops.shape(self_attention_cache)[2]
217
+
218
+ causal_mask = compute_causal_mask(
219
+ batch_size,
220
+ input_length,
221
+ output_length,
222
+ (
223
+ 0
224
+ if self_attention_cache_update_index is None
225
+ else self_attention_cache_update_index
226
+ ),
227
+ )
228
+ return (
229
+ ops.minimum(decoder_mask, causal_mask)
230
+ if decoder_mask is not None
231
+ else causal_mask
232
+ )
233
+
234
+ def get_config(self):
235
+ config = super().get_config()
236
+ config.update(
237
+ {
238
+ "intermediate_dim": self.intermediate_dim,
239
+ "num_heads": self.num_heads,
240
+ "dropout": self.dropout,
241
+ "rotary_percentage": self.rotary_percentage,
242
+ "rotary_max_wavelength": self.rotary_max_wavelength,
243
+ "max_sequence_length": self.max_sequence_length,
244
+ "activation": keras.activations.serialize(self.activation),
245
+ "layer_norm_epsilon": self.layer_norm_epsilon,
246
+ "kernel_initializer": keras.initializers.serialize(
247
+ self.kernel_initializer
248
+ ),
249
+ "bias_initializer": keras.initializers.serialize(
250
+ self.bias_initializer
251
+ ),
252
+ "decoder_sequence_shape": self._decoder_sequence_shape,
253
+ }
254
+ )
255
+ return config
256
+
257
+ def compute_output_shape(self, decoder_sequence_shape):
258
+ return decoder_sequence_shape
@@ -0,0 +1,145 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+
17
+ from keras_hub.src.api_export import keras_hub_export
18
+ from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
19
+ from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
20
+ from keras_hub.src.models.preprocessor import Preprocessor
21
+ from keras_hub.src.utils.keras_utils import (
22
+ convert_inputs_to_list_of_tensor_segments,
23
+ )
24
+
25
+
26
+ @keras_hub_export("keras_hub.models.GPTNeoXPreprocessor")
27
+ class GPTNeoXPreprocessor(Preprocessor):
28
+ """GPTNeoX preprocessing layer which tokenizes and packs inputs.
29
+
30
+ This preprocessing layer will do 2 things:
31
+
32
+ - Tokenize the inputs using the `tokenizer`.
33
+ - Construct a dictionary with keys `"token_ids"`, `"padding_mask"`, that can
34
+ be passed directly to a `keras_hub.models.GPTNeoXBackbone`.
35
+
36
+ This layer can be used directly with `tf.data.Dataset.map` to preprocess
37
+ string data in the `(x, y, sample_weight)` format used by
38
+ `keras.Model.fit`.
39
+
40
+ The call method of this layer accepts three arguments, `x`, `y`, and
41
+ `sample_weight`. `x` can be a python string or tensor representing a single
42
+ segment, a list of python strings representing a batch of single segments,
43
+ or a list of tensors representing multiple segments to be packed together.
44
+ `y` and `sample_weight` are both optional, can have any format, and will be
45
+ passed through unaltered.
46
+
47
+ `GPTNeoXPreprocessor` forces the input to have only one segment, as GPTNeoX is
48
+ mainly used for generation tasks. For tasks having multi-segment inputs
49
+ like "glue/mnli", please use a model designed for classification purposes
50
+ such as BERT or RoBERTa.
51
+
52
+ Args:
53
+ tokenizer: A `keras_hub.models.GPTNeoXTokenizer` instance.
54
+ sequence_length: The length of the packed inputs.
55
+ add_start_token: If `True`, the preprocessor will prepend the tokenizer
56
+ start token to each input sequence.
57
+ add_end_token: If `True`, the preprocessor will append the tokenizer
58
+ end token to each input sequence.
59
+
60
+ Call arguments:
61
+ x: A string, `tf.Tensor` or list of python strings.
62
+ y: Any label data. Will be passed through unaltered.
63
+ sample_weight: Any label weight data. Will be passed through unaltered.
64
+ sequence_length: Pass to override the configured `sequence_length` of
65
+ the layer.
66
+ """
67
+
68
+ tokenizer_cls = GPTNeoXTokenizer
69
+
70
+ def __init__(
71
+ self,
72
+ tokenizer,
73
+ sequence_length=1024,
74
+ add_start_token=True,
75
+ add_end_token=True,
76
+ **kwargs,
77
+ ):
78
+ super().__init__(**kwargs)
79
+ self.tokenizer = tokenizer
80
+ self.packer = None
81
+ self.sequence_length = sequence_length
82
+ self.add_start_token = add_start_token
83
+ self.add_end_token = add_end_token
84
+
85
+ def build(self, input_shape):
86
+ # Defer packer creation to `build()` so that we can be sure tokenizer
87
+ # assets have loaded when restoring a saved model.
88
+ self.packer = StartEndPacker(
89
+ start_value=self.tokenizer.start_token_id,
90
+ end_value=self.tokenizer.end_token_id,
91
+ pad_value=self.tokenizer.pad_token_id,
92
+ sequence_length=self.sequence_length,
93
+ return_padding_mask=True,
94
+ )
95
+ self.built = True
96
+
97
+ def call(
98
+ self,
99
+ x,
100
+ y=None,
101
+ sample_weight=None,
102
+ sequence_length=None,
103
+ ):
104
+ x = convert_inputs_to_list_of_tensor_segments(x)
105
+ if len(x) != 1:
106
+ raise ValueError(
107
+ "GPTNeoX requires each input feature to contain only "
108
+ f"one segment, but received {len(x)}. If you are using GPTNeoX "
109
+ "for a multi-segment classification task, please refer to "
110
+ "classification models like BERT or RoBERTa."
111
+ )
112
+ sequence_length = sequence_length or self.sequence_length
113
+ token_ids, padding_mask = self.packer(
114
+ self.tokenizer(x[0]),
115
+ sequence_length=sequence_length,
116
+ add_start_value=self.add_start_token,
117
+ add_end_value=self.add_end_token,
118
+ )
119
+ x = {
120
+ "token_ids": token_ids,
121
+ "padding_mask": padding_mask,
122
+ }
123
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
124
+
125
+ def get_config(self):
126
+ config = super().get_config()
127
+ config.update(
128
+ {
129
+ "sequence_length": self.sequence_length,
130
+ "add_start_token": self.add_start_token,
131
+ "add_end_token": self.add_end_token,
132
+ }
133
+ )
134
+ return config
135
+
136
+ @property
137
+ def sequence_length(self):
138
+ """The padded length of model input sequences."""
139
+ return self._sequence_length
140
+
141
+ @sequence_length.setter
142
+ def sequence_length(self, value):
143
+ self._sequence_length = value
144
+ if self.packer is not None:
145
+ self.packer.sequence_length = value
@@ -0,0 +1,88 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.api_export import keras_hub_export
16
+ from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
17
+
18
+
19
+ @keras_hub_export("keras_hub.models.GPTNeoXTokenizer")
20
+ class GPTNeoXTokenizer(BytePairTokenizer):
21
+ """A GPTNeoX tokenizer using Byte-Pair Encoding subword segmentation.
22
+
23
+ This tokenizer class will tokenize raw strings into integer sequences and
24
+ is based on `keras_hub.tokenizers.BytePairTokenizer`. Unlike the
25
+ underlying tokenizer, it will check for all special tokens needed by GPTNeoX
26
+ models and provides a `from_preset()` method to automatically download
27
+ a matching vocabulary for a GPTNeoX preset.
28
+
29
+ This tokenizer does not provide truncation or padding of inputs.
30
+
31
+ If input is a batch of strings (rank > 0), the layer will output a
32
+ `tf.RaggedTensor` where the last dimension of the output is ragged.
33
+
34
+ If input is a scalar string (rank == 0), the layer will output a dense
35
+ `tf.Tensor` with static shape `[None]`.
36
+
37
+ Args:
38
+ vocabulary: string or dict, maps token to integer ids. If it is a
39
+ string, it should be the file path to a json file.
40
+ merges: string or list, contains the merge rule. If it is a string,
41
+ it should be the file path to merge rules. The merge rule file
42
+ should have one merge rule per line. Every merge rule contains
43
+ merge entities separated by a space.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ vocabulary=None,
49
+ merges=None,
50
+ **kwargs,
51
+ ):
52
+ # GPTNeoX uses the same start as end token, i.e., "<|endoftext|>".
53
+ self.end_token = self.start_token = "<|endoftext|>"
54
+
55
+ super().__init__(
56
+ vocabulary=vocabulary,
57
+ merges=merges,
58
+ unsplittable_tokens=[self.end_token],
59
+ **kwargs,
60
+ )
61
+
62
+ def set_vocabulary_and_merges(self, vocabulary, merges):
63
+ super().set_vocabulary_and_merges(vocabulary, merges)
64
+
65
+ if vocabulary is not None:
66
+ # Check for necessary special tokens.
67
+ if self.end_token not in self.get_vocabulary():
68
+ raise ValueError(
69
+ f"Cannot find token `'{self.end_token}'` in the provided "
70
+ f"`vocabulary`. Please provide `'{self.end_token}'` in "
71
+ "your `vocabulary` or use a pretrained `vocabulary` name."
72
+ )
73
+
74
+ self.end_token_id = self.token_to_id(self.end_token)
75
+ self.start_token_id = self.end_token_id
76
+ self.pad_token_id = 0
77
+ else:
78
+ self.end_token_id = None
79
+ self.start_token_id = None
80
+ self.pad_token_id = None
81
+
82
+ def get_config(self):
83
+ config = super().get_config()
84
+ # In the constructor, we pass the list of special tokens to the
85
+ # `unsplittable_tokens` arg of the superclass' constructor. Hence, we
86
+ # delete it from the config here.
87
+ del config["unsplittable_tokens"]
88
+ return config
@@ -0,0 +1,90 @@
1
+ # Copyright 2023 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+
16
+ from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.models.task import Task
18
+
19
+
20
+ @keras_hub_export("keras_hub.models.ImageClassifier")
21
+ class ImageClassifier(Task):
22
+ """Base class for all image classification tasks.
23
+
24
+ `ImageClassifier` tasks wrap a `keras_hub.models.Backbone` and
25
+ a `keras_hub.models.Preprocessor` to create a model that can be used for
26
+ image classification. `ImageClassifier` tasks take an additional
27
+ `num_classes` argument, controlling the number of predicted output classes.
28
+
29
+ To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
30
+ labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
31
+
32
+ All `ImageClassifier` tasks include a `from_preset()` constructor which can be
33
+ used to load a pre-trained config and weights.
34
+ """
35
+
36
+ def __init__(self, *args, **kwargs):
37
+ super().__init__(*args, **kwargs)
38
+ # Default compilation.
39
+ self.compile()
40
+
41
+ def compile(
42
+ self,
43
+ optimizer="auto",
44
+ loss="auto",
45
+ *,
46
+ metrics="auto",
47
+ **kwargs,
48
+ ):
49
+ """Configures the `ImageClassifier` task for training.
50
+
51
+ The `ImageClassifier` task extends the default compilation signature of
52
+ `keras.Model.compile` with defaults for `optimizer`, `loss`, and
53
+ `metrics`. To override these defaults, pass any value
54
+ to these arguments during compilation.
55
+
56
+ Args:
57
+ optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
58
+ instance. Defaults to `"auto"`, which uses the default optimizer
59
+ for the given model and task. See `keras.Model.compile` and
60
+ `keras.optimizers` for more info on possible `optimizer` values.
61
+ loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
62
+ Defaults to `"auto"`, where a
63
+ `keras.losses.SparseCategoricalCrossentropy` loss will be
64
+ applied for the classification task. See
65
+ `keras.Model.compile` and `keras.losses` for more info on
66
+ possible `loss` values.
67
+ metrics: `"auto"`, or a list of metrics to be evaluated by
68
+ the model during training and testing. Defaults to `"auto"`,
69
+ where a `keras.metrics.SparseCategoricalAccuracy` will be
70
+ applied to track the accuracy of the model during training.
71
+ See `keras.Model.compile` and `keras.metrics` for
72
+ more info on possible `metrics` values.
73
+ **kwargs: See `keras.Model.compile` for a full list of arguments
74
+ supported by the compile method.
75
+ """
76
+ if optimizer == "auto":
77
+ optimizer = keras.optimizers.Adam(5e-5)
78
+ if loss == "auto":
79
+ activation = getattr(self, "activation", None)
80
+ activation = keras.activations.get(activation)
81
+ from_logits = activation != keras.activations.softmax
82
+ loss = keras.losses.SparseCategoricalCrossentropy(from_logits)
83
+ if metrics == "auto":
84
+ metrics = [keras.metrics.SparseCategoricalAccuracy()]
85
+ super().compile(
86
+ optimizer=optimizer,
87
+ loss=loss,
88
+ metrics=metrics,
89
+ **kwargs,
90
+ )
@@ -0,0 +1,20 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.models.llama.llama_backbone import LlamaBackbone
16
+ from keras_hub.src.models.llama.llama_presets import backbone_presets
17
+ from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer
18
+ from keras_hub.src.utils.preset_utils import register_presets
19
+
20
+ register_presets(backbone_presets, (LlamaBackbone, LlamaTokenizer))