keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,79 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from keras_hub.src.api_export import keras_hub_export
15
+ from keras_hub.src.models.gemma.gemma_preprocessor import GemmaTokenizer
16
+
17
+
18
+ @keras_hub_export("keras_hub.models.PaliGemmaTokenizer")
19
+ class PaliGemmaTokenizer(GemmaTokenizer):
20
+ """PaliGemma tokenizer layer based on SentencePiece.
21
+
22
+ This tokenizer class will tokenize raw strings into integer sequences and
23
+ is based on `keras_hub.tokenizers.SentencePieceTokenizer`. Unlike the
24
+ underlying tokenizer, it will check for all special tokens needed by
25
+ PaliGemma models and provides a `from_preset()` method to automatically
26
+ download a matching vocabulary for a PaliGemma preset.
27
+
28
+ If input is a batch of strings (rank > 0), the layer will output a
29
+ `tf.RaggedTensor` where the last dimension of the output is ragged.
30
+
31
+ If input is a scalar string (rank == 0), the layer will output a dense
32
+ `tf.Tensor` with static shape `[None]`.
33
+
34
+ Args:
35
+ proto: Either a `string` path to a SentencePiece proto file, or a
36
+ `bytes` object with a serialized SentencePiece proto. See the
37
+ [SentencePiece repository](https://github.com/google/sentencepiece)
38
+ for more details on the format.
39
+
40
+ Examples:
41
+
42
+ ```python
43
+ # Unbatched input.
44
+ tokenizer = keras_hub.models.PaliGemmaTokenizer.from_preset(
45
+ "pali_gemma_3b_224"
46
+ )
47
+ tokenizer("The quick brown fox jumped.")
48
+
49
+ # Batched input.
50
+ tokenizer(["The quick brown fox jumped.", "The fox slept."])
51
+
52
+ # Detokenization.
53
+ tokenizer.detokenize(tokenizer("The quick brown fox jumped."))
54
+
55
+ # Custom vocabulary.
56
+ bytes_io = io.BytesIO()
57
+ ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."])
58
+ sentencepiece.SentencePieceTrainer.train(
59
+ sentence_iterator=ds.as_numpy_iterator(),
60
+ model_writer=bytes_io,
61
+ vocab_size=8,
62
+ model_type="WORD",
63
+ pad_id=0,
64
+ bos_id=1,
65
+ eos_id=2,
66
+ unk_id=3,
67
+ pad_piece="<pad>",
68
+ bos_piece="<bos>",
69
+ eos_piece="<eos>",
70
+ unk_piece="<unk>",
71
+ )
72
+ tokenizer = keras_hub.models.PaliGemmaTokenizer(
73
+ proto=bytes_io.getvalue(),
74
+ )
75
+ tokenizer("The quick brown fox jumped.")
76
+ ```
77
+ """
78
+
79
+ pass
@@ -0,0 +1,566 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writingf, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+ from keras import ops
16
+
17
+
18
+ class PaliGemmaVitEmbeddings(keras.layers.Layer):
19
+ def __init__(
20
+ self,
21
+ image_size,
22
+ patch_size,
23
+ hidden_dim,
24
+ num_channels=3,
25
+ dtype=None,
26
+ **kwargs,
27
+ ):
28
+ super().__init__(**kwargs)
29
+ self.hidden_dim = hidden_dim
30
+ self.image_size = image_size
31
+ self.patch_size = patch_size
32
+ self.num_channels = num_channels
33
+ self.patch_embedding = keras.layers.Conv2D(
34
+ filters=self.hidden_dim,
35
+ kernel_size=self.patch_size,
36
+ strides=self.patch_size,
37
+ padding="valid",
38
+ activation=None,
39
+ dtype=dtype,
40
+ name="embedding_conv",
41
+ )
42
+ self.num_patches = (self.image_size // self.patch_size) ** 2
43
+ self.num_positions = self.num_patches
44
+ self.position_embedding = keras.layers.Embedding(
45
+ self.num_positions,
46
+ self.hidden_dim,
47
+ dtype=dtype,
48
+ name="position_embedding",
49
+ )
50
+
51
+ self.position_ids = ops.expand_dims(
52
+ ops.arange(self.num_positions), axis=0
53
+ )
54
+
55
+ def build(self, input_shape):
56
+ self.patch_embedding.build(input_shape)
57
+ self.position_embedding.build([1, self.num_positions])
58
+ self.built = True
59
+
60
+ def call(self, input_tokens):
61
+ x = self.patch_embedding(input_tokens)
62
+ input_shape = ops.shape(x)
63
+ x = ops.reshape(x, [input_shape[0], -1, input_shape[-1]])
64
+ x = x + self.position_embedding(self.position_ids)
65
+ return x
66
+
67
+ def compute_output_shape(self, input_shape):
68
+ return (
69
+ input_shape[0],
70
+ self.num_patches,
71
+ self.hidden_dim,
72
+ )
73
+
74
+
75
+ class PaliGemmaVitAttention(keras.layers.Layer):
76
+ """
77
+ Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py # noqa: E501
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ hidden_dim,
83
+ num_heads,
84
+ dropout=0.0,
85
+ dtype=None,
86
+ **kwargs,
87
+ ):
88
+ super().__init__(**kwargs)
89
+
90
+ self.hidden_dim = hidden_dim
91
+ self.num_heads = num_heads
92
+ self.dropout = dropout
93
+ self.head_dim = self.hidden_dim // self.num_heads
94
+ if self.head_dim * self.num_heads != self.hidden_dim:
95
+ raise ValueError(
96
+ f"hidden_dim must be divisible by num_heads (got `hidden_dim`"
97
+ f": {self.hidden_dim} and `num_heads`:"
98
+ f" {self.num_heads})."
99
+ )
100
+ self.dropout_layer = keras.layers.Dropout(
101
+ self.dropout,
102
+ dtype=dtype,
103
+ name="dropout",
104
+ )
105
+ self.scale = self.head_dim**-0.5
106
+ self.query_proj = keras.layers.Dense(
107
+ units=self.hidden_dim,
108
+ dtype=dtype,
109
+ name="query_proj",
110
+ )
111
+ self.key_proj = keras.layers.Dense(
112
+ units=self.hidden_dim,
113
+ dtype=dtype,
114
+ name="key_proj",
115
+ )
116
+ self.value_proj = keras.layers.Dense(
117
+ units=self.hidden_dim,
118
+ dtype=dtype,
119
+ name="value_proj",
120
+ )
121
+ self.out_proj = keras.layers.Dense(
122
+ units=self.hidden_dim,
123
+ dtype=dtype,
124
+ name="out_proj",
125
+ )
126
+
127
+ def build(self, input_shape):
128
+ self.query_proj.build([None, None, self.hidden_dim])
129
+ self.key_proj.build([None, None, self.hidden_dim])
130
+ self.value_proj.build([None, None, self.hidden_dim])
131
+ self.out_proj.build([None, None, self.hidden_dim])
132
+ self.built = True
133
+
134
+ def _transpose_for_scores(self, tensor, batch_size):
135
+ """
136
+ Adapted from https://github.com/huggingface/transformers/blob/8e164c5400b7b413c7b8fb32e35132001effc970/src/transformers/models/bert/modeling_tf_bert.py#L252 # noqa: E501
137
+ """
138
+ # [batch_size, seq_len, all_head_dim] ->
139
+ # [batch_size, seq_len, num_heads, head_dim]
140
+ tensor = ops.reshape(
141
+ tensor, (batch_size, -1, self.num_heads, self.head_dim)
142
+ )
143
+ # [batch_size, seq_len, num_heads, head_dim] ->
144
+ # [batch_size, num_heads, seq_len, head_dim]
145
+ return ops.transpose(tensor, axes=[0, 2, 1, 3])
146
+
147
+ def call(
148
+ self,
149
+ x,
150
+ attention_mask=None,
151
+ return_attention_scores=None,
152
+ training=False,
153
+ ):
154
+ batch_size = ops.shape(x)[0]
155
+ mixed_query_layer = self.query_proj(inputs=x)
156
+ mixed_key_layer = self.key_proj(inputs=x)
157
+ mixed_value_layer = self.value_proj(inputs=x)
158
+ query_layer = self._transpose_for_scores(mixed_query_layer, batch_size)
159
+ key_layer = self._transpose_for_scores(mixed_key_layer, batch_size)
160
+ value_layer = self._transpose_for_scores(mixed_value_layer, batch_size)
161
+
162
+ # Scaled dot product between key and query = raw attention scores.
163
+ attention_scores = ops.matmul(
164
+ query_layer, ops.transpose(key_layer, axes=[0, 1, 3, 2])
165
+ )
166
+ dk = ops.cast(ops.sqrt(self.head_dim), dtype=attention_scores.dtype)
167
+ attention_scores = ops.divide(
168
+ attention_scores, dk
169
+ ) # (batch_size, num_heads, seq_len_q, seq_len_k)
170
+
171
+ if attention_mask is not None:
172
+ # Apply the attention mask (precomputed for all layers in the
173
+ # call() function)
174
+ attention_scores = ops.add(attention_scores, attention_mask)
175
+
176
+ # Normalize the attention scores to probabilities.
177
+ attention_probs = ops.softmax(attention_scores, axis=-1)
178
+
179
+ # This is actually dropping out entire tokens to attend to, which might
180
+ # seem a bit unusual, but is taken from the original Transformer paper.
181
+ dropout_attention_probs = self.dropout_layer(
182
+ inputs=attention_probs, training=training
183
+ )
184
+
185
+ attn_output = ops.matmul(dropout_attention_probs, value_layer)
186
+ attn_output = ops.transpose(attn_output, axes=[0, 2, 1, 3])
187
+
188
+ # (batch_size, seq_len_q, hidden_dim)
189
+ attn_output = ops.reshape(
190
+ attn_output, (batch_size, -1, self.hidden_dim)
191
+ )
192
+
193
+ attn_output = self.out_proj(attn_output, training=training)
194
+ return (attn_output, attention_probs)
195
+
196
+ def get_config(self):
197
+ config = super().get_config()
198
+ config.update(
199
+ {
200
+ "hidden_dim": self.hidden_dim,
201
+ "num_heads": self.num_heads,
202
+ "dropout": self.dropout,
203
+ }
204
+ )
205
+ return config
206
+
207
+
208
+ class PaliGemmaVitEncoderBlock(keras.layers.Layer):
209
+ def __init__(
210
+ self,
211
+ num_heads,
212
+ intermediate_dim,
213
+ **kwargs,
214
+ ):
215
+ super().__init__(**kwargs)
216
+ self.num_heads = num_heads
217
+ self.intermediate_dim = intermediate_dim
218
+
219
+ def compute_attention(self, x, mask=None):
220
+ mask = None
221
+ if mask is not None:
222
+ mask = ops.cast(mask, dtype=x.dtype) if mask is not None else None
223
+ return self.attn(x, attention_mask=mask)[0]
224
+
225
+ def build(self, input_shape):
226
+ hidden_dim = input_shape[-1]
227
+ self.attn = PaliGemmaVitAttention(
228
+ hidden_dim,
229
+ self.num_heads,
230
+ dtype=self.dtype_policy,
231
+ name="multi_head_attention",
232
+ )
233
+ self.layer_norm_1 = keras.layers.LayerNormalization(
234
+ epsilon=1e-6,
235
+ dtype=self.dtype_policy,
236
+ name="layer_norm_1",
237
+ )
238
+ self.mlp_dense_1 = keras.layers.Dense(
239
+ self.intermediate_dim,
240
+ dtype=self.dtype_policy,
241
+ name="mlp_dense_1",
242
+ )
243
+ self.mlp_dense_2 = keras.layers.Dense(
244
+ hidden_dim,
245
+ dtype=self.dtype_policy,
246
+ name="mlp_dense_2",
247
+ )
248
+ self.layer_norm_2 = keras.layers.LayerNormalization(
249
+ epsilon=1e-6,
250
+ dtype=self.dtype_policy,
251
+ name="layer_norm_2",
252
+ )
253
+ self.attn.build(None)
254
+ self.layer_norm_1.build([None, None, hidden_dim])
255
+ self.mlp_dense_1.build([None, None, hidden_dim])
256
+ self.mlp_dense_2.build([None, None, self.intermediate_dim])
257
+ self.layer_norm_2.build([None, None, hidden_dim])
258
+ self.built = True
259
+
260
+ def call(self, x, mask=None):
261
+ residual = x
262
+ x = self.layer_norm_1(x)
263
+ # mask = ops.ones_like(x) if mask is None else mask
264
+ x = self.compute_attention(x, mask)
265
+ x = x + residual
266
+ residual = x
267
+ x = self.mlp_dense_1(self.layer_norm_2(residual))
268
+ x = keras.activations.gelu(x, approximate=True)
269
+ x = self.mlp_dense_2(x)
270
+ return residual + x
271
+
272
+ def compute_output_shape(self, inputs_shape):
273
+ return inputs_shape
274
+
275
+ def get_config(self):
276
+ config = super().get_config()
277
+ config.update(
278
+ {
279
+ "num_heads": self.num_heads,
280
+ "intermediate_dim": self.intermediate_dim,
281
+ }
282
+ )
283
+ return config
284
+
285
+
286
+ class PaliGemmaVitEncoder(keras.layers.Layer):
287
+ def __init__(
288
+ self,
289
+ patch_size,
290
+ image_size,
291
+ hidden_dim,
292
+ num_layers,
293
+ num_heads,
294
+ intermediate_dim,
295
+ dtype=None,
296
+ **kwargs,
297
+ ):
298
+ super().__init__(**kwargs)
299
+ self.hidden_dim = hidden_dim
300
+ self.num_layers = num_layers
301
+ self.num_heads = num_heads
302
+ self.intermediate_dim = intermediate_dim
303
+ self.patch_size = patch_size
304
+ self.image_size = image_size
305
+ self.encoder_layer_norm = keras.layers.LayerNormalization(
306
+ epsilon=1e-6,
307
+ dtype=dtype,
308
+ name="encoder_layer_norm",
309
+ )
310
+ self.vision_embeddings = PaliGemmaVitEmbeddings(
311
+ hidden_dim=hidden_dim,
312
+ patch_size=patch_size,
313
+ image_size=image_size,
314
+ dtype=dtype,
315
+ name="encoder_embeddings",
316
+ )
317
+ self.resblocks = [
318
+ PaliGemmaVitEncoderBlock(
319
+ self.num_heads,
320
+ self.intermediate_dim,
321
+ dtype=dtype,
322
+ name=f"encoder_block_{i}",
323
+ )
324
+ for i in range(self.num_layers)
325
+ ]
326
+
327
+ def build(self, input_shape):
328
+ self.vision_embeddings.build(input_shape)
329
+ for block in self.resblocks:
330
+ block.build([None, None, self.hidden_dim])
331
+ self.encoder_layer_norm.build([None, None, self.hidden_dim])
332
+ self.built = True
333
+
334
+ def call(
335
+ self,
336
+ x,
337
+ mask=None,
338
+ ):
339
+ x = self.vision_embeddings(x)
340
+ for block in self.resblocks:
341
+ x = block(x, mask=mask)
342
+ x = self.encoder_layer_norm(x)
343
+ return x
344
+
345
+ def compute_output_shape(self, inputs_shape):
346
+ return [inputs_shape[0], inputs_shape[1], self.hidden_dim]
347
+
348
+ def get_config(self):
349
+ config = super().get_config()
350
+ config.update(
351
+ {
352
+ "hidden_dim": self.hidden_dim,
353
+ "num_layers": self.num_layers,
354
+ "num_heads": self.num_heads,
355
+ "intermediate_dim": self.intermediate_dim,
356
+ "patch_size": self.patch_size,
357
+ "image_size": self.image_size,
358
+ }
359
+ )
360
+ return config
361
+
362
+
363
+ class MultiHeadAttentionPooling(keras.layers.Layer):
364
+ def __init__(
365
+ self,
366
+ hidden_dim=None,
367
+ num_heads=12,
368
+ dropout=0.0,
369
+ **kwargs,
370
+ ):
371
+ super().__init__(**kwargs)
372
+ self.hidden_dim = hidden_dim
373
+ self.num_heads = num_heads
374
+ self.dropout = dropout
375
+
376
+ def build(self, input_shape):
377
+ if self.hidden_dim is None:
378
+ self.hidden_dim = input_shape[-1] * 4
379
+ self.probe = self.add_weight(
380
+ shape=(1, 1, input_shape[-1]),
381
+ initializer="glorot_uniform",
382
+ dtype=self.dtype_policy,
383
+ )
384
+ self.mha = keras.layers.MultiHeadAttention(
385
+ key_dim=input_shape[-1] // self.num_heads,
386
+ num_heads=self.num_heads,
387
+ dtype=self.dtype_policy,
388
+ )
389
+ self.layer_norm = keras.layers.LayerNormalization(
390
+ epsilon=1e-6,
391
+ dtype=self.dtype_policy,
392
+ )
393
+ self.mlp_block = keras.Sequential(
394
+ [
395
+ keras.layers.Dense(
396
+ self.hidden_dim,
397
+ activation="gelu",
398
+ dtype=self.dtype_policy,
399
+ ),
400
+ keras.layers.Dropout(
401
+ self.dropout,
402
+ dtype=self.dtype_policy,
403
+ ),
404
+ keras.layers.Dense(
405
+ input_shape[-1],
406
+ dtype=self.dtype_policy,
407
+ ),
408
+ ]
409
+ )
410
+
411
+ def call(self, x):
412
+ batch_size = ops.shape(x)[0]
413
+ probe = ops.tile(self.probe, [batch_size, 1, 1])
414
+ x = self.mha(probe, x)
415
+ y = self.layer_norm(x)
416
+ x = x + self.mlp_block(y)
417
+ return x[:, 0]
418
+
419
+
420
+ class PaliGemmaVit(keras.Model):
421
+ """Vision Transformer (ViT) model for PaliGemma.
422
+
423
+ Args:
424
+ image_size: int. The height/width of the image. Both height and width is
425
+ expected to be the same.
426
+ include_rescaling: bool. If true, the image input will be rescaled from
427
+ the range `[0, 255]`, to the range `[0, 1]`.
428
+ patch_size: int. The size of each square patch in the input image.
429
+ num_heads: int. The number of attention heads for the vision(image)
430
+ transformer encoder.
431
+ hidden_dim: int. The size of the transformer hidden state at the end
432
+ of each vision transformer layer.
433
+ num_layers: int. The number of transformer layers.
434
+ intermediate_dim: int. The output dimension of the first Dense layer in
435
+ a two-layer feedforward network for transformer.
436
+ num_classes: int. The number of output classes. If this model is used
437
+ as a image classifier, this value would correspond to the number of
438
+ output classes.
439
+ pooling: string. The encoded vision embeddings are pooled using the
440
+ specified polling setting. The accepted values are `"map"`, `"gap"`,
441
+ `"zero"` or `None`. Defaults to `None`.
442
+ classifier_activation: activation fucntion. The activation that is used
443
+ for final output classification
444
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
445
+ for the models computations and weights. Note that some
446
+ computations, such as softmax and layer normalization will always
447
+ be done a float32 precision regardless of dtype.
448
+
449
+ Example:
450
+ ```python
451
+ image = np.random.rand(224, 224, 3)
452
+ vit_model = PaliGemmaVit(image_size=224)
453
+ # The output will be of shape:
454
+ # [batch_size, image_sequence_length, num_classes]
455
+ output = vit_model([image])
456
+ ```
457
+ """
458
+
459
+ def __init__(
460
+ self,
461
+ image_size,
462
+ patch_size,
463
+ num_heads,
464
+ hidden_dim,
465
+ num_layers,
466
+ intermediate_dim,
467
+ num_classes,
468
+ include_rescaling=True,
469
+ pooling=None,
470
+ classifier_activation=None,
471
+ dtype=None,
472
+ **kwargs,
473
+ ):
474
+ # === Functional Model ===
475
+ image_input = keras.Input(
476
+ shape=(image_size, image_size, 3), name="images"
477
+ )
478
+ x = image_input # Intermediate result.
479
+ if include_rescaling:
480
+ rescaling = keras.layers.Rescaling(
481
+ scale=1.0 / 127.5, offset=-1.0, name="rescaling"
482
+ )
483
+ x = rescaling(image_input)
484
+ x = PaliGemmaVitEncoder(
485
+ hidden_dim=hidden_dim,
486
+ num_layers=num_layers,
487
+ num_heads=num_heads,
488
+ intermediate_dim=intermediate_dim,
489
+ patch_size=patch_size,
490
+ image_size=image_size,
491
+ dtype=dtype,
492
+ name="image_encoder",
493
+ )(x)
494
+ if pooling == "map":
495
+ x = MultiHeadAttentionPooling(
496
+ num_heads=num_heads,
497
+ hidden_dim=hidden_dim,
498
+ dtype=dtype,
499
+ name="pooling",
500
+ )(x)
501
+ elif pooling == "gap":
502
+ x = ops.mean(x, axis=1)
503
+ elif pooling == "zero":
504
+ x = x[:, 0]
505
+ elif pooling is None:
506
+ x = x
507
+ else:
508
+ raise ValueError(
509
+ "Invalid value for argument `pooling`. "
510
+ "Expected one of 'map', 'gap', None. "
511
+ f"Received: pooling={pooling}"
512
+ )
513
+ outputs = keras.layers.Dense(
514
+ num_classes,
515
+ activation=classifier_activation,
516
+ dtype=dtype,
517
+ name="image_classifier",
518
+ )(x)
519
+ super().__init__(
520
+ inputs=image_input,
521
+ outputs=outputs,
522
+ **kwargs,
523
+ )
524
+
525
+ # === Config ===
526
+ self.num_heads = num_heads
527
+ self.hidden_dim = hidden_dim
528
+ self.num_layers = num_layers
529
+ self.intermediate_dim = intermediate_dim
530
+ self.pooling = pooling
531
+ self.num_classes = num_classes
532
+ self.image_size = image_size
533
+ self.include_rescaling = include_rescaling
534
+ self.patch_size = patch_size
535
+ self.classifier_activation = keras.activations.get(
536
+ classifier_activation
537
+ )
538
+ self.image_sequence_length = int((image_size / patch_size) ** 2)
539
+ # Before Keras 3.2, there is no `keras.dtype_policies.get`.
540
+ if hasattr(keras.dtype_policies, "get"):
541
+ self.dtype_policy = keras.dtype_policies.get(dtype)
542
+ else:
543
+ if isinstance(dtype, keras.dtype_policies.DTypePolicy):
544
+ dtype = dtype.name
545
+ dtype = dtype or keras.config.dtype_policy().name
546
+ self.dtype_policy = keras.dtype_policies.DTypePolicy(dtype)
547
+
548
+ def get_config(self):
549
+ config = super().get_config()
550
+ config.update(
551
+ {
552
+ "num_heads": self.num_heads,
553
+ "hidden_dim": self.hidden_dim,
554
+ "num_layers": self.num_layers,
555
+ "intermediate_dim": self.intermediate_dim,
556
+ "pooling": self.pooling,
557
+ "num_classes": self.num_classes,
558
+ "classifier_activation": keras.activations.serialize(
559
+ self.classifier_activation
560
+ ),
561
+ "image_size": self.image_size,
562
+ "include_rescaling": self.include_rescaling,
563
+ "patch_size": self.patch_size,
564
+ }
565
+ )
566
+ return config
@@ -0,0 +1,20 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone
16
+ from keras_hub.src.models.phi3.phi3_presets import backbone_presets
17
+ from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
18
+ from keras_hub.src.utils.preset_utils import register_presets
19
+
20
+ register_presets(backbone_presets, (Phi3Backbone, Phi3Tokenizer))