keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,262 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+
17
+ from keras_hub.src.api_export import keras_hub_export
18
+ from keras_hub.src.utils.keras_utils import clone_initializer
19
+
20
+ from keras_hub.src.layers.modeling.transformer_layer_utils import ( # isort:skip
21
+ merge_padding_and_attention_mask,
22
+ )
23
+
24
+
25
+ @keras_hub_export("keras_hub.layers.TransformerEncoder")
26
+ class TransformerEncoder(keras.layers.Layer):
27
+ """Transformer encoder.
28
+
29
+ This class follows the architecture of the transformer encoder layer in the
30
+ paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users
31
+ can instantiate multiple instances of this class to stack up an encoder.
32
+
33
+ This layer will correctly compute an attention mask from an implicit
34
+ Keras padding mask (for example, by passing `mask_zero=True` to a
35
+ `keras.layers.Embedding` layer). See the Masking and Padding
36
+ [guide](https://keras.io/guides/understanding_masking_and_padding/)
37
+ for more details.
38
+
39
+ Args:
40
+ intermediate_dim: int, the hidden size of feedforward network.
41
+ num_heads: int, the number of heads in the
42
+ `keras.layers.MultiHeadAttention` layer.
43
+ dropout: float. the dropout value, shared by
44
+ `keras.layers.MultiHeadAttention` and feedforward network.
45
+ Defaults to `0.`.
46
+ activation: string or `keras.activations`. the
47
+ activation function of feedforward network.
48
+ Defaults to `"relu"`.
49
+ layer_norm_epsilon: float. The epsilon value in layer
50
+ normalization components. Defaults to `1e-5`.
51
+ kernel_initializer: string or `keras.initializers` initializer.
52
+ The kernel initializer for the dense and multiheaded
53
+ attention layers. Defaults to `"glorot_uniform"`.
54
+ bias_initializer: string or `keras.initializers` initializer.
55
+ The bias initializer for the dense and multiheaded
56
+ attention layers. Defaults to `"zeros"`.
57
+ normalize_first: bool. If True, the inputs to the
58
+ attention layer and the intermediate dense layer are normalized
59
+ (similar to GPT-2). If set to False, outputs of attention layer and
60
+ intermediate dense layer are normalized (similar to BERT).
61
+ Defaults to `False`.
62
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
63
+ including `name`, `trainable`, `dtype` etc.
64
+
65
+ Example:
66
+
67
+ ```python
68
+ # Create a single transformer encoder layer.
69
+ encoder = keras_hub.layers.TransformerEncoder(
70
+ intermediate_dim=64, num_heads=8)
71
+
72
+ # Create a simple model containing the encoder.
73
+ input = keras.Input(shape=(10, 64))
74
+ output = encoder(input)
75
+ model = keras.Model(inputs=input, outputs=output)
76
+
77
+ # Call encoder on the inputs.
78
+ input_data = np.random.uniform(size=(2, 10, 64))
79
+ output = model(input_data)
80
+ ```
81
+
82
+ References:
83
+ - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ intermediate_dim,
89
+ num_heads,
90
+ dropout=0,
91
+ activation="relu",
92
+ layer_norm_epsilon=1e-05,
93
+ kernel_initializer="glorot_uniform",
94
+ bias_initializer="zeros",
95
+ normalize_first=False,
96
+ **kwargs,
97
+ ):
98
+ super().__init__(**kwargs)
99
+ self.intermediate_dim = intermediate_dim
100
+ self.num_heads = num_heads
101
+ self.dropout = dropout
102
+ self.activation = keras.activations.get(activation)
103
+ self.layer_norm_epsilon = layer_norm_epsilon
104
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
105
+ self.bias_initializer = keras.initializers.get(bias_initializer)
106
+ self.normalize_first = normalize_first
107
+ self.supports_masking = True
108
+
109
+ def build(self, inputs_shape):
110
+ # Infer the dimension of our hidden feature size from the build shape.
111
+ hidden_dim = inputs_shape[-1]
112
+ # Attention head size is `hidden_dim` over the number of heads.
113
+ key_dim = int(hidden_dim // self.num_heads)
114
+ if key_dim == 0:
115
+ raise ValueError(
116
+ "Attention `key_dim` computed cannot be zero. "
117
+ f"The `hidden_dim` value of {hidden_dim} has to be equal to "
118
+ f"or greater than `num_heads` value of {self.num_heads}."
119
+ )
120
+
121
+ # Self attention layers.
122
+ self._self_attention_layer = keras.layers.MultiHeadAttention(
123
+ num_heads=self.num_heads,
124
+ key_dim=key_dim,
125
+ dropout=self.dropout,
126
+ kernel_initializer=clone_initializer(self.kernel_initializer),
127
+ bias_initializer=clone_initializer(self.bias_initializer),
128
+ dtype=self.dtype_policy,
129
+ name="self_attention_layer",
130
+ )
131
+ if hasattr(self._self_attention_layer, "_build_from_signature"):
132
+ self._self_attention_layer._build_from_signature(
133
+ query=inputs_shape,
134
+ value=inputs_shape,
135
+ )
136
+ else:
137
+ self._self_attention_layer.build(
138
+ query_shape=inputs_shape,
139
+ value_shape=inputs_shape,
140
+ )
141
+ self._self_attention_layer_norm = keras.layers.LayerNormalization(
142
+ epsilon=self.layer_norm_epsilon,
143
+ dtype=self.dtype_policy,
144
+ name="self_attention_layer_norm",
145
+ )
146
+ self._self_attention_layer_norm.build(inputs_shape)
147
+ self._self_attention_dropout = keras.layers.Dropout(
148
+ rate=self.dropout,
149
+ dtype=self.dtype_policy,
150
+ name="self_attention_dropout",
151
+ )
152
+
153
+ # Feedforward layers.
154
+ self._feedforward_layer_norm = keras.layers.LayerNormalization(
155
+ epsilon=self.layer_norm_epsilon,
156
+ dtype=self.dtype_policy,
157
+ name="feedforward_layer_norm",
158
+ )
159
+ self._feedforward_layer_norm.build(inputs_shape)
160
+ self._feedforward_intermediate_dense = keras.layers.Dense(
161
+ self.intermediate_dim,
162
+ activation=self.activation,
163
+ kernel_initializer=clone_initializer(self.kernel_initializer),
164
+ bias_initializer=clone_initializer(self.bias_initializer),
165
+ dtype=self.dtype_policy,
166
+ name="feedforward_intermediate_dense",
167
+ )
168
+ self._feedforward_intermediate_dense.build(inputs_shape)
169
+ self._feedforward_output_dense = keras.layers.Dense(
170
+ hidden_dim,
171
+ kernel_initializer=clone_initializer(self.kernel_initializer),
172
+ bias_initializer=clone_initializer(self.bias_initializer),
173
+ dtype=self.dtype_policy,
174
+ name="feedforward_output_dense",
175
+ )
176
+ intermediate_shape = list(inputs_shape)
177
+ intermediate_shape[-1] = self.intermediate_dim
178
+ self._feedforward_output_dense.build(tuple(intermediate_shape))
179
+ self._feedforward_dropout = keras.layers.Dropout(
180
+ rate=self.dropout,
181
+ dtype=self.dtype_policy,
182
+ name="feedforward_dropout",
183
+ )
184
+ self.built = True
185
+
186
+ def call(
187
+ self, inputs, padding_mask=None, attention_mask=None, training=None
188
+ ):
189
+ """Forward pass of the TransformerEncoder.
190
+
191
+ Args:
192
+ inputs: a Tensor. The input data to TransformerEncoder, should be
193
+ of shape [batch_size, sequence_length, hidden_dim].
194
+ padding_mask: a boolean Tensor. It indicates if the token should be
195
+ masked because the token is introduced due to padding.
196
+ `padding_mask` should have shape [batch_size, sequence_length].
197
+ attention_mask: a boolean Tensor. Customized mask used to mask out
198
+ certain tokens. `attention_mask` should have shape
199
+ [batch_size, sequence_length, sequence_length].
200
+ training: a boolean indicating whether the layer should behave in
201
+ training mode or in inference mode.
202
+
203
+ Returns:
204
+ A Tensor of the same shape as the `inputs`.
205
+ """
206
+ x = inputs # Intermediate result.
207
+
208
+ # Compute self attention mask.
209
+ self_attention_mask = merge_padding_and_attention_mask(
210
+ inputs, padding_mask, attention_mask
211
+ )
212
+
213
+ # Self attention block.
214
+ residual = x
215
+ if self.normalize_first:
216
+ x = self._self_attention_layer_norm(x)
217
+ x = self._self_attention_layer(
218
+ query=x,
219
+ value=x,
220
+ attention_mask=self_attention_mask,
221
+ training=training,
222
+ )
223
+ x = self._self_attention_dropout(x, training=training)
224
+ x = x + residual
225
+ if not self.normalize_first:
226
+ x = self._self_attention_layer_norm(x)
227
+
228
+ # Feedforward block.
229
+ residual = x
230
+ if self.normalize_first:
231
+ x = self._feedforward_layer_norm(x)
232
+ x = self._feedforward_intermediate_dense(x)
233
+ x = self._feedforward_output_dense(x)
234
+ x = self._feedforward_dropout(x, training=training)
235
+ x = x + residual
236
+ if not self.normalize_first:
237
+ x = self._feedforward_layer_norm(x)
238
+
239
+ return x
240
+
241
+ def get_config(self):
242
+ config = super().get_config()
243
+ config.update(
244
+ {
245
+ "intermediate_dim": self.intermediate_dim,
246
+ "num_heads": self.num_heads,
247
+ "dropout": self.dropout,
248
+ "activation": keras.activations.serialize(self.activation),
249
+ "layer_norm_epsilon": self.layer_norm_epsilon,
250
+ "kernel_initializer": keras.initializers.serialize(
251
+ self.kernel_initializer
252
+ ),
253
+ "bias_initializer": keras.initializers.serialize(
254
+ self.bias_initializer
255
+ ),
256
+ "normalize_first": self.normalize_first,
257
+ }
258
+ )
259
+ return config
260
+
261
+ def compute_output_shape(self, inputs_shape):
262
+ return inputs_shape
@@ -0,0 +1,106 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from absl import logging
16
+ from keras import ops
17
+
18
+
19
+ def _check_masks_shapes(inputs, padding_mask, attention_mask):
20
+ mask = padding_mask
21
+ if hasattr(inputs, "_keras_mask") and mask is None:
22
+ mask = inputs._keras_mask
23
+ if mask is not None:
24
+ if len(mask.shape) != 2:
25
+ raise ValueError(
26
+ "`padding_mask` should have shape "
27
+ "(batch_size, target_length). "
28
+ f"Received shape `{mask.shape}`."
29
+ )
30
+ if attention_mask is not None:
31
+ if len(attention_mask.shape) != 3:
32
+ raise ValueError(
33
+ "`attention_mask` should have shape "
34
+ "(batch_size, target_length, source_length). "
35
+ f"Received shape `{mask.shape}`."
36
+ )
37
+
38
+
39
+ def compute_causal_mask(batch_size, input_length, output_length, cache_index=0):
40
+ """Compute a causal attention mask for a transformer decoder.
41
+
42
+ Args:
43
+ batch_size: batch size for the mask.
44
+ input_length: the length of key/value tensors in the attention layer.
45
+ output_length: the length of query tensors in the attention layer.
46
+ cache_index: the current index for cached generation. If passed, the
47
+ query sequence will be considered to start at `cache_index` rather
48
+ than zero. For example, a causal mask with `output_length=1` and
49
+ `cache_index=5` would allow the query tensor to attend to the first
50
+ five positions of the key/value tensors.
51
+
52
+ Return:
53
+ A causal attention mask with shape
54
+ `(batch_size, output_length, input_length)` that can be passed to a
55
+ attention layer.
56
+ """
57
+ i = ops.arange(output_length, dtype="float32")
58
+ i = i + ops.cast(cache_index, "float32")
59
+ i = ops.expand_dims(i, axis=1)
60
+ j = ops.arange(input_length, dtype="float32")
61
+ mask = ops.expand_dims(i >= j, axis=0)
62
+
63
+ return ops.broadcast_to(mask, (batch_size, output_length, input_length))
64
+
65
+
66
+ def merge_padding_and_attention_mask(
67
+ inputs,
68
+ padding_mask,
69
+ attention_mask,
70
+ ):
71
+ """Merge the padding mask with a customized attention mask.
72
+
73
+ Args:
74
+ inputs: the input sequence.
75
+ padding_mask: the 1D padding mask, of shape
76
+ [batch_size, sequence_length].
77
+ attention_mask: the 2D customized mask, of shape
78
+ [batch_size, sequence1_length, sequence2_length].
79
+
80
+ Return:
81
+ A merged 2D mask or None. If only `padding_mask` is provided, the
82
+ returned mask is padding_mask with one additional axis.
83
+ """
84
+ _check_masks_shapes(inputs, padding_mask, attention_mask)
85
+ mask = padding_mask
86
+ if hasattr(inputs, "_keras_mask"):
87
+ if mask is None:
88
+ # If no padding mask is explicitly provided, we look for padding
89
+ # mask from the input data.
90
+ mask = inputs._keras_mask
91
+ else:
92
+ logging.warning(
93
+ "You are explicitly setting `padding_mask` while the `inputs` "
94
+ "have built-in mask, so the built-in mask is ignored."
95
+ )
96
+ if mask is not None:
97
+ # Add an axis for broadcasting, the attention mask should be 2D
98
+ # (not including the batch axis).
99
+ mask = ops.cast(ops.expand_dims(mask, axis=1), "int32")
100
+ if attention_mask is not None:
101
+ attention_mask = ops.cast(attention_mask, "int32")
102
+ if mask is None:
103
+ return attention_mask
104
+ else:
105
+ return ops.minimum(mask, attention_mask)
106
+ return mask
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,220 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.layers.preprocessing.preprocessing_layer import (
18
+ PreprocessingLayer,
19
+ )
20
+ from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
21
+
22
+ try:
23
+ import tensorflow as tf
24
+ import tensorflow_text as tf_text
25
+ except ImportError:
26
+ tf = None
27
+ tf_text = None
28
+
29
+
30
+ @keras_hub_export("keras_hub.layers.MaskedLMMaskGenerator")
31
+ class MaskedLMMaskGenerator(PreprocessingLayer):
32
+ """Layer that applies language model masking.
33
+
34
+ This layer is useful for preparing inputs for masked language modeling
35
+ (MaskedLM) tasks. It follows the masking strategy described in the
36
+ [original BERT paper](https://arxiv.org/abs/1810.04805). Given tokenized
37
+ text, it randomly selects certain number of tokens for masking. Then for
38
+ each selected token, it has a chance (configurable) to be replaced by
39
+ "mask token" or random token, or stay unchanged.
40
+
41
+ Input data should be passed as tensors, `tf.RaggedTensor`s, or lists. For
42
+ batched input, inputs should be a list of lists or a rank two tensor. For
43
+ unbatched inputs, each element should be a list or a rank one tensor.
44
+
45
+ This layer can be used with `tf.data` to generate dynamic masks on the fly
46
+ during training.
47
+
48
+ Args:
49
+ vocabulary_size: int, the size of the vocabulary.
50
+ mask_selection_rate: float, the probability of a token is selected for
51
+ masking.
52
+ mask_token_id: int. The id of mask token.
53
+ mask_selection_length: int. Maximum number of tokens
54
+ selected for masking in each sequence. If set, the output
55
+ `mask_positions`, `mask_ids` and `mask_weights` will be padded
56
+ to dense tensors of length `mask_selection_length`, otherwise
57
+ the output will be a RaggedTensor. Defaults to `None`.
58
+ unselectable_token_ids: A list of tokens id that should not be
59
+ considered eligible for masking. By default, we assume `0`
60
+ corresponds to a padding token and ignore it. Defaults to `[0]`.
61
+ mask_token_rate: float. `mask_token_rate` must be
62
+ between 0 and 1 which indicates how often the mask_token is
63
+ substituted for tokens selected for masking. Defaults to `0.8`.
64
+ random_token_rate: float. `random_token_rate` must be
65
+ between 0 and 1 which indicates how often a random token is
66
+ substituted for tokens selected for masking.
67
+ Note: mask_token_rate + random_token_rate <= 1, and for
68
+ (1 - mask_token_rate - random_token_rate), the token will not be
69
+ changed. Defaults to `0.1`.
70
+
71
+ Returns:
72
+ A Dict with 4 keys:
73
+ token_ids: Tensor or RaggedTensor, has the same type and shape of
74
+ input. Sequence after getting masked.
75
+ mask_positions: Tensor, or RaggedTensor if `mask_selection_length`
76
+ is None. The positions of token_ids getting masked.
77
+ mask_ids: Tensor, or RaggedTensor if `mask_selection_length` is
78
+ None. The original token ids at masked positions.
79
+ mask_weights: Tensor, or RaggedTensor if `mask_selection_length` is
80
+ None. `mask_weights` has the same shape as `mask_positions` and
81
+ `mask_ids`. Each element in `mask_weights` should be 0 or 1,
82
+ 1 means the corresponding position in `mask_positions` is an
83
+ actual mask, 0 means it is a pad.
84
+
85
+ Examples:
86
+
87
+ Basic usage.
88
+ ```python
89
+ masker = keras_hub.layers.MaskedLMMaskGenerator(
90
+ vocabulary_size=10,
91
+ mask_selection_rate=0.2,
92
+ mask_token_id=0,
93
+ mask_selection_length=5
94
+ )
95
+ # Dense input.
96
+ masker([1, 2, 3, 4, 5])
97
+
98
+ # Ragged input.
99
+ masker([[1, 2], [1, 2, 3, 4]])
100
+ ```
101
+
102
+ Masking a batch that contains special tokens.
103
+ ```python
104
+ pad_id, cls_id, sep_id, mask_id = 0, 1, 2, 3
105
+ batch = [
106
+ [cls_id, 4, 5, 6, sep_id, 7, 8, sep_id, pad_id, pad_id],
107
+ [cls_id, 4, 5, sep_id, 6, 7, 8, 9, sep_id, pad_id],
108
+ ]
109
+
110
+ masker = keras_hub.layers.MaskedLMMaskGenerator(
111
+ vocabulary_size = 10,
112
+ mask_selection_rate = 0.2,
113
+ mask_selection_length = 5,
114
+ mask_token_id = mask_id,
115
+ unselectable_token_ids = [
116
+ cls_id,
117
+ sep_id,
118
+ pad_id,
119
+ ]
120
+ )
121
+ masker(batch)
122
+ ```
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ vocabulary_size,
128
+ mask_selection_rate,
129
+ mask_token_id,
130
+ mask_selection_length=None,
131
+ unselectable_token_ids=[0],
132
+ mask_token_rate=0.8,
133
+ random_token_rate=0.1,
134
+ **kwargs,
135
+ ):
136
+ super().__init__(**kwargs)
137
+
138
+ self.vocabulary_size = vocabulary_size
139
+ self.unselectable_token_ids = unselectable_token_ids
140
+ self.mask_selection_rate = mask_selection_rate
141
+ self.mask_selection_length = mask_selection_length
142
+ self.mask_token_rate = mask_token_rate
143
+ self.random_token_rate = random_token_rate
144
+
145
+ if mask_token_id >= vocabulary_size:
146
+ raise ValueError(
147
+ f"Mask token id should be in range [0, vocabulary_size - 1], "
148
+ f"but received mask_token_id={mask_token_id}."
149
+ )
150
+ self.mask_token_id = mask_token_id
151
+
152
+ max_selections = self.mask_selection_length
153
+ if max_selections is None:
154
+ # Set a large number to remove the `max_selections_per_batch` cap.
155
+ max_selections = 2**31 - 1
156
+ self._random_selector = tf_text.RandomItemSelector(
157
+ max_selections_per_batch=max_selections,
158
+ selection_rate=self.mask_selection_rate,
159
+ unselectable_ids=self.unselectable_token_ids,
160
+ )
161
+ self._mask_values_chooser = tf_text.MaskValuesChooser(
162
+ self.vocabulary_size,
163
+ self.mask_token_id,
164
+ mask_token_rate=self.mask_token_rate,
165
+ random_token_rate=self.random_token_rate,
166
+ )
167
+
168
+ def call(self, inputs):
169
+ inputs, unbatched, rectangular = convert_to_ragged_batch(inputs)
170
+
171
+ (
172
+ token_ids,
173
+ mask_positions,
174
+ mask_ids,
175
+ ) = tf_text.mask_language_model(
176
+ inputs,
177
+ item_selector=self._random_selector,
178
+ mask_values_chooser=self._mask_values_chooser,
179
+ )
180
+
181
+ if rectangular:
182
+ # If we converted the input from dense to ragged, convert back.
183
+ token_ids = token_ids.to_tensor()
184
+
185
+ mask_weights = tf.ones_like(mask_positions, self.compute_dtype)
186
+ # If `mask_selection_length` is set, convert to dense.
187
+ if self.mask_selection_length:
188
+ target_shape = tf.cast([-1, self.mask_selection_length], "int64")
189
+ mask_positions = mask_positions.to_tensor(shape=target_shape)
190
+ mask_ids = mask_ids.to_tensor(shape=target_shape)
191
+ mask_weights = mask_weights.to_tensor(shape=target_shape)
192
+
193
+ if unbatched:
194
+ # If inputs is 1D, we format the output to be 1D as well.
195
+ token_ids = tf.squeeze(token_ids, axis=0)
196
+ mask_positions = tf.squeeze(mask_positions, axis=0)
197
+ mask_ids = tf.squeeze(mask_ids, axis=0)
198
+ mask_weights = tf.squeeze(mask_weights, axis=0)
199
+
200
+ return {
201
+ "token_ids": token_ids,
202
+ "mask_positions": mask_positions,
203
+ "mask_ids": mask_ids,
204
+ "mask_weights": mask_weights,
205
+ }
206
+
207
+ def get_config(self):
208
+ config = super().get_config()
209
+ config.update(
210
+ {
211
+ "vocabulary_size": self.vocabulary_size,
212
+ "mask_selection_rate": self.mask_selection_rate,
213
+ "mask_selection_length": self.mask_selection_length,
214
+ "unselectable_token_ids": self.unselectable_token_ids,
215
+ "mask_token_id": self.mask_token_id,
216
+ "mask_token_rate": self.mask_token_rate,
217
+ "random_token_rate": self.random_token_rate,
218
+ }
219
+ )
220
+ return config