keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,225 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+ from keras import ops
16
+
17
+ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
18
+ from keras_hub.src.utils.keras_utils import clone_initializer
19
+
20
+
21
+ class LlamaAttention(keras.layers.Layer):
22
+ """A cached grounded query attention layer with sliding window."""
23
+
24
+ def __init__(
25
+ self,
26
+ num_query_heads,
27
+ num_key_value_heads,
28
+ rope_max_wavelength=10000,
29
+ rope_scaling_factor=1.0,
30
+ kernel_initializer="glorot_uniform",
31
+ dropout=0,
32
+ **kwargs,
33
+ ):
34
+ super().__init__(**kwargs)
35
+ self.num_query_heads = num_query_heads
36
+ self.num_key_value_heads = num_key_value_heads
37
+ self.dropout = dropout
38
+
39
+ self.num_key_value_groups = num_query_heads // num_key_value_heads
40
+ self.rope_max_wavelength = rope_max_wavelength
41
+
42
+ self.kernel_initializer = keras.initializers.get(
43
+ clone_initializer(kernel_initializer)
44
+ )
45
+
46
+ self.rope_scaling_factor = rope_scaling_factor
47
+
48
+ def build(self, inputs_shape):
49
+ # Einsum variables:
50
+ # b = batch size
51
+ # q = query length
52
+ # k = key/value length
53
+ # m = model dim
54
+ # u = num query heads
55
+ # v = num key/value heads
56
+ # h = head dim
57
+ hidden_dim = inputs_shape[-1]
58
+ head_dim = hidden_dim // self.num_query_heads
59
+ self._norm_factor = ops.sqrt(ops.cast(head_dim, self.compute_dtype))
60
+
61
+ self._query_dense = keras.layers.EinsumDense(
62
+ equation="bqm,muh->bquh",
63
+ output_shape=(None, self.num_query_heads, head_dim),
64
+ kernel_initializer=self.kernel_initializer,
65
+ dtype=self.dtype_policy,
66
+ name="query",
67
+ )
68
+ self._query_dense.build(inputs_shape)
69
+
70
+ self._key_dense = keras.layers.EinsumDense(
71
+ equation="bkm,mvh->bkvh",
72
+ output_shape=(
73
+ None,
74
+ self.num_key_value_heads,
75
+ head_dim,
76
+ ),
77
+ kernel_initializer=self.kernel_initializer,
78
+ dtype=self.dtype_policy,
79
+ name="key",
80
+ )
81
+ self._key_dense.build(inputs_shape)
82
+
83
+ self._value_dense = keras.layers.EinsumDense(
84
+ equation="bkm,mvh->bkvh",
85
+ output_shape=(
86
+ None,
87
+ self.num_key_value_heads,
88
+ head_dim,
89
+ ),
90
+ kernel_initializer=self.kernel_initializer,
91
+ dtype=self.dtype_policy,
92
+ name="value",
93
+ )
94
+ self._value_dense.build(inputs_shape)
95
+
96
+ self._softmax = keras.layers.Softmax(
97
+ axis=-1,
98
+ dtype="float32",
99
+ name="attention_softmax",
100
+ )
101
+
102
+ self._dropout_layer = keras.layers.Dropout(
103
+ rate=self.dropout,
104
+ dtype=self.dtype_policy,
105
+ )
106
+
107
+ self._output_dense = keras.layers.EinsumDense(
108
+ equation="bquh,uhm->bqm",
109
+ output_shape=(None, hidden_dim),
110
+ kernel_initializer=self.kernel_initializer,
111
+ dtype=self.dtype_policy,
112
+ name="attention_output",
113
+ )
114
+ self._output_dense.build((None, None, self.num_query_heads, head_dim))
115
+
116
+ self.rotary_embedding_layer = RotaryEmbedding(
117
+ max_wavelength=self.rope_max_wavelength,
118
+ scaling_factor=self.rope_scaling_factor,
119
+ dtype=self.dtype_policy,
120
+ )
121
+
122
+ self._dot_product_equation = "bquh,bkuh->buqk"
123
+ self._combine_equation = "buqk,bkuh->bquh"
124
+
125
+ self.built = True
126
+
127
+ def call(
128
+ self,
129
+ hidden_states,
130
+ attention_mask=None,
131
+ cache=None,
132
+ cache_update_index=None,
133
+ training=None,
134
+ ):
135
+ start_index = (
136
+ cache_update_index if cache_update_index is not None else 0
137
+ )
138
+
139
+ query = self._query_dense(hidden_states)
140
+
141
+ # Compute RoPE for queries
142
+ query = self.rotary_embedding_layer(query, start_index=start_index)
143
+
144
+ def _compute_key_value(x):
145
+ key, value = self._key_dense(x), self._value_dense(x)
146
+ # Compute RoPE for keys
147
+ key = self.rotary_embedding_layer(key, start_index=start_index)
148
+ return key, value
149
+
150
+ if cache is not None:
151
+ key_cache = cache[:, 0, ...]
152
+ value_cache = cache[:, 1, ...]
153
+ if cache_update_index is None:
154
+ key = key_cache
155
+ value = value_cache
156
+ else:
157
+ key_update, value_update = _compute_key_value(hidden_states)
158
+ start = [0, cache_update_index, 0, 0]
159
+ key = ops.slice_update(key_cache, start, key_update)
160
+ value = ops.slice_update(value_cache, start, value_update)
161
+ cache = ops.stack((key, value), axis=1)
162
+ else:
163
+ if cache_update_index is not None:
164
+ raise ValueError(
165
+ "`cache_update_index` should not be set if `cache` is "
166
+ f"`None`. Received: cache={cache}, "
167
+ f"cache_update_index={cache_update_index}"
168
+ )
169
+ key, value = _compute_key_value(hidden_states)
170
+
171
+ # [batch_shape, seq_len, num_key_value_heads, head_dim]
172
+ # -> [batch_shape, seq_len, num_heads, head_dim]
173
+ key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
174
+ value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
175
+
176
+ attention_output = self._compute_attention(
177
+ query, key, value, attention_mask
178
+ )
179
+
180
+ attention_output = self._dropout_layer(
181
+ attention_output, training=training
182
+ )
183
+
184
+ attention_output = self._output_dense(attention_output)
185
+
186
+ if cache is not None:
187
+ return attention_output, cache
188
+ return attention_output
189
+
190
+ def _masked_softmax(self, attention_scores, attention_mask=None):
191
+ if attention_mask is not None:
192
+ return self._softmax(
193
+ attention_scores, attention_mask[:, None, :, :]
194
+ )
195
+ return self._softmax(attention_scores)
196
+
197
+ def _compute_attention(self, query, key, value, attention_mask=None):
198
+ attention_scores = ops.einsum(self._dot_product_equation, query, key)
199
+
200
+ attention_scores = attention_scores / self._norm_factor
201
+ attention_scores = self._masked_softmax(
202
+ attention_scores, attention_mask
203
+ )
204
+ attention_scores = ops.cast(attention_scores, self.compute_dtype)
205
+ attention_output = ops.einsum(
206
+ self._combine_equation, attention_scores, value
207
+ )
208
+
209
+ return attention_output
210
+
211
+ def get_config(self):
212
+ config = super().get_config()
213
+ config.update(
214
+ {
215
+ "num_query_heads": self.num_query_heads,
216
+ "num_key_value_heads": self.num_key_value_heads,
217
+ "rope_max_wavelength": self.rope_max_wavelength,
218
+ "rope_scaling_factor": self.rope_scaling_factor,
219
+ "kernel_initializer": keras.initializers.serialize(
220
+ self.kernel_initializer
221
+ ),
222
+ "dropout": self.dropout,
223
+ }
224
+ )
225
+ return config
@@ -0,0 +1,188 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+ from keras import ops
16
+
17
+ from keras_hub.src.api_export import keras_hub_export
18
+ from keras_hub.src.layers.modeling.reversible_embedding import (
19
+ ReversibleEmbedding,
20
+ )
21
+ from keras_hub.src.models.backbone import Backbone
22
+ from keras_hub.src.models.llama.llama_decoder import LlamaTransformerDecoder
23
+ from keras_hub.src.models.llama.llama_layernorm import LlamaLayerNorm
24
+
25
+
26
+ def _llama_kernel_initializer(stddev=0.02):
27
+ return keras.initializers.RandomNormal(stddev=stddev)
28
+
29
+
30
+ @keras_hub_export("keras_hub.models.LlamaBackbone")
31
+ class LlamaBackbone(Backbone):
32
+ """
33
+ The Llama Transformer core architecture with hyperparameters.
34
+
35
+ This network implements a Transformer-based decoder network,
36
+ Llama, as described in
37
+ ["Llama 7B"](https://arxiv.org/pdf/2310.06825.pdf).
38
+ It includes the embedding lookups and transformer layers.
39
+
40
+ The default constructor gives a fully customizable, randomly initialized
41
+ Llama model with any number of layers, heads, and embedding
42
+ dimensions. To load preset architectures and weights, use the `from_preset`
43
+ constructor.
44
+
45
+ Args:
46
+ vocabulary_size (int): The size of the token vocabulary.
47
+ num_layers (int): The number of transformer layers.
48
+ num_query_heads (int): The number of query attention heads for
49
+ each transformer.
50
+ hidden_dim (int): The size of the transformer encoding and pooling layers.
51
+ intermediate_dim (int): The output dimension of the first Dense layer in a
52
+ three-layer feedforward network for each transformer.
53
+ num_key_value_heads (int): The number of key and value attention heads for
54
+ each transformer.
55
+ rope_max_wavelength (int, optional): The maximum angular wavelength of the
56
+ sine/cosine curves, for rotary embeddings. Defaults to `10000`.
57
+ rope_scaling_factor (float, optional): The scaling factor for calculation
58
+ of roatary embedding. Defaults to `1.0`.
59
+ layer_norm_epsilon (float, optional): Epsilon for the layer normalization
60
+ layers in the transformer decoder. Defaults to `1e-6`.
61
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
62
+ for model computations and weights. Note that some computations,
63
+ such as softmax and layer normalization, will always be done at
64
+ float32 precision regardless of dtype.
65
+
66
+ Examples:
67
+
68
+ ```python
69
+ input_data = {
70
+ "token_ids": np.ones(shape=(1, 12), dtype="int32"),
71
+ "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
72
+ }
73
+
74
+ # Pretrained Llama decoder.
75
+ model = keras_hub.models.LlamaBackbone.from_preset("llama7b_base_en")
76
+ model(input_data)
77
+
78
+ # Randomly initialized Llama decoder with custom config.
79
+ model = keras_hub.models.LlamaBackbone(
80
+ vocabulary_size=10,
81
+ hidden_dim=512,
82
+ num_layers=2,
83
+ num_query_heads=32,
84
+ num_key_value_heads=8,
85
+ intermediate_dim=1024,
86
+ layer_norm_epsilon=1e-6,
87
+ dtype="float32"
88
+ )
89
+ model(input_data)
90
+ ```
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ vocabulary_size,
96
+ num_layers,
97
+ num_query_heads,
98
+ hidden_dim,
99
+ intermediate_dim,
100
+ num_key_value_heads,
101
+ rope_max_wavelength=10000,
102
+ rope_scaling_factor=1.0,
103
+ layer_norm_epsilon=1e-6,
104
+ dropout=0,
105
+ dtype=None,
106
+ **kwargs,
107
+ ):
108
+ # === Layers ===
109
+ self.token_embedding = ReversibleEmbedding(
110
+ input_dim=vocabulary_size,
111
+ output_dim=hidden_dim,
112
+ tie_weights=False,
113
+ embeddings_initializer=_llama_kernel_initializer(stddev=0.01),
114
+ dtype=dtype,
115
+ name="token_embedding",
116
+ )
117
+ self.transformer_layers = []
118
+ for i in range(num_layers):
119
+ layer = LlamaTransformerDecoder(
120
+ intermediate_dim=intermediate_dim,
121
+ num_query_heads=num_query_heads,
122
+ num_key_value_heads=num_key_value_heads,
123
+ rope_max_wavelength=rope_max_wavelength,
124
+ rope_scaling_factor=rope_scaling_factor,
125
+ layer_norm_epsilon=layer_norm_epsilon,
126
+ activation=ops.silu,
127
+ kernel_initializer=_llama_kernel_initializer(stddev=0.02),
128
+ dropout=dropout,
129
+ dtype=dtype,
130
+ name=f"transformer_layer_{i}",
131
+ )
132
+ self.transformer_layers.append(layer)
133
+ self.layer_norm = LlamaLayerNorm(
134
+ epsilon=layer_norm_epsilon,
135
+ dtype=dtype,
136
+ name="sequence_output_layernorm",
137
+ )
138
+
139
+ # === Functional Model ===
140
+ token_id_input = keras.Input(
141
+ shape=(None,), dtype="int32", name="token_ids"
142
+ )
143
+ padding_mask_input = keras.Input(
144
+ shape=(None,), dtype="int32", name="padding_mask"
145
+ )
146
+ x = self.token_embedding(token_id_input)
147
+ for transformer_layer in self.transformer_layers:
148
+ x = transformer_layer(x, decoder_padding_mask=padding_mask_input)
149
+ sequence_output = self.layer_norm(x)
150
+ super().__init__(
151
+ inputs={
152
+ "token_ids": token_id_input,
153
+ "padding_mask": padding_mask_input,
154
+ },
155
+ outputs=sequence_output,
156
+ dtype=dtype,
157
+ **kwargs,
158
+ )
159
+
160
+ # === Config ===
161
+ self.vocabulary_size = vocabulary_size
162
+ self.num_layers = num_layers
163
+ self.num_query_heads = num_query_heads
164
+ self.hidden_dim = hidden_dim
165
+ self.intermediate_dim = intermediate_dim
166
+ self.rope_max_wavelength = rope_max_wavelength
167
+ self.num_key_value_heads = num_key_value_heads
168
+ self.rope_scaling_factor = rope_scaling_factor
169
+ self.layer_norm_epsilon = layer_norm_epsilon
170
+ self.dropout = dropout
171
+
172
+ def get_config(self):
173
+ config = super().get_config()
174
+ config.update(
175
+ {
176
+ "vocabulary_size": self.vocabulary_size,
177
+ "num_layers": self.num_layers,
178
+ "num_query_heads": self.num_query_heads,
179
+ "hidden_dim": self.hidden_dim,
180
+ "intermediate_dim": self.intermediate_dim,
181
+ "rope_max_wavelength": self.rope_max_wavelength,
182
+ "rope_scaling_factor": self.rope_scaling_factor,
183
+ "num_key_value_heads": self.num_key_value_heads,
184
+ "layer_norm_epsilon": self.layer_norm_epsilon,
185
+ "dropout": self.dropout,
186
+ }
187
+ )
188
+ return config