keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,218 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from keras import ops
15
+
16
+ from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.models.causal_lm import CausalLM
18
+ from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone
19
+ from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import (
20
+ Phi3CausalLMPreprocessor,
21
+ )
22
+ from keras_hub.src.utils.python_utils import classproperty
23
+ from keras_hub.src.utils.tensor_utils import any_equal
24
+
25
+
26
+ @keras_hub_export("keras_hub.models.Phi3CausalLM")
27
+ class Phi3CausalLM(CausalLM):
28
+ """An end-to-end Phi3 model for causal language modeling.
29
+
30
+ A causal language model (LM) predicts the next token based on previous
31
+ tokens. This task setup can be used to train the model unsupervised on
32
+ plain text input, or to autoregressively generate plain text similar to
33
+ the data used for training. This task can be used for pre-training or
34
+ fine-tuning a Phi-3 model, simply by calling `fit()`.
35
+
36
+ This model has a `generate()` method, which generates text based on a
37
+ prompt. The generation strategy used is controlled by an additional
38
+ `sampler` argument on `compile()`. You can recompile the model with
39
+ different `keras_hub.samplers` objects to control the generation. By
40
+ default, `"top_k"` sampling will be used.
41
+
42
+ Args:
43
+ backbone: A `keras_hub.models.Phi3Backbone` instance.
44
+ preprocessor: A `keras_hub.models.Phi3CausalLMPreprocessor` or `None`.
45
+ If `None`, this model will not apply preprocessing, and inputs
46
+ should be preprocessed before calling the model.
47
+ """
48
+
49
+ def __init__(self, backbone, preprocessor=None, **kwargs):
50
+ # === Layers ===
51
+ self.backbone = backbone
52
+ self.preprocessor = preprocessor
53
+
54
+ # === Functional Model ===
55
+ inputs = backbone.inputs
56
+ hidden_states = backbone(inputs)
57
+ outputs = backbone.token_embedding(hidden_states, reverse=True)
58
+ super().__init__(
59
+ inputs=inputs,
60
+ outputs=outputs,
61
+ **kwargs,
62
+ )
63
+
64
+ @classproperty
65
+ def backbone_cls(cls):
66
+ return Phi3Backbone
67
+
68
+ @classproperty
69
+ def preprocessor_cls(cls):
70
+ return Phi3CausalLMPreprocessor
71
+
72
+ def call_with_cache(
73
+ self,
74
+ token_ids,
75
+ cache,
76
+ cache_update_index,
77
+ ):
78
+ """Forward pass of `Phi3CausalLM` with cache.
79
+
80
+ `call_with_cache` adds an additional forward pass for the model for
81
+ autoregressive inference. Unlike calling the model directly, this method
82
+ allows caching previous key/value Tensors in multi-head attention layer,
83
+ and avoids recomputing the outputs of seen tokens.
84
+
85
+ Args:
86
+ token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
87
+ cache: a dense float Tensor, the cache of key and value.
88
+ cache_update_index: int, or int Tensor. The index of current inputs
89
+ in the whole sequence.
90
+
91
+ Returns:
92
+ A (logits, hidden_states, cache) tuple. Where `logits` is the
93
+ language model logits for the input token_ids, `hidden_states` is
94
+ the final hidden representation of the input tokens, and `cache` is
95
+ the decoding cache.
96
+ """
97
+ x = self.backbone.token_embedding(token_ids)
98
+ # Each decoder layer has a cache; we update them separately.
99
+ updated_cache = []
100
+ for i in range(self.backbone.num_layers):
101
+ current_cache = cache[:, i, ...]
102
+ x, next_cache = self.backbone.transformer_layers[i](
103
+ x,
104
+ attention_cache=current_cache,
105
+ attention_cache_update_index=cache_update_index,
106
+ )
107
+ updated_cache.append(next_cache)
108
+ cache = ops.stack(updated_cache, axis=1)
109
+ hidden_states = x = self.backbone.layer_norm(x)
110
+ logits = self.backbone.token_embedding(x, reverse=True)
111
+ return logits, hidden_states, cache
112
+
113
+ def _build_cache(self, token_ids):
114
+ """Build an empty cache for use with `call_with_cache()`."""
115
+ batch_size = ops.shape(token_ids)[0]
116
+ max_length = ops.shape(token_ids)[1]
117
+ num_layers = self.backbone.num_layers
118
+ num_key_value_heads = self.backbone.num_key_value_heads
119
+ head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads
120
+ shape = [
121
+ batch_size,
122
+ num_layers,
123
+ 2,
124
+ max_length,
125
+ num_key_value_heads,
126
+ head_dim,
127
+ ]
128
+ cache = ops.zeros(shape, dtype=self.compute_dtype)
129
+ # Seed the cache.
130
+ _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0)
131
+ return hidden_states, cache
132
+
133
+ def generate_step(
134
+ self,
135
+ inputs,
136
+ stop_token_ids=None,
137
+ ):
138
+ """A compilable generation function for a single batch of inputs.
139
+
140
+ This function represents the inner, XLA-compilable, generation function
141
+ for a single batch of inputs. Inputs should have the same structure as
142
+ model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
143
+
144
+ Args:
145
+ inputs: A dictionary with two keys `"token_ids"` and
146
+ `"padding_mask"` and batched tensor values.
147
+ stop_token_ids: Tuple of id's of the end token to stop on. If all
148
+ sequences have produced a new stop token, generation
149
+ will stop.
150
+ """
151
+ token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
152
+ # Create and seed cache with a single forward pass.
153
+ hidden_states, cache = self._build_cache(token_ids)
154
+ # Compute the lengths of all user inputted tokens ids.
155
+ row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
156
+ # Start at the first index that has no user inputted id.
157
+ index = ops.min(row_lengths)
158
+
159
+ def next(prompt, cache, index):
160
+ # The cache index is the index of our previous token.
161
+ cache_update_index = index - 1
162
+ batch_size = ops.shape(prompt)[0]
163
+ prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1])
164
+ logits, hidden_states, cache = self.call_with_cache(
165
+ prompt,
166
+ cache,
167
+ cache_update_index,
168
+ )
169
+ return (
170
+ ops.squeeze(logits, axis=1),
171
+ ops.squeeze(hidden_states, axis=1),
172
+ cache,
173
+ )
174
+
175
+ token_ids = self.sampler(
176
+ next=next,
177
+ prompt=token_ids,
178
+ cache=cache,
179
+ index=index,
180
+ mask=padding_mask,
181
+ stop_token_ids=stop_token_ids,
182
+ hidden_states=hidden_states,
183
+ model=self,
184
+ )
185
+
186
+ # Compute an output padding mask with the token ids we updated.
187
+ if stop_token_ids is not None:
188
+ # Build a mask of stop token locations not in the original
189
+ # prompt (not in locations where `padding_mask` is True).
190
+ end_locations = any_equal(
191
+ token_ids, stop_token_ids, ops.logical_not(padding_mask)
192
+ )
193
+ end_locations = ops.cast(end_locations, "int32")
194
+ # Use cumsum to get ones in all locations after end_locations.
195
+ cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
196
+ overflow = cumsum - end_locations
197
+ # Our padding mask is the inverse of these overflow locations.
198
+ padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
199
+ else:
200
+ # Without early stopping, all locations will have been updated.
201
+ padding_mask = ops.ones_like(token_ids, dtype="bool")
202
+ return {
203
+ "token_ids": token_ids,
204
+ "padding_mask": padding_mask,
205
+ }
206
+
207
+ def generate(self, inputs, max_length=None, stop_token_ids="auto"):
208
+ if self.preprocessor and stop_token_ids == "auto":
209
+ # Stop at:
210
+ # `<|endoftext|>` (end of sequence token).
211
+ # `<|end|>` (end of turn token).
212
+ stop_token_ids = [self.preprocessor.tokenizer.end_token_id]
213
+ end_of_turn_id = self.preprocessor.tokenizer.token_to_id("<|end|>")
214
+ if end_of_turn_id != 0:
215
+ # If `<|end|>` exists in the vocabulary.
216
+ stop_token_ids.append(end_of_turn_id)
217
+
218
+ return super().generate(inputs, max_length, stop_token_ids)
@@ -0,0 +1,173 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from absl import logging
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.models.phi3.phi3_preprocessor import Phi3Preprocessor
20
+ from keras_hub.src.utils.keras_utils import (
21
+ convert_inputs_to_list_of_tensor_segments,
22
+ )
23
+ from keras_hub.src.utils.tensor_utils import strip_to_ragged
24
+
25
+
26
+ @keras_hub_export("keras_hub.models.Phi3CausalLMPreprocessor")
27
+ class Phi3CausalLMPreprocessor(Phi3Preprocessor):
28
+ """Phi3 Causal LM preprocessor.
29
+
30
+ This preprocessing layer is meant for use with
31
+ `keras_hub.models.Phi3CausalLM`. By default, it will take in batches of
32
+ strings, and return outputs in a `(x, y, sample_weight)` format, where the
33
+ `y` label is the next token id in the `x` sequence.
34
+
35
+ For use with generation, the layer also exposes two methods
36
+ `generate_preprocess()` and `generate_postprocess()`. When this preprocessor
37
+ is attached to a `keras_hub.models.Phi3CausalLM` instance, these methods
38
+ will be called implicitly in `generate()`. They can also be called
39
+ standalone (e.g. to precompute preprocessing inputs for generation in a
40
+ separate process).
41
+
42
+ Args:
43
+ tokenizer: A `keras_hub.models.Phi3Tokenizer` instance.
44
+ sequence_length: The length of the packed inputs.
45
+ add_start_token: If `True`, the preprocessor will prepend the tokenizer
46
+ start token to each input sequence. Default is `True`.
47
+ add_end_token: If `True`, the preprocessor will append the tokenizer
48
+ end token to each input sequence. Default is `False`.
49
+
50
+ Call arguments:
51
+ x: A string, `tf.Tensor` or list of python strings.
52
+ y: Label data. Should always be `None` as the layer generates labels.
53
+ sample_weight: Label weights. Should always be `None` as the layer
54
+ generates label weights.
55
+ sequence_length: Pass to override the configured `sequence_length` of
56
+ the layer.
57
+
58
+ Examples:
59
+ ```python
60
+ # Load the preprocessor from a preset.
61
+ preprocessor = keras_hub.models.Phi3CausalLMPreprocessor.from_preset(
62
+ "phi3_mini_4k_instruct_en"
63
+ )
64
+
65
+ # Tokenize and pack a single sentence.
66
+ sentence = tf.constant("League of legends")
67
+ preprocessor(sentence)
68
+ # Same output.
69
+ preprocessor("League of legends")
70
+
71
+ # Tokenize a batch of sentences.
72
+ sentences = tf.constant(["Taco tuesday", "Fish taco please!"])
73
+ preprocessor(sentences)
74
+ # Same output.
75
+ preprocessor(["Taco tuesday", "Fish taco please!"])
76
+
77
+ # Map a dataset to preprocess a single sentence.
78
+ features = tf.constant(
79
+ [
80
+ "Avatar 2 is amazing!",
81
+ "Well, I am not sure.",
82
+ ]
83
+ )
84
+ labels = tf.constant([1, 0])
85
+ ds = tf.data.Dataset.from_tensor_slices((features, labels))
86
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
87
+
88
+ # Map a dataset to preprocess unlabled sentences.
89
+ ds = tf.data.Dataset.from_tensor_slices(features)
90
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
91
+ ```
92
+ """
93
+
94
+ def call(
95
+ self,
96
+ x,
97
+ y=None,
98
+ sample_weight=None,
99
+ sequence_length=None,
100
+ ):
101
+ if y is not None or sample_weight is not None:
102
+ logging.warning(
103
+ "`Phi3CausalLMPreprocessor` generates `y` and "
104
+ "`sample_weight` based on your input data, but your data "
105
+ "already contains `y` or `sample_weight`. Your `y` and "
106
+ "`sample_weight` will be ignored."
107
+ )
108
+ sequence_length = sequence_length or self.sequence_length
109
+
110
+ x = convert_inputs_to_list_of_tensor_segments(x)[0]
111
+ x = self.tokenizer(x)
112
+ # Pad with one extra token to account for the truncation below.
113
+ token_ids, padding_mask = self.packer(
114
+ x,
115
+ sequence_length=sequence_length + 1,
116
+ add_start_value=self.add_start_token,
117
+ add_end_value=self.add_end_token,
118
+ )
119
+ # The last token does not have a next token, so we truncate it out.
120
+ x = {
121
+ "token_ids": token_ids[..., :-1],
122
+ "padding_mask": padding_mask[..., :-1],
123
+ }
124
+ # Target `y` will be the next token.
125
+ y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
126
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
127
+
128
+ def generate_preprocess(
129
+ self,
130
+ x,
131
+ sequence_length=None,
132
+ ):
133
+ """Convert strings to integer token input for generation.
134
+
135
+ Similar to calling the layer for training, this method takes in strings
136
+ or tensor strings, tokenizes and packs the input, and computes a padding
137
+ mask masking all inputs not filled in with a padded value.
138
+
139
+ Unlike calling the layer for training, this method does not compute
140
+ labels and will never append a `tokenizer.end_token_id` to the end of
141
+ the sequence (as generation is expected to continue at the end of the
142
+ inputted prompt).
143
+ """
144
+ if not self.built:
145
+ self.build(None)
146
+
147
+ x = convert_inputs_to_list_of_tensor_segments(x)[0]
148
+ x = self.tokenizer(x)
149
+ token_ids, padding_mask = self.packer(
150
+ x, sequence_length=sequence_length, add_end_value=False
151
+ )
152
+ return {
153
+ "token_ids": token_ids,
154
+ "padding_mask": padding_mask,
155
+ }
156
+
157
+ def generate_postprocess(
158
+ self,
159
+ x,
160
+ ):
161
+ """Convert integer token output to strings for generation.
162
+
163
+ This method reverses `generate_preprocess()`, by first removing all
164
+ padding and start/end tokens, and then converting the integer sequence
165
+ back to a string.
166
+ """
167
+ token_ids, padding_mask = x["token_ids"], x["padding_mask"]
168
+ ids_to_strip = (
169
+ self.tokenizer.start_token_id,
170
+ self.tokenizer.end_token_id,
171
+ )
172
+ token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip)
173
+ return self.tokenizer.detokenize(token_ids)
@@ -0,0 +1,260 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+ from keras import ops
16
+
17
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
18
+ compute_causal_mask,
19
+ )
20
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
21
+ merge_padding_and_attention_mask,
22
+ )
23
+ from keras_hub.src.models.phi3.phi3_attention import Phi3Attention
24
+ from keras_hub.src.models.phi3.phi3_layernorm import Phi3LayerNorm
25
+ from keras_hub.src.utils.keras_utils import clone_initializer
26
+
27
+
28
+ class Phi3Decoder(keras.layers.Layer):
29
+ """A Transformer decoder layer for the Phi-3 backbone."""
30
+
31
+ def __init__(
32
+ self,
33
+ hidden_dim,
34
+ intermediate_dim,
35
+ num_query_heads,
36
+ num_key_value_heads,
37
+ activation="silu",
38
+ layer_norm_epsilon=1e-5,
39
+ kernel_initializer="glorot_uniform",
40
+ dropout=0,
41
+ max_sequence_length=4096,
42
+ pretraining_sequence_length=4096,
43
+ rope_max_wavelength=10000,
44
+ rope_scaling_type=None,
45
+ rope_scaling_short_factor=None,
46
+ rope_scaling_long_factor=None,
47
+ **kwargs,
48
+ ):
49
+ super().__init__(**kwargs)
50
+ self.hidden_dim = hidden_dim
51
+ self.intermediate_dim = intermediate_dim
52
+ self.num_query_heads = num_query_heads
53
+ self.num_key_value_heads = num_key_value_heads
54
+
55
+ self.max_sequence_length = max_sequence_length
56
+ self.pretraining_sequence_length = pretraining_sequence_length
57
+ self.rope_max_wavelength = rope_max_wavelength
58
+ self.rope_scaling_type = rope_scaling_type
59
+ self.rope_scaling_short_factor = rope_scaling_short_factor
60
+ self.rope_scaling_long_factor = rope_scaling_long_factor
61
+
62
+ self.dropout = dropout
63
+
64
+ self.layer_norm_epsilon = layer_norm_epsilon
65
+ self.activation = keras.activations.get(activation)
66
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
67
+
68
+ def build(self, decoder_sequence_shape):
69
+
70
+ # Pre-attention layernorm.
71
+ self.pre_attention_layernorm = Phi3LayerNorm(
72
+ epsilon=self.layer_norm_epsilon,
73
+ dtype=self.dtype_policy,
74
+ name="pre_attention_layernorm",
75
+ )
76
+ self.pre_attention_layernorm.build(decoder_sequence_shape)
77
+
78
+ # Self attention layer.
79
+ self.attention = Phi3Attention(
80
+ num_query_heads=self.num_query_heads,
81
+ num_key_value_heads=self.num_key_value_heads,
82
+ kernel_initializer=clone_initializer(self.kernel_initializer),
83
+ dropout=self.dropout,
84
+ max_sequence_length=self.max_sequence_length,
85
+ pretraining_sequence_length=self.pretraining_sequence_length,
86
+ rope_max_wavelength=self.rope_max_wavelength,
87
+ rope_scaling_type=self.rope_scaling_type,
88
+ rope_scaling_short_factor=self.rope_scaling_short_factor,
89
+ rope_scaling_long_factor=self.rope_scaling_long_factor,
90
+ dtype=self.dtype_policy,
91
+ name="attention",
92
+ )
93
+ self.attention.build(decoder_sequence_shape)
94
+
95
+ # Post-attention layernorm.
96
+ self.post_attention_layernorm = Phi3LayerNorm(
97
+ epsilon=self.layer_norm_epsilon,
98
+ dtype=self.dtype_policy,
99
+ name="post_attention_layernorm",
100
+ )
101
+ self.post_attention_layernorm.build(decoder_sequence_shape)
102
+
103
+ # feedforward layers.
104
+ self.feedforward_intermediate_dense = keras.layers.Dense(
105
+ self.intermediate_dim,
106
+ kernel_initializer=clone_initializer(self.kernel_initializer),
107
+ use_bias=False,
108
+ dtype=self.dtype_policy,
109
+ name="feedforward_intermediate_dense",
110
+ )
111
+ self.feedforward_intermediate_dense.build(decoder_sequence_shape)
112
+
113
+ self.feedforward_gate_dense = keras.layers.Dense(
114
+ self.intermediate_dim,
115
+ kernel_initializer=clone_initializer(self.kernel_initializer),
116
+ use_bias=False,
117
+ dtype=self.dtype_policy,
118
+ name="feedforward_gate_dense",
119
+ )
120
+ self.feedforward_gate_dense.build(decoder_sequence_shape)
121
+
122
+ self.feedforward_output_dense = keras.layers.Dense(
123
+ self.hidden_dim,
124
+ kernel_initializer=clone_initializer(self.kernel_initializer),
125
+ use_bias=False,
126
+ dtype=self.dtype_policy,
127
+ name="feedforward_output_dense",
128
+ )
129
+
130
+ self.feedforward_output_dense.build(
131
+ self.feedforward_gate_dense.compute_output_shape(
132
+ decoder_sequence_shape
133
+ )
134
+ )
135
+
136
+ # Dropout
137
+ self.attention_dropout = keras.layers.Dropout(
138
+ rate=self.dropout,
139
+ dtype=self.dtype_policy,
140
+ name="attention_dropout",
141
+ )
142
+ self.feedforward_dropout = keras.layers.Dropout(
143
+ rate=self.dropout,
144
+ dtype=self.dtype_policy,
145
+ name="feedforward_dropout",
146
+ )
147
+
148
+ self.built = True
149
+
150
+ def call(
151
+ self,
152
+ decoder_sequence,
153
+ decoder_padding_mask=None,
154
+ decoder_attention_mask=None,
155
+ attention_cache=None,
156
+ attention_cache_update_index=None,
157
+ ):
158
+ self_attention_mask = self._compute_self_attention_mask(
159
+ decoder_sequence=decoder_sequence,
160
+ decoder_padding_mask=decoder_padding_mask,
161
+ decoder_attention_mask=decoder_attention_mask,
162
+ attention_cache=attention_cache,
163
+ attention_cache_update_index=attention_cache_update_index,
164
+ )
165
+ residual = decoder_sequence
166
+ x = self.pre_attention_layernorm(decoder_sequence)
167
+ x = self.attention(
168
+ hidden_states=x,
169
+ attention_mask=self_attention_mask,
170
+ cache=attention_cache,
171
+ cache_update_index=attention_cache_update_index,
172
+ )
173
+ if attention_cache is not None:
174
+ x, attention_cache = x
175
+ x = self.attention_dropout(x)
176
+ x = x + residual
177
+
178
+ residual = x
179
+ x = self.post_attention_layernorm(x)
180
+ # Note that we run the activation function in full 32-bit
181
+ # precision since this is what `torch.nn.functional.silu`
182
+ # does. Internally, `torch.nn.functional.silu` converts the
183
+ # inputs to float32, computes SiLU, and converts the outputs
184
+ # back to compute dtype.
185
+ # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501
186
+ # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501
187
+ gate_output = self.feedforward_gate_dense(x)
188
+ gate_output = ops.cast(gate_output, "float32")
189
+ gate_output = self.activation(gate_output)
190
+ gate_output = ops.cast(gate_output, self.compute_dtype)
191
+ x = self.feedforward_intermediate_dense(x)
192
+ x = self.feedforward_output_dense(ops.multiply(x, gate_output))
193
+ x = self.feedforward_dropout(x)
194
+ decoder_output = x + residual
195
+
196
+ if attention_cache is not None:
197
+ return decoder_output, attention_cache
198
+ return decoder_output
199
+
200
+ def _compute_self_attention_mask(
201
+ self,
202
+ decoder_sequence,
203
+ decoder_padding_mask,
204
+ decoder_attention_mask,
205
+ attention_cache,
206
+ attention_cache_update_index,
207
+ ):
208
+ decoder_mask = merge_padding_and_attention_mask(
209
+ decoder_sequence, decoder_padding_mask, decoder_attention_mask
210
+ )
211
+ batch_size = ops.shape(decoder_sequence)[0]
212
+ input_length = output_length = ops.shape(decoder_sequence)[1]
213
+ # We need to handle a rectangular causal mask when doing cached
214
+ # decoding. For generative inference, `decoder_sequence` will
215
+ # generally be length 1, and `cache` will be the full generation length.
216
+ if attention_cache is not None:
217
+ input_length = ops.shape(attention_cache)[2]
218
+
219
+ cache_update_index = (
220
+ 0
221
+ if attention_cache_update_index is None
222
+ else attention_cache_update_index
223
+ )
224
+
225
+ causal_mask = compute_causal_mask(
226
+ batch_size, input_length, output_length, cache_update_index
227
+ )
228
+
229
+ return (
230
+ ops.minimum(decoder_mask, causal_mask)
231
+ if decoder_mask is not None
232
+ else causal_mask
233
+ )
234
+
235
+ def compute_output_shape(self, decoder_sequence_shape):
236
+ return decoder_sequence_shape
237
+
238
+ def get_config(self):
239
+ config = super().get_config()
240
+ config.update(
241
+ {
242
+ "hidden_dim": self.hidden_dim,
243
+ "intermediate_dim": self.intermediate_dim,
244
+ "num_query_heads": self.num_query_heads,
245
+ "num_key_value_heads": self.num_key_value_heads,
246
+ "activation": keras.activations.serialize(self.activation),
247
+ "layer_norm_epsilon": self.layer_norm_epsilon,
248
+ "kernel_initializer": keras.initializers.serialize(
249
+ self.kernel_initializer
250
+ ),
251
+ "dropout": self.dropout,
252
+ "max_sequence_length": self.max_sequence_length,
253
+ "pretraining_sequence_length": self.pretraining_sequence_length,
254
+ "rope_max_wavelength": self.rope_max_wavelength,
255
+ "rope_scaling_type": self.rope_scaling_type,
256
+ "rope_scaling_short_factor": self.rope_scaling_short_factor,
257
+ "rope_scaling_long_factor": self.rope_scaling_long_factor,
258
+ }
259
+ )
260
+ return config
@@ -0,0 +1,48 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+ from keras import ops
16
+
17
+
18
+ # TODO: Deprecate this in favor of
19
+ # `keras.layers.LayerNormalization(rms_scaling=True)` once Keras 2 support is
20
+ # removed.
21
+ class Phi3LayerNorm(keras.layers.Layer):
22
+ """A normalization layer for Phi-3 that implements RMS normalization."""
23
+
24
+ def __init__(self, epsilon=1e-6, **kwargs):
25
+ super().__init__(**kwargs)
26
+ self.epsilon = epsilon
27
+
28
+ def build(self, input_shape):
29
+ dim = input_shape[-1]
30
+ self.scale = self.add_weight(
31
+ name="scale",
32
+ trainable=True,
33
+ shape=(dim,),
34
+ initializer="ones",
35
+ dtype=self.variable_dtype,
36
+ )
37
+ self.built = True
38
+
39
+ def call(self, x):
40
+ x = ops.cast(x, "float32")
41
+ var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True)
42
+ x = x * ops.rsqrt(var + self.epsilon)
43
+ return ops.cast(x * self.scale, self.compute_dtype)
44
+
45
+ def get_config(self):
46
+ config = super().get_config()
47
+ config.update({"epsilon": self.epsilon})
48
+ return config