keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,496 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.layers.modeling.cached_multi_head_attention import (
20
+ CachedMultiHeadAttention,
21
+ )
22
+ from keras_hub.src.utils.keras_utils import clone_initializer
23
+
24
+ from keras_hub.src.layers.modeling.transformer_layer_utils import ( # isort:skip
25
+ compute_causal_mask,
26
+ merge_padding_and_attention_mask,
27
+ )
28
+
29
+
30
+ @keras_hub_export("keras_hub.layers.TransformerDecoder")
31
+ class TransformerDecoder(keras.layers.Layer):
32
+ """Transformer decoder.
33
+
34
+ This class follows the architecture of the transformer decoder layer in the
35
+ paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users
36
+ can instantiate multiple instances of this class to stack up a decoder.
37
+
38
+ By default, this layer will apply a causal mask to the decoder attention
39
+ layer. You can also pass padding or attention masks directly to the layer
40
+ during call, e.g. with `decoder_padding_mask` or `decoder_attention_mask`.
41
+
42
+ This layer can be called with either one or two inputs. The number of inputs
43
+ must be consistent across all calls. The options are as follows:
44
+ `layer(decoder_sequence)`: no cross-attention will be built into the
45
+ decoder block. This is useful when building a "decoder-only"
46
+ transformer such as GPT-2.
47
+ `layer(decoder_sequence, encoder_sequence)`: cross-attention will be
48
+ built into the decoder block. This is useful when building an
49
+ "encoder-decoder" transformer, such as the original transformer
50
+ model described in Attention is All You Need.
51
+
52
+ Args:
53
+ intermediate_dim: int, the hidden size of feedforward network.
54
+ num_heads: int, the number of heads in MultiHeadAttention.
55
+ dropout: float. the dropout value, shared by
56
+ MultiHeadAttention and feedforward network. Defaults to `0.`.
57
+ activation: string or `keras.activations`. the
58
+ activation function of feedforward network.
59
+ Defaults to `"relu"`.
60
+ layer_norm_epsilon: float. The eps value in layer
61
+ normalization components. Defaults to `1e-5`.
62
+ kernel_initializer: string or `keras.initializers` initializer.
63
+ The kernel initializer for the dense and multiheaded
64
+ attention layers. Defaults to `"glorot_uniform"`.
65
+ bias_initializer: string or `keras.initializers` initializer.
66
+ The bias initializer for the dense and multiheaded
67
+ attention layers. Defaults to `"zeros"`.
68
+ normalize_first: bool. If True, the inputs to the
69
+ attention layer(s) and the intermediate dense layer are normalized
70
+ (similar to GPT-2). If set to False, outputs of attention layer and
71
+ intermediate dense layer are normalized (similar to BERT).
72
+ Defaults to `False`.
73
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
74
+ including `name`, `trainable`, `dtype` etc.
75
+
76
+ Example:
77
+ ```python
78
+ # Create a single transformer decoder layer.
79
+ decoder = keras_hub.layers.TransformerDecoder(
80
+ intermediate_dim=64, num_heads=8)
81
+
82
+ # Create a simple model containing the decoder.
83
+ decoder_input = keras.Input(shape=(10, 64))
84
+ encoder_input = keras.Input(shape=(10, 64))
85
+ output = decoder(decoder_input, encoder_input)
86
+ model = keras.Model(
87
+ inputs=(decoder_input, encoder_input),
88
+ outputs=output,
89
+ )
90
+
91
+ # Call decoder on the inputs.
92
+ decoder_input_data = np.random.uniform(size=(2, 10, 64))
93
+ encoder_input_data = np.random.uniform(size=(2, 10, 64))
94
+ decoder_output = model((decoder_input_data, encoder_input_data))
95
+ ```
96
+
97
+ References:
98
+ - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
99
+
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ intermediate_dim,
105
+ num_heads,
106
+ dropout=0,
107
+ activation="relu",
108
+ layer_norm_epsilon=1e-05,
109
+ kernel_initializer="glorot_uniform",
110
+ bias_initializer="zeros",
111
+ normalize_first=False,
112
+ **kwargs,
113
+ ):
114
+ # Work around for model saving, we need to ensure our model is built
115
+ # immediately after restoring from config.
116
+ decoder_sequence_shape = kwargs.pop("decoder_sequence_shape", None)
117
+ encoder_sequence_shape = kwargs.pop("encoder_sequence_shape", None)
118
+
119
+ super().__init__(**kwargs)
120
+ self.intermediate_dim = intermediate_dim
121
+ self.num_heads = num_heads
122
+ self.dropout = dropout
123
+ self.activation = keras.activations.get(activation)
124
+ self.layer_norm_epsilon = layer_norm_epsilon
125
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
126
+ self.bias_initializer = keras.initializers.get(bias_initializer)
127
+ self.normalize_first = normalize_first
128
+ self.supports_masking = True
129
+ self._decoder_sequence_shape = None
130
+ self._encoder_sequence_shape = None
131
+
132
+ if decoder_sequence_shape:
133
+ self.build(decoder_sequence_shape, encoder_sequence_shape)
134
+
135
+ def build(
136
+ self,
137
+ decoder_sequence_shape,
138
+ encoder_sequence_shape=None,
139
+ ):
140
+ self._decoder_sequence_shape = decoder_sequence_shape
141
+ self._encoder_sequence_shape = encoder_sequence_shape
142
+ # Infer the dimension of our hidden feature size from the build shape.
143
+ hidden_dim = decoder_sequence_shape[-1]
144
+ # Attention head size is `hidden_dim` over the number of heads.
145
+ head_dim = int(hidden_dim // self.num_heads)
146
+ if head_dim == 0:
147
+ raise ValueError(
148
+ "Attention `head_dim` computed cannot be zero. "
149
+ f"The `hidden_dim` value of {hidden_dim} has to be equal to "
150
+ f"or greater than `num_heads` value of {self.num_heads}."
151
+ )
152
+
153
+ # Self attention layers.
154
+ self._self_attention_layer = CachedMultiHeadAttention(
155
+ num_heads=self.num_heads,
156
+ key_dim=head_dim,
157
+ dropout=self.dropout,
158
+ kernel_initializer=clone_initializer(self.kernel_initializer),
159
+ bias_initializer=clone_initializer(self.bias_initializer),
160
+ dtype=self.dtype_policy,
161
+ name="self_attention",
162
+ )
163
+ if hasattr(self._self_attention_layer, "_build_from_signature"):
164
+ self._self_attention_layer._build_from_signature(
165
+ query=decoder_sequence_shape,
166
+ value=decoder_sequence_shape,
167
+ )
168
+ else:
169
+ self._self_attention_layer.build(
170
+ query_shape=decoder_sequence_shape,
171
+ value_shape=decoder_sequence_shape,
172
+ )
173
+ self._self_attention_layer_norm = keras.layers.LayerNormalization(
174
+ epsilon=self.layer_norm_epsilon,
175
+ dtype=self.dtype_policy,
176
+ name="self_attention_layer_norm",
177
+ )
178
+ self._self_attention_layer_norm.build(decoder_sequence_shape)
179
+ self._self_attention_dropout = keras.layers.Dropout(
180
+ rate=self.dropout,
181
+ dtype=self.dtype_policy,
182
+ name="self_attention_dropout",
183
+ )
184
+
185
+ # Cross attention layers are optional.
186
+ self._cross_attention_layer = None
187
+ if encoder_sequence_shape:
188
+ self._cross_attention_layer = CachedMultiHeadAttention(
189
+ num_heads=self.num_heads,
190
+ key_dim=head_dim,
191
+ value_dim=head_dim,
192
+ dropout=self.dropout,
193
+ kernel_initializer=clone_initializer(self.kernel_initializer),
194
+ bias_initializer=clone_initializer(self.bias_initializer),
195
+ dtype=self.dtype_policy,
196
+ name="cross_attention",
197
+ )
198
+ if hasattr(self._cross_attention_layer, "_build_from_signature"):
199
+ self._cross_attention_layer._build_from_signature(
200
+ query=decoder_sequence_shape,
201
+ value=encoder_sequence_shape,
202
+ )
203
+ else:
204
+ self._cross_attention_layer.build(
205
+ query_shape=decoder_sequence_shape,
206
+ value_shape=encoder_sequence_shape,
207
+ )
208
+ self._cross_attention_layer_norm = keras.layers.LayerNormalization(
209
+ epsilon=self.layer_norm_epsilon,
210
+ dtype=self.dtype_policy,
211
+ name="cross_attention_layer_norm",
212
+ )
213
+ self._cross_attention_layer_norm.build(decoder_sequence_shape)
214
+ self._cross_attention_dropout = keras.layers.Dropout(
215
+ rate=self.dropout,
216
+ dtype=self.dtype_policy,
217
+ name="cross_attention_dropout",
218
+ )
219
+
220
+ # Feedforward layers.
221
+ self._feedforward_intermediate_dense = keras.layers.Dense(
222
+ self.intermediate_dim,
223
+ activation=self.activation,
224
+ kernel_initializer=clone_initializer(self.kernel_initializer),
225
+ bias_initializer=clone_initializer(self.bias_initializer),
226
+ dtype=self.dtype_policy,
227
+ name="feedforward_intermediate_dense",
228
+ )
229
+ self._feedforward_intermediate_dense.build(decoder_sequence_shape)
230
+ self._feedforward_output_dense = keras.layers.Dense(
231
+ hidden_dim,
232
+ kernel_initializer=clone_initializer(self.kernel_initializer),
233
+ bias_initializer=clone_initializer(self.bias_initializer),
234
+ dtype=self.dtype_policy,
235
+ name="feedforward_output_dense",
236
+ )
237
+ intermediate_shape = list(decoder_sequence_shape)
238
+ intermediate_shape[-1] = self.intermediate_dim
239
+ self._feedforward_output_dense.build(tuple(intermediate_shape))
240
+ self._feedforward_layer_norm = keras.layers.LayerNormalization(
241
+ epsilon=self.layer_norm_epsilon,
242
+ dtype=self.dtype_policy,
243
+ name="feedforward_layer_norm",
244
+ )
245
+ self._feedforward_layer_norm.build(decoder_sequence_shape)
246
+ self._feedforward_dropout = keras.layers.Dropout(
247
+ rate=self.dropout,
248
+ dtype=self.dtype_policy,
249
+ name="feedforward_dropout",
250
+ )
251
+ # Create layers based on input shape.
252
+ self.built = True
253
+
254
+ def call(
255
+ self,
256
+ decoder_sequence,
257
+ encoder_sequence=None,
258
+ decoder_padding_mask=None,
259
+ decoder_attention_mask=None,
260
+ encoder_padding_mask=None,
261
+ encoder_attention_mask=None,
262
+ self_attention_cache=None,
263
+ self_attention_cache_update_index=None,
264
+ cross_attention_cache=None,
265
+ cross_attention_cache_update_index=None,
266
+ use_causal_mask=True,
267
+ training=None,
268
+ ):
269
+ """Forward pass of the TransformerDecoder.
270
+
271
+ Args:
272
+ decoder_sequence: a Tensor. The decoder input sequence.
273
+ encoder_sequence: a Tensor. The encoder input sequence. For decoder
274
+ only models (like GPT2), this should be left `None`. Once the
275
+ model is called once without an encoder_sequence, you cannot
276
+ call it again with encoder_sequence.
277
+ decoder_padding_mask: a boolean Tensor, the padding mask of decoder
278
+ sequence, must be of shape
279
+ `[batch_size, decoder_sequence_length]`.
280
+ decoder_attention_mask: a boolean Tensor. Customized decoder
281
+ sequence mask, must be of shape
282
+ `[batch_size, decoder_sequence_length, decoder_sequence_length]`.
283
+ encoder_padding_mask: a boolean Tensor, the padding mask of encoder
284
+ sequence, must be of shape
285
+ `[batch_size, encoder_sequence_length]`.
286
+ encoder_attention_mask: a boolean Tensor. Customized encoder
287
+ sequence mask, must be of shape
288
+ `[batch_size, encoder_sequence_length, encoder_sequence_length]`.
289
+ self_attention_cache: a dense float Tensor. The cache of key/values
290
+ pairs in the self-attention layer. Has shape
291
+ `[batch_size, 2, max_seq_len, num_heads, key_dims]`.
292
+ self_attention_cache_update_index: an int or int Tensor, the index
293
+ at which to update the `self_attention_cache`. Usually, this is
294
+ the index of the current token being processed during decoding.
295
+ cross_attention_cache: a dense float Tensor. The cache of
296
+ key/value pairs in the cross-attention layer. Has shape
297
+ `[batch_size, 2, S, num_heads, key_dims]`.
298
+ cross_attention_cache_update_index: an int or int Tensor, the index
299
+ at which to update the `cross_attention_cache`. Usually, this is
300
+ either `0` (compute the entire `cross_attention_cache`), or
301
+ `None` (reuse a previously computed `cross_attention_cache`).
302
+ use_causal_mask: bool, defaults to `True`. If true, a causal mask
303
+ (masking out future input) is applied `on the decoder sequence.
304
+ training: a boolean indicating whether the layer should behave in
305
+ training mode or in inference mode.
306
+
307
+ Returns:
308
+ One of three things, depending on call arguments:
309
+ - `outputs`, if `self_attention_cache` is `None.
310
+ - `(outputs, self_attention_cache)`, if `self_attention_cache` is
311
+ set and the layer has no cross-attention.
312
+ - `(outputs, self_attention_cache, cross_attention_cache)`, if
313
+ `self_attention_cache` and `cross_attention_cache` are set and
314
+ the layer has cross-attention.
315
+ """
316
+
317
+ has_encoder_sequence = encoder_sequence is not None
318
+
319
+ has_cross_attention = self._cross_attention_layer is not None
320
+ if not has_cross_attention and has_encoder_sequence:
321
+ raise ValueError(
322
+ "The number of call arguments to "
323
+ "`keras_hub.layers.TransformerDecoder` should not change. "
324
+ "Use `layer(decoder_sequence, encoder_sequence)` to "
325
+ "build a layer with cross attention, or "
326
+ "`layer(decoder_sequence)` to build a layer without. "
327
+ "This layer has been built without cross attention, but "
328
+ "you are trying to call it with encoder_sequence."
329
+ )
330
+ elif has_cross_attention and not has_encoder_sequence:
331
+ raise ValueError(
332
+ "The number of call arguments to "
333
+ "`keras_hub.layers.TransformerDecoder` should not change. "
334
+ "Use `layer(decoder_sequence, encoder_sequence)` to "
335
+ "build a layer with cross attention, or "
336
+ "`layer(decoder_sequence)` to build a layer without. "
337
+ "This layer has been built with cross attention, but "
338
+ "you did not provide encoder_sequence."
339
+ )
340
+
341
+ has_self_attention_cache = self_attention_cache is not None
342
+ has_cross_attention_cache = cross_attention_cache is not None
343
+ if has_cross_attention and (
344
+ has_self_attention_cache != has_cross_attention_cache
345
+ ):
346
+ raise ValueError(
347
+ "When calling `keras_hub.layers.TransformerDecoder` with "
348
+ "cross-attention (with both `encoder_sequence` and "
349
+ "`decoder_sequence`), `self_attention_cache` and "
350
+ "`cross_attention_cache` should both be set or both be `None`. "
351
+ "One cannot be `None` while the other is not. Received: "
352
+ f"self_attention_cache={self_attention_cache}, "
353
+ f"cross_attention_cache={cross_attention_cache}."
354
+ )
355
+
356
+ self_attention_mask = self._compute_self_attention_mask(
357
+ decoder_sequence=decoder_sequence,
358
+ decoder_padding_mask=decoder_padding_mask,
359
+ decoder_attention_mask=decoder_attention_mask,
360
+ use_causal_mask=use_causal_mask,
361
+ self_attention_cache=self_attention_cache,
362
+ self_attention_cache_update_index=self_attention_cache_update_index,
363
+ )
364
+
365
+ x = decoder_sequence # Intermediate result.
366
+
367
+ # Self attention block.
368
+ residual = x
369
+ if self.normalize_first:
370
+ x = self._self_attention_layer_norm(x)
371
+ attention_output = self._self_attention_layer(
372
+ query=x,
373
+ value=x,
374
+ attention_mask=self_attention_mask,
375
+ cache=self_attention_cache,
376
+ cache_update_index=self_attention_cache_update_index,
377
+ training=training,
378
+ )
379
+ if self_attention_cache is None:
380
+ x = attention_output
381
+ else:
382
+ x, self_attention_cache = attention_output
383
+ x = self._self_attention_dropout(x, training=training)
384
+ x = x + residual
385
+ if not self.normalize_first:
386
+ x = self._self_attention_layer_norm(x)
387
+
388
+ # Cross attention is optional.
389
+ if has_cross_attention:
390
+ # Compute cross attention mask.
391
+ cross_attention_mask = merge_padding_and_attention_mask(
392
+ encoder_sequence, encoder_padding_mask, encoder_attention_mask
393
+ )
394
+
395
+ # Cross attention block.
396
+ residual = x
397
+ if self.normalize_first:
398
+ x = self._cross_attention_layer_norm(x)
399
+ attention_output = self._cross_attention_layer(
400
+ query=x,
401
+ value=encoder_sequence,
402
+ attention_mask=cross_attention_mask,
403
+ cache=cross_attention_cache,
404
+ cache_update_index=cross_attention_cache_update_index,
405
+ training=training,
406
+ )
407
+ if cross_attention_cache is None:
408
+ x = attention_output
409
+ else:
410
+ x, cross_attention_cache = attention_output
411
+ x = self._cross_attention_dropout(x, training=training)
412
+ x = x + residual
413
+ if not self.normalize_first:
414
+ x = self._cross_attention_layer_norm(x)
415
+
416
+ # Feedforward block.
417
+ residual = x
418
+ if self.normalize_first:
419
+ x = self._feedforward_layer_norm(x)
420
+ x = self._feedforward_intermediate_dense(x)
421
+ x = self._feedforward_output_dense(x)
422
+ x = self._feedforward_dropout(x, training=training)
423
+ x = x + residual
424
+ if not self.normalize_first:
425
+ x = self._feedforward_layer_norm(x)
426
+
427
+ if self_attention_cache is not None:
428
+ if has_cross_attention:
429
+ return (x, self_attention_cache, cross_attention_cache)
430
+ else:
431
+ return (x, self_attention_cache)
432
+ else:
433
+ return x
434
+
435
+ def _compute_self_attention_mask(
436
+ self,
437
+ decoder_sequence,
438
+ decoder_padding_mask,
439
+ decoder_attention_mask,
440
+ use_causal_mask,
441
+ self_attention_cache,
442
+ self_attention_cache_update_index,
443
+ ):
444
+ decoder_mask = merge_padding_and_attention_mask(
445
+ decoder_sequence, decoder_padding_mask, decoder_attention_mask
446
+ )
447
+ if use_causal_mask:
448
+ batch_size = ops.shape(decoder_sequence)[0]
449
+ input_length = output_length = ops.shape(decoder_sequence)[1]
450
+ # We need to handle a rectangular causal mask when doing cached
451
+ # decoding. For generative inference, `decoder_sequence` will
452
+ # generally be length 1, and `cache` will be the full generation length.
453
+ if self_attention_cache is not None:
454
+ input_length = ops.shape(self_attention_cache)[2]
455
+
456
+ causal_mask = compute_causal_mask(
457
+ batch_size,
458
+ input_length,
459
+ output_length,
460
+ (
461
+ 0
462
+ if self_attention_cache_update_index is None
463
+ else self_attention_cache_update_index
464
+ ),
465
+ )
466
+ return (
467
+ ops.minimum(decoder_mask, causal_mask)
468
+ if decoder_mask is not None
469
+ else causal_mask
470
+ )
471
+ return decoder_mask
472
+
473
+ def get_config(self):
474
+ config = super().get_config()
475
+ config.update(
476
+ {
477
+ "intermediate_dim": self.intermediate_dim,
478
+ "num_heads": self.num_heads,
479
+ "dropout": self.dropout,
480
+ "activation": keras.activations.serialize(self.activation),
481
+ "layer_norm_epsilon": self.layer_norm_epsilon,
482
+ "kernel_initializer": keras.initializers.serialize(
483
+ self.kernel_initializer
484
+ ),
485
+ "bias_initializer": keras.initializers.serialize(
486
+ self.bias_initializer
487
+ ),
488
+ "normalize_first": self.normalize_first,
489
+ "decoder_sequence_shape": self._decoder_sequence_shape,
490
+ "encoder_sequence_shape": self._encoder_sequence_shape,
491
+ }
492
+ )
493
+ return config
494
+
495
+ def compute_output_shape(self, decoder_sequence_shape):
496
+ return decoder_sequence_shape