keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,238 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+ from keras import ops
16
+
17
+ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
18
+ from keras_hub.src.utils.keras_utils import clone_initializer
19
+
20
+
21
+ # This is just a self-attention layer in Mistral. But it can be generalized
22
+ # to use the `keras_hub.layers.CachedMultiHeadAttention` API. Since this layer
23
+ # implements grouped-query attention and sliding window attention, it might be
24
+ # useful outside of Mistral itself.
25
+ # TODO(tirthasheshpatel): Generalize the attention layer
26
+ # TODO(tirthasheshpatel): Merge `LlamaAttention` with this layer
27
+ # TODO(tirthasheshpatel): Use flash attention
28
+ class CachedMistralAttention(keras.layers.Layer):
29
+ """A cached grounded query attention layer with sliding window."""
30
+
31
+ def __init__(
32
+ self,
33
+ num_query_heads,
34
+ num_key_value_heads,
35
+ rope_max_wavelength=10000,
36
+ rope_scaling_factor=1.0,
37
+ kernel_initializer="glorot_uniform",
38
+ sliding_window=512,
39
+ dropout=0,
40
+ **kwargs,
41
+ ):
42
+ super().__init__(**kwargs)
43
+ self._num_query_heads = num_query_heads
44
+ self._num_key_value_heads = num_key_value_heads
45
+ self._sliding_window = sliding_window
46
+ self._dropout = dropout
47
+
48
+ self._num_key_value_groups = num_query_heads // num_key_value_heads
49
+ self._rope_max_wavelength = rope_max_wavelength
50
+
51
+ self._kernel_initializer = keras.initializers.get(
52
+ clone_initializer(kernel_initializer)
53
+ )
54
+
55
+ self._rope_scaling_factor = rope_scaling_factor
56
+
57
+ def build(self, inputs_shape):
58
+ # Einsum variables:
59
+ # b = batch size
60
+ # q = query length
61
+ # k = key/value length
62
+ # m = model dim
63
+ # u = num query heads
64
+ # v = num key/value heads
65
+ # h = head dim
66
+ self._hidden_dim = inputs_shape[-1]
67
+ self._head_dim = self._hidden_dim // self._num_query_heads
68
+
69
+ self._query_dense = keras.layers.EinsumDense(
70
+ equation="bqm,muh->bquh",
71
+ output_shape=(None, self._num_query_heads, self._head_dim),
72
+ kernel_initializer=self._kernel_initializer,
73
+ dtype=self.dtype_policy,
74
+ name="query",
75
+ )
76
+ self._query_dense.build(inputs_shape)
77
+
78
+ self._key_dense = keras.layers.EinsumDense(
79
+ equation="bkm,mvh->bkvh",
80
+ output_shape=(
81
+ None,
82
+ self._num_key_value_heads,
83
+ self._head_dim,
84
+ ),
85
+ kernel_initializer=self._kernel_initializer,
86
+ dtype=self.dtype_policy,
87
+ name="key",
88
+ )
89
+ self._key_dense.build(inputs_shape)
90
+
91
+ self._value_dense = keras.layers.EinsumDense(
92
+ equation="bkm,mvh->bkvh",
93
+ output_shape=(
94
+ None,
95
+ self._num_key_value_heads,
96
+ self._head_dim,
97
+ ),
98
+ kernel_initializer=self._kernel_initializer,
99
+ dtype=self.dtype_policy,
100
+ name="value",
101
+ )
102
+ self._value_dense.build(inputs_shape)
103
+
104
+ self._softmax = keras.layers.Softmax(
105
+ axis=-1,
106
+ dtype="float32",
107
+ name="attention_softmax",
108
+ )
109
+
110
+ self._dropout_layer = keras.layers.Dropout(
111
+ rate=self._dropout,
112
+ dtype=self.dtype_policy,
113
+ )
114
+
115
+ self._output_dense = keras.layers.EinsumDense(
116
+ equation="bquh,uhm->bqm",
117
+ output_shape=(None, self._hidden_dim),
118
+ kernel_initializer=self._kernel_initializer,
119
+ dtype=self.dtype_policy,
120
+ name="attention_output",
121
+ )
122
+ self._output_dense.build(
123
+ (None, None, self._num_query_heads, self._head_dim)
124
+ )
125
+
126
+ self.rotary_embedding_layer = RotaryEmbedding(
127
+ max_wavelength=self._rope_max_wavelength,
128
+ scaling_factor=self._rope_scaling_factor,
129
+ dtype=self.dtype_policy,
130
+ )
131
+
132
+ self._dot_product_equation = "bquh,bkuh->buqk"
133
+ self._combine_equation = "buqk,bkuh->bquh"
134
+
135
+ self.built = True
136
+
137
+ def call(
138
+ self,
139
+ hidden_states,
140
+ attention_mask=None,
141
+ cache=None,
142
+ cache_update_index=None,
143
+ training=None,
144
+ ):
145
+ start_index = (
146
+ cache_update_index if cache_update_index is not None else 0
147
+ )
148
+
149
+ query = self._query_dense(hidden_states)
150
+
151
+ # Compute RoPE for queries
152
+ query = self.rotary_embedding_layer(query, start_index=start_index)
153
+
154
+ def _compute_key_value(x):
155
+ key, value = self._key_dense(x), self._value_dense(x)
156
+ # Compute RoPE for keys
157
+ key = self.rotary_embedding_layer(key, start_index=start_index)
158
+ return key, value
159
+
160
+ if cache is not None:
161
+ key_cache = cache[:, 0, ...]
162
+ value_cache = cache[:, 1, ...]
163
+ if cache_update_index is None:
164
+ key = key_cache
165
+ value = value_cache
166
+ else:
167
+ key_update, value_update = _compute_key_value(hidden_states)
168
+ start = [0, cache_update_index, 0, 0]
169
+ key = ops.slice_update(key_cache, start, key_update)
170
+ value = ops.slice_update(value_cache, start, value_update)
171
+ cache = ops.stack((key, value), axis=1)
172
+ else:
173
+ if cache_update_index is not None:
174
+ raise ValueError(
175
+ "`cache_update_index` should not be set if `cache` is "
176
+ f"`None`. Received: cache={cache}, "
177
+ f"cache_update_index={cache_update_index}"
178
+ )
179
+ key, value = _compute_key_value(hidden_states)
180
+
181
+ # [batch_shape, seq_len, num_key_value_heads, head_dim]
182
+ # -> [batch_shape, seq_len, num_heads, head_dim]
183
+ key = ops.repeat(key, repeats=self._num_key_value_groups, axis=2)
184
+ value = ops.repeat(value, repeats=self._num_key_value_groups, axis=2)
185
+
186
+ attention_output = self._compute_attention(
187
+ query, key, value, attention_mask
188
+ )
189
+
190
+ attention_output = self._dropout_layer(
191
+ attention_output, training=training
192
+ )
193
+
194
+ attention_output = self._output_dense(attention_output)
195
+
196
+ if cache is not None:
197
+ return attention_output, cache
198
+ return attention_output
199
+
200
+ def _masked_softmax(self, attention_scores, attention_mask=None):
201
+ if attention_mask is not None:
202
+ return self._softmax(
203
+ attention_scores, attention_mask[:, None, :, :]
204
+ )
205
+ return self._softmax(attention_scores)
206
+
207
+ def _compute_attention(self, query, key, value, attention_mask=None):
208
+ attention_scores = ops.einsum(self._dot_product_equation, query, key)
209
+
210
+ norm_factor = ops.sqrt(ops.cast(self._head_dim, self.compute_dtype))
211
+
212
+ attention_scores = attention_scores / norm_factor
213
+ attention_scores = self._masked_softmax(
214
+ attention_scores, attention_mask
215
+ )
216
+ attention_scores = ops.cast(attention_scores, self.compute_dtype)
217
+ attention_output = ops.einsum(
218
+ self._combine_equation, attention_scores, value
219
+ )
220
+
221
+ return attention_output
222
+
223
+ def get_config(self):
224
+ config = super().get_config()
225
+ config.update(
226
+ {
227
+ "num_query_heads": self._num_query_heads,
228
+ "num_key_value_heads": self._num_key_value_heads,
229
+ "rope_max_wavelength": self._rope_max_wavelength,
230
+ "rope_scaling_factor": self._rope_scaling_factor,
231
+ "kernel_initializer": keras.initializers.serialize(
232
+ self._kernel_initializer
233
+ ),
234
+ "sliding_window": self._sliding_window,
235
+ "dropout": self._dropout,
236
+ }
237
+ )
238
+ return config
@@ -0,0 +1,203 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.layers.modeling.reversible_embedding import (
20
+ ReversibleEmbedding,
21
+ )
22
+ from keras_hub.src.models.backbone import Backbone
23
+ from keras_hub.src.models.mistral.mistral_layer_norm import (
24
+ MistralLayerNormalization,
25
+ )
26
+ from keras_hub.src.models.mistral.mistral_transformer_decoder import (
27
+ MistralTransformerDecoder,
28
+ )
29
+
30
+
31
+ def _mistral_kernel_initializer(stddev=0.02):
32
+ return keras.initializers.RandomNormal(stddev=stddev)
33
+
34
+
35
+ @keras_hub_export("keras_hub.models.MistralBackbone")
36
+ class MistralBackbone(Backbone):
37
+ """
38
+ The Mistral Transformer core architecture with hyperparameters.
39
+
40
+ This network implements a Transformer-based decoder network,
41
+ Mistral, as described in
42
+ ["Mistral 7B"](https://arxiv.org/pdf/2310.06825.pdf).
43
+ It includes the embedding lookups and transformer layers.
44
+
45
+ The default constructor gives a fully customizable, randomly initialized
46
+ Mistral model with any number of layers, heads, and embedding
47
+ dimensions. To load preset architectures and weights, use the `from_preset`
48
+ constructor.
49
+
50
+ Args:
51
+ vocabulary_size (int): The size of the token vocabulary.
52
+ num_layers (int): The number of transformer layers.
53
+ num_query_heads (int): The number of query attention heads for
54
+ each transformer.
55
+ hidden_dim (int): The size of the transformer encoding and pooling layers.
56
+ intermediate_dim (int): The output dimension of the first Dense layer in a
57
+ three-layer feedforward network for each transformer.
58
+ num_key_value_heads (int): The number of key and value attention heads for
59
+ each transformer.
60
+ rope_max_wavelength (int, optional): The maximum angular wavelength of the
61
+ sine/cosine curves, for rotary embeddings. Defaults to `10000`.
62
+ rope_scaling_factor (float, optional): The scaling factor for calculation
63
+ of roatary embedding. Defaults to `1.0`.
64
+ layer_norm_epsilon (float, optional): Epsilon for the layer normalization
65
+ layers in the transformer decoder. Defaults to `1e-6`.
66
+ sliding_window (int, optional): The sliding window for the mistral
67
+ attention layers. This controls the maximum cache size for the attention
68
+ layers in each transformer decoder. Only `sliding_window` number of tokens
69
+ are saved in the cache and used to generate the next token.
70
+ Defaults to `512`.
71
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
72
+ for model computations and weights. Note that some computations,
73
+ such as softmax and layer normalization, will always be done at
74
+ float32 precision regardless of dtype.
75
+
76
+ Examples:
77
+
78
+ ```python
79
+ input_data = {
80
+ "token_ids": np.ones(shape=(1, 12), dtype="int32"),
81
+ "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
82
+ }
83
+
84
+ # Pretrained Mistral decoder.
85
+ model = keras_hub.models.MistralBackbone.from_preset("mistral7b_base_en")
86
+ model(input_data)
87
+
88
+ # Randomly initialized Mistral decoder with custom config.
89
+ model = keras_hub.models.MistralBackbone(
90
+ vocabulary_size=10,
91
+ hidden_dim=512,
92
+ num_layers=2,
93
+ num_query_heads=32,
94
+ num_key_value_heads=8,
95
+ intermediate_dim=1024,
96
+ sliding_window=512,
97
+ layer_norm_epsilon=1e-6,
98
+ dtype="float32"
99
+ )
100
+ model(input_data)
101
+ ```
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ vocabulary_size,
107
+ num_layers,
108
+ num_query_heads,
109
+ hidden_dim,
110
+ intermediate_dim,
111
+ num_key_value_heads,
112
+ rope_max_wavelength=10000,
113
+ rope_scaling_factor=1.0,
114
+ layer_norm_epsilon=1e-6,
115
+ sliding_window=512,
116
+ dropout=0,
117
+ dtype=None,
118
+ **kwargs,
119
+ ):
120
+ # === Layers ===
121
+ self.token_embedding = ReversibleEmbedding(
122
+ input_dim=vocabulary_size,
123
+ output_dim=hidden_dim,
124
+ tie_weights=False,
125
+ embeddings_initializer=_mistral_kernel_initializer(stddev=0.01),
126
+ dtype=dtype,
127
+ name="token_embedding",
128
+ )
129
+ self.transformer_layers = []
130
+ for i in range(num_layers):
131
+ layer = MistralTransformerDecoder(
132
+ intermediate_dim=intermediate_dim,
133
+ num_query_heads=num_query_heads,
134
+ num_key_value_heads=num_key_value_heads,
135
+ rope_max_wavelength=rope_max_wavelength,
136
+ rope_scaling_factor=rope_scaling_factor,
137
+ layer_norm_epsilon=layer_norm_epsilon,
138
+ activation=ops.silu,
139
+ kernel_initializer=_mistral_kernel_initializer(stddev=0.02),
140
+ sliding_window=sliding_window,
141
+ dropout=dropout,
142
+ dtype=dtype,
143
+ name=f"transformer_layer_{i}",
144
+ )
145
+ self.transformer_layers.append(layer)
146
+ self.layer_norm = MistralLayerNormalization(
147
+ epsilon=layer_norm_epsilon,
148
+ dtype=dtype,
149
+ name="sequence_output_layernorm",
150
+ )
151
+
152
+ # === Functional Model ===
153
+ token_id_input = keras.Input(
154
+ shape=(None,), dtype="int32", name="token_ids"
155
+ )
156
+ padding_mask_input = keras.Input(
157
+ shape=(None,), dtype="int32", name="padding_mask"
158
+ )
159
+ x = self.token_embedding(token_id_input)
160
+ for transformer_layer in self.transformer_layers:
161
+ x = transformer_layer(x, decoder_padding_mask=padding_mask_input)
162
+ sequence_output = self.layer_norm(x)
163
+ super().__init__(
164
+ inputs={
165
+ "token_ids": token_id_input,
166
+ "padding_mask": padding_mask_input,
167
+ },
168
+ outputs=sequence_output,
169
+ dtype=dtype,
170
+ **kwargs,
171
+ )
172
+
173
+ # === Config ===
174
+ self.vocabulary_size = vocabulary_size
175
+ self.num_layers = num_layers
176
+ self.num_query_heads = num_query_heads
177
+ self.hidden_dim = hidden_dim
178
+ self.intermediate_dim = intermediate_dim
179
+ self.rope_max_wavelength = rope_max_wavelength
180
+ self.num_key_value_heads = num_key_value_heads
181
+ self.rope_scaling_factor = rope_scaling_factor
182
+ self.sliding_window = sliding_window
183
+ self.layer_norm_epsilon = layer_norm_epsilon
184
+ self.dropout = dropout
185
+
186
+ def get_config(self):
187
+ config = super().get_config()
188
+ config.update(
189
+ {
190
+ "vocabulary_size": self.vocabulary_size,
191
+ "num_layers": self.num_layers,
192
+ "num_query_heads": self.num_query_heads,
193
+ "hidden_dim": self.hidden_dim,
194
+ "intermediate_dim": self.intermediate_dim,
195
+ "rope_max_wavelength": self.rope_max_wavelength,
196
+ "rope_scaling_factor": self.rope_scaling_factor,
197
+ "num_key_value_heads": self.num_key_value_heads,
198
+ "sliding_window": self.sliding_window,
199
+ "layer_norm_epsilon": self.layer_norm_epsilon,
200
+ "dropout": self.dropout,
201
+ }
202
+ )
203
+ return config