keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,261 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+
17
+ from keras_hub.src.api_export import keras_hub_export
18
+ from keras_hub.src.layers.modeling.reversible_embedding import (
19
+ ReversibleEmbedding,
20
+ )
21
+ from keras_hub.src.models.backbone import Backbone
22
+ from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm
23
+ from keras_hub.src.models.t5.t5_transformer_layer import T5TransformerLayer
24
+
25
+
26
+ @keras_hub_export("keras_hub.models.T5Backbone")
27
+ class T5Backbone(Backbone):
28
+ """T5 encoder-decoder backbone model.
29
+
30
+ T5 is a LLM pretrained on a mix of unsupervised and supervised tasks,
31
+ where each task is converted to a sequence-to-sequence format.
32
+ T5 works well on a variety of tasks out-of-the-box by prepending
33
+ various prefixex to the input sequence, e.g., for translation:
34
+ `"translate English to German: ..."`, for summarization:
35
+ `"summarize: ..."`.
36
+
37
+ T5 was introduced in
38
+ [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683)
39
+
40
+ The default constructor gives a fully customizable, randomly initialized T5
41
+ model with any number of layers, heads, and embedding dimensions. To load
42
+ preset architectures and weights, use the `from_preset` constructor.
43
+
44
+ Disclaimer: Pre-trained models are provided on an "as is" basis, without
45
+ warranties or conditions of any kind.
46
+
47
+ Args:
48
+ vocabulary_size: int. The size of the token vocabulary.
49
+ num_layers: int. The number of Transformer layers.
50
+ num_heads: int. The number of attention heads for each Transformer.
51
+ The hidden size must be divisible by the number of attention heads.
52
+ hidden_dim: int. The hidden size of the Transformer layers.
53
+ intermediate_dim: int. The output dimension of the first Dense layer in
54
+ a two-layer feedforward network for each Transformer layer.
55
+ key_value_dim: int. The dimension of each head of the key/value
56
+ projections in the multi-head attention layers. Defaults to
57
+ hidden_dim / num_heads.
58
+ dropout: float. Dropout probability for the Transformer layers.
59
+ activation: activation function (or activation string name). The
60
+ activation to be used in the inner dense blocks of the
61
+ Transformer layers. Defaults to `"relu"`.
62
+ use_gated_activation: boolean. Whether to use activation gating in
63
+ the inner dense blocks of the Transformer layers.
64
+ The original T5 architecture didn't use gating, but more
65
+ recent versions do. Defaults to `True`.
66
+ layer_norm_epsilon: float. Epsilon factor to be used in the
67
+ layer normalization layers in the Transformer layers.
68
+ tie_embedding_weights: boolean. If `True`, the weights of the token
69
+ embedding and the weights projecting language model outputs from
70
+ `hidden_dim`.
71
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
72
+ for model computations and weights. Note that some computations,
73
+ such as softmax and layer normalization, will always be done at
74
+ float32 precision regardless of dtype.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ vocabulary_size,
80
+ num_layers,
81
+ num_heads,
82
+ hidden_dim,
83
+ intermediate_dim,
84
+ key_value_dim=None,
85
+ dropout=0.1,
86
+ activation="relu",
87
+ use_gated_activation=True,
88
+ layer_norm_epsilon=1e-06,
89
+ tie_embedding_weights=True,
90
+ dtype=None,
91
+ **kwargs,
92
+ ):
93
+ # Token embedding layer. This layer is shared by encoder and decoder.
94
+ self.token_embedding = ReversibleEmbedding(
95
+ input_dim=vocabulary_size,
96
+ output_dim=hidden_dim,
97
+ tie_weights=tie_embedding_weights,
98
+ embeddings_initializer=keras.initializers.TruncatedNormal(1.0),
99
+ dtype=dtype,
100
+ name="token_embedding",
101
+ )
102
+ self.encoder_embedding_dropout = keras.layers.Dropout(
103
+ dropout,
104
+ dtype=dtype,
105
+ name="encoder_embedding_dropout",
106
+ )
107
+ self.encoder_transformer_layers = []
108
+ for i in range(num_layers):
109
+ layer = T5TransformerLayer(
110
+ is_decoder=False,
111
+ hidden_dim=hidden_dim,
112
+ intermediate_dim=intermediate_dim,
113
+ key_value_dim=key_value_dim or hidden_dim // num_heads,
114
+ dropout=dropout,
115
+ activation=activation,
116
+ layer_norm_epsilon=layer_norm_epsilon,
117
+ num_heads=num_heads,
118
+ use_gated_activation=use_gated_activation,
119
+ use_relative_attention_bias=bool(i == 0),
120
+ dtype=dtype,
121
+ name=f"transformer_encoder_layer_{i}",
122
+ )
123
+ self.encoder_transformer_layers.append(layer)
124
+ self.encoder_layer_norm = T5LayerNorm(
125
+ epsilon=layer_norm_epsilon,
126
+ dtype=dtype,
127
+ name="encoder_output_layer_norm",
128
+ )
129
+ self.encoder_dropout = keras.layers.Dropout(
130
+ dropout,
131
+ dtype=dtype,
132
+ name="encoder_output_dropout",
133
+ )
134
+ self.decoder_embedding_dropout = keras.layers.Dropout(
135
+ dropout,
136
+ dtype=dtype,
137
+ name="decoder_embedding_dropout",
138
+ )
139
+ self.decoder_transformer_layers = []
140
+ for i in range(num_layers):
141
+ layer = T5TransformerLayer(
142
+ is_decoder=True,
143
+ hidden_dim=hidden_dim,
144
+ intermediate_dim=intermediate_dim,
145
+ key_value_dim=key_value_dim or hidden_dim // num_heads,
146
+ dropout=dropout,
147
+ activation=activation,
148
+ layer_norm_epsilon=layer_norm_epsilon,
149
+ num_heads=num_heads,
150
+ use_gated_activation=use_gated_activation,
151
+ use_relative_attention_bias=bool(i == 0),
152
+ dtype=dtype,
153
+ name=f"transformer_decoder_layer_{i}",
154
+ )
155
+ self.decoder_transformer_layers.append(layer)
156
+ self.decoder_layer_norm = T5LayerNorm(
157
+ epsilon=layer_norm_epsilon,
158
+ dtype=dtype,
159
+ name="decoder_output_layer_norm",
160
+ )
161
+ self.decoder_dropout = keras.layers.Dropout(
162
+ dropout,
163
+ dtype=dtype,
164
+ name="decoder_output_dropout",
165
+ )
166
+
167
+ # === Functional Model ===
168
+ encoder_token_id_input = keras.Input(
169
+ shape=(None,), dtype="int32", name="encoder_token_ids"
170
+ )
171
+ encoder_padding_mask_input = keras.Input(
172
+ shape=(None,), dtype="int32", name="encoder_padding_mask"
173
+ )
174
+ decoder_token_id_input = keras.Input(
175
+ shape=(None,), dtype="int32", name="decoder_token_ids"
176
+ )
177
+ decoder_padding_mask_input = keras.Input(
178
+ shape=(None,), dtype="int32", name="decoder_padding_mask"
179
+ )
180
+ # Encoder.
181
+ x = self.token_embedding(encoder_token_id_input)
182
+ x = self.encoder_embedding_dropout(x)
183
+ encoder_attention_mask = encoder_padding_mask_input[:, None, :]
184
+ position_bias = None
185
+ for transformer_layer in self.encoder_transformer_layers:
186
+ output = transformer_layer(
187
+ x,
188
+ attention_mask=encoder_attention_mask,
189
+ position_bias=position_bias,
190
+ use_causal_mask=False,
191
+ )
192
+ if isinstance(output, tuple):
193
+ x, position_bias = output
194
+ x = self.encoder_layer_norm(x)
195
+ x = self.encoder_dropout(x)
196
+ encoder_output = x
197
+ # Decoder.
198
+ x = self.token_embedding(decoder_token_id_input)
199
+ x = self.decoder_embedding_dropout(x)
200
+ decoder_attention_mask = decoder_padding_mask_input[:, None, :]
201
+ position_bias = None
202
+ for transformer_layer in self.decoder_transformer_layers:
203
+ output = transformer_layer(
204
+ x,
205
+ attention_mask=decoder_attention_mask,
206
+ position_bias=position_bias,
207
+ encoder_hidden_states=encoder_output,
208
+ encoder_attention_mask=encoder_attention_mask,
209
+ use_causal_mask=True,
210
+ )
211
+ if isinstance(output, tuple):
212
+ x, position_bias = output
213
+ x = self.decoder_layer_norm(x)
214
+ x = self.decoder_dropout(x)
215
+ decoder_output = x
216
+ super().__init__(
217
+ {
218
+ "encoder_token_ids": encoder_token_id_input,
219
+ "encoder_padding_mask": encoder_padding_mask_input,
220
+ "decoder_token_ids": decoder_token_id_input,
221
+ "decoder_padding_mask": decoder_padding_mask_input,
222
+ },
223
+ outputs={
224
+ "encoder_sequence_output": encoder_output,
225
+ "decoder_sequence_output": decoder_output,
226
+ },
227
+ dtype=dtype,
228
+ **kwargs,
229
+ )
230
+
231
+ # === Config ===
232
+ self.vocabulary_size = vocabulary_size
233
+ self.hidden_dim = hidden_dim
234
+ self.intermediate_dim = intermediate_dim
235
+ self.num_layers = num_layers
236
+ self.num_heads = num_heads
237
+ self.activation = keras.activations.get(activation)
238
+ self.key_value_dim = key_value_dim
239
+ self.dropout = dropout
240
+ self.use_gated_activation = use_gated_activation
241
+ self.layer_norm_epsilon = layer_norm_epsilon
242
+ self.tie_embedding_weights = tie_embedding_weights
243
+
244
+ def get_config(self):
245
+ config = super().get_config()
246
+ config.update(
247
+ {
248
+ "vocabulary_size": self.vocabulary_size,
249
+ "hidden_dim": self.hidden_dim,
250
+ "intermediate_dim": self.intermediate_dim,
251
+ "num_layers": self.num_layers,
252
+ "num_heads": self.num_heads,
253
+ "activation": keras.activations.serialize(self.activation),
254
+ "key_value_dim": self.key_value_dim,
255
+ "dropout": self.dropout,
256
+ "use_gated_activation": self.use_gated_activation,
257
+ "layer_norm_epsilon": self.layer_norm_epsilon,
258
+ "tie_embedding_weights": self.tie_embedding_weights,
259
+ }
260
+ )
261
+ return config
@@ -0,0 +1,35 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+
19
+ class T5LayerNorm(keras.layers.Layer):
20
+ def __init__(self, epsilon=1e-6, **kwargs):
21
+ super().__init__(**kwargs)
22
+ self.epsilon = epsilon
23
+
24
+ def build(self, input_shape):
25
+ self.weight = self.add_weight(
26
+ name="weight",
27
+ shape=(input_shape[-1],),
28
+ initializer="ones",
29
+ )
30
+ self.built = True
31
+
32
+ def call(self, hidden_states):
33
+ variance = ops.mean(ops.square(hidden_states), axis=-1, keepdims=True)
34
+ hidden_states = hidden_states * ops.rsqrt(variance + self.epsilon)
35
+ return self.weight * hidden_states
@@ -0,0 +1,324 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ import numpy as np
17
+ from keras import ops
18
+
19
+
20
+ class T5MultiHeadAttention(keras.layers.Layer):
21
+ # This layer is adapted from Hugging Face
22
+ # Ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_tf_t5.py
23
+ def __init__(
24
+ self,
25
+ is_decoder,
26
+ hidden_dim,
27
+ key_value_dim,
28
+ num_heads,
29
+ dropout,
30
+ use_relative_attention_bias=False,
31
+ **kwargs,
32
+ ):
33
+ super().__init__(**kwargs)
34
+ self.is_decoder = is_decoder
35
+ self.hidden_dim = hidden_dim
36
+ self.key_value_dim = key_value_dim
37
+ self.num_heads = num_heads
38
+ self.use_relative_attention_bias = use_relative_attention_bias
39
+
40
+ self.inner_dim = self.num_heads * self.key_value_dim
41
+ self.relative_attention_buckets = 32
42
+ self.relative_attention_max_distance = 128
43
+
44
+ self.query_projector = keras.layers.Dense(
45
+ self.inner_dim,
46
+ use_bias=False,
47
+ kernel_initializer=keras.initializers.RandomNormal(
48
+ mean=0, stddev=(self.inner_dim * self.key_value_dim) ** -0.5
49
+ ),
50
+ dtype=self.dtype_policy,
51
+ name="query_projector",
52
+ )
53
+ self.key_projector = keras.layers.Dense(
54
+ self.inner_dim,
55
+ use_bias=False,
56
+ kernel_initializer=keras.initializers.RandomNormal(
57
+ mean=0, stddev=self.inner_dim**-0.5
58
+ ),
59
+ dtype=self.dtype_policy,
60
+ name="key_projector",
61
+ )
62
+ self.value_projector = keras.layers.Dense(
63
+ self.inner_dim,
64
+ use_bias=False,
65
+ kernel_initializer=keras.initializers.RandomNormal(
66
+ mean=0, stddev=self.inner_dim**-0.5
67
+ ),
68
+ dtype=self.dtype_policy,
69
+ name="value_projector",
70
+ )
71
+ self.output_projector = keras.layers.Dense(
72
+ self.hidden_dim,
73
+ use_bias=False,
74
+ kernel_initializer=keras.initializers.RandomNormal(
75
+ mean=0, stddev=self.inner_dim**-0.5
76
+ ),
77
+ dtype=self.dtype_policy,
78
+ name="output_projector",
79
+ )
80
+ self.dropout_layer = keras.layers.Dropout(
81
+ dropout,
82
+ dtype=self.dtype_policy,
83
+ )
84
+
85
+ if self.use_relative_attention_bias:
86
+ self.relative_attention_bias = self.add_weight(
87
+ name="embeddings",
88
+ shape=[self.relative_attention_buckets, self.num_heads],
89
+ initializer=keras.initializers.RandomNormal(
90
+ mean=0, stddev=self.inner_dim**-0.5
91
+ ),
92
+ )
93
+
94
+ @staticmethod
95
+ def _relative_position_bucket(
96
+ relative_position, bidirectional=True, num_buckets=32, max_distance=128
97
+ ):
98
+ """Adapted from Mesh Tensorflow.
99
+
100
+ Translate relative position to a bucket number for relative attention.
101
+ The relative position is defined as memory_position - query_position,
102
+ i.e. the distance in tokens from the attending position to the
103
+ attended-to position. If bidirectional=False, then positive relative
104
+ positions are invalid. We use smaller buckets for
105
+ small absolute relative_position and larger buckets for larger absolute
106
+ relative_positions. All relative positions >= max_distance map to
107
+ the same bucket. All relative positions <= -max_distance map to
108
+ the same bucket. This should allow for more graceful generalization to
109
+ longer sequences than the model has been trained on.
110
+
111
+ Args:
112
+ relative_position: an int32 Tensor
113
+ bidirectional: a boolean - whether the attention is bidirectional
114
+ num_buckets: an integer
115
+ max_distance: an integer
116
+
117
+ Returns:
118
+ Tensor with the same shape as relative_position,
119
+ containing int32 values in the range [0, num_buckets)
120
+ """
121
+ relative_buckets = 0
122
+ if bidirectional:
123
+ num_buckets //= 2
124
+ relative_buckets += (
125
+ ops.cast(
126
+ ops.greater(relative_position, 0),
127
+ dtype=relative_position.dtype,
128
+ )
129
+ * num_buckets
130
+ )
131
+ relative_position = ops.abs(relative_position)
132
+ else:
133
+ relative_position = -ops.minimum(relative_position, 0)
134
+ # now n is in the range [0, inf)
135
+ max_exact = num_buckets // 2
136
+ is_small = ops.less(relative_position, max_exact)
137
+ relative_position_if_large = max_exact + ops.cast(
138
+ ops.log(
139
+ ops.cast(relative_position, "float32")
140
+ / ops.cast(max_exact, "float32")
141
+ )
142
+ / ops.cast(ops.log(max_distance / max_exact), "float32")
143
+ * (num_buckets - max_exact),
144
+ dtype=relative_position.dtype,
145
+ )
146
+ relative_position_if_large = ops.minimum(
147
+ relative_position_if_large, num_buckets - 1
148
+ )
149
+ relative_buckets += ops.where(
150
+ is_small, relative_position, relative_position_if_large
151
+ )
152
+ return relative_buckets
153
+
154
+ def compute_bias(self, query_length, key_length):
155
+ """Compute binned relative position bias"""
156
+ context_position = ops.arange(query_length)[:, None]
157
+ memory_position = ops.arange(key_length)[None, :]
158
+ relative_position = (
159
+ memory_position - context_position
160
+ ) # shape (query_length, key_length)
161
+ relative_position_bucket = self._relative_position_bucket(
162
+ relative_position,
163
+ bidirectional=(not self.is_decoder),
164
+ num_buckets=self.relative_attention_buckets,
165
+ max_distance=self.relative_attention_max_distance,
166
+ )
167
+ values = ops.take(
168
+ self.relative_attention_bias, relative_position_bucket, axis=0
169
+ ) # shape (query_length, key_length, num_heads)
170
+ values = ops.expand_dims(
171
+ ops.transpose(values, axes=(2, 0, 1)), axis=0
172
+ ) # shape (1, num_heads, query_length, key_length)
173
+ return values
174
+
175
+ def call(
176
+ self,
177
+ hidden_states,
178
+ mask=None,
179
+ key_value_states=None,
180
+ position_bias=None,
181
+ past_key_value=None,
182
+ layer_head_mask=None,
183
+ query_length=None,
184
+ training=False,
185
+ ):
186
+ # Input is (batch_size, query_length, dim)
187
+ # past_key_value[0] is (batch_size, num_heads, q_len - 1, dim_per_head)
188
+ batch_size, seq_length = ops.shape(hidden_states)[:2]
189
+
190
+ real_seq_length = seq_length
191
+
192
+ if past_key_value is not None:
193
+ if len(past_key_value) != 2:
194
+ raise ValueError(
195
+ f"Argument `past_key_value` should have 2 past states: "
196
+ f"keys and values. Got {len(past_key_value)} past states."
197
+ )
198
+ real_seq_length += (
199
+ ops.shape(past_key_value[0])[2]
200
+ if query_length is None
201
+ else query_length
202
+ )
203
+
204
+ key_length = (
205
+ real_seq_length
206
+ if key_value_states is None
207
+ else ops.shape(key_value_states)[1]
208
+ )
209
+
210
+ def shape(hidden_states):
211
+ return ops.transpose(
212
+ ops.reshape(
213
+ hidden_states,
214
+ (batch_size, -1, self.num_heads, self.key_value_dim),
215
+ ),
216
+ axes=(0, 2, 1, 3),
217
+ )
218
+
219
+ def unshape(hidden_states):
220
+ return ops.reshape(
221
+ ops.transpose(hidden_states, axes=(0, 2, 1, 3)),
222
+ (batch_size, -1, self.inner_dim),
223
+ )
224
+
225
+ def project(
226
+ hidden_states, proj_layer, key_value_states, past_key_value
227
+ ):
228
+ """projects hidden states correctly to key/query states"""
229
+ if key_value_states is None:
230
+ # self-attention
231
+ # (batch_size, num_heads, seq_length, dim_per_head)
232
+ hidden_states = shape(proj_layer(hidden_states))
233
+ elif past_key_value is None:
234
+ # cross-attention
235
+ # (batch_size, num_heads, seq_length, dim_per_head)
236
+ hidden_states = shape(proj_layer(key_value_states))
237
+
238
+ if past_key_value is not None:
239
+ if key_value_states is None:
240
+ # self-attention
241
+ # (batch_size, num_heads, key_length, dim_per_head)
242
+ hidden_states = ops.concat(
243
+ [past_key_value, hidden_states], axis=2
244
+ )
245
+ else:
246
+ # cross-attention
247
+ hidden_states = past_key_value
248
+ return hidden_states
249
+
250
+ # get query
251
+ query_states = shape(
252
+ self.query_projector(hidden_states)
253
+ ) # (batch_size, num_heads, query_length, dim_per_head)
254
+
255
+ # get key/value
256
+ key_states = project(
257
+ hidden_states,
258
+ self.key_projector,
259
+ key_value_states,
260
+ past_key_value[0] if past_key_value is not None else None,
261
+ )
262
+ value_states = project(
263
+ hidden_states,
264
+ self.value_projector,
265
+ key_value_states,
266
+ past_key_value[1] if past_key_value is not None else None,
267
+ )
268
+
269
+ scores = ops.einsum(
270
+ "bnqd,bnkd->bnqk", query_states, key_states
271
+ ) # (batch_size, num_heads, query_length, key_length)
272
+
273
+ if position_bias is None:
274
+ if not self.use_relative_attention_bias:
275
+ position_bias = ops.zeros(
276
+ (1, self.num_heads, real_seq_length, key_length),
277
+ self.compute_dtype,
278
+ )
279
+ else:
280
+ position_bias = self.compute_bias(real_seq_length, key_length)
281
+
282
+ # if key and values are already calculated we want only
283
+ # the last query position bias
284
+ if past_key_value is not None:
285
+ if not self.use_relative_attention_bias:
286
+ position_bias = position_bias[:, :, -seq_length:, :]
287
+ else:
288
+ # we might have a padded past structure,
289
+ # in which case we want to fetch the position bias slice
290
+ # right after the most recently filled past index
291
+ most_recently_filled_past_index = ops.amax(
292
+ ops.where(past_key_value[0][0, 0, :, 0] != 0.0)
293
+ )
294
+ position_bias = ops.slice(
295
+ position_bias,
296
+ (0, 0, most_recently_filled_past_index + 1, 0),
297
+ (1, self.num_heads, seq_length, real_seq_length),
298
+ )
299
+
300
+ if mask is not None:
301
+ # Add a new mask axis for the head dim.
302
+ mask = mask[:, np.newaxis, :, :]
303
+ # Add a very large negative position bias for masked positions.
304
+ mask = (1.0 - ops.cast(mask, position_bias.dtype)) * -1e9
305
+ position_bias = position_bias + mask
306
+
307
+ scores += ops.cast(position_bias, scores.dtype)
308
+ weights = ops.nn.softmax(
309
+ scores, axis=-1
310
+ ) # (batch_size, num_heads, query_length, key_length)
311
+ weights = self.dropout_layer(
312
+ weights, training=training
313
+ ) # (batch_size, num_heads, query_length, key_length)
314
+
315
+ # Optionally mask heads
316
+ if layer_head_mask is not None:
317
+ weights = ops.reshape(layer_head_mask, (1, -1, 1, 1)) * weights
318
+
319
+ attention_output = ops.matmul(
320
+ weights, value_states
321
+ ) # (batch_size, num_heads, query_length, dim_per_head)
322
+
323
+ attention_output = self.output_projector(unshape(attention_output))
324
+ return (attention_output, position_bias)