keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,133 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+
19
+ class ContentAndQueryEmbedding(keras.layers.Layer):
20
+ """
21
+ Content and Query Embedding.
22
+
23
+ This class creates Content and Query Embeddings for XLNet model
24
+ which is later used in XLNet Encoder.
25
+
26
+ Args:
27
+ vocabulary_size: int, number of tokens in the vocabulary.
28
+ hidden_dim: int, the size hidden states.
29
+ dropout: float, defaults to 0. the dropout value, shared by
30
+ `keras.layers.TwoStreamRelativeAttention` and feedforward network.
31
+ kernel_initializer_range: int, defaults to 0.02. The kernel initializer
32
+ range for the dense and relative attention layers.
33
+ name: string, defaults to None. The name of the layer.
34
+ **kwargs: other keyword arguments.
35
+
36
+ References:
37
+ - [XLNet: Generalized Autoregressive Pretraining for Language Understanding]
38
+ (https://arxiv.org/abs/1906.08237)
39
+ """
40
+
41
+ def __init__(
42
+ self, vocabulary_size, hidden_dim, dropout, name=None, **kwargs
43
+ ):
44
+ super().__init__(name=name, **kwargs)
45
+ self.vocabulary_size = vocabulary_size
46
+ self.hidden_dim = hidden_dim
47
+ self.dropout = dropout
48
+
49
+ def positional_embedding(self, pos_seq, inv_freq, bsz=None):
50
+ sinusoid_inp = ops.einsum("i,d->id", pos_seq, inv_freq)
51
+ pos_emb = ops.concatenate(
52
+ [ops.sin(sinusoid_inp), ops.cos(sinusoid_inp)], axis=-1
53
+ )
54
+ pos_emb = ops.expand_dims(pos_emb, 1)
55
+ pos_emb = (
56
+ ops.ones(
57
+ [
58
+ ops.shape(pos_emb)[0],
59
+ ops.shape(pos_emb)[1] * bsz,
60
+ ops.shape(pos_emb)[2],
61
+ ],
62
+ dtype=self.compute_dtype,
63
+ )
64
+ * pos_emb
65
+ )
66
+
67
+ return pos_emb
68
+
69
+ def relative_positional_encoding(self, qlen, klen, bsz=None, clamp_len=-1):
70
+ """create relative positional encoding."""
71
+ freq_seq = ops.arange(0, self.hidden_dim, 2.0, dtype="float32")
72
+ freq_seq = ops.cast(freq_seq, self.compute_dtype)
73
+ inv_freq = 1 / (10000 ** (freq_seq / self.hidden_dim))
74
+
75
+ beg, end = klen, -qlen
76
+
77
+ fwd_pos_seq = ops.arange(beg, end, -1.0, dtype="float32")
78
+ fwd_pos_seq = ops.cast(fwd_pos_seq, self.compute_dtype)
79
+ if clamp_len > 0:
80
+ fwd_pos_seq = ops.clip(
81
+ fwd_pos_seq, x_min=-clamp_len, x_max=clamp_len
82
+ )
83
+ pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
84
+
85
+ return pos_emb
86
+
87
+ def build(self, input_shape):
88
+ self.word_embed = keras.layers.Embedding(
89
+ input_dim=self.vocabulary_size,
90
+ output_dim=self.hidden_dim,
91
+ dtype=self.dtype_policy,
92
+ name="word_embedding",
93
+ )
94
+ self.word_embed.build(input_shape)
95
+ self.dropout_layer = keras.layers.Dropout(
96
+ self.dropout,
97
+ dtype=self.dtype_policy,
98
+ )
99
+ super().build(input_shape)
100
+
101
+ def call(
102
+ self,
103
+ token_id_input,
104
+ mlen=None,
105
+ ):
106
+ mlen = 0 if mlen is None else mlen
107
+
108
+ bsz, qlen = ops.shape(token_id_input)[0], ops.shape(token_id_input)[1]
109
+ klen = mlen + qlen
110
+
111
+ # Word embeddings and prepare h & g hidden states
112
+ word_emb = self.word_embed(token_id_input)
113
+ word_emb = self.dropout_layer(word_emb)
114
+
115
+ # Positional encoding
116
+ pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
117
+ pos_emb = self.dropout_layer(pos_emb)
118
+ pos_emb = ops.reshape(
119
+ pos_emb,
120
+ [
121
+ ops.shape(pos_emb)[1],
122
+ ops.shape(pos_emb)[0],
123
+ ops.shape(pos_emb)[2],
124
+ ],
125
+ )
126
+
127
+ return word_emb, pos_emb
128
+
129
+ def compute_output_shape(self, token_id_input_shape):
130
+ return [
131
+ token_id_input_shape + (self.hidden_dim,),
132
+ (token_id_input_shape[0], 1, self.hidden_dim),
133
+ ]
@@ -0,0 +1,378 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+ from keras_hub.src.models.xlnet.relative_attention import (
19
+ TwoStreamRelativeAttention,
20
+ )
21
+
22
+
23
+ def xlnet_kernel_initializer(stddev=0.02):
24
+ return keras.initializers.TruncatedNormal(stddev=stddev)
25
+
26
+
27
+ class XLNetEncoder(keras.layers.Layer):
28
+ """
29
+ XLNet Encoder.
30
+
31
+ This class follows the architecture of the transformer encoder layer in the
32
+ paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users
33
+ can instantiate multiple instances of this class to stack up an encoder.
34
+
35
+ Contrary to the single hidden state used in the paper mentioned above, this
36
+ Encoder uses two hidden states, Content State and Query State. Thus calculates
37
+ Two Stream Relative Attention using both of the hidden states. To know more
38
+ please check the reference.
39
+
40
+ Args:
41
+ num_heads: int, the number of heads in the
42
+ `keras.layers.TwoStreamRelativeAttention` layer.
43
+ hidden_dim: int, the size hidden states.
44
+ head_dim: int, the size of each attention head.
45
+ intermediate_dim: int, the hidden size of feedforward network.
46
+ dropout: float, defaults to 0.0 the dropout value, shared by
47
+ `keras.layers.TwoStreamRelativeAttention` and feedforward network.
48
+ activation: string or `keras.activations`, defaults to "gelu". the
49
+ activation function of feedforward network.
50
+ layer_norm_epsilon: float, defaults to 1e-12. The epsilon value in layer
51
+ normalization components.
52
+ kernel_initializer_range: int, defaults to 0.02. The kernel initializer
53
+ range for the dense and relative attention layers.
54
+ bias_initializer: string or `keras.initializers` initializer,
55
+ defaults to "zeros". The bias initializer for
56
+ the dense and multiheaded relative attention layers.
57
+ name: string, defaults to None. The name of the layer.
58
+ **kwargs: other keyword arguments.
59
+
60
+ References:
61
+ - [XLNet: Generalized Autoregressive Pretraining for Language Understanding]
62
+ (https://arxiv.org/abs/1906.08237)
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ num_heads,
68
+ hidden_dim,
69
+ head_dim,
70
+ intermediate_dim,
71
+ dropout=0.0,
72
+ activation="gelu",
73
+ layer_norm_epsilon=1e-12,
74
+ kernel_initializer_range=0.02,
75
+ bias_initializer="zeros",
76
+ name=None,
77
+ **kwargs
78
+ ):
79
+ super().__init__(name=name, **kwargs)
80
+ self.num_heads = num_heads
81
+ self.hidden_dim = hidden_dim
82
+ self.head_dim = head_dim
83
+ self.intermediate_dim = intermediate_dim
84
+ self.dropout = dropout
85
+ self.activation = activation
86
+ self.layer_norm_epsilon = layer_norm_epsilon
87
+ self.kernel_initializer_range = kernel_initializer_range
88
+ self.bias_initializer = keras.initializers.get(bias_initializer)
89
+ self.kernel_initializer = xlnet_kernel_initializer(
90
+ self.kernel_initializer_range
91
+ )
92
+
93
+ def build(self, input_shape):
94
+ # Attention Part
95
+ self.relative_attention = TwoStreamRelativeAttention(
96
+ num_heads=self.num_heads,
97
+ key_dim=self.head_dim,
98
+ kernel_initializer=self.kernel_initializer,
99
+ bias_initializer=self.bias_initializer,
100
+ dtype=self.dtype_policy,
101
+ name="rel_attn",
102
+ )
103
+ self.relative_attention.build(input_shape)
104
+
105
+ self.layer_norm = keras.layers.LayerNormalization(
106
+ epsilon=self.layer_norm_epsilon,
107
+ dtype=self.dtype_policy,
108
+ name="layer_norm_rel_attn",
109
+ )
110
+ self.layer_norm.build(input_shape)
111
+
112
+ self.dropout_attn = keras.layers.Dropout(
113
+ self.dropout,
114
+ dtype=self.dtype_policy,
115
+ )
116
+
117
+ # Feed-Forward Part
118
+ self.layer_norm_ff = keras.layers.LayerNormalization(
119
+ epsilon=self.layer_norm_epsilon,
120
+ dtype=self.dtype_policy,
121
+ name="layer_norm_ff",
122
+ )
123
+ self.layer_norm_ff.build(input_shape)
124
+
125
+ self.feedforward_intermediate_dense = keras.layers.Dense(
126
+ self.intermediate_dim,
127
+ kernel_initializer=self.kernel_initializer,
128
+ dtype=self.dtype_policy,
129
+ name="feedforward_intermediate_dense",
130
+ )
131
+ self.feedforward_intermediate_dense.build(input_shape)
132
+
133
+ self.feedforward_output_dense = keras.layers.Dense(
134
+ self.hidden_dim,
135
+ kernel_initializer=self.kernel_initializer,
136
+ dtype=self.dtype_policy,
137
+ name="feedforward_output_dense",
138
+ )
139
+ self.feedforward_output_dense.build(
140
+ self.feedforward_intermediate_dense.compute_output_shape(
141
+ input_shape
142
+ )
143
+ )
144
+
145
+ self.dropout_ff = keras.layers.Dropout(
146
+ self.dropout,
147
+ dtype=self.dtype_policy,
148
+ )
149
+
150
+ self.activation_function_ff = keras.activations.get(self.activation)
151
+
152
+ self.content_attention_bias = self.add_weight(
153
+ shape=(self.num_heads, self.head_dim),
154
+ initializer=self.bias_initializer,
155
+ trainable=True,
156
+ name="content_attention_bias",
157
+ )
158
+
159
+ self.positional_attention_bias = self.add_weight(
160
+ shape=(self.num_heads, self.head_dim),
161
+ initializer=self.bias_initializer,
162
+ trainable=True,
163
+ name="positional_attention_bias",
164
+ )
165
+
166
+ self.segment_attention_bias = self.add_weight(
167
+ shape=(self.num_heads, self.head_dim),
168
+ initializer=self.bias_initializer,
169
+ trainable=True,
170
+ name="segment_attention_bias",
171
+ )
172
+
173
+ self.segment_encoding = self.add_weight(
174
+ shape=(2, self.num_heads, self.head_dim),
175
+ initializer=self.kernel_initializer,
176
+ trainable=True,
177
+ name="segment_encoding",
178
+ )
179
+
180
+ super().build(input_shape)
181
+
182
+ def call(
183
+ self,
184
+ output_content,
185
+ attn_mask_content,
186
+ attn_mask_query,
187
+ pos_emb,
188
+ seg_mat,
189
+ output_query=None,
190
+ mems=None,
191
+ target_mapping=None,
192
+ ):
193
+ # rel_attn
194
+ attn_out_content, attn_out_query = self.relative_attention(
195
+ content_stream=output_content,
196
+ query_stream=output_query,
197
+ content_attention_mask=attn_mask_content,
198
+ query_attention_mask=attn_mask_query,
199
+ relative_position_encoding=pos_emb,
200
+ content_attention_bias=self.content_attention_bias,
201
+ positional_attention_bias=self.positional_attention_bias,
202
+ segment_attention_bias=self.segment_attention_bias,
203
+ segment_matrix=seg_mat,
204
+ segment_encoding=self.segment_encoding,
205
+ target_mapping=target_mapping,
206
+ state=mems,
207
+ )
208
+
209
+ attn_out_content = self.dropout_attn(attn_out_content)
210
+ attn_out_content = attn_out_content + output_content
211
+ attn_out_content = self.layer_norm(attn_out_content)
212
+
213
+ if attn_out_query is not None:
214
+ attn_out_query = self.dropout_attn(attn_out_query)
215
+ attn_out_query = attn_out_query + output_query
216
+ attn_out_query = self.layer_norm(attn_out_query)
217
+
218
+ # feed-forward
219
+ ff_out_content = attn_out_content
220
+ ff_out_content = self.feedforward_intermediate_dense(ff_out_content)
221
+ ff_out_content = self.activation_function_ff(ff_out_content)
222
+ ff_out_content = self.dropout_ff(ff_out_content)
223
+ ff_out_content = self.feedforward_output_dense(ff_out_content)
224
+ ff_out_content = self.dropout_ff(ff_out_content)
225
+ ff_out_content = self.layer_norm_ff(ff_out_content + attn_out_content)
226
+
227
+ if attn_out_query is not None:
228
+ ff_out_query = attn_out_query
229
+ ff_out_query = self.feedforward_intermediate_dense(ff_out_query)
230
+ ff_out_query = self.activation_function_ff(ff_out_query)
231
+ ff_out_query = self.dropout_ff(ff_out_query)
232
+ ff_out_query = self.feedforward_output_dense(ff_out_query)
233
+ ff_out_query = self.dropout_ff(ff_out_query)
234
+ ff_out_query = self.layer_norm_ff(ff_out_query + attn_out_query)
235
+
236
+ return ff_out_content, ff_out_query
237
+
238
+ return ff_out_content, None
239
+
240
+ def compute_output_shape(
241
+ self,
242
+ output_content_shape,
243
+ pos_emb_shape,
244
+ attn_mask_content_shape,
245
+ attn_mask_query_shape,
246
+ seg_mat_shape,
247
+ output_query_shape=None,
248
+ ):
249
+ return [output_content_shape, output_content_shape]
250
+
251
+
252
+ class XLNetAttentionMaskLayer(keras.layers.Layer):
253
+ """
254
+ Attention Mask Layer for XLNet Encoder Block.
255
+
256
+ This layer processes attention masks for both content state and query state
257
+ during the forward pass.
258
+
259
+ Args:
260
+ hidden_dim: int, the size hidden states.
261
+ kernel_initializer_range: int, defaults to 0.02. The kernel initializer
262
+ range for the dense and relative attention layers.
263
+ **kwargs: other keyword arguments.
264
+ """
265
+
266
+ def __init__(self, hidden_dim, kernel_initializer_range, **kwargs):
267
+ super().__init__(**kwargs)
268
+ self.hidden_dim = hidden_dim
269
+ self.kernel_initializer_range = kernel_initializer_range
270
+ self.kernel_initializer = xlnet_kernel_initializer(
271
+ self.kernel_initializer_range
272
+ )
273
+
274
+ def build(self, inputs_shape):
275
+ self.mask_emb = self.add_weight(
276
+ shape=(1, 1, self.hidden_dim),
277
+ initializer=self.kernel_initializer,
278
+ trainable=True,
279
+ name="mask_emb",
280
+ )
281
+ self.built = True
282
+
283
+ def call(self, inputs, mlen=None):
284
+ bsz, qlen = ops.shape(inputs)[0], ops.shape(inputs)[1]
285
+ mlen = 0 if mlen is None else mlen
286
+
287
+ inputs = 1 - inputs
288
+ inputs = ops.reshape(
289
+ inputs,
290
+ [ops.shape(inputs)[1], ops.shape(inputs)[0]],
291
+ )
292
+
293
+ data_mask = ops.expand_dims(inputs, 0)
294
+
295
+ if mlen > 0:
296
+ mems_mask = ops.zeros([ops.shape(data_mask)[0], mlen, bsz])
297
+ data_mask = ops.concatenate(
298
+ [ops.cast(mems_mask, dtype="int32"), data_mask], axis=1
299
+ )
300
+ attn_mask_query = ops.expand_dims(data_mask, -1)
301
+
302
+ attn_mask_query = ops.cast(
303
+ attn_mask_query > 0, dtype=attn_mask_query.dtype
304
+ )
305
+
306
+ # Since ops.eye doesn't support tensorflow Tensor as input.
307
+ # we need to create custom function here.
308
+ n = ops.expand_dims(ops.arange(qlen), -1)
309
+ m = ops.arange(qlen)
310
+ attn_mask_content = -ops.cast(
311
+ ops.where(n == m, 1, 0), attn_mask_query.dtype
312
+ )
313
+
314
+ if mlen > 0:
315
+ attn_mask_content = ops.concatenate(
316
+ [
317
+ ops.zeros([qlen, mlen], dtype=attn_mask_content.dtype),
318
+ attn_mask_content,
319
+ ],
320
+ axis=-1,
321
+ )
322
+
323
+ attn_mask_content = ops.cast(
324
+ (
325
+ attn_mask_query
326
+ + ops.expand_dims(ops.expand_dims(attn_mask_content, -1), -1)
327
+ )
328
+ > 0,
329
+ dtype=attn_mask_content.dtype,
330
+ )
331
+
332
+ # to make sure inputs suitable for TwoStreamRelativeAttention
333
+ attn_mask_content = 1.0 - ops.cast(
334
+ ops.transpose(ops.squeeze(attn_mask_content, -1), [2, 0, 1]),
335
+ "float32",
336
+ )
337
+ attn_mask_query = 1.0 - ops.cast(
338
+ ops.transpose(ops.squeeze(attn_mask_query, -1), [2, 0, 1]),
339
+ "float32",
340
+ )
341
+
342
+ return attn_mask_content, attn_mask_query
343
+
344
+ def compute_output_shape(self, padding_mask_shape):
345
+ return [padding_mask_shape, padding_mask_shape]
346
+
347
+
348
+ class XLNetSegmentMatrixLayer(keras.layers.Layer):
349
+ """
350
+ This layer creates Segment Matrix for XLNet Encoder.
351
+ """
352
+
353
+ def call(self, segment_ids, mlen=None):
354
+ bsz = ops.shape(segment_ids)[0]
355
+ mlen = 0 if mlen is None else mlen
356
+
357
+ # Prepare seg_mat
358
+ segment_ids = ops.transpose(segment_ids, [1, 0])
359
+
360
+ if mlen > 0:
361
+ mem_pad = ops.zeros([mlen, bsz], dtype=segment_ids.dtype)
362
+ cat_ids = ops.concatenate([mem_pad, segment_ids], 0)
363
+ else:
364
+ cat_ids = segment_ids
365
+
366
+ # `1` indicates not in the same segment [qlen x klen x bsz]
367
+ seg_mat = ops.cast(
368
+ ops.logical_not(ops.equal(segment_ids[:, None], cat_ids[None, :])),
369
+ dtype=segment_ids.dtype,
370
+ )
371
+
372
+ # to make sure inputs suitable for TwoStreamRelativeAttention
373
+ seg_mat = ops.cast(ops.transpose(seg_mat, [2, 0, 1]), dtype="bool")
374
+
375
+ return seg_mat
376
+
377
+ def compute_output_shape(self, segment_ids_shape):
378
+ return segment_ids_shape
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.