keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,313 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writingf, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from keras import ops
15
+
16
+ from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.models.causal_lm import CausalLM
18
+ from keras_hub.src.models.pali_gemma.pali_gemma_backbone import (
19
+ PaliGemmaBackbone,
20
+ )
21
+ from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import (
22
+ PaliGemmaCausalLMPreprocessor,
23
+ )
24
+ from keras_hub.src.utils.tensor_utils import any_equal
25
+
26
+
27
+ @keras_hub_export("keras_hub.models.PaliGemmaCausalLM")
28
+ class PaliGemmaCausalLM(CausalLM):
29
+ """An end-to-end multi modal PaliGemma model for causal language modeling.
30
+
31
+ A causal language model (LM) predicts the next token based on previous
32
+ tokens. This task setup can be used to train the model unsupervised on
33
+ image and plain text input, or to autoregressively generate plain text
34
+ similar to the data used for training.
35
+
36
+ This model has a `generate()` method, which generates text based on a
37
+ prompt. The generation strategy used is controlled by an additional
38
+ `sampler` argument on `compile()`. You can recompile the model with
39
+ different `keras_hub.samplers` objects to control the generation. By
40
+ default, `"greedy"` sampling will be used.
41
+
42
+ This model can optionally be configured with a `preprocessor` layer, in
43
+ which case it will automatically apply preprocessing to string inputs during
44
+ `fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default
45
+ when creating the model with `from_preset()`.
46
+
47
+ Args:
48
+ backbone: A `keras_hub.models.PaliGemmaBackbone` instance.
49
+ preprocessor: A `keras_hub.models.PaliGemmaCausalLMPreprocessor` or
50
+ `None`. If `None`, this model will not apply preprocessing, and
51
+ inputs should be preprocessed before calling the model.
52
+
53
+ Examples:
54
+
55
+ Use `generate()` to do text generation.
56
+ ```python
57
+ image = np.random.rand(224, 224, 3)
58
+ pali_gemma_lm = keras_hub.models.PaliGemmaCausalLM.from_preset(
59
+ "pali_gemma_3b_mix_224"
60
+ )
61
+ pali_gemma_lm.generate(
62
+ {
63
+ "images": image,
64
+ "text": ["answer en where is the cow standing?\\n"]
65
+ }
66
+ )
67
+
68
+ # Generate with batched prompts.
69
+ pali_gemma_lm.generate(
70
+ {
71
+ "images": [image, image],
72
+ "text": ["answer en where is the cow standing?\\n", "caption en\\n"]
73
+ }
74
+ )
75
+ ```
76
+
77
+ Use `generate()` without preprocessing.
78
+ ```python
79
+ image = np.random.rand(224, 224, 3)
80
+ inputs = {
81
+ "images": [image, image],
82
+ # Token ids for "<bos> Keras is".
83
+ "token_ids": np.array([[2, 214064, 603, 0, 0, 0, 0]] * 2),
84
+ # Use `"padding_mask"` to indicate values that should not be overridden.
85
+ "padding_mask": np.array([[1, 1, 1, 0, 0, 0, 0]] * 2),
86
+ }
87
+
88
+ pali_gemma_lm = keras_hub.models.PaliGemmaCausalLM.from_preset(
89
+ "pali_gemma_3b_mix_224",
90
+ preprocessor=None,
91
+ )
92
+ pali_gemma_lm.generate(inputs)
93
+ ```
94
+
95
+ Custom backbone and vocabulary.
96
+ ```python
97
+ tokenizer = keras_hub.models.PaliGemmaTokenizer(
98
+ proto="proto.spm",
99
+ )
100
+ preprocessor = keras_hub.models.PaliGemmaCausalLMPreprocessor(
101
+ tokenizer=tokenizer,
102
+ sequence_length=128,
103
+ )
104
+ backbone = keras_hub.models.PaliGemmaBackbone()
105
+ pali_gemma_lm = keras_hub.models.PaliGemmaCausalLM(
106
+ backbone=backbone,
107
+ preprocessor=preprocessor,
108
+ )
109
+ ```
110
+ """
111
+
112
+ backbone_cls = PaliGemmaBackbone
113
+ preprocessor_cls = PaliGemmaCausalLMPreprocessor
114
+
115
+ def __init__(
116
+ self,
117
+ preprocessor,
118
+ backbone,
119
+ **kwargs,
120
+ ):
121
+ # === Layers ===
122
+ self.preprocessor = preprocessor
123
+ self.backbone = backbone
124
+
125
+ # === Functional Model ===
126
+ inputs = backbone.inputs
127
+ hidden_state = backbone(inputs=inputs)
128
+ outputs = backbone.token_embedding(hidden_state, reverse=True)
129
+ outputs = outputs[:, backbone.image_sequence_length :, :]
130
+ super().__init__(
131
+ inputs=inputs,
132
+ outputs=outputs,
133
+ **kwargs,
134
+ )
135
+
136
+ def compile(
137
+ self,
138
+ optimizer="auto",
139
+ loss="auto",
140
+ *,
141
+ weighted_metrics="auto",
142
+ sampler="greedy",
143
+ **kwargs,
144
+ ):
145
+ super().compile(
146
+ optimizer=optimizer,
147
+ loss=loss,
148
+ weighted_metrics=weighted_metrics,
149
+ sampler=sampler,
150
+ **kwargs,
151
+ )
152
+
153
+ def call_with_cache(
154
+ self,
155
+ token_ids,
156
+ cache,
157
+ cache_update_index,
158
+ img_embeddings=None,
159
+ padding_mask=None,
160
+ ):
161
+ """Forward pass of `PaliGemmaCausalLM` with cache.
162
+
163
+ `call_with_cache` adds an additional forward pass for the model for
164
+ autoregressive inference. Unlike calling the model directly, this method
165
+ allows caching previous key/value Tensors in multi-head attention layer,
166
+ and avoids recomputing the outputs of seen tokens.
167
+
168
+ Args:
169
+ token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
170
+ cache: a dense float Tensor, the cache of key and value.
171
+ cache_update_index: int, or int Tensor. The index of current inputs
172
+ in the whole sequence.
173
+ img_embeddings: a dense float Tensor with shape
174
+ `(batch_size, image_sequence_length, hidden_dim)`.
175
+ padding_mask: a dense int Tensor with shape
176
+ `(batch_size, max_length)`.
177
+
178
+ Returns:
179
+ A (logits, hidden_states, cache) tuple. Where `logits` is the
180
+ language model logits for the input token_ids, `hidden_states` is
181
+ the final hidden representation of the input tokens, and `cache` is
182
+ the decoding cache.
183
+ """
184
+ text_embeddings = self.backbone.token_embedding(token_ids)
185
+ text_embeddings = text_embeddings * ops.cast(
186
+ ops.sqrt(self.backbone.hidden_dim), text_embeddings.dtype
187
+ )
188
+
189
+ if img_embeddings is not None:
190
+ x = ops.concatenate((img_embeddings, text_embeddings), axis=1)
191
+ else:
192
+ x = text_embeddings
193
+
194
+ # Each decoder layer has a cache; we update them separately.
195
+ caches = []
196
+ for i, transformer_layer in enumerate(self.backbone.transformer_layers):
197
+ current_cache = cache[:, i, ...]
198
+ x, next_cache = transformer_layer(
199
+ x,
200
+ cache=current_cache,
201
+ cache_update_index=cache_update_index,
202
+ padding_mask=padding_mask,
203
+ )
204
+ caches.append(next_cache)
205
+ cache = ops.stack(caches, axis=1)
206
+ hidden_states = x = self.backbone.layer_norm(x)
207
+ logits = self.backbone.token_embedding(x, reverse=True)
208
+ return logits, hidden_states, cache
209
+
210
+ def _build_cache(self, token_ids, img_embeddings, padding_mask):
211
+ """Build an empty cache for use with `call_with_cache()`."""
212
+ batch_size = ops.shape(token_ids)[0]
213
+ max_length = (
214
+ ops.shape(token_ids)[1] + self.backbone.image_sequence_length
215
+ )
216
+ num_layers = self.backbone.num_layers
217
+ num_heads = self.backbone.num_key_value_heads
218
+ head_dim = self.backbone.head_dim
219
+ shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
220
+ cache = ops.zeros(shape, dtype=self.compute_dtype)
221
+ # Seed the cache.
222
+ logits, hidden_states, cache = self.call_with_cache(
223
+ token_ids=token_ids,
224
+ img_embeddings=img_embeddings,
225
+ cache=cache,
226
+ cache_update_index=0,
227
+ padding_mask=padding_mask,
228
+ )
229
+ return hidden_states, cache
230
+
231
+ def generate_step(self, inputs, stop_token_ids=None):
232
+ """A compilable generation function for a single batch of inputs.
233
+
234
+ This function represents the inner, XLA-compilable, generation function
235
+ for a single batch of inputs. Inputs should have the same structure as
236
+ model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
237
+
238
+ Args:
239
+ inputs: A dictionary with two keys `"token_ids"` and
240
+ `"padding_mask"` and batched tensor values.
241
+ stop_token_ids: Tuple of id's of end token's to stop on. If all
242
+ sequences have produced a new stop token, generation
243
+ will stop.
244
+ """
245
+ token_ids, padding_mask, images = (
246
+ inputs["token_ids"],
247
+ inputs["padding_mask"],
248
+ inputs["images"],
249
+ )
250
+ if len(ops.shape(images)) == 3:
251
+ # Handle an unbatched image. Unlike `token_ids` and `padding_mask`
252
+ # this will not automatically be upranked.
253
+ images = ops.expand_dims(images, axis=0)
254
+ img_embeddings = self.backbone.vit_encoder(images)
255
+
256
+ # Create and seed cache with a single forward pass.
257
+ hidden_states, cache = self._build_cache(
258
+ token_ids, img_embeddings, padding_mask
259
+ )
260
+ # Compute the lengths of all user inputted tokens ids.
261
+ row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
262
+ # Start at the first index that has no user inputted id.
263
+ index = ops.min(row_lengths)
264
+
265
+ def next(prompt, cache, index):
266
+ # The cache index is the index of our previous token.
267
+ cache_update_index = index - 1 + self.backbone.image_sequence_length
268
+ batch_size = ops.shape(prompt)[0]
269
+ prompt = ops.slice(prompt, [0, index - 1], [batch_size, 1])
270
+ logits, hidden_states, cache = self.call_with_cache(
271
+ token_ids=prompt,
272
+ cache=cache,
273
+ cache_update_index=cache_update_index,
274
+ )
275
+ return (
276
+ ops.squeeze(logits, axis=1),
277
+ ops.squeeze(hidden_states, axis=1),
278
+ cache,
279
+ )
280
+
281
+ token_ids = self.sampler(
282
+ next=next,
283
+ prompt=token_ids,
284
+ cache=cache,
285
+ index=index,
286
+ mask=padding_mask,
287
+ stop_token_ids=stop_token_ids,
288
+ hidden_states=hidden_states,
289
+ model=self,
290
+ )
291
+
292
+ # Compute an output padding mask with the token ids we updated.
293
+ if stop_token_ids is not None:
294
+ # Build a mask of `stop_token_ids` locations not in the original
295
+ # prompt (not in locations where `padding_mask` is True).
296
+ end_locations = any_equal(
297
+ token_ids, stop_token_ids, ops.logical_not(padding_mask)
298
+ )
299
+
300
+ end_locations = ops.cast(end_locations, "int32")
301
+ # Use cumsum to get ones in all locations after end_locations.
302
+ cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
303
+ overflow = cumsum - end_locations
304
+ # Our padding mask is the inverse of these overflow locations.
305
+ padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
306
+ else:
307
+ # Without early stopping, all locations will have been updated.
308
+ padding_mask = ops.ones_like(token_ids, dtype="bool")
309
+ return {
310
+ "token_ids": token_ids,
311
+ "padding_mask": padding_mask,
312
+ "images": images,
313
+ }
@@ -0,0 +1,147 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+ from absl import logging
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.layers.preprocessing.multi_segment_packer import (
20
+ MultiSegmentPacker,
21
+ )
22
+ from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import (
23
+ GemmaCausalLMPreprocessor,
24
+ )
25
+ from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
26
+ PaliGemmaTokenizer,
27
+ )
28
+ from keras_hub.src.utils.keras_utils import (
29
+ convert_inputs_to_list_of_tensor_segments,
30
+ )
31
+
32
+
33
+ @keras_hub_export("keras_hub.models.PaliGemmaCausalLMPreprocessor")
34
+ class PaliGemmaCausalLMPreprocessor(GemmaCausalLMPreprocessor):
35
+ tokenizer_cls = PaliGemmaTokenizer
36
+
37
+ def __init__(
38
+ self,
39
+ tokenizer,
40
+ sequence_length=512,
41
+ add_start_token=True,
42
+ add_end_token=True,
43
+ **kwargs,
44
+ ):
45
+ super().__init__(
46
+ tokenizer, sequence_length, add_start_token, add_end_token, **kwargs
47
+ )
48
+
49
+ def build(self, input_shape):
50
+ # Defer packer creation to `build()` so that we can be sure tokenizer
51
+ # assets have loaded when restoring a saved model.
52
+ self.packer = MultiSegmentPacker(
53
+ start_value=self.tokenizer.start_token_id,
54
+ end_value=self.tokenizer.end_token_id,
55
+ pad_value=self.tokenizer.pad_token_id,
56
+ sep_value=[],
57
+ sequence_length=self.sequence_length,
58
+ )
59
+ self.built = True
60
+
61
+ def call(
62
+ self,
63
+ x,
64
+ y=None,
65
+ sample_weight=None,
66
+ sequence_length=None,
67
+ ):
68
+ if y is not None or sample_weight is not None:
69
+ logging.warning(
70
+ "`PaliGemmaCausalLMPreprocessor` generates `y` and `sample_weight` "
71
+ "based on your input data, but your data already contains `y` "
72
+ "or `sample_weight`. Your `y` and `sample_weight` will be "
73
+ "ignored."
74
+ )
75
+ sequence_length = sequence_length or self.sequence_length
76
+
77
+ images, prompts, responses = x["images"], x["prompts"], x["responses"]
78
+ if keras.config.backend() == "tensorflow":
79
+ # Tensorflow backend needs uniform ouput types.
80
+ images = ops.convert_to_tensor(images)
81
+ prompts = convert_inputs_to_list_of_tensor_segments(prompts)[0]
82
+ prompts = self.tokenizer(prompts)
83
+ responses = convert_inputs_to_list_of_tensor_segments(responses)[0]
84
+ responses = self.tokenizer(responses)
85
+ # Pad with one extra token to account for the truncation below.
86
+ token_ids, segment_ids = self.packer(
87
+ (prompts, responses),
88
+ sequence_length=sequence_length + 1,
89
+ add_start_value=self.add_start_token,
90
+ add_end_value=self.add_end_token,
91
+ )
92
+ padding_mask = token_ids != self.tokenizer.pad_token_id
93
+ response_mask = segment_ids == 1
94
+ # The last token does not have a next token, so we truncate it out.
95
+ x = {
96
+ "token_ids": token_ids[..., :-1],
97
+ "response_mask": response_mask[..., :-1],
98
+ "padding_mask": padding_mask[..., :-1],
99
+ "images": images,
100
+ }
101
+ # Target `y` will be the next token.
102
+ y = token_ids[..., 1:]
103
+ # Only compute the loss for labels in the response.
104
+ sample_weight = response_mask[..., 1:]
105
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
106
+
107
+ def generate_preprocess(
108
+ self,
109
+ x,
110
+ sequence_length=None,
111
+ ):
112
+ """Convert strings to integer token input for generation.
113
+
114
+ Similar to calling the layer for training, this method takes in strings
115
+ or tensor strings, tokenizes and packs the input, and computes a padding
116
+ mask masking all inputs not filled in with a padded value.
117
+
118
+ Unlike calling the layer for training, this method does not compute
119
+ labels and will never append a `tokenizer.end_token_id` to the end of
120
+ the sequence (as generation is expected to continue at the end of the
121
+ inputted prompt).
122
+ """
123
+ if not self.built:
124
+ self.build(None)
125
+ sequence_length = sequence_length or self.sequence_length
126
+
127
+ images, prompts = x["images"], x["prompts"]
128
+ prompts = convert_inputs_to_list_of_tensor_segments(prompts)[0]
129
+ prompts = self.tokenizer(prompts)
130
+ segments = [prompts]
131
+ if "responses" in x:
132
+ responses = x["responses"]
133
+ responses = convert_inputs_to_list_of_tensor_segments(responses)[0]
134
+ segments.append(self.tokenizer(responses))
135
+ token_ids, segment_ids = self.packer(
136
+ segments,
137
+ sequence_length=sequence_length,
138
+ add_end_value=False,
139
+ )
140
+ padding_mask = token_ids != self.tokenizer.pad_token_id
141
+ response_mask = segment_ids == 1
142
+ return {
143
+ "images": images,
144
+ "token_ids": token_ids,
145
+ "response_mask": response_mask,
146
+ "padding_mask": padding_mask,
147
+ }
@@ -0,0 +1,160 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
19
+ compute_causal_mask,
20
+ )
21
+ from keras_hub.src.models.gemma.gemma_decoder_block import GemmaDecoderBlock
22
+
23
+
24
+ class PaliGemmaDecoderBlock(GemmaDecoderBlock):
25
+ """PaliGemma mixed decoder block.
26
+
27
+ This class implements a decoder block of the PaliGemma Architecture: a
28
+ mixed transformer decoder block. Intended to be used with an input
29
+ sequence comprised of both embedded image and text data, this block
30
+ functions largely identically to the `GemmaDecoderBlock` class, with a
31
+ notable exception in the computation of attention masks.
32
+
33
+ Specifically, this decoder block will use causal self-attention on the
34
+ text portion of the input, while using full self-attention for image
35
+ data. It is expected that any image data occurs before text data in the
36
+ input.
37
+
38
+ Args:
39
+ hidden_dim: int. The size of the transformer hidden state at the end
40
+ of the block.
41
+ intermediate_dim: int. The output dimension of the first Dense layer in
42
+ the two-layer feedforward network.
43
+ head_dim: int. The size of each attention head.
44
+ num_query_heads: int. The number of heads for the query projections in
45
+ the attention layer.
46
+ num_key_value_heads: int. The number of heads for the key and value
47
+ projections in the attention layer.
48
+ layer_norm_epsilon: float. The epsilon hyperparameter used for layer
49
+ normalization.
50
+ dropout: float. The dropout rate for the transformer attention layer.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ hidden_dim,
56
+ intermediate_dim,
57
+ head_dim,
58
+ num_query_heads,
59
+ num_key_value_heads,
60
+ layer_norm_epsilon=1e-6,
61
+ dropout=0,
62
+ **kwargs,
63
+ ):
64
+ super().__init__(
65
+ hidden_dim=hidden_dim,
66
+ intermediate_dim=intermediate_dim,
67
+ head_dim=head_dim,
68
+ num_query_heads=num_query_heads,
69
+ num_key_value_heads=num_key_value_heads,
70
+ layer_norm_epsilon=layer_norm_epsilon,
71
+ dropout=dropout,
72
+ **kwargs,
73
+ )
74
+
75
+ def call(
76
+ self,
77
+ x,
78
+ padding_mask=None,
79
+ response_mask=None,
80
+ cache=None,
81
+ cache_update_index=0,
82
+ ):
83
+ normalized_x = self.pre_attention_norm(x)
84
+ attention_mask = self._compute_attention_mask(
85
+ normalized_x, padding_mask, cache, cache_update_index, response_mask
86
+ )
87
+ if cache is not None:
88
+ attention, new_cache = self.attention(
89
+ normalized_x,
90
+ attention_mask=attention_mask,
91
+ cache=cache,
92
+ cache_update_index=cache_update_index,
93
+ )
94
+ else:
95
+ attention = self.attention(
96
+ normalized_x,
97
+ attention_mask=attention_mask,
98
+ )
99
+
100
+ if self.dropout:
101
+ attention = self.attention_dropout(attention)
102
+
103
+ attention_x = x + attention
104
+ normalized_x = self.pre_ffw_norm(attention_x)
105
+
106
+ x1 = self.gating_ffw(normalized_x)
107
+ x2 = self.gating_ffw_2(normalized_x)
108
+ x = keras.activations.gelu(x1, approximate=True) * x2
109
+ x = self.ffw_linear(x)
110
+
111
+ x = x + attention_x
112
+
113
+ if cache is not None:
114
+ return x, new_cache
115
+ return x
116
+
117
+ def _compute_attention_mask(
118
+ self,
119
+ x,
120
+ padding_mask,
121
+ cache,
122
+ cache_update_index,
123
+ response_mask=None,
124
+ ):
125
+ batch_size = ops.shape(x)[0]
126
+ input_length = output_length = ops.shape(x)[1]
127
+ if cache is not None:
128
+ input_length = ops.shape(cache)[2]
129
+
130
+ causal_mask = compute_causal_mask(
131
+ batch_size=batch_size,
132
+ input_length=input_length,
133
+ output_length=output_length,
134
+ cache_index=cache_update_index,
135
+ )
136
+
137
+ if padding_mask is None:
138
+ # We should only hit this case during generative decoding.
139
+ # Just the causal mask is fine in this case.
140
+ return causal_mask
141
+
142
+ def token_to_attention_mask(mask, fill_value):
143
+ """Reshape token mask -> attention mask padding for image tokens."""
144
+ mask = ops.cast(mask, "int32")
145
+ pad = input_length - ops.shape(mask)[1]
146
+ mask = ops.pad(mask, ((0, 0), (pad, 0)), constant_values=fill_value)
147
+ return ops.expand_dims(mask, axis=1)
148
+
149
+ padding_mask = token_to_attention_mask(padding_mask, 1)
150
+ if response_mask is not None:
151
+ response_mask = token_to_attention_mask(response_mask, 0)
152
+ not_response_mask = ops.logical_not(response_mask)
153
+ # Only apply the causal mask to the response tokens.
154
+ causal_mask = ops.logical_and(causal_mask, response_mask)
155
+ # Only apply block attention to the non-response tokens.
156
+ padding_mask = ops.logical_and(padding_mask, not_response_mask)
157
+
158
+ # Use block attention for the padding mask,
159
+ # which marks all image and prompt tokens.
160
+ return ops.logical_or(padding_mask, causal_mask)
@@ -0,0 +1,78 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """PaliGemma model preset configurations."""
15
+
16
+ # Metadata for loading pretrained model weights.
17
+ backbone_presets = {
18
+ "pali_gemma_3b_mix_224": {
19
+ "metadata": {
20
+ "description": (
21
+ "image size 224, mix fine tuned, text sequence " "length is 256"
22
+ ),
23
+ "params": 2923335408,
24
+ "official_name": "PaliGemma",
25
+ "path": "pali_gemma",
26
+ "model_card": "https://www.kaggle.com/models/google/paligemma",
27
+ },
28
+ "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_224/1",
29
+ },
30
+ "pali_gemma_3b_mix_448": {
31
+ "metadata": {
32
+ "description": (
33
+ "image size 448, mix fine tuned, text sequence length is 512"
34
+ ),
35
+ "params": 2924220144,
36
+ "official_name": "PaliGemma",
37
+ "path": "pali_gemma",
38
+ "model_card": "https://www.kaggle.com/models/google/paligemma",
39
+ },
40
+ "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_448/1",
41
+ },
42
+ "pali_gemma_3b_224": {
43
+ "metadata": {
44
+ "description": (
45
+ "image size 224, pre trained, text sequence length is 128"
46
+ ),
47
+ "params": 2923335408,
48
+ "official_name": "PaliGemma",
49
+ "path": "pali_gemma",
50
+ "model_card": "https://www.kaggle.com/models/google/paligemma",
51
+ },
52
+ "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_224/1",
53
+ },
54
+ "pali_gemma_3b_448": {
55
+ "metadata": {
56
+ "description": (
57
+ "image size 448, pre trained, text sequence length is 512"
58
+ ),
59
+ "params": 2924220144,
60
+ "official_name": "PaliGemma",
61
+ "path": "pali_gemma",
62
+ "model_card": "https://www.kaggle.com/models/google/paligemma",
63
+ },
64
+ "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_448/1",
65
+ },
66
+ "pali_gemma_3b_896": {
67
+ "metadata": {
68
+ "description": (
69
+ "image size 896, pre trained, text sequence length " "is 512"
70
+ ),
71
+ "params": 2927759088,
72
+ "official_name": "PaliGemma",
73
+ "path": "pali_gemma",
74
+ "model_card": "https://www.kaggle.com/models/google/paligemma",
75
+ },
76
+ "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_896/1",
77
+ },
78
+ }