keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,169 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+
20
+
21
+ @keras_hub_export("keras_hub.layers.RotaryEmbedding")
22
+ class RotaryEmbedding(keras.layers.Layer):
23
+ """Rotary positional encoding layer.
24
+
25
+ This layer encodes absolute positional information with a rotation
26
+ matrix. It calculates the rotary encoding with a mix of sine and
27
+ cosine functions with geometrically increasing wavelengths.
28
+ Defined and formulated in [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4).
29
+ The input must be a tensor with shape a sequence dimension and a feature
30
+ dimension. Typically, this will either an input with shape
31
+ `(batch_size, sequence_length, feature_length)` or
32
+ `(batch_size, sequence_length, num_heads, feature_length)`.
33
+ This layer will return a new tensor with the rotary embedding applied to
34
+ the input tensor.
35
+
36
+ Args:
37
+ max_wavelength: int. The maximum angular wavelength of the sine/cosine
38
+ curves.
39
+ scaling_factor: float. The scaling factor used to scale positions of
40
+ the tokens.
41
+ sequence_axis: int. Sequence axis in the input tensor.
42
+ feature_axis: int. Feature axis in the input tensor.
43
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
44
+ including `name`, `trainable`, `dtype` etc.
45
+
46
+ Call arguments:
47
+ inputs: The tensor inputs to apply the embedding to. This can have
48
+ any shape, but must contain both a sequence and feature axis. The
49
+ rotary embedding will be applied to `inputs` and returned.
50
+ start_index: An integer or integer tensor. The starting position to
51
+ compute the rotary embedding from. This is useful during cached
52
+ decoding, where each position is predicted separately in a loop.
53
+
54
+ Examples:
55
+
56
+ ```python
57
+ batch_size = 16
58
+ feature_length = 18
59
+ sequence_length = 256
60
+ num_heads = 8
61
+
62
+ # No multi-head dimension.
63
+ tensor = np.ones((batch_size, sequence_length, feature_length))
64
+ rot_emb_layer = RotaryEmbedding()
65
+ tensor_rot = rot_emb_layer(tensor)
66
+
67
+ # With multi-head dimension.
68
+ tensor = np.ones((batch_size, sequence_length, num_heads, feature_length))
69
+ tensor_rot = rot_emb_layer(tensor)
70
+ ```
71
+
72
+ References:
73
+ - [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4)
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ max_wavelength=10000,
79
+ scaling_factor=1.0,
80
+ sequence_axis=1,
81
+ feature_axis=-1,
82
+ **kwargs
83
+ ):
84
+ super().__init__(**kwargs)
85
+ self.max_wavelength = max_wavelength
86
+ self.sequence_axis = sequence_axis
87
+ self.feature_axis = feature_axis
88
+ self.scaling_factor = scaling_factor
89
+ self.built = True
90
+
91
+ def call(self, inputs, start_index=0, positions=None):
92
+ inputs = ops.moveaxis(
93
+ inputs, (self.feature_axis, self.sequence_axis), (-1, 1)
94
+ )
95
+ cos_emb, sin_emb = self._compute_cos_sin_embedding(
96
+ inputs, start_index, positions
97
+ )
98
+ output = self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb)
99
+ return ops.moveaxis(
100
+ output, (-1, 1), (self.feature_axis, self.sequence_axis)
101
+ )
102
+
103
+ def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):
104
+ x1, x2 = ops.split(tensor, 2, axis=-1)
105
+ # Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA
106
+ # compilation on jax. We should be able to remove this once the
107
+ # following PR is in all jax releases we care about:
108
+ # https://github.com/openxla/xla/pull/7875
109
+ half_rot_tensor = ops.stack((-x2, x1), axis=-2)
110
+ half_rot_tensor = ops.reshape(half_rot_tensor, ops.shape(tensor))
111
+ return (tensor * cos_emb) + (half_rot_tensor * sin_emb)
112
+
113
+ def _compute_positions(self, inputs, start_index=0):
114
+ seq_len = ops.shape(inputs)[1]
115
+ positions = ops.arange(seq_len, dtype="float32")
116
+ return positions + ops.cast(start_index, dtype="float32")
117
+
118
+ def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None):
119
+ feature_axis = len(inputs.shape) - 1
120
+ sequence_axis = 1
121
+
122
+ rotary_dim = ops.shape(inputs)[feature_axis]
123
+ inverse_freq = self._get_inverse_freq(rotary_dim)
124
+
125
+ if positions is None:
126
+ positions = self._compute_positions(inputs, start_index)
127
+ else:
128
+ positions = ops.cast(positions, "float32")
129
+
130
+ positions = positions / ops.cast(self.scaling_factor, "float32")
131
+ freq = ops.einsum("i,j->ij", positions, inverse_freq)
132
+ embedding = ops.stack((freq, freq), axis=-2)
133
+ embedding = ops.reshape(
134
+ embedding, (*ops.shape(freq)[:-1], ops.shape(freq)[-1] * 2)
135
+ )
136
+
137
+ # Reshape the embedding to be broadcastable with input shape.
138
+ if feature_axis < sequence_axis:
139
+ embedding = ops.transpose(embedding)
140
+ for axis in range(len(inputs.shape)):
141
+ if axis != sequence_axis and axis != feature_axis:
142
+ embedding = ops.expand_dims(embedding, axis)
143
+
144
+ cos_emb = ops.cast(ops.cos(embedding), self.compute_dtype)
145
+ sin_emb = ops.cast(ops.sin(embedding), self.compute_dtype)
146
+ return cos_emb, sin_emb
147
+
148
+ def _get_inverse_freq(self, rotary_dim):
149
+ freq_range = ops.divide(
150
+ ops.arange(0, rotary_dim, 2, dtype="float32"),
151
+ ops.cast(rotary_dim, "float32"),
152
+ )
153
+ inverse_freq = 1.0 / (self.max_wavelength**freq_range)
154
+ return inverse_freq
155
+
156
+ def get_config(self):
157
+ config = super().get_config()
158
+ config.update(
159
+ {
160
+ "max_wavelength": self.max_wavelength,
161
+ "scaling_factor": self.scaling_factor,
162
+ "sequence_axis": self.sequence_axis,
163
+ "feature_axis": self.feature_axis,
164
+ }
165
+ )
166
+ return config
167
+
168
+ def compute_output_shape(self, input_shape):
169
+ return input_shape
@@ -0,0 +1,108 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+
20
+
21
+ @keras_hub_export("keras_hub.layers.SinePositionEncoding")
22
+ class SinePositionEncoding(keras.layers.Layer):
23
+ """Sinusoidal positional encoding layer.
24
+
25
+ This layer calculates the position encoding as a mix of sine and cosine
26
+ functions with geometrically increasing wavelengths. Defined and formulized
27
+ in [Attention is All You Need](https://arxiv.org/abs/1706.03762).
28
+
29
+ Takes as input an embedded token tensor. The input must have shape
30
+ [batch_size, sequence_length, feature_size]. This layer will return a
31
+ positional encoding the same size as the embedded token tensor, which
32
+ can be added directly to the embedded token tensor.
33
+
34
+ Args:
35
+ max_wavelength: The maximum angular wavelength of the sine/cosine
36
+ curves, as described in Attention is All You Need. Defaults to
37
+ `10000`.
38
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
39
+ including `name`, `trainable`, `dtype` etc.
40
+
41
+ Call arguments:
42
+ inputs: The tensor inputs to compute an embedding for, with shape
43
+ `(batch_size, sequence_length, hidden_dim)`.
44
+ start_index: An integer or integer tensor. The starting position to
45
+ compute the encoding from. This is useful during cached decoding,
46
+ where each position is predicted separately in a loop.
47
+
48
+ Example:
49
+ ```python
50
+ # create a simple embedding layer with sinusoidal positional encoding
51
+ seq_len = 100
52
+ vocab_size = 1000
53
+ embedding_dim = 32
54
+ inputs = keras.Input((seq_len,), dtype="float32")
55
+ embedding = keras.layers.Embedding(
56
+ input_dim=vocab_size, output_dim=embedding_dim
57
+ )(inputs)
58
+ positional_encoding = keras_hub.layers.SinePositionEncoding()(embedding)
59
+ outputs = embedding + positional_encoding
60
+ ```
61
+
62
+ References:
63
+ - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ max_wavelength=10000,
69
+ **kwargs,
70
+ ):
71
+ super().__init__(**kwargs)
72
+ self.max_wavelength = max_wavelength
73
+ self.built = True
74
+
75
+ def call(self, inputs, start_index=0):
76
+ shape = ops.shape(inputs)
77
+ seq_length = shape[-2]
78
+ hidden_size = shape[-1]
79
+ positions = ops.arange(seq_length)
80
+ positions = ops.cast(positions + start_index, self.compute_dtype)
81
+ min_freq = ops.cast(1 / self.max_wavelength, dtype=self.compute_dtype)
82
+ timescales = ops.power(
83
+ min_freq,
84
+ ops.cast(2 * (ops.arange(hidden_size) // 2), self.compute_dtype)
85
+ / ops.cast(hidden_size, self.compute_dtype),
86
+ )
87
+ angles = ops.expand_dims(positions, 1) * ops.expand_dims(timescales, 0)
88
+ # even indices are sine, odd are cosine
89
+ cos_mask = ops.cast(ops.arange(hidden_size) % 2, self.compute_dtype)
90
+ sin_mask = 1 - cos_mask
91
+ # embedding shape is [seq_length, hidden_size]
92
+ positional_encodings = (
93
+ ops.sin(angles) * sin_mask + ops.cos(angles) * cos_mask
94
+ )
95
+
96
+ return ops.broadcast_to(positional_encodings, shape)
97
+
98
+ def get_config(self):
99
+ config = super().get_config()
100
+ config.update(
101
+ {
102
+ "max_wavelength": self.max_wavelength,
103
+ }
104
+ )
105
+ return config
106
+
107
+ def compute_output_shape(self, input_shape):
108
+ return input_shape
@@ -0,0 +1,150 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+
17
+ from keras_hub.src.api_export import keras_hub_export
18
+ from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
19
+ from keras_hub.src.layers.modeling.reversible_embedding import (
20
+ ReversibleEmbedding,
21
+ )
22
+ from keras_hub.src.utils.keras_utils import clone_initializer
23
+
24
+
25
+ @keras_hub_export("keras_hub.layers.TokenAndPositionEmbedding")
26
+ class TokenAndPositionEmbedding(keras.layers.Layer):
27
+ """A layer which sums a token and position embedding.
28
+
29
+ Token and position embeddings are ways of representing words and their order
30
+ in a sentence. This layer creates a `keras.layers.Embedding` token embedding
31
+ and a `keras_hub.layers.PositionEmbedding` position embedding and sums their
32
+ output when called. This layer assumes that the last dimension in the input
33
+ corresponds to the sequence dimension.
34
+
35
+ Args:
36
+ vocabulary_size: The size of the vocabulary.
37
+ sequence_length: The maximum length of input sequence
38
+ embedding_dim: The output dimension of the embedding layer
39
+ tie_weights: Boolean, whether or not the matrix for embedding and
40
+ the matrix for the `reverse` projection should share the same
41
+ weights.
42
+ embeddings_initializer: The initializer to use for the Embedding
43
+ Layers
44
+ mask_zero: Boolean, whether or not the input value 0 is a special
45
+ "padding" value that should be masked out.
46
+ This is useful when using recurrent layers which may take variable
47
+ length input. If this is True, then all subsequent layers in the
48
+ model need to support masking or an exception will be raised.
49
+ If mask_zero` is set to True, as a consequence, index 0 cannot be
50
+ used in the vocabulary
51
+ (input_dim should equal size of vocabulary + 1).
52
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
53
+ including `name`, `trainable`, `dtype` etc.
54
+
55
+ Example:
56
+ ```python
57
+ inputs = np.ones(shape=(1, 50), dtype="int32")
58
+ embedding_layer = keras_hub.layers.TokenAndPositionEmbedding(
59
+ vocabulary_size=10_000,
60
+ sequence_length=50,
61
+ embedding_dim=128,
62
+ )
63
+ outputs = embedding_layer(inputs)
64
+ ```
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ vocabulary_size,
70
+ sequence_length,
71
+ embedding_dim,
72
+ tie_weights=True,
73
+ embeddings_initializer="uniform",
74
+ mask_zero=False,
75
+ **kwargs
76
+ ):
77
+ super().__init__(**kwargs)
78
+ if vocabulary_size is None:
79
+ raise ValueError(
80
+ "`vocabulary_size` must be an Integer, received `None`."
81
+ )
82
+ if sequence_length is None:
83
+ raise ValueError(
84
+ "`sequence_length` must be an Integer, received `None`."
85
+ )
86
+ if embedding_dim is None:
87
+ raise ValueError(
88
+ "`embedding_dim` must be an Integer, received `None`."
89
+ )
90
+ self.vocabulary_size = int(vocabulary_size)
91
+ self.sequence_length = int(sequence_length)
92
+ self.embedding_dim = int(embedding_dim)
93
+ self.embeddings_initializer = keras.initializers.get(
94
+ embeddings_initializer
95
+ )
96
+ self.token_embedding = ReversibleEmbedding(
97
+ vocabulary_size,
98
+ embedding_dim,
99
+ tie_weights=tie_weights,
100
+ embeddings_initializer=clone_initializer(
101
+ self.embeddings_initializer
102
+ ),
103
+ mask_zero=mask_zero,
104
+ dtype=self.dtype_policy,
105
+ name="token_embedding",
106
+ )
107
+ self.position_embedding = PositionEmbedding(
108
+ sequence_length=sequence_length,
109
+ initializer=clone_initializer(self.embeddings_initializer),
110
+ dtype=self.dtype_policy,
111
+ name="position_embedding",
112
+ )
113
+ self.supports_masking = self.token_embedding.supports_masking
114
+
115
+ def build(self, input_shape):
116
+ input_shape = tuple(input_shape)
117
+ self.token_embedding.build(input_shape)
118
+ self.position_embedding.build(input_shape + (self.embedding_dim,))
119
+ self.built = True
120
+
121
+ def get_config(self):
122
+ config = super().get_config()
123
+ config.update(
124
+ {
125
+ "vocabulary_size": self.vocabulary_size,
126
+ "sequence_length": self.sequence_length,
127
+ "embedding_dim": self.embedding_dim,
128
+ "embeddings_initializer": keras.initializers.serialize(
129
+ self.embeddings_initializer
130
+ ),
131
+ "tie_weights": self.token_embedding.tie_weights,
132
+ "mask_zero": self.token_embedding.mask_zero,
133
+ }
134
+ )
135
+ return config
136
+
137
+ def call(self, inputs, start_index=0):
138
+ embedded_tokens = self.token_embedding(inputs)
139
+ embedded_positions = self.position_embedding(
140
+ embedded_tokens,
141
+ start_index=start_index,
142
+ )
143
+ outputs = embedded_tokens + embedded_positions
144
+ return outputs
145
+
146
+ def compute_mask(self, inputs, mask=None):
147
+ return self.token_embedding.compute_mask(inputs, mask=mask)
148
+
149
+ def compute_output_shape(self, input_shape):
150
+ return tuple(input_shape) + (self.embedding_dim,)