keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,227 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+
17
+ from keras_hub.src.models.deberta_v3.disentangled_self_attention import (
18
+ DisentangledSelfAttention,
19
+ )
20
+ from keras_hub.src.utils.keras_utils import clone_initializer
21
+
22
+ from keras_hub.src.layers.modeling.transformer_layer_utils import ( # isort:skip
23
+ merge_padding_and_attention_mask,
24
+ )
25
+
26
+
27
+ class DisentangledAttentionEncoder(keras.layers.Layer):
28
+ """Disentangled attention encoder.
29
+
30
+ This class follows the architecture of the disentangled attention encoder
31
+ layer in the paper
32
+ ["DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing"](https://arxiv.org/abs/2111.09543).
33
+ Users can instantiate multiple instances of this class to stack up a
34
+ an encoder model which has disentangled self-attention.
35
+
36
+ `DisentangledAttentionEncoder` is similar to
37
+ `keras_hub.layers.TransformerEncoder`, except for the attention layer - it
38
+ uses disentangled self-attention instead of multi-head attention.
39
+
40
+ Args:
41
+ intermediate_dim: int, the hidden size of feedforward network.
42
+ num_heads: int, the number of heads in the attention layer.
43
+ max_position_embeddings: int. The maximum input
44
+ sequence length. Defaults to `512`.
45
+ bucket_size: int. The size of the relative position
46
+ buckets. Generally equal to `max_sequence_length // 2`.
47
+ Defaults to `256`.
48
+ dropout: float. The dropout value, shared by
49
+ the attention layer and feedforward network.
50
+ Defaults to `0.0`.
51
+ activation: string or `keras.activations`. the
52
+ activation function of feedforward network.
53
+ Defaults to `"relu"`.
54
+ layer_norm_epsilon: float. The epsilon value in layer
55
+ normalization components. Defaults to `1e-5`.
56
+ kernel_initializer: string or `keras.initializers` initializer.
57
+ The kernel initializer for the dense and disentangled
58
+ self-attention layers. Defaults to `"glorot_uniform"`.
59
+ bias_initializer: string or `keras.initializers` initializer.
60
+ The bias initializer for the dense and disentangled
61
+ self-attention layers. Defaults to `"zeros"`.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ intermediate_dim,
67
+ num_heads,
68
+ max_position_embeddings=512,
69
+ bucket_size=256,
70
+ dropout=0,
71
+ activation="relu",
72
+ layer_norm_epsilon=1e-05,
73
+ kernel_initializer="glorot_uniform",
74
+ bias_initializer="zeros",
75
+ **kwargs
76
+ ):
77
+ super().__init__(**kwargs)
78
+ self.intermediate_dim = intermediate_dim
79
+ self.num_heads = num_heads
80
+ self.max_position_embeddings = max_position_embeddings
81
+ self.bucket_size = bucket_size
82
+ self.dropout = dropout
83
+ self.activation = keras.activations.get(activation)
84
+ self.layer_norm_epsilon = layer_norm_epsilon
85
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
86
+ self.bias_initializer = keras.initializers.get(bias_initializer)
87
+ self._built = False
88
+ self.supports_masking = True
89
+
90
+ def build(self, inputs_shape):
91
+ # Infer the dimension of our hidden feature size from the build shape.
92
+ hidden_dim = inputs_shape[-1]
93
+
94
+ # Self attention layers.
95
+ self._self_attention_layer = DisentangledSelfAttention(
96
+ num_heads=self.num_heads,
97
+ hidden_dim=hidden_dim,
98
+ max_position_embeddings=self.max_position_embeddings,
99
+ bucket_size=self.bucket_size,
100
+ dropout=self.dropout,
101
+ kernel_initializer=clone_initializer(self.kernel_initializer),
102
+ bias_initializer=clone_initializer(self.bias_initializer),
103
+ dtype=self.dtype_policy,
104
+ name="self_attention_layer",
105
+ )
106
+ self._self_attention_layer.build(inputs_shape)
107
+ self._self_attention_layer_norm = keras.layers.LayerNormalization(
108
+ epsilon=self.layer_norm_epsilon,
109
+ dtype=self.dtype_policy,
110
+ name="self_attention_layer_norm",
111
+ )
112
+ self._self_attention_layer_norm.build(inputs_shape)
113
+ self._self_attention_dropout = keras.layers.Dropout(
114
+ rate=self.dropout,
115
+ dtype=self.dtype_policy,
116
+ name="self_attention_dropout",
117
+ )
118
+
119
+ # Feedforward layers.
120
+ self._feedforward_layer_norm = keras.layers.LayerNormalization(
121
+ epsilon=self.layer_norm_epsilon,
122
+ dtype=self.dtype_policy,
123
+ name="feedforward_layer_norm",
124
+ )
125
+ self._feedforward_layer_norm.build(inputs_shape)
126
+ self._feedforward_intermediate_dense = keras.layers.Dense(
127
+ self.intermediate_dim,
128
+ activation=self.activation,
129
+ kernel_initializer=clone_initializer(self.kernel_initializer),
130
+ bias_initializer=clone_initializer(self.bias_initializer),
131
+ dtype=self.dtype_policy,
132
+ name="feedforward_intermediate_dense",
133
+ )
134
+ self._feedforward_intermediate_dense.build(inputs_shape)
135
+ self._feedforward_output_dense = keras.layers.Dense(
136
+ hidden_dim,
137
+ kernel_initializer=clone_initializer(self.kernel_initializer),
138
+ bias_initializer=clone_initializer(self.bias_initializer),
139
+ dtype=self.dtype_policy,
140
+ name="feedforward_output_dense",
141
+ )
142
+ intermediate_shape = list(inputs_shape)
143
+ intermediate_shape[-1] = self.intermediate_dim
144
+ self._feedforward_output_dense.build(tuple(intermediate_shape))
145
+ self._feedforward_dropout = keras.layers.Dropout(
146
+ rate=self.dropout,
147
+ dtype=self.dtype_policy,
148
+ name="feedforward_dropout",
149
+ )
150
+ self.built = True
151
+
152
+ def call(
153
+ self,
154
+ inputs,
155
+ rel_embeddings,
156
+ padding_mask=None,
157
+ attention_mask=None,
158
+ ):
159
+ """Forward pass of `DisentangledAttentionEncoder`.
160
+
161
+ Args:
162
+ inputs: a Tensor. The input data to `DisentangledAttentionEncoder`, should be
163
+ of shape [batch_size, sequence_length, hidden_dim].
164
+ rel_embeddings: a Tensor. The relative position embedding matrix,
165
+ should be of shape `[batch_size, 2 * bucket_size, hidden_dim]`.
166
+ padding_mask: a boolean Tensor. It indicates if the token should be
167
+ masked because the token is introduced due to padding.
168
+ `padding_mask` should have shape [batch_size, sequence_length].
169
+ False means the certain token is masked out.
170
+ attention_mask: a boolean Tensor. Customized mask used to mask out
171
+ certain tokens. `attention_mask` should have shape
172
+ [batch_size, sequence_length, sequence_length].
173
+
174
+ Returns:
175
+ A Tensor of the same shape as the `inputs`.
176
+ """
177
+ x = inputs
178
+
179
+ # Compute self attention mask.
180
+ self_attention_mask = merge_padding_and_attention_mask(
181
+ inputs, padding_mask, attention_mask
182
+ )
183
+
184
+ # Self attention block.
185
+ residual = x
186
+ x = self._self_attention_layer(
187
+ x,
188
+ rel_embeddings=rel_embeddings,
189
+ attention_mask=self_attention_mask,
190
+ )
191
+ x = self._self_attention_dropout(x)
192
+ x = x + residual
193
+ x = self._self_attention_layer_norm(x)
194
+
195
+ # Feedforward block.
196
+ residual = x
197
+ x = self._feedforward_intermediate_dense(x)
198
+ x = self._feedforward_output_dense(x)
199
+ x = self._feedforward_dropout(x)
200
+ x = x + residual
201
+ x = self._feedforward_layer_norm(x)
202
+
203
+ return x
204
+
205
+ def get_config(self):
206
+ config = super().get_config()
207
+ config.update(
208
+ {
209
+ "intermediate_dim": self.intermediate_dim,
210
+ "num_heads": self.num_heads,
211
+ "max_position_embeddings": self.max_position_embeddings,
212
+ "bucket_size": self.bucket_size,
213
+ "dropout": self.dropout,
214
+ "activation": keras.activations.serialize(self.activation),
215
+ "layer_norm_epsilon": self.layer_norm_epsilon,
216
+ "kernel_initializer": keras.initializers.serialize(
217
+ self.kernel_initializer
218
+ ),
219
+ "bias_initializer": keras.initializers.serialize(
220
+ self.bias_initializer
221
+ ),
222
+ }
223
+ )
224
+ return config
225
+
226
+ def compute_output_shape(self, inputs_shape):
227
+ return inputs_shape
@@ -0,0 +1,412 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import keras
18
+ from keras import ops
19
+
20
+ from keras_hub.src.utils.keras_utils import clone_initializer
21
+
22
+
23
+ class DisentangledSelfAttention(keras.layers.Layer):
24
+ """DisentangledSelfAttention layer.
25
+
26
+ This is an implementation of disentangled self-attention as described in the
27
+ paper ["DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing"](https://arxiv.org/abs/2111.09543).
28
+ Effectively, this layer implements Multi-Head Self Attention with relative
29
+ attention, i.e., to get the final attention score, we compute the
30
+ content-to-position and position-to-content attention scores, and add these
31
+ scores to the vanilla multi-head self-attention scores.
32
+
33
+ Args:
34
+ num_heads: int. Number of attention heads.
35
+ hidden_dim: int. Hidden dimension of the input, i.e., `hidden_states`.
36
+ max_position_embeddings: int. The maximum input
37
+ sequence length. Defaults to `512`.
38
+ bucket_size: int. The size of the relative position
39
+ buckets. Generally equal to `max_sequence_length // 2`.
40
+ Defaults to `256`.
41
+ dropout: float. Dropout probability. Defaults to `0.1`.
42
+ kernel_initializer: string or `keras.initializers` initializer.
43
+ The kernel initializer for the dense layers.
44
+ Defaults to `"glorot_uniform"`.
45
+ bias_initializer: string or `keras.initializers` initializer.
46
+ The bias initializer for the dense layers.
47
+ Defaults to `"zeros"`.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ num_heads,
53
+ hidden_dim,
54
+ max_position_embeddings=512,
55
+ bucket_size=256,
56
+ dropout=0.1,
57
+ kernel_initializer="glorot_uniform",
58
+ bias_initializer="zeros",
59
+ **kwargs,
60
+ ):
61
+ super().__init__(**kwargs)
62
+
63
+ # Passed args.
64
+ self.num_heads = num_heads
65
+ self.hidden_dim = hidden_dim
66
+ self.max_position_embeddings = max_position_embeddings
67
+ self.bucket_size = bucket_size
68
+ self.dropout = dropout
69
+
70
+ # Initializers.
71
+ self._kernel_initializer = keras.initializers.get(kernel_initializer)
72
+ self._bias_initializer = keras.initializers.get(bias_initializer)
73
+
74
+ # Derived args.
75
+ self.attn_head_size = hidden_dim // num_heads
76
+
77
+ # We have three types of attention - MHA, p2c and c2p.
78
+ num_type_attn = 3
79
+ self.scale_factor = 1.0 / math.sqrt(
80
+ float(num_type_attn * self.attn_head_size)
81
+ )
82
+
83
+ def build(self, inputs_shape, rel_embeddings_shape=None):
84
+ # Q, K, V linear layers.
85
+ self._query_dense = keras.layers.EinsumDense(
86
+ equation="abc,cde->abde",
87
+ output_shape=(None, self.num_heads, self.attn_head_size),
88
+ bias_axes="de",
89
+ **self._get_common_kwargs_for_sublayer(use_bias=True),
90
+ dtype=self.dtype_policy,
91
+ name="query",
92
+ )
93
+ self._query_dense.build(inputs_shape)
94
+ self._key_dense = keras.layers.EinsumDense(
95
+ equation="abc,cde->abde",
96
+ output_shape=(None, self.num_heads, self.attn_head_size),
97
+ bias_axes="de",
98
+ **self._get_common_kwargs_for_sublayer(use_bias=True),
99
+ dtype=self.dtype_policy,
100
+ name="key",
101
+ )
102
+ self._key_dense.build(inputs_shape)
103
+ self._value_dense = keras.layers.EinsumDense(
104
+ equation="abc,cde->abde",
105
+ output_shape=(None, self.num_heads, self.attn_head_size),
106
+ bias_axes="de",
107
+ **self._get_common_kwargs_for_sublayer(use_bias=True),
108
+ dtype=self.dtype_policy,
109
+ name="value",
110
+ )
111
+ self._value_dense.build(inputs_shape)
112
+
113
+ # Relative attention.
114
+ self._position_dropout_layer = keras.layers.Dropout(
115
+ self.dropout,
116
+ dtype=self.dtype_policy,
117
+ )
118
+
119
+ self._attn_dropout_layer = keras.layers.Dropout(
120
+ self.dropout,
121
+ dtype=self.dtype_policy,
122
+ name="attention_dropout",
123
+ )
124
+ self._softmax = keras.layers.Softmax(
125
+ axis=-1,
126
+ dtype="float32",
127
+ name="attention_softmax",
128
+ )
129
+
130
+ # Output.
131
+ self._output_dense = keras.layers.EinsumDense(
132
+ equation="abc,cd->abd",
133
+ output_shape=(None, self.hidden_dim),
134
+ bias_axes="d",
135
+ **self._get_common_kwargs_for_sublayer(use_bias=True),
136
+ dtype=self.dtype_policy,
137
+ name="attention_output",
138
+ )
139
+ self._output_dense.build(inputs_shape)
140
+ self.built = True
141
+
142
+ def _get_common_kwargs_for_sublayer(self, use_bias=True):
143
+ common_kwargs = {}
144
+
145
+ kernel_initializer = clone_initializer(self._kernel_initializer)
146
+ bias_initializer = clone_initializer(self._bias_initializer)
147
+
148
+ common_kwargs["kernel_initializer"] = kernel_initializer
149
+ if use_bias:
150
+ common_kwargs["bias_initializer"] = bias_initializer
151
+
152
+ return common_kwargs
153
+
154
+ def _masked_softmax(self, attention_scores, attention_mask=None):
155
+ """Normalizes the attention scores to probabilities using softmax.
156
+
157
+ This implementation is the similar to the one present in
158
+ `keras.layers.MultiHeadAttention`.
159
+ """
160
+
161
+ if attention_mask is not None:
162
+ mask_expansion_axis = -3
163
+ for _ in range(
164
+ len(attention_scores.shape) - len(attention_mask.shape)
165
+ ):
166
+ attention_mask = ops.expand_dims(
167
+ attention_mask, axis=mask_expansion_axis
168
+ )
169
+ return self._softmax(attention_scores, attention_mask)
170
+
171
+ def _compute_attention(
172
+ self,
173
+ query,
174
+ key,
175
+ value,
176
+ rel_embeddings,
177
+ attention_mask=None,
178
+ training=None,
179
+ ):
180
+ """Computes the attention score and returns the attended outputs.
181
+
182
+ This function computes vanilla MHA score, and relative attention scores
183
+ (p2c and c2p). It then sums them up to get the final attention score,
184
+ which is used to compute the attended outputs.
185
+ """
186
+
187
+ attention_scores = ops.einsum(
188
+ "aecd,abcd->acbe",
189
+ key,
190
+ query,
191
+ )
192
+ attention_scores = ops.multiply(attention_scores, self.scale_factor)
193
+
194
+ rel_embeddings = self._position_dropout_layer(
195
+ rel_embeddings,
196
+ training=training,
197
+ )
198
+
199
+ rel_attn_scores = self._compute_disentangled_attention(
200
+ query=query,
201
+ key=key,
202
+ rel_embeddings=rel_embeddings,
203
+ )
204
+
205
+ if rel_attn_scores is not None:
206
+ attention_scores += rel_attn_scores
207
+
208
+ attention_scores = self._masked_softmax(
209
+ attention_scores, attention_mask
210
+ )
211
+ attention_scores = self._attn_dropout_layer(
212
+ attention_scores, training=training
213
+ )
214
+ attention_output = ops.einsum(
215
+ "acbe,aecd->abcd", attention_scores, value
216
+ )
217
+
218
+ return attention_output, attention_scores
219
+
220
+ def _make_log_bucket_position(self, rel_pos):
221
+ dtype = rel_pos.dtype
222
+ sign = ops.sign(rel_pos)
223
+ mid = self.bucket_size // 2
224
+ mid = ops.cast(mid, dtype=dtype)
225
+
226
+ # If `rel_pos[i][j]` is out of bounds, assign value `mid`.
227
+ abs_pos = ops.where(
228
+ condition=(rel_pos < mid) & (rel_pos > -mid),
229
+ x1=mid - 1,
230
+ x2=ops.abs(rel_pos),
231
+ )
232
+
233
+ def _get_log_pos(abs_pos, mid):
234
+ numerator = ops.log(abs_pos / mid)
235
+ numerator = numerator * ops.cast(mid - 1, dtype=numerator.dtype)
236
+ denominator = ops.log((self.max_position_embeddings - 1) / mid)
237
+ val = ops.ceil(numerator / denominator)
238
+ val = ops.cast(val, dtype=mid.dtype)
239
+ val = val + mid
240
+ return val
241
+
242
+ log_pos = _get_log_pos(abs_pos, mid)
243
+
244
+ bucket_pos = ops.where(
245
+ condition=abs_pos <= mid,
246
+ x1=rel_pos,
247
+ x2=log_pos * sign,
248
+ )
249
+ bucket_pos = ops.cast(bucket_pos, dtype="int")
250
+
251
+ return bucket_pos
252
+
253
+ def _get_rel_pos(self, num_positions):
254
+ ids = ops.arange(num_positions)
255
+ ids = ops.cast(ids, dtype="int")
256
+ query_ids = ops.expand_dims(ids, axis=-1)
257
+ key_ids = ops.expand_dims(ids, axis=0)
258
+ key_ids = ops.repeat(key_ids, repeats=num_positions, axis=0)
259
+
260
+ rel_pos = query_ids - key_ids
261
+ rel_pos = self._make_log_bucket_position(rel_pos)
262
+
263
+ rel_pos = ops.expand_dims(ops.expand_dims(rel_pos, axis=0), axis=0)
264
+ return rel_pos
265
+
266
+ def _compute_disentangled_attention(
267
+ self,
268
+ query,
269
+ key,
270
+ rel_embeddings,
271
+ ):
272
+ """Computes relative attention scores (p2c and c2p)."""
273
+
274
+ batch_size = ops.shape(query)[0]
275
+ num_positions = ops.shape(query)[1]
276
+
277
+ rel_pos = self._get_rel_pos(num_positions)
278
+
279
+ rel_attn_span = self.bucket_size
280
+ score = 0
281
+
282
+ pos_query = self._query_dense(rel_embeddings)
283
+ pos_key = self._key_dense(rel_embeddings)
284
+
285
+ # c2p
286
+ c2p_attn_scores = ops.einsum(
287
+ "aecd,abcd->acbe",
288
+ pos_key,
289
+ query,
290
+ )
291
+ c2p_pos = ops.clip(rel_pos + rel_attn_span, 0, rel_attn_span * 2 - 1)
292
+ c2p_pos = ops.broadcast_to(
293
+ c2p_pos,
294
+ shape=(
295
+ batch_size,
296
+ self.num_heads,
297
+ num_positions,
298
+ num_positions,
299
+ ),
300
+ )
301
+
302
+ if keras.config.backend() == "tensorflow":
303
+ # Work around dynamic shape bug on tensorflow backend.
304
+ import tensorflow as tf
305
+
306
+ c2p_attn_scores = tf.gather(
307
+ c2p_attn_scores,
308
+ indices=c2p_pos,
309
+ batch_dims=3,
310
+ )
311
+ else:
312
+ c2p_attn_scores = ops.take_along_axis(
313
+ c2p_attn_scores,
314
+ indices=c2p_pos,
315
+ axis=3,
316
+ )
317
+ c2p_attn_scores = ops.multiply(c2p_attn_scores, self.scale_factor)
318
+ score += c2p_attn_scores
319
+
320
+ # p2c
321
+ p2c_attn_scores = ops.einsum(
322
+ "aecd,abcd->acbe",
323
+ pos_query,
324
+ key,
325
+ )
326
+ p2c_pos = ops.clip(-rel_pos + rel_attn_span, 0, rel_attn_span * 2 - 1)
327
+ p2c_pos = ops.broadcast_to(
328
+ p2c_pos,
329
+ shape=(
330
+ batch_size,
331
+ self.num_heads,
332
+ num_positions,
333
+ num_positions,
334
+ ),
335
+ )
336
+ if keras.config.backend() == "tensorflow":
337
+ # Work around dynamic shape bug on tensorflow backend.
338
+ import tensorflow as tf
339
+
340
+ p2c_attn_scores = tf.gather(
341
+ p2c_attn_scores,
342
+ indices=p2c_pos,
343
+ batch_dims=3,
344
+ )
345
+ else:
346
+ p2c_attn_scores = ops.take_along_axis(
347
+ p2c_attn_scores,
348
+ indices=p2c_pos,
349
+ axis=3,
350
+ )
351
+ p2c_attn_scores = ops.transpose(p2c_attn_scores, [0, 1, 3, 2])
352
+ p2c_attn_scores = ops.multiply(p2c_attn_scores, self.scale_factor)
353
+ score += p2c_attn_scores
354
+
355
+ return score
356
+
357
+ def call(
358
+ self,
359
+ inputs,
360
+ rel_embeddings,
361
+ attention_mask=None,
362
+ return_attention_scores=False,
363
+ training=None,
364
+ ):
365
+ # `query`, `key`, `value` shape:
366
+ # `(batch_size, sequence_length, num_heads, attn_head_size)`.
367
+ query = self._query_dense(inputs)
368
+ key = self._key_dense(inputs)
369
+ value = self._value_dense(inputs)
370
+
371
+ attention_output, attention_scores = self._compute_attention(
372
+ query=query,
373
+ key=key,
374
+ value=value,
375
+ rel_embeddings=rel_embeddings,
376
+ attention_mask=attention_mask,
377
+ training=training,
378
+ )
379
+
380
+ # Reshape `attention_output` to `(batch_size, sequence_length, hidden_dim)`.
381
+ attention_output = ops.reshape(
382
+ attention_output,
383
+ [
384
+ ops.shape(attention_output)[0],
385
+ ops.shape(attention_output)[1],
386
+ self.hidden_dim,
387
+ ],
388
+ )
389
+ attention_output = self._output_dense(attention_output)
390
+
391
+ if return_attention_scores:
392
+ return attention_output, attention_scores
393
+ return attention_output
394
+
395
+ def get_config(self):
396
+ config = super().get_config()
397
+ config.update(
398
+ {
399
+ "num_heads": self.num_heads,
400
+ "hidden_dim": self.hidden_dim,
401
+ "max_position_embeddings": self.max_position_embeddings,
402
+ "bucket_size": self.bucket_size,
403
+ "dropout": self.dropout,
404
+ "kernel_initializer": keras.initializers.serialize(
405
+ self._kernel_initializer
406
+ ),
407
+ "bias_initializer": keras.initializers.serialize(
408
+ self._bias_initializer
409
+ ),
410
+ }
411
+ )
412
+ return config