keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,459 @@
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import string
17
+
18
+ import keras
19
+ from keras import ops
20
+
21
+ _CHR_IDX = string.ascii_lowercase
22
+
23
+
24
+ def _build_proj_equation(free_dims, bound_dims, output_dims):
25
+ """
26
+ Builds an einsum equation for projections inside multi-head attention.
27
+ """
28
+ input_str = ""
29
+ kernel_str = ""
30
+ output_str = ""
31
+ bias_axes = ""
32
+ letter_offset = 0
33
+ for i in range(free_dims):
34
+ char = _CHR_IDX[i + letter_offset]
35
+ input_str += char
36
+ output_str += char
37
+
38
+ letter_offset += free_dims
39
+ for i in range(bound_dims):
40
+ char = _CHR_IDX[i + letter_offset]
41
+ input_str += char
42
+ kernel_str += char
43
+
44
+ letter_offset += bound_dims
45
+ for i in range(output_dims):
46
+ char = _CHR_IDX[i + letter_offset]
47
+ kernel_str += char
48
+ output_str += char
49
+ bias_axes += char
50
+ equation = "%s,%s->%s" % (input_str, kernel_str, output_str)
51
+
52
+ return equation, bias_axes, len(output_str)
53
+
54
+
55
+ def _get_output_shape(output_rank, known_last_dims):
56
+ return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
57
+
58
+
59
+ def _rel_shift(x, klen=-1):
60
+ """
61
+ Performs relative shift to form the relative attention score.
62
+ """
63
+
64
+ x = ops.transpose(x, [2, 3, 0, 1])
65
+ x_size = ops.shape(x)
66
+ x = ops.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]])
67
+ x = ops.slice(
68
+ x, [1, 0, 0, 0], [x_size[1] - 1, x_size[0], x_size[2], x_size[3]]
69
+ )
70
+ x = ops.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]])
71
+ x = ops.slice(x, [0, 0, 0, 0], [x_size[0], klen, x_size[2], x_size[3]])
72
+
73
+ x = ops.transpose(x, [2, 3, 0, 1])
74
+
75
+ return x
76
+
77
+
78
+ class TwoStreamRelativeAttention(keras.layers.MultiHeadAttention):
79
+ """Two-stream relative self-attention for XLNet.
80
+
81
+ In XLNet, each token has two associated vectors at each self-attention layer,
82
+ the content stream (h) and the query stream (g). The content stream is the
83
+ self-attention stream as in Transformer XL and represents the context and
84
+ content (the token itself). The query stream only has access to contextual
85
+ information and the position, but not the content.
86
+
87
+ This layer shares the same build signature as `keras.layers.MultiHeadAttention`
88
+ but has different input/output projections.
89
+
90
+ We use the notations `B`, `T`, `S`, `M`, `L`, `E`, `P`, `dim`, `num_heads`
91
+ below, where
92
+ `B` is the batch dimension, `T` is the target sequence length,
93
+ `S` in the source sequence length, `M` is the length of the state or memory,
94
+ `L` is the length of relative positional encoding, `E` is the last dimension
95
+ of query input, `P` is the number of predictions, `dim` is the dimensionality
96
+ of the encoder layers. and `num_heads` is the number of attention heads.
97
+
98
+ Args:
99
+ content_stream: `Tensor` of shape `[B, T, dim]`.
100
+ content_attention_bias: Bias `Tensor` for content based attention of shape
101
+ `[num_heads, dim]`.
102
+ positional_attention_bias: Bias `Tensor` for position based attention of
103
+ shape `[num_heads, dim]`.
104
+ query_stream: `Tensor` of shape `[B, P, dim]`.
105
+ target_mapping: `Tensor` of shape `[B, P, S]`.
106
+ relative_position_encoding: Relative positional encoding `Tensor` of
107
+ shape `[B, L, dim]`.
108
+ segment_matrix: Optional `Tensor` representing segmentation IDs used in
109
+ XLNet of shape `[B, S, S + M]`.
110
+ segment_encoding: Optional `Tensor` representing the segmentation
111
+ encoding as used in XLNet of shape `[2, num_heads, dim]`.
112
+ segment_attention_bias: Optional trainable bias parameter added to the
113
+ query had when calculating the segment-based attention score used
114
+ in XLNet of shape `[num_heads, dim]`.
115
+ state: Optional `Tensor` of shape `[B, M, E]`.
116
+ If passed, this is also attended over as in Transformer XL.
117
+ content_attention_mask: a boolean mask of shape `[B, T, S]` that
118
+ prevents attention to certain positions for content attention
119
+ computation.
120
+ query_attention_mask: a boolean mask of shape `[B, T, S]` that
121
+ prevents attention to certain position for query attention
122
+ computation.
123
+ """
124
+
125
+ def __init__(self, kernel_initializer="glorot_uniform", **kwargs):
126
+ super().__init__(kernel_initializer=kernel_initializer, **kwargs)
127
+
128
+ def _get_common_kwargs_for_sublayer(self):
129
+ common_kwargs = dict(
130
+ kernel_initializer=self._kernel_initializer,
131
+ bias_initializer=self._bias_initializer,
132
+ kernel_regularizer=self._kernel_regularizer,
133
+ bias_regularizer=self._bias_regularizer,
134
+ activity_regularizer=self._activity_regularizer,
135
+ kernel_constraint=self._kernel_constraint,
136
+ bias_constraint=self._bias_constraint,
137
+ )
138
+ return common_kwargs
139
+
140
+ def build(self, content_stream_shape):
141
+ self._use_bias = False
142
+
143
+ self._query_shape = content_stream_shape
144
+ self._key_shape = content_stream_shape
145
+ self._value_shape = content_stream_shape
146
+
147
+ free_dims = len(self._query_shape) - 1
148
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
149
+ free_dims, bound_dims=1, output_dims=2
150
+ )
151
+ self._query_dense = keras.layers.EinsumDense(
152
+ einsum_equation,
153
+ output_shape=_get_output_shape(
154
+ output_rank - 1, [self._num_heads, self._key_dim]
155
+ ),
156
+ bias_axes=bias_axes if self._use_bias else None,
157
+ dtype=self.dtype_policy,
158
+ name="query",
159
+ **self._get_common_kwargs_for_sublayer(),
160
+ )
161
+ self._query_dense.build(self._query_shape)
162
+
163
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
164
+ len(self._key_shape) - 1, bound_dims=1, output_dims=2
165
+ )
166
+ self._key_dense = keras.layers.EinsumDense(
167
+ einsum_equation,
168
+ output_shape=_get_output_shape(
169
+ output_rank - 1, [self._num_heads, self._key_dim]
170
+ ),
171
+ bias_axes=bias_axes if self._use_bias else None,
172
+ dtype=self.dtype_policy,
173
+ name="key",
174
+ **self._get_common_kwargs_for_sublayer(),
175
+ )
176
+ self._key_dense.build(self._key_shape)
177
+
178
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
179
+ len(self._value_shape) - 1, bound_dims=1, output_dims=2
180
+ )
181
+ self._value_dense = keras.layers.EinsumDense(
182
+ einsum_equation,
183
+ output_shape=_get_output_shape(
184
+ output_rank - 1, [self._num_heads, self._value_dim]
185
+ ),
186
+ bias_axes=bias_axes if self._use_bias else None,
187
+ dtype=self.dtype_policy,
188
+ name="value",
189
+ **self._get_common_kwargs_for_sublayer(),
190
+ )
191
+ self._value_dense.build(self._value_shape)
192
+
193
+ free_dims = len(self._query_shape) - 1
194
+ _, _, output_rank = _build_proj_equation(
195
+ free_dims, bound_dims=2, output_dims=1
196
+ )
197
+ self._output_dense = keras.layers.EinsumDense(
198
+ "ibnd,hnd->ibh",
199
+ output_shape=_get_output_shape(
200
+ output_rank - 1, [self._query_shape[-1]]
201
+ ),
202
+ bias_axes=None,
203
+ dtype=self.dtype_policy,
204
+ name="attention_output",
205
+ **self._get_common_kwargs_for_sublayer(),
206
+ )
207
+ self._output_dense.build(
208
+ self._value_dense.compute_output_shape(self._value_dim)
209
+ )
210
+
211
+ einsum_equation, _, output_rank = _build_proj_equation(
212
+ len(self._key_shape) - 1, bound_dims=1, output_dims=2
213
+ )
214
+ self._encoding_dense = keras.layers.EinsumDense(
215
+ einsum_equation,
216
+ output_shape=_get_output_shape(
217
+ output_rank - 1, [self._num_heads, self._key_dim]
218
+ ),
219
+ bias_axes=None,
220
+ dtype=self.dtype_policy,
221
+ name="encoding",
222
+ **self._get_common_kwargs_for_sublayer(),
223
+ )
224
+ self._encoding_dense.build(self._key_shape)
225
+
226
+ self._build_attention(output_rank)
227
+ self.built = True
228
+
229
+ def compute_attention(
230
+ self,
231
+ query,
232
+ key,
233
+ value,
234
+ position,
235
+ content_attention_bias,
236
+ positional_attention_bias,
237
+ segment_matrix=None,
238
+ segment_encoding=None,
239
+ segment_attention_bias=None,
240
+ attention_mask=None,
241
+ ):
242
+ """Computes the attention.
243
+
244
+ This function defines the computation inside `call` with projected
245
+ multihead Q, K, V, R inputs.
246
+
247
+ We use the notations `B`, `T`, `S`, `M`, `L`, `num_heads`, `key_dim`
248
+ below, where
249
+ `B` is the batch dimension, `T` is the target sequence length,
250
+ `S` in the source sequence length, `M` is the length of the state,
251
+ `L` is the length of relative positional encoding, `num_heads` is
252
+ number of attention heads and `key_dim` is size of each attention head
253
+ for query and key.
254
+
255
+ Args:
256
+ query: Projected query `Tensor` of shape
257
+ `[B, T, num_heads, key_dim]`.
258
+ key: Projected key `Tensor` of shape
259
+ `[B, S + M, num_heads, key_dim]`.
260
+ value: Projected value `Tensor` of shape
261
+ `[B, S + M, num_heads, key_dim]`.
262
+ position: Projected position `Tensor` of shape
263
+ `[B, L, num_heads, key_dim]`.
264
+ content_attention_bias: Trainable bias parameter added to the query
265
+ head when calculating the content-based attention score.
266
+ positional_attention_bias: Trainable bias parameter added to the
267
+ query head when calculating the position-based attention score.
268
+ segment_matrix: Optional `Tensor` representing segmentation IDs
269
+ used in XLNet.
270
+ segment_encoding: Optional trainable `Tensor` representing the
271
+ segmentation encoding as used in XLNet.
272
+ segment_attention_bias: Optional trainable bias parameter added
273
+ to the query had when calculating the segment-based attention
274
+ score used in XLNet.
275
+ attention_mask: (default None) Optional mask that is added to
276
+ attention logits. If state is not None, the mask source sequence
277
+ dimension should extend M.
278
+ Returns:
279
+ attention_output: Multi-headed output of attention computation of
280
+ shape `[B, S, num_heads, key_dim]`.
281
+ """
282
+ content_attention = ops.einsum(
283
+ self._dot_product_equation, key, query + content_attention_bias
284
+ )
285
+ positional_attention = ops.einsum(
286
+ self._dot_product_equation,
287
+ position,
288
+ query + positional_attention_bias,
289
+ )
290
+ positional_attention = _rel_shift(
291
+ positional_attention, klen=ops.shape(content_attention)[3]
292
+ )
293
+
294
+ if segment_matrix is not None:
295
+ segment_attention = ops.einsum(
296
+ "bind,snd->bnis",
297
+ query + segment_attention_bias,
298
+ segment_encoding,
299
+ )
300
+ target_shape = ops.shape(positional_attention)
301
+ segment_attention = ops.where(
302
+ ops.broadcast_to(
303
+ ops.expand_dims(segment_matrix, 1), target_shape
304
+ ),
305
+ ops.broadcast_to(segment_attention[:, :, :, 1:], target_shape),
306
+ ops.broadcast_to(segment_attention[:, :, :, :1], target_shape),
307
+ )
308
+ attention_sum = (
309
+ content_attention + positional_attention + segment_attention
310
+ )
311
+ else:
312
+ attention_sum = content_attention + positional_attention
313
+
314
+ attention_scores = ops.multiply(
315
+ attention_sum, 1.0 / math.sqrt(float(self._key_dim))
316
+ )
317
+
318
+ attention_scores = self._masked_softmax(
319
+ attention_scores, attention_mask
320
+ )
321
+
322
+ attention_output = self._dropout_layer(attention_scores)
323
+
324
+ attention_output = ops.einsum(
325
+ self._combine_equation, attention_output, value
326
+ )
327
+
328
+ return attention_output
329
+
330
+ def call(
331
+ self,
332
+ content_stream,
333
+ content_attention_bias,
334
+ positional_attention_bias,
335
+ relative_position_encoding,
336
+ query_stream=None,
337
+ target_mapping=None,
338
+ segment_matrix=None,
339
+ segment_encoding=None,
340
+ segment_attention_bias=None,
341
+ state=None,
342
+ content_attention_mask=None,
343
+ query_attention_mask=None,
344
+ ):
345
+ """Compute multi-head relative attention over inputs.
346
+
347
+ We use the notations `B`, `T`, `M`, `E` below, where
348
+ `B` is the batch dimension, `T` is the target sequence length,
349
+ `M` is the length of the state or memory and `E` is the last
350
+ dimension of query input.
351
+
352
+ Args:
353
+ content_stream: The content representation, commonly referred to as h.
354
+ This serves a similar role to the standard hidden states in
355
+ Transformer-XL.
356
+ content_attention_bias: A trainable bias parameter added to the query
357
+ head when calculating the content-based attention score.
358
+ positional_attention_bias: A trainable bias parameter added to the
359
+ query head when calculating the position-based attention score.
360
+ query_stream: The query representation, commonly referred to as g.
361
+ This only has access to contextual information and position, but
362
+ not content. If not provided, then this is
363
+ MultiHeadRelativeAttention with self-attention.
364
+ relative_position_encoding: relative positional encoding for key
365
+ and value.
366
+ target_mapping: Optional `Tensor` representing the target mapping
367
+ used in partial prediction.
368
+ segment_matrix: Optional `Tensor` representing segmentation IDs
369
+ used in XLNet.
370
+ segment_encoding: Optional `Tensor` representing the segmentation
371
+ encoding as used in XLNet.
372
+ segment_attention_bias: Optional trainable bias parameter added
373
+ to the query head when calculating the segment-based attention
374
+ score.
375
+ state: (default None) optional state. If passed, this is also
376
+ attended over as in TransformerXL and XLNet.
377
+ content_attention_mask: (default None) Optional mask that is added
378
+ to content attention logits. If state is not None, the mask
379
+ source sequence dimension should extend M.
380
+ query_attention_mask: (default None) Optional mask that is added to
381
+ query attention logits. If state is not None, the mask source
382
+ sequence dimension should extend M.
383
+
384
+ Returns:
385
+ content_attention_output, query_attention_output: the results of the
386
+ computation, both of shape `[B, T, E]`.
387
+ """
388
+
389
+ if state is not None and len(state.shape) > 1:
390
+ content_and_memory_stream = ops.concatenate(
391
+ [state, content_stream], 1
392
+ )
393
+ else:
394
+ content_and_memory_stream = content_stream
395
+
396
+ # `query` = [B, T, N, H]
397
+ query = self._query_dense(content_stream)
398
+
399
+ # `key` = [B, S + M, N, H]
400
+ key = self._key_dense(content_and_memory_stream)
401
+
402
+ # `value` = [B, S + M, N, H]
403
+ value = self._value_dense(content_and_memory_stream)
404
+
405
+ # `position` = [B, L, N, H]
406
+ position = self._encoding_dense(relative_position_encoding)
407
+
408
+ content_attention_output = self.compute_attention(
409
+ query=query,
410
+ key=key,
411
+ value=value,
412
+ position=position,
413
+ content_attention_bias=content_attention_bias,
414
+ positional_attention_bias=positional_attention_bias,
415
+ segment_matrix=segment_matrix,
416
+ segment_encoding=segment_encoding,
417
+ segment_attention_bias=segment_attention_bias,
418
+ attention_mask=content_attention_mask,
419
+ )
420
+
421
+ # `content_attention_output` = [B, S, N, H]
422
+ content_attention_output = self._output_dense(content_attention_output)
423
+
424
+ query_attention_output = None
425
+ if query_stream is not None:
426
+ query = self._query_dense(query_stream)
427
+ if target_mapping is not None:
428
+ query = ops.einsum("bmnd,bml->blnd", query, target_mapping)
429
+ query_attention_output = self.compute_attention(
430
+ query=query,
431
+ key=key,
432
+ value=value,
433
+ position=position,
434
+ content_attention_bias=content_attention_bias,
435
+ positional_attention_bias=positional_attention_bias,
436
+ segment_matrix=segment_matrix,
437
+ segment_encoding=segment_encoding,
438
+ segment_attention_bias=segment_attention_bias,
439
+ attention_mask=query_attention_mask,
440
+ )
441
+ query_attention_output = ops.einsum(
442
+ "blnd,bml->bmnd", query_attention_output, target_mapping
443
+ )
444
+ else:
445
+ query_attention_output = self.compute_attention(
446
+ query=query,
447
+ key=key,
448
+ value=value,
449
+ position=position,
450
+ content_attention_bias=content_attention_bias,
451
+ positional_attention_bias=positional_attention_bias,
452
+ segment_matrix=segment_matrix,
453
+ segment_encoding=segment_encoding,
454
+ segment_attention_bias=segment_attention_bias,
455
+ attention_mask=query_attention_mask,
456
+ )
457
+ query_attention_output = self._output_dense(query_attention_output)
458
+
459
+ return content_attention_output, query_attention_output
@@ -0,0 +1,222 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+
17
+ from keras_hub.src.api_export import keras_hub_export
18
+ from keras_hub.src.models.backbone import Backbone
19
+ from keras_hub.src.models.xlnet.xlnet_content_and_query_embedding import (
20
+ ContentAndQueryEmbedding,
21
+ )
22
+ from keras_hub.src.models.xlnet.xlnet_encoder import XLNetAttentionMaskLayer
23
+ from keras_hub.src.models.xlnet.xlnet_encoder import XLNetEncoder
24
+ from keras_hub.src.models.xlnet.xlnet_encoder import XLNetSegmentMatrixLayer
25
+
26
+
27
+ @keras_hub_export("keras_hub.models.XLNetBackbone")
28
+ class XLNetBackbone(Backbone):
29
+ """XLNet encoder network.
30
+
31
+ This class implements a XLNet Transformer.
32
+
33
+ The default constructor gives a fully customizable, randomly initialized
34
+ XLNet encoder with any number of layers, heads, and embedding dimensions.
35
+ To load preset architectures and weights, use the `from_preset` constructor.
36
+
37
+ Disclaimer: Pre-trained models are provided on an "as is" basis, without
38
+ warranties or conditions of any kind.
39
+
40
+ Attributes:
41
+ vocabulary_size: int. The size of the token vocabulary.
42
+ num_layers: int. The number of transformer encoder layers.
43
+ num_heads: int, the number of heads in the
44
+ `keras.layers.TwoStreamRelativeAttention` layer.
45
+ hidden_dim: int, the size hidden states.
46
+ intermediate_dim: int, the hidden size of feedforward network.
47
+ dropout: float, defaults to 0.0 the dropout value, shared by
48
+ `keras.layers.TwoStreamRelativeAttention` and feedforward network.
49
+ activation: string or `keras.activations`, defaults to "gelu". the
50
+ activation function of feedforward network.
51
+ kernel_initializer_range: int, defaults to 0.02. The kernel initializer
52
+ range for the dense and relative attention layers.
53
+ bias_initializer: string or `keras.initializers` initializer,
54
+ defaults to "zeros". The bias initializer for
55
+ the dense and multiheaded relative attention layers.
56
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
57
+ for model computations and weights. Note that some computations,
58
+ such as softmax and layer normalization, will always be done at
59
+ float32 precision regardless of dtype.
60
+
61
+ Call arguments:
62
+ token_ids: Indices of input sequence tokens in the vocabulary of shape
63
+ `[batch_size, sequence_length]`.
64
+ segment_ids: Segment token indices to indicate first and second portions
65
+ of the inputs of shape `[batch_size, sequence_length]`.
66
+ padding_mask: Mask to avoid performing attention on padding token indices
67
+ of shape `[batch_size, sequence_length]`.
68
+
69
+ Example:
70
+ ```python
71
+ import numpy as np
72
+ from keras_hub.src.models import XLNetBackbone
73
+
74
+ input_data = {
75
+ "token_ids": np.array(
76
+ [460, 5272, 1758, 4905, 9, 4, 3], shape=(1, 7),
77
+ ),
78
+ "segment_ids": np.array(
79
+ [0, 0, 0, 0, 0, 0, 2], shape=(1, 7),
80
+ ),
81
+ "padding_mask": np.array(
82
+ [1, 1, 1, 1, 1, 1, 1], shape=(1, 7)
83
+ ),
84
+ }
85
+
86
+ # Randomly initialized XLNet encoder with a custom config
87
+ model = keras_hub.models.XLNetBackbone(
88
+ vocabulary_size=32000,
89
+ num_layers=12,
90
+ num_heads=12,
91
+ hidden_dim=768,
92
+ intermediate_dim=3072,
93
+ )
94
+ output = model(input_data)
95
+ ```
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ vocabulary_size,
101
+ num_layers,
102
+ num_heads,
103
+ hidden_dim,
104
+ intermediate_dim,
105
+ dropout=0.0,
106
+ activation="gelu",
107
+ kernel_initializer_range=0.02,
108
+ bias_initializer="zeros",
109
+ dtype=None,
110
+ **kwargs,
111
+ ):
112
+ # === Layers ===
113
+ self.content_query_embedding = ContentAndQueryEmbedding(
114
+ vocabulary_size=vocabulary_size,
115
+ hidden_dim=hidden_dim,
116
+ dropout=dropout,
117
+ dtype=dtype,
118
+ name="content_query_embedding",
119
+ )
120
+ self.attn_mask_layer = XLNetAttentionMaskLayer(
121
+ hidden_dim=hidden_dim,
122
+ kernel_initializer_range=kernel_initializer_range,
123
+ dtype=dtype,
124
+ name="encoder_block_attn_mask_layer",
125
+ )
126
+ self.seg_mat_layer = XLNetSegmentMatrixLayer(
127
+ dtype=dtype,
128
+ name="encoder_block_seg_mat_layer",
129
+ )
130
+ head_dim = hidden_dim // num_heads
131
+ self.transformer_layers = []
132
+ for i in range(num_layers):
133
+ layer = XLNetEncoder(
134
+ num_heads=num_heads,
135
+ hidden_dim=hidden_dim,
136
+ head_dim=head_dim,
137
+ intermediate_dim=intermediate_dim,
138
+ dropout=dropout,
139
+ activation=activation,
140
+ layer_norm_epsilon=1e-12,
141
+ kernel_initializer_range=kernel_initializer_range,
142
+ bias_initializer=bias_initializer,
143
+ dtype=dtype,
144
+ name=f"xlnet_encoder_{i}",
145
+ )
146
+ self.transformer_layers.append(layer)
147
+ self.dropout = keras.layers.Dropout(
148
+ dropout,
149
+ dtype=dtype,
150
+ name="dropout",
151
+ )
152
+
153
+ # === Functional Model ===
154
+ token_id_input = keras.Input(
155
+ shape=(None,), dtype="int32", name="token_ids"
156
+ )
157
+ padding_mask_input = keras.Input(
158
+ shape=(None,), dtype="int32", name="padding_mask"
159
+ )
160
+ segment_id_input = keras.Input(
161
+ shape=(None,), dtype="int32", name="segment_ids"
162
+ )
163
+ # Content and Query Embedding
164
+ word_emb, pos_emb = self.content_query_embedding(token_id_input)
165
+ # Apply XLNetAttentionMaskLayer and XLNetSegmentMatrixLayer Layers
166
+ # to get the processed attention masks and segment matrix.
167
+ attn_mask_content, attn_mask_query = self.attn_mask_layer(
168
+ padding_mask_input
169
+ )
170
+ seg_mat = self.seg_mat_layer(segment_id_input)
171
+ output_content = word_emb
172
+ for transformer_layer in self.transformer_layers:
173
+ output_content, output_query = transformer_layer(
174
+ output_content=output_content,
175
+ attn_mask_content=attn_mask_content,
176
+ attn_mask_query=attn_mask_query,
177
+ pos_emb=pos_emb,
178
+ seg_mat=seg_mat,
179
+ )
180
+ output = self.dropout(output_content)
181
+ super().__init__(
182
+ inputs={
183
+ "token_ids": token_id_input,
184
+ "padding_mask": padding_mask_input,
185
+ "segment_ids": segment_id_input,
186
+ },
187
+ outputs=output,
188
+ dtype=dtype,
189
+ **kwargs,
190
+ )
191
+
192
+ # === Config ===
193
+ self.vocabulary_size = vocabulary_size
194
+ self.num_layers = num_layers
195
+ self.num_heads = num_heads
196
+ self.hidden_dim = hidden_dim
197
+ self.intermediate_dim = intermediate_dim
198
+ self.dropout = dropout
199
+ self.activation = activation
200
+ self.kernel_initializer_range = kernel_initializer_range
201
+ self.bias_initializer = bias_initializer
202
+
203
+ def get_config(self):
204
+ config = super().get_config()
205
+ config.update(
206
+ {
207
+ "vocabulary_size": self.vocabulary_size,
208
+ "num_layers": self.num_layers,
209
+ "num_heads": self.num_heads,
210
+ "hidden_dim": self.hidden_dim,
211
+ "intermediate_dim": self.intermediate_dim,
212
+ "dropout": self.dropout,
213
+ "activation": self.activation,
214
+ "kernel_initializer_range": self.kernel_initializer_range,
215
+ "bias_initializer": self.bias_initializer,
216
+ }
217
+ )
218
+ return config
219
+
220
+ @property
221
+ def token_embedding(self):
222
+ return self.get_layer("content_query_embedding").word_embed