keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,250 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+ import numpy as np
16
+ from keras import ops
17
+
18
+ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
19
+ from keras_hub.src.utils.keras_utils import clone_initializer
20
+
21
+
22
+ class CachedGemmaAttention(keras.layers.Layer):
23
+ """A cached grouped query attention layer."""
24
+
25
+ def __init__(
26
+ self,
27
+ head_dim,
28
+ num_query_heads,
29
+ num_key_value_heads,
30
+ kernel_initializer="glorot_uniform",
31
+ logit_soft_cap=None,
32
+ use_sliding_window_attention=False,
33
+ sliding_window_size=4096,
34
+ query_head_dim_normalize=True,
35
+ dropout=0,
36
+ **kwargs,
37
+ ):
38
+ super().__init__(**kwargs)
39
+ self.num_query_heads = num_query_heads
40
+ self.num_key_value_heads = num_key_value_heads
41
+ self.head_dim = head_dim
42
+ self.logit_soft_cap = logit_soft_cap
43
+ self.use_sliding_window_attention = use_sliding_window_attention
44
+ self.sliding_window_size = sliding_window_size
45
+ self.query_head_dim_normalize = query_head_dim_normalize
46
+ self.dropout = dropout
47
+
48
+ self._kernel_initializer = keras.initializers.get(
49
+ clone_initializer(kernel_initializer)
50
+ )
51
+ self.num_key_value_groups = num_query_heads // num_key_value_heads
52
+ self.query_head_dim_normalize = query_head_dim_normalize
53
+
54
+ def build(self, inputs_shape):
55
+ self.hidden_dim = inputs_shape[-1]
56
+
57
+ self.query_dense = keras.layers.EinsumDense(
58
+ "btd,ndh->btnh",
59
+ output_shape=(None, self.num_query_heads, self.head_dim),
60
+ kernel_initializer=self._kernel_initializer,
61
+ dtype=self.dtype_policy,
62
+ name="query",
63
+ )
64
+ self.query_dense.build(inputs_shape)
65
+
66
+ self.key_dense = keras.layers.EinsumDense(
67
+ "bsd,kdh->bskh",
68
+ output_shape=(None, self.num_key_value_heads, self.head_dim),
69
+ kernel_initializer=self._kernel_initializer,
70
+ dtype=self.dtype_policy,
71
+ name="key",
72
+ )
73
+ self.key_dense.build(inputs_shape)
74
+
75
+ self.value_dense = keras.layers.EinsumDense(
76
+ "bsd,kdh->bskh",
77
+ output_shape=(None, self.num_key_value_heads, self.head_dim),
78
+ kernel_initializer=self._kernel_initializer,
79
+ dtype=self.dtype_policy,
80
+ name="value",
81
+ )
82
+ self.value_dense.build(inputs_shape)
83
+
84
+ self.dropout_layer = keras.layers.Dropout(
85
+ rate=self.dropout,
86
+ dtype=self.dtype_policy,
87
+ )
88
+
89
+ self.output_dense = keras.layers.EinsumDense(
90
+ equation="btnh,nhd->btd",
91
+ output_shape=(None, self.hidden_dim),
92
+ kernel_initializer=self._kernel_initializer,
93
+ dtype=self.dtype_policy,
94
+ name="attention_output",
95
+ )
96
+ self.output_dense.build(
97
+ (None, None, self.num_query_heads, self.head_dim)
98
+ )
99
+ self.softmax = keras.layers.Softmax(dtype="float32")
100
+
101
+ self.rope_layer = RotaryEmbedding(
102
+ max_wavelength=10_000.0, dtype=self.dtype_policy
103
+ )
104
+
105
+ self.built = True
106
+
107
+ def _apply_rope(self, x, start_index):
108
+ """Rope rotate q or k."""
109
+ x = self.rope_layer(x, start_index=start_index)
110
+ # Gemma uses a different layout for positional embeddings.
111
+ # The transformation below ensures the embeddings are numerically
112
+ # equivalent to the original gemma implementation.
113
+ x = ops.reshape(
114
+ ops.stack(ops.split(x, 2, axis=-1), axis=-1), ops.shape(x)
115
+ )
116
+ return x
117
+
118
+ def _compute_attention(
119
+ self,
120
+ q,
121
+ k,
122
+ v,
123
+ attention_mask,
124
+ training=False,
125
+ cache_update_index=0,
126
+ ):
127
+ if self.query_head_dim_normalize:
128
+ query_normalization = 1 / np.sqrt(self.head_dim)
129
+ else:
130
+ query_normalization = 1 / np.sqrt(
131
+ self.hidden_dim // self.num_query_heads
132
+ )
133
+
134
+ q *= ops.cast(query_normalization, dtype=q.dtype)
135
+ q_shape = ops.shape(q)
136
+ q = ops.reshape(
137
+ q,
138
+ (
139
+ *q_shape[:-2],
140
+ self.num_key_value_heads,
141
+ self.num_query_heads // self.num_key_value_heads,
142
+ q_shape[-1],
143
+ ),
144
+ )
145
+ b, q_len, _, _, h = ops.shape(q)
146
+
147
+ attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k)
148
+
149
+ if self.logit_soft_cap is not None:
150
+ attention_logits = ops.divide(attention_logits, self.logit_soft_cap)
151
+ attention_logits = ops.multiply(
152
+ ops.tanh(attention_logits), self.logit_soft_cap
153
+ )
154
+
155
+ if self.use_sliding_window_attention:
156
+ attention_mask = self._mask_sliding_window(
157
+ attention_mask,
158
+ cache_update_index=cache_update_index,
159
+ )
160
+
161
+ attention_mask = attention_mask[:, None, None, :, :]
162
+ orig_dtype = attention_logits.dtype
163
+ attention_softmax = self.softmax(attention_logits, mask=attention_mask)
164
+ attention_softmax = ops.cast(attention_softmax, orig_dtype)
165
+
166
+ if self.dropout:
167
+ attention_softmax = self.dropout_layer(
168
+ attention_softmax, training=training
169
+ )
170
+
171
+ results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v)
172
+ return ops.reshape(results, (b, q_len, self.num_query_heads, h))
173
+
174
+ def _mask_sliding_window(
175
+ self,
176
+ attention_mask,
177
+ cache_update_index=0,
178
+ ):
179
+ batch_size, query_len, key_len = ops.shape(attention_mask)
180
+ # Compute the sliding window for square attention.
181
+ all_ones = ops.ones((key_len, key_len), "bool")
182
+ if keras.config.backend() == "tensorflow":
183
+ # TODO: trui/tril has issues with dynamic shape on the tensorflow
184
+ # backend. We should fix, but use `band_part` for now.
185
+ import tensorflow as tf
186
+
187
+ band_size = ops.minimum(key_len, self.sliding_window_size - 1)
188
+ band_size = ops.cast(band_size, "int32")
189
+ sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size)
190
+ else:
191
+ sliding_mask = ops.triu(
192
+ all_ones, -1 * self.sliding_window_size + 1
193
+ ) * ops.tril(all_ones, self.sliding_window_size - 1)
194
+ # Slice the window for short queries during generation.
195
+ start = (cache_update_index, 0)
196
+ sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
197
+ sliding_mask = ops.expand_dims(sliding_mask, 0)
198
+ return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))
199
+
200
+ def call(
201
+ self,
202
+ x,
203
+ attention_mask=None,
204
+ cache=None,
205
+ cache_update_index=0,
206
+ training=False,
207
+ ):
208
+ query = self.query_dense(x)
209
+ query = self._apply_rope(query, cache_update_index)
210
+
211
+ if cache is not None:
212
+ key_cache = cache[:, 0, ...]
213
+ value_cache = cache[:, 1, ...]
214
+ key_update = self.key_dense(x)
215
+ key_update = self._apply_rope(key_update, cache_update_index)
216
+ value_update = self.value_dense(x)
217
+ start = [0, cache_update_index, 0, 0]
218
+ key = ops.slice_update(key_cache, start, key_update)
219
+ value = ops.slice_update(value_cache, start, value_update)
220
+ cache = ops.stack((key, value), axis=1)
221
+ else:
222
+ key = self.key_dense(x)
223
+ key = self._apply_rope(key, cache_update_index)
224
+ value = self.value_dense(x)
225
+
226
+ attention_vec = self._compute_attention(
227
+ query,
228
+ key,
229
+ value,
230
+ attention_mask,
231
+ training=training,
232
+ cache_update_index=cache_update_index,
233
+ )
234
+
235
+ # Wipe attn vec if there are no attended tokens.
236
+ no_attended_tokens = ops.all(
237
+ ops.equal(attention_mask, 0), axis=-1, keepdims=True
238
+ )[..., None]
239
+ attention_vec = ops.where(
240
+ no_attended_tokens, ops.zeros_like(attention_vec), attention_vec
241
+ )
242
+
243
+ attention_output = self.output_dense(attention_vec)
244
+
245
+ if cache is not None:
246
+ return attention_output, cache
247
+ return attention_output
248
+
249
+ def compute_output_shape(self, input_shape):
250
+ return input_shape
@@ -0,0 +1,316 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import keras
17
+ from keras import ops
18
+
19
+ from keras_hub.src.api_export import keras_hub_export
20
+ from keras_hub.src.layers.modeling.reversible_embedding import (
21
+ ReversibleEmbedding,
22
+ )
23
+ from keras_hub.src.models.backbone import Backbone
24
+ from keras_hub.src.models.gemma.gemma_decoder_block import GemmaDecoderBlock
25
+ from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
26
+
27
+
28
+ @keras_hub_export("keras_hub.models.GemmaBackbone")
29
+ class GemmaBackbone(Backbone):
30
+ """Gemma core network with hyperparameters.
31
+
32
+ This backbone implements the base Transformer network for the Gemma model.
33
+ It includes the embedding lookups and transformer layers. This backbone
34
+ will output the final hidden states for each token, not generative
35
+ predictions over the vocabulary space. For a higher-level object for text
36
+ generation, see `keras_hub.models.GemmaCausalLM`.
37
+
38
+ The default constructor gives a fully customizable, randomly initialized
39
+ Gemma model with any number of layers, heads, and embedding dimensions. To
40
+ load preset architectures and weights, use the `from_preset` constructor.
41
+
42
+ Args:
43
+ vocabulary_size: int. The size of the token vocabulary.
44
+ num_layers: int. The number of transformer layers.
45
+ num_query_heads: int. The number of heads for the query projections in
46
+ the attention layer.
47
+ num_key_value_heads: int. The number of heads for the key and value
48
+ projections in the attention layer.
49
+ hidden_dim: int. The size of the transformer hidden state at the end
50
+ of each transformer layer.
51
+ intermediate_dim: int. The output dimension of the first Dense layer in
52
+ a two-layer feedforward network for each transformer.
53
+ head_dim: int. The size of each attention head.
54
+ layer_norm_epsilon: float. The epsilon value user for every layer norm
55
+ in the transformer model.
56
+ dropout: float. Dropout probability for the Transformer encoder.
57
+ query_head_dim_normalize: boolean. If `True` normalize the query before
58
+ attention with `head_dim`. If `False`, normalize the query with
59
+ `hidden_dim / num_query_heads`. Defaults to True.
60
+ use_post_ffw_norm: boolean. Whether to normalize after the feedforward
61
+ block. Defaults to False.
62
+ use_post_attention_norm: boolean. Whether to normalize after the attention
63
+ block. Defaults to False.
64
+ attention_logit_soft_cap: None or int. Soft cap for the attention logits.
65
+ Defaults to None.
66
+ final_logit_soft_cap: None or int. Soft cap for the final logits.
67
+ Defaults to None.
68
+ use_sliding_window_attention boolean. Whether to use sliding local
69
+ window attention. Defaults to False.
70
+ sliding_window_size: int. Size of the sliding local window. Defaults to
71
+ 4096.
72
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
73
+ for the models computations and weights. Note that some
74
+ computations, such as softmax and layer normalization will always
75
+ be done a float32 precision regardless of dtype.
76
+
77
+ Example:
78
+ ```python
79
+ input_data = {
80
+ "token_ids": np.ones(shape=(1, 12), dtype="int32"),
81
+ "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
82
+ }
83
+
84
+ # Pretrained Gemma decoder.
85
+ model = keras_hub.models.GemmaBackbone.from_preset("gemma_2b_en")
86
+ model(input_data)
87
+
88
+ # Randomly initialized Gemma decoder with custom config.
89
+ model = keras_hub.models.GemmaBackbone(
90
+ vocabulary_size=50257,
91
+ num_layers=12,
92
+ num_query_heads=12,
93
+ num_key_value_heads=1,
94
+ hidden_dim=768,
95
+ intermediate_dim=3072,
96
+ head_dim=64,
97
+ )
98
+ model(input_data)
99
+ ```
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ vocabulary_size,
105
+ num_layers,
106
+ num_query_heads,
107
+ num_key_value_heads,
108
+ hidden_dim,
109
+ intermediate_dim,
110
+ head_dim,
111
+ query_head_dim_normalize=True,
112
+ use_post_ffw_norm=False,
113
+ use_post_attention_norm=False,
114
+ attention_logit_soft_cap=None,
115
+ final_logit_soft_cap=None,
116
+ use_sliding_window_attention=False,
117
+ sliding_window_size=4096,
118
+ layer_norm_epsilon=1e-6,
119
+ dropout=0,
120
+ dtype=None,
121
+ **kwargs,
122
+ ):
123
+ # === Layers ===
124
+ self.token_embedding = ReversibleEmbedding(
125
+ input_dim=vocabulary_size,
126
+ output_dim=hidden_dim,
127
+ tie_weights=True,
128
+ embeddings_initializer=keras.initializers.VarianceScaling(
129
+ scale=1.0,
130
+ mode="fan_in",
131
+ distribution="untruncated_normal",
132
+ seed=None,
133
+ ),
134
+ dtype=dtype,
135
+ logit_soft_cap=final_logit_soft_cap,
136
+ name="token_embedding",
137
+ )
138
+ self.transformer_layers = []
139
+ for i in range(num_layers):
140
+ sliding_window = use_sliding_window_attention and (i % 2 == 0)
141
+ layer = GemmaDecoderBlock(
142
+ intermediate_dim=intermediate_dim,
143
+ hidden_dim=hidden_dim,
144
+ num_query_heads=num_query_heads,
145
+ head_dim=head_dim,
146
+ num_key_value_heads=num_key_value_heads,
147
+ query_head_dim_normalize=query_head_dim_normalize,
148
+ use_post_ffw_norm=use_post_ffw_norm,
149
+ use_post_attention_norm=use_post_attention_norm,
150
+ logit_soft_cap=attention_logit_soft_cap,
151
+ use_sliding_window_attention=sliding_window,
152
+ sliding_window_size=sliding_window_size,
153
+ dropout=dropout,
154
+ dtype=dtype,
155
+ name=f"decoder_block_{i}",
156
+ )
157
+ self.transformer_layers.append(layer)
158
+ self.layer_norm = RMSNormalization(
159
+ epsilon=layer_norm_epsilon,
160
+ dtype=dtype,
161
+ name="final_normalization",
162
+ )
163
+
164
+ # === Functional Model ===
165
+ token_id_input = keras.Input(
166
+ shape=(None,), dtype="float32", name="token_ids"
167
+ )
168
+ padding_mask_input = keras.Input(
169
+ shape=(None,), dtype="float32", name="padding_mask"
170
+ )
171
+ x = self.token_embedding(token_id_input)
172
+ x = x * ops.cast(ops.sqrt(hidden_dim), x.dtype)
173
+ for transformer_layer in self.transformer_layers:
174
+ x = transformer_layer(x, padding_mask=padding_mask_input)
175
+ sequence_output = self.layer_norm(x)
176
+ super().__init__(
177
+ inputs={
178
+ "token_ids": token_id_input,
179
+ "padding_mask": padding_mask_input,
180
+ },
181
+ outputs=sequence_output,
182
+ dtype=dtype,
183
+ **kwargs,
184
+ )
185
+
186
+ # === Config ===
187
+ self.vocabulary_size = vocabulary_size
188
+ self.num_layers = num_layers
189
+ self.num_query_heads = num_query_heads
190
+ self.num_key_value_heads = num_key_value_heads
191
+ self.hidden_dim = hidden_dim
192
+ self.intermediate_dim = intermediate_dim
193
+ self.head_dim = head_dim
194
+ self.layer_norm_epsilon = layer_norm_epsilon
195
+ self.dropout = dropout
196
+ self.query_head_dim_normalize = query_head_dim_normalize
197
+ self.use_post_ffw_norm = use_post_ffw_norm
198
+ self.use_post_attention_norm = use_post_attention_norm
199
+ self.attention_logit_soft_cap = attention_logit_soft_cap
200
+ self.final_logit_soft_cap = final_logit_soft_cap
201
+ self.sliding_window_size = sliding_window_size
202
+ self.use_sliding_window_attention = use_sliding_window_attention
203
+
204
+ def get_config(self):
205
+ config = super().get_config()
206
+ config.update(
207
+ {
208
+ "vocabulary_size": self.vocabulary_size,
209
+ "num_layers": self.num_layers,
210
+ "num_query_heads": self.num_query_heads,
211
+ "num_key_value_heads": self.num_key_value_heads,
212
+ "hidden_dim": self.hidden_dim,
213
+ "intermediate_dim": self.intermediate_dim,
214
+ "head_dim": self.head_dim,
215
+ "layer_norm_epsilon": self.layer_norm_epsilon,
216
+ "dropout": self.dropout,
217
+ "query_head_dim_normalize": self.query_head_dim_normalize,
218
+ "use_post_ffw_norm": self.use_post_ffw_norm,
219
+ "use_post_attention_norm": self.use_post_attention_norm,
220
+ "final_logit_soft_cap": self.final_logit_soft_cap,
221
+ "attention_logit_soft_cap": self.attention_logit_soft_cap,
222
+ "sliding_window_size": self.sliding_window_size,
223
+ "use_sliding_window_attention": self.use_sliding_window_attention,
224
+ }
225
+ )
226
+ return config
227
+
228
+ @staticmethod
229
+ def get_layout_map(
230
+ device_mesh,
231
+ model_parallel_dim_name="model",
232
+ data_parallel_dim_name="batch",
233
+ ):
234
+ """Get a `keras.distribution.LayoutMap` for model parallel distribution.
235
+
236
+ The returned `LayoutMap` contains the sharding spec for the gemma
237
+ backbone weights, so that you can use it to distribute weights across
238
+ the accelerators.
239
+
240
+ Example:
241
+ ```
242
+ # Feel free to change the mesh shape to balance data and model parallel
243
+ mesh = keras.distribution.DeviceMesh(
244
+ shape=(1, 8), axis_names=('batch', 'model'),
245
+ devices=keras.distribution.list_devices())
246
+ layout_map = GemmaBackbone.get_layout_map(
247
+ mesh, model_parallel_dim_name="model")
248
+
249
+ distribution = keras.distribution.ModelParallel(
250
+ mesh, layout_map, batch_dim_name='batch')
251
+ with distribution.scope():
252
+ gemma_model = keras_hub.models.GemmaCausalLM.from_preset()
253
+ ```
254
+
255
+ Args:
256
+ device_mesh: The `keras.distribution.DeviceMesh` instance for
257
+ distribution.
258
+ model_parallel_dim_name: The axis name of the device mesh, where
259
+ the weights should be partition on.
260
+ data_parallel_dim_name: The axis name of the device mesh, where
261
+ the data should be partition on.
262
+ Return:
263
+ `keras.distribution.LayoutMap` that contains the sharding spec
264
+ of all the model weights.
265
+ """
266
+ # The weight path and shape of the Gemma backbone is like below (for 2G)
267
+ # token_embedding/embeddings, (256128, 2048), 524550144
268
+ # repeat block for decoder
269
+ # ...
270
+ # decoder_block_17/pre_attention_norm/scale, (2048,), 2048
271
+ # decoder_block_17/attention/query/kernel, (8, 2048, 256), 4194304
272
+ # decoder_block_17/attention/key/kernel, (8, 2048, 256), 4194304
273
+ # decoder_block_17/attention/value/kernel, (8, 2048, 256), 4194304
274
+ # decoder_block_17/attention/attention_output/kernel, (8, 256, 2048), 4194304
275
+ # decoder_block_17/pre_ffw_norm/scale, (2048,), 2048
276
+ # decoder_block_17/ffw_gating/kernel, (2048, 16384), 33554432
277
+ # decoder_block_17/ffw_gating_2/kernel, (2048, 16384), 33554432
278
+ # decoder_block_17/ffw_linear/kernel, (16384, 2048), 33554432
279
+ if not isinstance(device_mesh, keras.distribution.DeviceMesh):
280
+ raise ValueError(
281
+ "Invalid device_mesh type. Expected `keras.distribution.Device`,"
282
+ f" got {type(device_mesh)}"
283
+ )
284
+ if model_parallel_dim_name not in device_mesh.axis_names:
285
+ raise ValueError(
286
+ f"{model_parallel_dim_name} is not found in the "
287
+ f"device_mesh.axis_names. {device_mesh.axis_name=}"
288
+ )
289
+ if data_parallel_dim_name not in device_mesh.axis_names:
290
+ raise ValueError(
291
+ f"{data_parallel_dim_name} is not found in the "
292
+ f"device_mesh.axis_names. {device_mesh.axis_name=}"
293
+ )
294
+ # Note that it is possible to further config the mesh to be 3D, eg
295
+ # (data, seq, model). We leave it as 2D for now for simplicity.
296
+ data_dim = data_parallel_dim_name
297
+ model_dim = model_parallel_dim_name
298
+ # The sharding config is based on the Gemma team training config.
299
+ # See https://arxiv.org/abs/2403.08295
300
+ layout_map = keras.distribution.LayoutMap(device_mesh)
301
+ layout_map["token_embedding/embeddings"] = (model_dim, data_dim)
302
+ layout_map["decoder_block.*attention.*(query|key|value).kernel"] = (
303
+ model_dim,
304
+ data_dim,
305
+ None,
306
+ )
307
+ layout_map["decoder_block.*attention_output.kernel"] = (
308
+ model_dim,
309
+ None,
310
+ data_dim,
311
+ )
312
+ layout_map["decoder_block.*ffw_gating.kernel"] = (data_dim, model_dim)
313
+ layout_map["decoder_block.*ffw_gating_2.kernel"] = (data_dim, model_dim)
314
+ layout_map["decoder_block.*ffw_linear.kernel"] = (model_dim, data_dim)
315
+
316
+ return layout_map