keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,490 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.models.bart.bart_backbone import BartBackbone
20
+ from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import (
21
+ BartSeq2SeqLMPreprocessor,
22
+ )
23
+ from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
24
+ from keras_hub.src.utils.tensor_utils import any_equal
25
+
26
+
27
+ @keras_hub_export("keras_hub.models.BartSeq2SeqLM")
28
+ class BartSeq2SeqLM(Seq2SeqLM):
29
+ """An end-to-end BART model for seq2seq language modeling.
30
+
31
+ A seq2seq language model (LM) is an encoder-decoder model which is used for
32
+ conditional text generation. The encoder is given a "context" text (fed to
33
+ the encoder), and the decoder predicts the next token based on both the
34
+ encoder inputs and the previous tokens. You can finetune `BartSeq2SeqLM` to
35
+ generate text for any seq2seq task (e.g., translation or summarization).
36
+
37
+ This model has a `generate()` method, which generates text based on
38
+ encoder inputs and an optional prompt for the decoder. The generation
39
+ strategy used is controlled by an additional `sampler` argument passed to
40
+ `compile()`. You can recompile the model with different `keras_hub.samplers`
41
+ objects to control the generation. By default, `"top_k"` sampling will be
42
+ used.
43
+
44
+ This model can optionally be configured with a `preprocessor` layer, in
45
+ which case it will automatically apply preprocessing to string inputs during
46
+ `fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default
47
+ when creating the model with `from_preset()`.
48
+
49
+ Disclaimer: Pre-trained models are provided on an "as is" basis, without
50
+ warranties or conditions of any kind. The underlying model is provided by a
51
+ third party and subject to a separate license, available
52
+ [here](https://github.com/facebookresearch/fairseq/).
53
+
54
+ Args:
55
+ backbone: A `keras_hub.models.BartBackbone` instance.
56
+ preprocessor: A `keras_hub.models.BartSeq2SeqLMPreprocessor` or `None`.
57
+ If `None`, this model will not apply preprocessing, and inputs
58
+ should be preprocessed before calling the model.
59
+
60
+ Examples:
61
+
62
+ Use `generate()` to do text generation, given an input context.
63
+ ```python
64
+ bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_base_en")
65
+ bart_lm.generate("The quick brown fox", max_length=30)
66
+
67
+ # Generate with batched inputs.
68
+ bart_lm.generate(["The quick brown fox", "The whale"], max_length=30)
69
+ ```
70
+
71
+ Compile the `generate()` function with a custom sampler.
72
+ ```python
73
+ bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_base_en")
74
+ bart_lm.compile(sampler="greedy")
75
+ bart_lm.generate("The quick brown fox", max_length=30)
76
+ ```
77
+
78
+ Use `generate()` with encoder inputs and an incomplete decoder input (prompt).
79
+ ```python
80
+ bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_base_en")
81
+ bart_lm.generate(
82
+ {
83
+ "encoder_text": "The quick brown fox",
84
+ "decoder_text": "The fast"
85
+ }
86
+ )
87
+ ```
88
+
89
+ Use `generate()` without preprocessing.
90
+ ```python
91
+ # Preprocessed inputs, with encoder inputs corresponding to
92
+ # "The quick brown fox", and the decoder inputs to "The fast". Use
93
+ # `"padding_mask"` to indicate values that should not be overridden.
94
+ prompt = {
95
+ "encoder_token_ids": np.array([[0, 133, 2119, 6219, 23602, 2, 1, 1]]),
96
+ "encoder_padding_mask": np.array(
97
+ [[True, True, True, True, True, True, False, False]]
98
+ ),
99
+ "decoder_token_ids": np.array([[2, 0, 133, 1769, 2, 1, 1]]),
100
+ "decoder_padding_mask": np.array([[True, True, True, True, False, False]])
101
+ }
102
+
103
+ bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
104
+ "bart_base_en",
105
+ preprocessor=None,
106
+ )
107
+ bart_lm.generate(prompt)
108
+ ```
109
+
110
+ Call `fit()` on a single batch.
111
+ ```python
112
+ features = {
113
+ "encoder_text": ["The quick brown fox jumped.", "I forgot my homework."],
114
+ "decoder_text": ["The fast hazel fox leapt.", "I forgot my assignment."]
115
+ }
116
+ bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_base_en")
117
+ bart_lm.fit(x=features, batch_size=2)
118
+ ```
119
+
120
+ Call `fit()` without preprocessing.
121
+ ```python
122
+ x = {
123
+ "encoder_token_ids": np.array([[0, 133, 2119, 2, 1]] * 2),
124
+ "encoder_padding_mask": np.array([[1, 1, 1, 1, 0]] * 2),
125
+ "decoder_token_ids": np.array([[2, 0, 133, 1769, 2]] * 2),
126
+ "decoder_padding_mask": np.array([[1, 1, 1, 1, 1]] * 2),
127
+ }
128
+ y = np.array([[0, 133, 1769, 2, 1]] * 2)
129
+ sw = np.array([[1, 1, 1, 1, 0]] * 2)
130
+
131
+ bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
132
+ "bart_base_en",
133
+ preprocessor=None,
134
+ )
135
+ bart_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2)
136
+ ```
137
+
138
+ Custom backbone and vocabulary.
139
+ ```python
140
+ features = {
141
+ "encoder_text": [" afternoon sun"],
142
+ "decoder_text": ["noon sun"],
143
+ }
144
+ vocab = {
145
+ "<s>": 0,
146
+ "<pad>": 1,
147
+ "</s>": 2,
148
+ "Ġafter": 5,
149
+ "noon": 6,
150
+ "Ġsun": 7,
151
+ }
152
+ merges = ["Ġ a", "Ġ s", "Ġ n", "e r", "n o", "o n", "Ġs u", "Ġa f", "no on"]
153
+ merges += ["Ġsu n", "Ġaf t", "Ġaft er"]
154
+
155
+ tokenizer = keras_hub.models.BartTokenizer(
156
+ vocabulary=vocab,
157
+ merges=merges,
158
+ )
159
+ preprocessor = keras_hub.models.BartSeq2SeqLMPreprocessor(
160
+ tokenizer=tokenizer,
161
+ encoder_sequence_length=128,
162
+ decoder_sequence_length=128,
163
+ )
164
+ backbone = keras_hub.models.BartBackbone(
165
+ vocabulary_size=50265,
166
+ num_layers=6,
167
+ num_heads=12,
168
+ hidden_dim=768,
169
+ intermediate_dim=3072,
170
+ max_sequence_length=128,
171
+ )
172
+ bart_lm = keras_hub.models.BartSeq2SeqLM(
173
+ backbone=backbone,
174
+ preprocessor=preprocessor,
175
+ )
176
+ bart_lm.fit(x=features, batch_size=2)
177
+ ```
178
+ """
179
+
180
+ backbone_cls = BartBackbone
181
+ preprocessor_cls = BartSeq2SeqLMPreprocessor
182
+
183
+ def __init__(
184
+ self,
185
+ backbone,
186
+ preprocessor=None,
187
+ **kwargs,
188
+ ):
189
+ # === Layers ===
190
+ self.backbone = backbone
191
+ self.preprocessor = preprocessor
192
+
193
+ # === Functional Model ===
194
+ inputs = backbone.input
195
+ hidden_states = backbone(inputs)["decoder_sequence_output"]
196
+ outputs = backbone.token_embedding(hidden_states, reverse=True)
197
+ super().__init__(
198
+ inputs=inputs,
199
+ outputs=outputs,
200
+ **kwargs,
201
+ )
202
+
203
+ def call_decoder_with_cache(
204
+ self,
205
+ encoder_hidden_states,
206
+ encoder_padding_mask,
207
+ decoder_token_ids,
208
+ self_attention_cache=None,
209
+ self_attention_cache_update_index=None,
210
+ cross_attention_cache=None,
211
+ cross_attention_cache_update_index=None,
212
+ ):
213
+ """Forward pass with a key/value caches for generative decoding..
214
+
215
+ `call_decoder_with_cache` adds an additional inference-time forward pass
216
+ for the model for seq2seq text generation. Unlike calling the model
217
+ directly, this method does two things to optimize text generation:
218
+
219
+ - Allows caching previous key/value tensors in the decoder's
220
+ self-attention layer to avoid recomputing the outputs of seen tokens.
221
+ - Allows caching key/value tensors in the decoder's cross-attention
222
+ layer to avoid recomputing the encoder outputs.
223
+
224
+ Args:
225
+ encoder_hidden_states: a dense float Tensor of shape
226
+ `(batch_size, encoder_sequence_length, hidden_dim)`. The
227
+ sequence of hidden states at the output of the encoder's last
228
+ layer.
229
+ encoder_padding_mask: a dense float Tensor of shape
230
+ `(batch_size, encoder_sequence_length)`. The padding mask for
231
+ the encoder input.
232
+ decoder_token_ids: a dense int Tensor of shape
233
+ `(batch_size, max_length)`. Input token ids to be fed to
234
+ the decoder.
235
+ self_attention_cache: a dense float Tensor of shape
236
+ `(batch_size, num_layers, 2, max_length, num_heads, key_dims)`.
237
+ The cached key/value tensors of previously seen tokens in the
238
+ decoder's self-attention layer.
239
+ self_attention_cache_update_index: an int or int Tensor, the index
240
+ at which to update the `self_attention_cache`. Usually, this is
241
+ the index of the current token being processed during decoding.
242
+ cross_attention_cache: a dense float Tensor of shape
243
+ `(batch_size, num_layers, 2, encoder_sequence_length, num_heads, key_dims)`.
244
+ The cached key/value tensors of the encoder outputs in the
245
+ decoder's cross-attention layer.
246
+ cross_attention_cache_update_index: an int or int Tensor, the index
247
+ at which to update the `cross_attention_cache`. Usually, this is
248
+ either `0` (compute the entire `cross_attention_cache`), or
249
+ `None` (reuse a previously computed `cross_attention_cache`).
250
+
251
+ Returns:
252
+ A `(logits, hidden_states, self_attention_cache, cross_attention_cache)`
253
+ tuple, where `logits` is the language model logits for the input
254
+ `decoder_token_ids`, `hidden_states` is the final hidden
255
+ representation of the input tokens, `self_attention_cache` is the
256
+ key/value cache in the decoder's self-attention layer and
257
+ `cross_attention_cache` is the key/value cache in the decoder's
258
+ cross-attention layer.
259
+ """
260
+ # Embedding layers.
261
+ tokens = self.backbone.token_embedding(decoder_token_ids)
262
+ positions = self.backbone.decoder_position_embedding(
263
+ tokens,
264
+ start_index=self_attention_cache_update_index,
265
+ )
266
+ # Sum, normalize and apply dropout to embeddings.
267
+ x = self.backbone.decoder_embeddings_add((tokens, positions))
268
+ x = self.backbone.decoder_embeddings_layer_norm(x)
269
+ x = self.backbone.decoder_embeddings_dropout(x)
270
+
271
+ # Every decoder layer has a separate cache for the self-attention layer
272
+ # and the cross-attention layer. We update all of them separately.
273
+ self_attention_caches = []
274
+ cross_attention_caches = []
275
+ for i, layer in enumerate(self.backbone.decoder_transformer_layers):
276
+ current_self_attention_cache = self_attention_cache[:, i, ...]
277
+ current_cross_attention_cache = cross_attention_cache[:, i, ...]
278
+ (
279
+ x,
280
+ next_self_attention_cache,
281
+ next_cross_attention_cache,
282
+ ) = layer(
283
+ decoder_sequence=x,
284
+ encoder_sequence=encoder_hidden_states,
285
+ encoder_padding_mask=encoder_padding_mask,
286
+ self_attention_cache=current_self_attention_cache,
287
+ self_attention_cache_update_index=self_attention_cache_update_index,
288
+ cross_attention_cache=current_cross_attention_cache,
289
+ cross_attention_cache_update_index=cross_attention_cache_update_index,
290
+ )
291
+ if self_attention_cache_update_index is not None:
292
+ self_attention_caches.append(next_self_attention_cache)
293
+ if cross_attention_cache_update_index is not None:
294
+ cross_attention_caches.append(next_cross_attention_cache)
295
+
296
+ if self_attention_cache_update_index is not None:
297
+ self_attention_cache = ops.stack(self_attention_caches, axis=1)
298
+ if cross_attention_cache_update_index is not None:
299
+ cross_attention_cache = ops.stack(cross_attention_caches, axis=1)
300
+
301
+ hidden_states = x
302
+ logits = self.backbone.token_embedding(hidden_states, reverse=True)
303
+ return (
304
+ logits,
305
+ hidden_states,
306
+ self_attention_cache,
307
+ cross_attention_cache,
308
+ )
309
+
310
+ def call_encoder(self, token_ids, padding_mask):
311
+ """Does a forward pass on the encoder and returns the encoder output."""
312
+ tokens = self.backbone.token_embedding(token_ids)
313
+ positions = self.backbone.encoder_position_embedding(tokens)
314
+ x = self.backbone.decoder_embeddings_add((tokens, positions))
315
+ x = self.backbone.encoder_embeddings_layer_norm(x)
316
+ x = self.backbone.encoder_embeddings_dropout(x)
317
+ for transformer_layer in self.backbone.encoder_transformer_layers:
318
+ x = transformer_layer(x, padding_mask=padding_mask)
319
+ return x
320
+
321
+ def _initialize_cache(self, encoder_token_ids, decoder_token_ids):
322
+ """Initializes empty self-attention cache and cross-attention cache."""
323
+ batch_size = ops.shape(encoder_token_ids)[0]
324
+ encoder_max_length = ops.shape(encoder_token_ids)[1]
325
+ decoder_max_length = ops.shape(decoder_token_ids)[1]
326
+
327
+ num_layers = self.backbone.num_layers
328
+ num_heads = self.backbone.num_heads
329
+ head_dim = self.backbone.hidden_dim // self.backbone.num_heads
330
+
331
+ shape = [
332
+ batch_size,
333
+ num_layers,
334
+ 2,
335
+ decoder_max_length,
336
+ num_heads,
337
+ head_dim,
338
+ ]
339
+ self_attention_cache = ops.zeros(shape, dtype=self.compute_dtype)
340
+
341
+ shape[3] = encoder_max_length
342
+ cross_attention_cache = ops.zeros(shape, dtype=self.compute_dtype)
343
+
344
+ return (self_attention_cache, cross_attention_cache)
345
+
346
+ def _build_cache(
347
+ self, encoder_token_ids, encoder_padding_mask, decoder_token_ids
348
+ ):
349
+ """Builds the self-attention cache and the cross-attention cache (key/value pairs)."""
350
+ encoder_hidden_states = self.call_encoder(
351
+ token_ids=encoder_token_ids, padding_mask=encoder_padding_mask
352
+ )
353
+ self_attention_cache, cross_attention_cache = self._initialize_cache(
354
+ encoder_token_ids, decoder_token_ids
355
+ )
356
+
357
+ # Seed the self-attention cache and the cross-attention cache.
358
+ (
359
+ _,
360
+ hidden_states,
361
+ self_attention_cache,
362
+ cross_attention_cache,
363
+ ) = self.call_decoder_with_cache(
364
+ encoder_hidden_states=encoder_hidden_states,
365
+ encoder_padding_mask=encoder_padding_mask,
366
+ decoder_token_ids=decoder_token_ids,
367
+ self_attention_cache=self_attention_cache,
368
+ self_attention_cache_update_index=0,
369
+ cross_attention_cache=cross_attention_cache,
370
+ cross_attention_cache_update_index=0,
371
+ )
372
+ return (
373
+ hidden_states,
374
+ encoder_hidden_states,
375
+ self_attention_cache,
376
+ cross_attention_cache,
377
+ )
378
+
379
+ def generate_step(
380
+ self,
381
+ inputs,
382
+ stop_token_ids=None,
383
+ ):
384
+ """A compilable generation function for a batch of inputs.
385
+
386
+ This function represents the inner, XLA-compilable, generation function
387
+ for a single batch of inputs. Inputs should have the same structure as
388
+ model inputs, a dictionary with keys `"encoder_token_ids"`,
389
+ `"encoder_padding_mask"`, `"decoder_token_ids"` and
390
+ `"decoder_padding_mask"`.
391
+
392
+ Args:
393
+ inputs: A dictionary with four keys - `"encoder_token_ids"`,
394
+ `"encoder_padding_mask"`, `"decoder_token_ids"` and
395
+ `"decoder_padding_mask"`, with batched tensor values.
396
+ stop_token_ids: Tuple of id's of end token's to stop on. If all
397
+ sequences have produced a new stop token, generation
398
+ will stop.
399
+ """
400
+ (
401
+ encoder_token_ids,
402
+ encoder_padding_mask,
403
+ decoder_token_ids,
404
+ decoder_padding_mask,
405
+ ) = (
406
+ inputs["encoder_token_ids"],
407
+ inputs["encoder_padding_mask"],
408
+ inputs["decoder_token_ids"],
409
+ inputs["decoder_padding_mask"],
410
+ )
411
+
412
+ batch_size = ops.shape(encoder_token_ids)[0]
413
+
414
+ # Create and seed cache with a single forward pass.
415
+ (
416
+ hidden_states,
417
+ encoder_hidden_states,
418
+ self_attention_cache,
419
+ cross_attention_cache,
420
+ ) = self._build_cache(
421
+ encoder_token_ids, encoder_padding_mask, decoder_token_ids
422
+ )
423
+ # Compute the lengths of all user inputted tokens ids.
424
+ row_lengths = ops.sum(ops.cast(decoder_padding_mask, "int32"), axis=-1)
425
+ # Start at the first index that has no user inputted id.
426
+ index = ops.min(row_lengths)
427
+
428
+ def next(prompt, cache, index):
429
+ # The cache index is the index of our previous token.
430
+ cache_index = index - 1
431
+ num_samples = ops.shape(prompt)[0]
432
+ prompt = ops.slice(prompt, [0, cache_index], [num_samples, 1])
433
+
434
+ def repeat_tensor(x):
435
+ """Repeats tensors along batch axis to match dim for beam search."""
436
+ if ops.shape(x)[0] == num_samples:
437
+ return x
438
+ return ops.repeat(x, repeats=num_samples // batch_size, axis=0)
439
+
440
+ logits, hidden_states, cache, _ = self.call_decoder_with_cache(
441
+ encoder_hidden_states=repeat_tensor(encoder_hidden_states),
442
+ encoder_padding_mask=repeat_tensor(encoder_padding_mask),
443
+ decoder_token_ids=prompt,
444
+ self_attention_cache=cache,
445
+ self_attention_cache_update_index=cache_index,
446
+ cross_attention_cache=repeat_tensor(cross_attention_cache),
447
+ cross_attention_cache_update_index=None,
448
+ )
449
+ return (
450
+ ops.squeeze(logits, axis=1),
451
+ ops.squeeze(hidden_states, axis=1),
452
+ cache,
453
+ )
454
+
455
+ decoder_token_ids = self.sampler(
456
+ next=next,
457
+ prompt=decoder_token_ids,
458
+ cache=self_attention_cache,
459
+ index=index,
460
+ mask=decoder_padding_mask,
461
+ stop_token_ids=stop_token_ids,
462
+ hidden_states=hidden_states,
463
+ model=self,
464
+ )
465
+
466
+ # Compute an output padding mask with the token ids we updated.
467
+ if stop_token_ids is not None:
468
+ # Build a mask of `stop_token_ids` locations not in the original
469
+ # prompt (not in locations where `decoder_padding_mask` is True).
470
+ end_locations = any_equal(
471
+ decoder_token_ids,
472
+ stop_token_ids,
473
+ ops.logical_not(decoder_padding_mask),
474
+ )
475
+ end_locations = ops.cast(end_locations, "int32")
476
+ # Use cumsum to get ones in all locations after `end_locations`.
477
+ cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
478
+ overflow = cumsum - end_locations
479
+ # Our padding mask is the inverse of these overflow locations.
480
+ decoder_padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
481
+ else:
482
+ # Without early stopping, all locations will have been updated.
483
+ decoder_padding_mask = ops.ones_like(
484
+ decoder_token_ids, dtype="bool"
485
+ )
486
+
487
+ return {
488
+ "decoder_token_ids": decoder_token_ids,
489
+ "decoder_padding_mask": decoder_padding_mask,
490
+ }