keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,143 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+
16
+ import keras
17
+ from keras import ops
18
+
19
+ from keras_hub.src.api_export import keras_hub_export
20
+
21
+
22
+ @keras_hub_export("keras_hub.layers.AlibiBias")
23
+ class AlibiBias(keras.layers.Layer):
24
+ """A layer that adds the alibi bias to attention scores.
25
+
26
+ This layer adds the alibi bias to the attention scores. Alibi bias is a
27
+ linear, non-learned bias. Defined and formalized in
28
+ [Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation](https://arxiv.org/abs/2108.12409).
29
+
30
+ This layer takes as input the attention scores. and returns the attention
31
+ scores after adding the alibi bias to it. The output will have the same
32
+ shape as the input.
33
+
34
+ Args:
35
+ alibi_bias_max: int. This value will be used to compute the slope of
36
+ each head. The heads' slopes are a geometric sequence that starts at
37
+ `2**(-alibi_bias_max/num_heads)` and uses that same value as its
38
+ ratio. Defaults to 8.
39
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
40
+ including `name`, `trainable`, `dtype` etc.
41
+
42
+ Call arguments:
43
+ attention_scores: The result of multipying the query and the key of the
44
+ multi-head attention layer of the transformer to add alibi bias to
45
+ it. With shape `(batch_size, num_heads, query_length, key_length)`.
46
+
47
+ Example:
48
+ ```python
49
+ query_length = 10
50
+ key_length = 10
51
+ num_heads = 4
52
+ batch_size = 2
53
+ hidden_dim = 8
54
+
55
+ # Create new alibi layer.
56
+ alibi_layer = keras_hub.layers.AlibiBias()
57
+
58
+ query = np.zeros((batch_size, num_heads, query_length, hidden_dim))
59
+ key = np.zeros((batch_size, num_heads, hidden_dim, key_length))
60
+
61
+ attention_scores = keras.ops.matmul(query, key)
62
+
63
+ # Add alibi bias to attention scores.
64
+ attention_scores = alibi_layer(attention_scores)
65
+ ```
66
+
67
+ References:
68
+ - [Press et al., 2021](https://arxiv.org/abs/2108.12409)
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ alibi_bias_max=8,
74
+ **kwargs,
75
+ ):
76
+ super().__init__(**kwargs)
77
+ self.alibi_bias_max = alibi_bias_max
78
+
79
+ def call(self, attention_scores):
80
+ shape = ops.shape(attention_scores)
81
+ if len(shape) != 4:
82
+ raise ValueError(
83
+ "Expected `attention_scores` shape to be "
84
+ "`(batch_size, num_heads, query_length, key_Length)`."
85
+ f" Recived shape={shape}"
86
+ )
87
+
88
+ key_length = shape[-1]
89
+ num_heads = shape[-3]
90
+
91
+ alibi_bias = self._get_alibi_bias(num_heads, key_length)
92
+
93
+ return ops.add(attention_scores, alibi_bias)
94
+
95
+ def _get_alibi_bias(self, num_heads, key_length):
96
+ slopes = ops.convert_to_tensor(
97
+ self._get_slopes(num_heads), dtype=self.compute_dtype
98
+ )
99
+ slopes = ops.expand_dims(slopes, 1)
100
+
101
+ seq_range = ops.expand_dims(
102
+ ops.arange(1 - key_length, 1, dtype="int32"), 0
103
+ )
104
+ seq_range = ops.cast(seq_range, dtype=self.compute_dtype)
105
+
106
+ alibi_bias = ops.multiply(slopes, seq_range)
107
+ alibi_bias = ops.expand_dims(alibi_bias, 1)
108
+
109
+ # return shape is `(1, num_heads, 1, key_length)`
110
+ return ops.expand_dims(alibi_bias, 0)
111
+
112
+ def _get_slopes(self, num_heads):
113
+ # this function is adopted from Alibi original implementation.
114
+ # https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
115
+ def get_slopes_power_of_2(n):
116
+ start = 2 ** (
117
+ -(2 ** -(math.log2(n) - math.log2(self.alibi_bias_max)))
118
+ )
119
+ ratio = start
120
+ return [start * ratio**i for i in range(n)]
121
+
122
+ if math.log2(num_heads).is_integer():
123
+ return get_slopes_power_of_2(num_heads)
124
+ else:
125
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
126
+ return (
127
+ get_slopes_power_of_2(closest_power_of_2)
128
+ + self._get_slopes(2 * closest_power_of_2)[0::2][
129
+ : num_heads - closest_power_of_2
130
+ ]
131
+ )
132
+
133
+ def compute_output_shape(self, input_shape):
134
+ return input_shape
135
+
136
+ def get_config(self):
137
+ config = super().get_config()
138
+ config.update(
139
+ {
140
+ "alibi_bias_max": self.alibi_bias_max,
141
+ }
142
+ )
143
+ return config
@@ -0,0 +1,137 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+
20
+
21
+ @keras_hub_export("keras_hub.layers.CachedMultiHeadAttention")
22
+ class CachedMultiHeadAttention(keras.layers.MultiHeadAttention):
23
+ """MultiHeadAttention layer with cache support.
24
+
25
+ This layer is suitable for use in autoregressive decoding. It can be used
26
+ to cache decoder self-attention and cross-attention. The forward pass
27
+ can happen in one of three modes:
28
+
29
+ - No cache, same as regular multi-head attention.
30
+ - Static cache (`cache_update_index` is None). In this case, the
31
+ cached key/value projections will be used and the input values will
32
+ be ignored.
33
+ - Updated cache (`cache_update_index` is not None). In this case, new
34
+ key/value projections are computed using the input, and spliced into
35
+ the cache at the specified index.
36
+
37
+ Note that caching is useful only during inference and should not be used
38
+ during training.
39
+
40
+ We use the notation `B`, `T`, `S` below, where `B` is the batch dimension,
41
+ `T` is the target sequence length, and `S` in the source sequence length.
42
+ Note that during generative decoding, `T` is usually 1 (you are
43
+ generating a target sequence of length one to predict the next token).
44
+
45
+ Call arguments:
46
+ query: Query `Tensor` of shape `(B, T, dim)`.
47
+ value: Value `Tensor` of shape `(B, S*, dim)`. if `cache` is None`, `S*`
48
+ must equal `S` and match the shape of `attention_mask`. If cache` is
49
+ not `None`, `S*` can be any length less than `S`, and the computed
50
+ value will be spliced into `cache` at `cache_update_index`.
51
+ key: Optional key `Tensor` of shape `(B, S*, dim)`. If `cache` is
52
+ `None`, `S*` must equal `S` and match the shape of
53
+ `attention_mask`. If `cache` is not `None`, `S*` can be any length
54
+ less than `S`, and the computed value will be spliced into `cache`
55
+ at `cache_update_index`.
56
+ attention_mask: a boolean mask of shape `(B, T, S)`. `attention_mask`
57
+ prevents attention to certain positions. The boolean mask specifies
58
+ which query elements can attend to which key elements, 1 indicates
59
+ attention and 0 indicates no attention. Broadcasting can happen for
60
+ the missing batch dimensions and the head dimension.
61
+ cache: a dense float Tensor. The key/value cache, of shape
62
+ `[B, 2, S, num_heads, key_dims]`, where `S` must agree with the
63
+ `attention_mask` shape. This argument is intended for use during
64
+ generation to avoid recomputing intermediate state.
65
+ cache_update_index: a int or int Tensor, the index at which to update
66
+ `cache` (usually the index of the current token being processed
67
+ when running generation). If `cache_update_index=None` while `cache`
68
+ is set, the cache will not be updated.
69
+ training: a boolean indicating whether the layer should behave in
70
+ training mode or in inference mode.
71
+
72
+ Returns:
73
+ An `(attention_output, cache)` tuple. `attention_output` is the result
74
+ of the computation, of shape `(B, T, dim)`, where `T` is for target
75
+ sequence shapes and `dim` is the query input last dimension if
76
+ `output_shape` is `None`. Otherwise, the multi-head outputs are
77
+ projected to the shape specified by `output_shape`. `cache` is the
78
+ updated cache.
79
+ """
80
+
81
+ def call(
82
+ self,
83
+ query,
84
+ value,
85
+ key=None,
86
+ attention_mask=None,
87
+ cache=None,
88
+ cache_update_index=None,
89
+ training=None,
90
+ ):
91
+ if key is None:
92
+ key = value
93
+
94
+ query = self._query_dense(query)
95
+
96
+ # If cache is not `None`, we will use the cache to compute the final key
97
+ # and value tensors. If `cache_update_index` is not None, we will first
98
+ # update the cache before use. To do this, we first call the
99
+ # `_key_dense` and `_value_dense` layers, and copy the outputs into the
100
+ # cache at the specified index. `cache = None` handles the training
101
+ # case, where we don't use the cache at all.
102
+ if cache is not None:
103
+ key_cache = cache[:, 0, ...]
104
+ value_cache = cache[:, 1, ...]
105
+ if cache_update_index is None:
106
+ key = key_cache
107
+ value = value_cache
108
+ else:
109
+ key_update = self._key_dense(key)
110
+ value_update = self._value_dense(value)
111
+ start = [0, cache_update_index, 0, 0]
112
+ key = ops.slice_update(key_cache, start, key_update)
113
+ value = ops.slice_update(value_cache, start, value_update)
114
+ cache = ops.stack((key, value), axis=1)
115
+ else:
116
+ if cache_update_index is not None:
117
+ raise ValueError(
118
+ "`cache_update_index` should not be set if `cache` is "
119
+ f"`None`. Received: cache={cache}, "
120
+ f"cache_update_index={cache_update_index}"
121
+ )
122
+ key = self._key_dense(key)
123
+ value = self._value_dense(value)
124
+
125
+ attention_output, attention_scores = self._compute_attention(
126
+ query=query,
127
+ key=key,
128
+ value=value,
129
+ attention_mask=attention_mask,
130
+ training=training,
131
+ )
132
+
133
+ attention_output = self._output_dense(attention_output)
134
+
135
+ if cache is not None:
136
+ return attention_output, cache
137
+ return attention_output
@@ -0,0 +1,200 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.utils.keras_utils import clone_initializer
20
+
21
+
22
+ @keras_hub_export("keras_hub.layers.FNetEncoder")
23
+ class FNetEncoder(keras.layers.Layer):
24
+ """FNet encoder.
25
+
26
+ This class follows the architecture of FNet encoder layer in the
27
+ [FNet paper](https://arxiv.org/abs/2105.03824). Users can instantiate
28
+ multiple instances of this class to stack up the encoder.
29
+
30
+ Note on masking: In the official FNet code, padding tokens are added to the
31
+ the input. However, the padding masks are deleted, i.e., mixing of
32
+ all tokens is done. This is because certain frequencies will be zeroed
33
+ out if we apply padding masks in every encoder layer. Hence, we don't
34
+ take padding mask as input in the call() function.
35
+
36
+ Args:
37
+ intermediate_dim: int. The hidden size of feedforward network.
38
+ dropout: float. The dropout value, applied in the
39
+ feedforward network. Defaults to `0.`.
40
+ activation: string or `keras.activations`. The
41
+ activation function of feedforward network.
42
+ Defaults to `"relu"`.
43
+ layer_norm_epsilon: float. The epsilon value in layer
44
+ normalization components. Defaults to `1e-5`.
45
+ kernel_initializer: `str` or `keras.initializers` initializer.
46
+ The kernel initializer for the dense layers.
47
+ Defaults to `"glorot_uniform"`.
48
+ bias_initializer: "string" or `keras.initializers` initializer.
49
+ The bias initializer for the dense layers.
50
+ Defaults to `"zeros"`.
51
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
52
+ including `name`, `trainable`, `dtype` etc.
53
+
54
+ Example:
55
+
56
+ ```python
57
+ # Create a single FNet encoder layer.
58
+ encoder = keras_hub.layers.FNetEncoder(
59
+ intermediate_dim=64)
60
+
61
+ # Create a simple model containing the encoder.
62
+ input = keras.Input(shape=(10, 64))
63
+ output = encoder(input)
64
+ model = keras.Model(inputs=input, outputs=output)
65
+
66
+ # Call encoder on the inputs.
67
+ input_data = np.random.uniform(size=(1, 10, 64))
68
+ output = model(input_data)
69
+ ```
70
+
71
+ References:
72
+ - [Lee-Thorp et al., 2021](https://arxiv.org/abs/2105.03824)
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ intermediate_dim,
78
+ dropout=0,
79
+ activation="relu",
80
+ layer_norm_epsilon=1e-5,
81
+ kernel_initializer="glorot_uniform",
82
+ bias_initializer="zeros",
83
+ **kwargs
84
+ ):
85
+ super().__init__(**kwargs)
86
+ self.intermediate_dim = intermediate_dim
87
+ self.dropout = dropout
88
+ self.activation = keras.activations.get(activation)
89
+ self.layer_norm_epsilon = layer_norm_epsilon
90
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
91
+ self.bias_initializer = keras.initializers.get(bias_initializer)
92
+
93
+ def build(self, inputs_shape):
94
+ # Create layers based on input shape.
95
+ feature_size = inputs_shape[-1]
96
+
97
+ # Layer Norm layers.
98
+ self._mixing_layer_norm = keras.layers.LayerNormalization(
99
+ epsilon=self.layer_norm_epsilon,
100
+ dtype=self.dtype_policy,
101
+ name="mixing_layer_norm",
102
+ )
103
+ self._mixing_layer_norm.build(inputs_shape)
104
+ self._output_layer_norm = keras.layers.LayerNormalization(
105
+ epsilon=self.layer_norm_epsilon,
106
+ dtype=self.dtype_policy,
107
+ name="output_layer_norm",
108
+ )
109
+ self._output_layer_norm.build(inputs_shape)
110
+
111
+ # Feedforward layers.
112
+ self._intermediate_dense = keras.layers.Dense(
113
+ self.intermediate_dim,
114
+ activation=self.activation,
115
+ kernel_initializer=clone_initializer(self.kernel_initializer),
116
+ bias_initializer=clone_initializer(self.bias_initializer),
117
+ dtype=self.dtype_policy,
118
+ name="intermediate_dense",
119
+ )
120
+ self._intermediate_dense.build(inputs_shape)
121
+ self._output_dense = keras.layers.Dense(
122
+ feature_size,
123
+ kernel_initializer=clone_initializer(self.kernel_initializer),
124
+ bias_initializer=clone_initializer(self.bias_initializer),
125
+ dtype=self.dtype_policy,
126
+ name="output_dense",
127
+ )
128
+ self._output_dense.build(
129
+ self._intermediate_dense.compute_output_shape(inputs_shape)
130
+ )
131
+ self._output_dropout = keras.layers.Dropout(
132
+ rate=self.dropout,
133
+ dtype=self.dtype_policy,
134
+ name="output_dropout",
135
+ )
136
+ self.built = True
137
+
138
+ def call(self, inputs, training=None):
139
+ """Forward pass of the FNetEncoder.
140
+
141
+ Args:
142
+ inputs: a Tensor. The input data to TransformerEncoder, should be
143
+ of shape [batch_size, sequence_length, feature_dim].
144
+ training: a boolean indicating whether the layer should behave in
145
+ training mode or in inference mode.
146
+
147
+ Returns:
148
+ A Tensor of the same shape as the `inputs`.
149
+ """
150
+
151
+ def fourier_transform(input):
152
+ # Apply FFT on the input and take the real part.
153
+ input_dtype = input.dtype
154
+ # FFT transforms do not support float16.
155
+ input = ops.cast(input, "float32")
156
+ real_in, imaginary_in = (input, ops.zeros_like(input))
157
+ real_out, _ = ops.fft2((real_in, imaginary_in))
158
+ return ops.cast(real_out, input_dtype)
159
+
160
+ def add_and_norm(input1, input2, norm_layer):
161
+ return norm_layer(input1 + input2)
162
+
163
+ def feed_forward(input):
164
+ x = self._intermediate_dense(input)
165
+ x = self._output_dense(x)
166
+ return self._output_dropout(x, training=training)
167
+
168
+ mixing_output = fourier_transform(inputs)
169
+
170
+ mixing_output = add_and_norm(
171
+ inputs, mixing_output, self._mixing_layer_norm
172
+ )
173
+
174
+ feed_forward_output = feed_forward(mixing_output)
175
+
176
+ x = add_and_norm(
177
+ mixing_output, feed_forward_output, self._output_layer_norm
178
+ )
179
+ return x
180
+
181
+ def get_config(self):
182
+ config = super().get_config()
183
+ config.update(
184
+ {
185
+ "intermediate_dim": self.intermediate_dim,
186
+ "dropout": self.dropout,
187
+ "activation": keras.activations.serialize(self.activation),
188
+ "layer_norm_epsilon": self.layer_norm_epsilon,
189
+ "kernel_initializer": keras.initializers.serialize(
190
+ self.kernel_initializer
191
+ ),
192
+ "bias_initializer": keras.initializers.serialize(
193
+ self.bias_initializer
194
+ ),
195
+ }
196
+ )
197
+ return config
198
+
199
+ def compute_output_shape(self, inputs_shape):
200
+ return inputs_shape