keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,319 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.api_export import keras_hub_export
16
+ from keras_hub.src.layers.preprocessing.preprocessing_layer import (
17
+ PreprocessingLayer,
18
+ )
19
+ from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
20
+
21
+ try:
22
+ import tensorflow as tf
23
+ import tensorflow_text as tf_text
24
+ except ImportError:
25
+ tf = None
26
+ tf_text = None
27
+
28
+
29
+ @keras_hub_export("keras_hub.layers.MultiSegmentPacker")
30
+ class MultiSegmentPacker(PreprocessingLayer):
31
+ """Packs multiple sequences into a single fixed width model input.
32
+
33
+ This layer packs multiple input sequences into a single fixed width sequence
34
+ containing start and end delimeters, forming a dense input suitable for a
35
+ classification task for BERT and BERT-like models.
36
+
37
+ Takes as input a tuple of token segments. Each tuple element should contain
38
+ the tokens for a segment, passed as tensors, `tf.RaggedTensor`s, or lists.
39
+ For batched input, each element in the tuple of segments should be a list of
40
+ lists or a rank two tensor. For unbatched inputs, each element should be a
41
+ list or rank one tensor.
42
+
43
+ The layer will process inputs as follows:
44
+ - Truncate all input segments to fit within `sequence_length` according to
45
+ the `truncate` strategy.
46
+ - Concatenate all input segments, adding a single `start_value` at the
47
+ start of the entire sequence, and multiple `end_value`s at the end of
48
+ each segment.
49
+ - Pad the resulting sequence to `sequence_length` using `pad_tokens`.
50
+ - Calculate a separate tensor of "segment ids", with integer type and the
51
+ same shape as the packed token output, where each integer index of the
52
+ segment the token originated from. The segment id of the `start_value`
53
+ is always 0, and the segment id of each `end_value` is the segment that
54
+ precedes it.
55
+
56
+ Args:
57
+ sequence_length: int. The desired output length.
58
+ start_value: int/str/list/tuple. The id(s) or token(s) that are to be
59
+ placed at the start of each sequence (called "[CLS]" for BERT). The
60
+ dtype must match the dtype of the input tensors to the layer.
61
+ end_value: int/str/list/tuple. The id(s) or token(s) that are to be
62
+ placed at the end of the last input segment (called "[SEP]" for
63
+ BERT). The dtype must match the dtype of the input tensors to the
64
+ layer.
65
+ sep_value: int/str/list/tuple. The id(s) or token(s) that are to be
66
+ placed at the end of every segment, except the last segment (called
67
+ "[SEP]" for BERT). If `None`, `end_value` is used. The dtype must
68
+ match the dtype of the input tensors to the layer.
69
+ pad_value: int/str. The id or token that is to be placed into the unused
70
+ positions after the last segment in the sequence
71
+ (called "[PAD]" for BERT).
72
+ truncate: str. The algorithm to truncate a list of batched segments to
73
+ fit a per-example length limit. The value can be either
74
+ `"round_robin"` or `"waterfall"`:
75
+ - `"round_robin"`: Available space is assigned one token at a
76
+ time in a round-robin fashion to the inputs that still need
77
+ some, until the limit is reached.
78
+ - `"waterfall"`: The allocation of the budget is done using a
79
+ "waterfall" algorithm that allocates quota in a
80
+ left-to-right manner and fills up the buckets until we run
81
+ out of budget. It support arbitrary number of segments.
82
+
83
+ Returns:
84
+ A tuple with two elements. The first is the dense, packed token
85
+ sequence. The second is an integer tensor of the same shape, containing
86
+ the segment ids.
87
+
88
+ Examples:
89
+
90
+ *Pack a single input for classification.*
91
+ >>> seq1 = [1, 2, 3, 4]
92
+ >>> packer = keras_hub.layers.MultiSegmentPacker(
93
+ ... sequence_length=8, start_value=101, end_value=102
94
+ ... )
95
+ >>> token_ids, segment_ids = packer((seq1,))
96
+ >>> np.array(token_ids)
97
+ array([101, 1, 2, 3, 4, 102, 0, 0], dtype=int32)
98
+ >>> np.array(segment_ids)
99
+ array([0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)
100
+
101
+ *Pack multiple inputs for classification.*
102
+ >>> seq1 = [1, 2, 3, 4]
103
+ >>> seq2 = [11, 12, 13, 14]
104
+ >>> packer = keras_hub.layers.MultiSegmentPacker(
105
+ ... sequence_length=8, start_value=101, end_value=102
106
+ ... )
107
+ >>> token_ids, segment_ids = packer((seq1, seq2))
108
+ >>> np.array(token_ids)
109
+ array([101, 1, 2, 3, 102, 11, 12, 102], dtype=int32)
110
+ >>> np.array(segment_ids)
111
+ array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32)
112
+
113
+ *Pack multiple inputs for classification with different sep tokens.*
114
+ >>> seq1 = [1, 2, 3, 4]
115
+ >>> seq2 = [11, 12, 13, 14]
116
+ >>> packer = keras_hub.layers.MultiSegmentPacker(
117
+ ... sequence_length=8,
118
+ ... start_value=101,
119
+ ... end_value=102,
120
+ ... sep_value=[102, 102],
121
+ ... )
122
+ >>> token_ids, segment_ids = packer((seq1, seq2))
123
+ >>> np.array(token_ids)
124
+ array([101, 1, 2, 102, 102, 11, 12, 102], dtype=int32)
125
+ >>> np.array(segment_ids)
126
+ array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32)
127
+
128
+ Reference:
129
+ [Devlin et al., 2018](https://arxiv.org/abs/1810.04805).
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ sequence_length,
135
+ start_value,
136
+ end_value,
137
+ sep_value=None,
138
+ pad_value=None,
139
+ truncate="round_robin",
140
+ **kwargs,
141
+ ):
142
+ super().__init__(**kwargs)
143
+
144
+ self.sequence_length = sequence_length
145
+ if truncate not in ("round_robin", "waterfall"):
146
+ raise ValueError(
147
+ "Only 'round_robin' and 'waterfall' algorithms are "
148
+ "supported. Received %s" % truncate
149
+ )
150
+ self.truncate = truncate
151
+
152
+ # Maintain private copies of start/end values for config purposes.
153
+ self._start_value = start_value
154
+ self._sep_value = sep_value
155
+ self._end_value = end_value
156
+
157
+ def check_special_value_type(value, value_name):
158
+ if isinstance(value, (int, str)):
159
+ return [value]
160
+ if value and not isinstance(value, (list, tuple)):
161
+ raise ValueError(
162
+ f"{value_name} should be of type int/str/list/tuple."
163
+ f"Received type: `{type(value)}`."
164
+ )
165
+ return value
166
+
167
+ start_value = check_special_value_type(start_value, "start_value")
168
+ if sep_value is None:
169
+ sep_value = end_value
170
+ sep_value = check_special_value_type(sep_value, "sep_value")
171
+ end_value = check_special_value_type(end_value, "end_value")
172
+
173
+ self.start_value = start_value
174
+ self.sep_value = sep_value
175
+ self.end_value = end_value
176
+
177
+ self.pad_value = pad_value
178
+
179
+ def get_config(self):
180
+ config = super().get_config()
181
+ config.update(
182
+ {
183
+ "sequence_length": self.sequence_length,
184
+ "start_value": self._start_value,
185
+ "end_value": self._end_value,
186
+ "sep_value": self._sep_value,
187
+ "pad_value": self.pad_value,
188
+ "truncate": self.truncate,
189
+ }
190
+ )
191
+ return config
192
+
193
+ def _sanitize_inputs(self, inputs):
194
+ """Force inputs to a list of rank 2 ragged tensors."""
195
+ # Sanitize inputs.
196
+ if not isinstance(inputs, (list, tuple)):
197
+ inputs = (inputs,)
198
+ if not inputs:
199
+ raise ValueError(
200
+ "At least one input is required for packing. "
201
+ f"Received: `inputs={inputs}`"
202
+ )
203
+ inputs, unbatched_list, _ = list(
204
+ zip(*(convert_to_ragged_batch(x) for x in inputs))
205
+ )
206
+ if len(set(unbatched_list)) != 1:
207
+ ranks = [1 if unbatched else 2 for unbatched in unbatched_list]
208
+ raise ValueError(
209
+ "All inputs for packing must have the same rank. "
210
+ f"Received: `inputs={inputs}` with ranks {ranks}"
211
+ )
212
+ return inputs, unbatched_list[0]
213
+
214
+ def _trim_inputs(self, inputs):
215
+ """Trim inputs to desired length."""
216
+ num_segments = len(inputs)
217
+ num_special_tokens = (
218
+ len(self.start_value)
219
+ + (num_segments - 1) * len(self.sep_value)
220
+ + len(self.end_value)
221
+ )
222
+ if self.truncate == "round_robin":
223
+ return tf_text.RoundRobinTrimmer(
224
+ self.sequence_length - num_special_tokens
225
+ ).trim(inputs)
226
+ elif self.truncate == "waterfall":
227
+ return tf_text.WaterfallTrimmer(
228
+ self.sequence_length - num_special_tokens
229
+ ).trim(inputs)
230
+ else:
231
+ raise ValueError("Unsupported truncate: %s" % self.truncate)
232
+
233
+ def _combine_inputs(
234
+ self,
235
+ segments,
236
+ add_start_value=True,
237
+ add_end_value=True,
238
+ ):
239
+ """Combine inputs with start and end values added."""
240
+ dtype = segments[0].dtype
241
+ batch_size = segments[0].nrows()
242
+ start_value = tf.convert_to_tensor(self.start_value, dtype=dtype)
243
+ sep_value = tf.convert_to_tensor(self.sep_value, dtype=dtype)
244
+ end_value = tf.convert_to_tensor(self.end_value, dtype=dtype)
245
+
246
+ start_columns = tf.repeat(
247
+ start_value[tf.newaxis, :], repeats=batch_size, axis=0
248
+ )
249
+ sep_columns = tf.repeat(
250
+ sep_value[tf.newaxis, :], repeats=batch_size, axis=0
251
+ )
252
+ end_columns = tf.repeat(
253
+ end_value[tf.newaxis, :], repeats=batch_size, axis=0
254
+ )
255
+ ones_sep_columns = tf.ones_like(sep_columns, dtype="int32")
256
+ ones_end_columns = tf.ones_like(end_columns, dtype="int32")
257
+
258
+ segments_to_combine = []
259
+ segment_ids_to_combine = []
260
+ if add_start_value:
261
+ segments_to_combine.append(start_columns)
262
+ start_segment = tf.zeros_like(start_columns, dtype="int32")
263
+ segment_ids_to_combine.append(start_segment)
264
+
265
+ for i, seg in enumerate(segments):
266
+ # Combine all segments.
267
+ segments_to_combine.append(seg)
268
+
269
+ # Combine segment ids.
270
+ segment_ids_to_combine.append(tf.ones_like(seg, dtype="int32") * i)
271
+
272
+ # Account for the sep/end tokens here.
273
+ if i == len(segments) - 1:
274
+ if add_end_value:
275
+ segments_to_combine.append(end_columns)
276
+ segment_ids_to_combine.append(ones_end_columns * i)
277
+ else:
278
+ segments_to_combine.append(sep_columns)
279
+ segment_ids_to_combine.append(ones_sep_columns * i)
280
+
281
+ token_ids = tf.concat(segments_to_combine, 1)
282
+ segment_ids = tf.concat(segment_ids_to_combine, 1)
283
+ return token_ids, segment_ids
284
+
285
+ def call(
286
+ self,
287
+ inputs,
288
+ sequence_length=None,
289
+ add_start_value=True,
290
+ add_end_value=True,
291
+ ):
292
+ inputs, unbatched = self._sanitize_inputs(inputs)
293
+
294
+ segments = self._trim_inputs(inputs)
295
+ token_ids, segment_ids = self._combine_inputs(
296
+ segments,
297
+ add_start_value=add_start_value,
298
+ add_end_value=add_end_value,
299
+ )
300
+ # Pad to dense tensor output.
301
+ sequence_length = sequence_length or self.sequence_length
302
+ shape = tf.cast([-1, sequence_length], "int64")
303
+ token_ids = token_ids.to_tensor(
304
+ shape=shape, default_value=self.pad_value
305
+ )
306
+ segment_ids = segment_ids.to_tensor(shape=shape)
307
+ # Remove the batch dim if added.
308
+ if unbatched:
309
+ token_ids = tf.squeeze(token_ids, 0)
310
+ segment_ids = tf.squeeze(segment_ids, 0)
311
+
312
+ return (token_ids, segment_ids)
313
+
314
+ def compute_output_shape(self, inputs_shape):
315
+ if isinstance(inputs_shape[0], tuple):
316
+ inputs_shape = inputs_shape[0]
317
+ inputs_shape = list(inputs_shape)
318
+ inputs_shape[-1] = self.sequence_length
319
+ return tuple(inputs_shape)
@@ -0,0 +1,62 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import tree
17
+
18
+ from keras_hub.src.utils.tensor_utils import assert_tf_libs_installed
19
+ from keras_hub.src.utils.tensor_utils import (
20
+ convert_to_backend_tensor_or_python_list,
21
+ )
22
+
23
+ try:
24
+ import tensorflow as tf
25
+ except ImportError:
26
+ tf = None
27
+
28
+
29
+ class PreprocessingLayer(keras.layers.Layer):
30
+ """Preprocessing layer base class."""
31
+
32
+ def __init__(self, **kwargs):
33
+ assert_tf_libs_installed(self.__class__.__name__)
34
+
35
+ super().__init__(**kwargs)
36
+ self._convert_input_args = False
37
+ self._allow_non_tensor_positional_args = True
38
+ # Most pre-preprocessing has no build.
39
+ if not hasattr(self, "build"):
40
+ self.built = True
41
+
42
+ def get_build_config(self):
43
+ return None
44
+
45
+ def __call__(self, *args, **kwargs):
46
+ # Always place on CPU for preprocessing, to avoid expensive back and
47
+ # forth copies to GPU before the trainable model.
48
+ with tf.device("cpu"):
49
+ outputs = super().__call__(*args, **kwargs)
50
+
51
+ # Jax and Torch lack native string and ragged types.
52
+ # If we are running on those backends and not running with tf.data
53
+ # (we are outside a tf.function), we covert all ragged and string
54
+ # tensor to pythonic types.
55
+ is_tf_backend = keras.config.backend() == "tensorflow"
56
+ is_in_tf_graph = not tf.executing_eagerly()
57
+ if not is_tf_backend and not is_in_tf_graph:
58
+ outputs = tree.map_structure(
59
+ convert_to_backend_tensor_or_python_list, outputs
60
+ )
61
+
62
+ return outputs
@@ -0,0 +1,271 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+
17
+ from keras_hub.src.api_export import keras_hub_export
18
+ from keras_hub.src.layers.preprocessing.preprocessing_layer import (
19
+ PreprocessingLayer,
20
+ )
21
+ from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
22
+ from keras_hub.src.utils.tensor_utils import is_int_dtype
23
+ from keras_hub.src.utils.tensor_utils import is_string_dtype
24
+
25
+ try:
26
+ import tensorflow as tf
27
+ except ImportError:
28
+ tf = None
29
+
30
+
31
+ @keras_hub_export("keras_hub.layers.RandomDeletion")
32
+ class RandomDeletion(PreprocessingLayer):
33
+ """Augments input by randomly deleting tokens.
34
+
35
+ This layer comes in handy when you need to generate new data using deletion
36
+ augmentation as described in the paper [EDA: Easy Data Augmentation
37
+ Techniques for Boosting Performance on Text Classification Tasks]
38
+ (https://arxiv.org/pdf/1901.11196.pdf). The layer expects the inputs to be
39
+ pre-split into token level inputs. This allows control over the level of
40
+ augmentation, you can split by character for character level swaps, or by
41
+ word for word level swaps.
42
+
43
+ Input data should be passed as tensors, `tf.RaggedTensor`s, or lists. For
44
+ batched input, inputs should be a list of lists or a rank two tensor. For
45
+ unbatched inputs, each element should be a list or a rank one tensor.
46
+
47
+ Args:
48
+ rate: The probability of a token being chosen for deletion.
49
+ max_deletions: The maximum number of tokens to delete.
50
+ skip_list: A list of token values that should not be considered
51
+ candidates for deletion.
52
+ skip_fn: A function that takes as input a scalar tensor token and
53
+ returns as output a scalar tensor True/False value. A value of
54
+ True indicates that the token should not be considered a
55
+ candidate for deletion. This function must be tracable--it
56
+ should consist of tensorflow operations.
57
+ skip_py_fn: A function that takes as input a python token value and
58
+ returns as output `True` or `False`. A value of True
59
+ indicates that should not be considered a candidate for deletion.
60
+ Unlike the `skip_fn` argument, this argument need not be
61
+ tracable--it can be any python function.
62
+ seed: A seed for the random number generator.
63
+
64
+ Examples:
65
+
66
+ Word level usage.
67
+ >>> keras.utils.set_random_seed(1337)
68
+ >>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"])
69
+ >>> augmenter=keras_hub.layers.RandomDeletion(rate=0.4, seed=42)
70
+ >>> augmented=augmenter(inputs)
71
+ >>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
72
+ <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'I like', b'and'],
73
+ dtype=object)>
74
+
75
+ Character level usage.
76
+ >>> keras.utils.set_random_seed(1337)
77
+ >>> inputs=tf.strings.unicode_split(["Hey Dude", "Speed Up"], "UTF-8")
78
+ >>> augmenter=keras_hub.layers.RandomDeletion(rate=0.4, seed=42)
79
+ >>> augmented=augmenter(inputs)
80
+ >>> tf.strings.reduce_join(augmented, axis=-1)
81
+ <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'H Dude', b'pedUp'],
82
+ dtype=object)>
83
+
84
+ Usage with skip_list.
85
+ >>> keras.utils.set_random_seed(1337)
86
+ >>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"])
87
+ >>> augmenter=keras_hub.layers.RandomDeletion(rate=0.4,
88
+ ... skip_list=["Keras", "Tensorflow"], seed=42)
89
+ >>> augmented=augmenter(inputs)
90
+ >>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
91
+ <tf.Tensor: shape=(2,), dtype=string,
92
+ numpy=array([b'I like', b'Keras Tensorflow'], dtype=object)>
93
+
94
+ Usage with skip_fn.
95
+ >>> def skip_fn(word):
96
+ ... return tf.strings.regex_full_match(word, r"\\pP")
97
+ >>> keras.utils.set_random_seed(1337)
98
+ >>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"])
99
+ >>> augmenter=keras_hub.layers.RandomDeletion(rate=0.4,
100
+ ... skip_fn=skip_fn, seed=42)
101
+ >>> augmented=augmenter(inputs)
102
+ >>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
103
+ <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'I like', b'and'],
104
+ dtype=object)>
105
+
106
+ Usage with skip_py_fn.
107
+ >>> def skip_py_fn(word):
108
+ ... return len(word) < 4
109
+ >>> keras.utils.set_random_seed(1337)
110
+ >>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"])
111
+ >>> augmenter=RandomDeletion(rate=0.4,
112
+ ... skip_py_fn=skip_py_fn, seed=42)
113
+ >>> augmented=augmenter(inputs)
114
+ >>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
115
+ <tf.Tensor: shape=(2,), dtype=string,
116
+ numpy=array([b'Hey I', b'and Tensorflow'], dtype=object)>
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ rate,
122
+ max_deletions=None,
123
+ skip_list=None,
124
+ skip_fn=None,
125
+ skip_py_fn=None,
126
+ seed=None,
127
+ name=None,
128
+ dtype="int32",
129
+ **kwargs,
130
+ ):
131
+ if not is_int_dtype(dtype) and not is_string_dtype(dtype):
132
+ raise ValueError(
133
+ "Output dtype must be an integer type or a string. "
134
+ f"Received: dtype={dtype}"
135
+ )
136
+
137
+ super().__init__(dtype=dtype, name=name, **kwargs)
138
+
139
+ self.rate = rate
140
+ self.max_deletions = max_deletions
141
+ self.seed = random.randint(1, 1e9) if seed is None else seed
142
+ self._generator = tf.random.Generator.from_seed(self.seed)
143
+ self.skip_list = skip_list
144
+ self.skip_fn = skip_fn
145
+ self.skip_py_fn = skip_py_fn
146
+ if self.max_deletions is not None and self.max_deletions < 0:
147
+ raise ValueError(
148
+ "max_deletions must be non-negative."
149
+ f"Received max_deletions={max_deletions}."
150
+ )
151
+
152
+ if self.rate > 1 or self.rate < 0:
153
+ raise ValueError(
154
+ "Rate must be between 0 and 1 (both inclusive)."
155
+ f"Received: rate={rate}"
156
+ )
157
+
158
+ if [self.skip_list, self.skip_fn, self.skip_py_fn].count(None) < 2:
159
+ raise ValueError(
160
+ "Exactly one of `skip_list`, `skip_fn`, `skip_py_fn` must be "
161
+ "provided."
162
+ )
163
+
164
+ if self.skip_list:
165
+ self.StaticHashTable = tf.lookup.StaticHashTable(
166
+ tf.lookup.KeyValueTensorInitializer(
167
+ tf.convert_to_tensor(self.skip_list),
168
+ tf.convert_to_tensor([True] * len(self.skip_list)),
169
+ ),
170
+ default_value=False,
171
+ )
172
+
173
+ def call(self, inputs):
174
+ inputs, unbatched, _ = convert_to_ragged_batch(inputs)
175
+
176
+ skip_masks = None
177
+ if self.skip_list:
178
+ skip_masks = self.StaticHashTable.lookup(inputs.flat_values)
179
+ elif self.skip_fn:
180
+ skip_masks = tf.map_fn(
181
+ self.skip_fn, inputs.flat_values, fn_output_signature="bool"
182
+ )
183
+ elif self.skip_py_fn:
184
+
185
+ def string_fn(token):
186
+ return self.skip_py_fn(token.numpy().decode("utf-8"))
187
+
188
+ def int_fn(token):
189
+ return self.skip_py_fn(token.numpy())
190
+
191
+ py_fn = string_fn if inputs.dtype == tf.string else int_fn
192
+
193
+ skip_masks = tf.map_fn(
194
+ lambda x: tf.py_function(py_fn, [x], "bool"),
195
+ inputs.flat_values,
196
+ fn_output_signature="bool",
197
+ )
198
+
199
+ positions_flat = tf.range(tf.size(inputs.flat_values))
200
+ positions = inputs.with_flat_values(positions_flat)
201
+ if skip_masks is not None:
202
+ skip_masks = tf.logical_not(skip_masks)
203
+ skip_masks.set_shape([None])
204
+ positions = tf.ragged.boolean_mask(
205
+ positions, inputs.with_flat_values(skip_masks)
206
+ )
207
+
208
+ # Figure out how many we are going to select.
209
+ token_counts = tf.cast(positions.row_lengths(), "float32")
210
+ num_to_select = tf.random.stateless_binomial(
211
+ shape=tf.shape(token_counts),
212
+ seed=self._generator.make_seeds()[:, 0],
213
+ counts=token_counts,
214
+ probs=self.rate,
215
+ )
216
+ if self.max_deletions is not None:
217
+ num_to_select = tf.math.minimum(num_to_select, self.max_deletions)
218
+ num_to_select = tf.cast(num_to_select, "int64")
219
+
220
+ # Shuffle and trim to items that are going to be selected.
221
+ def _shuffle_and_trim(x):
222
+ positions, top_n = x
223
+ shuffled = tf.random.shuffle(positions, seed=self.seed)
224
+ return shuffled[:top_n]
225
+
226
+ selected_for_mask = tf.map_fn(
227
+ _shuffle_and_trim,
228
+ (positions, num_to_select),
229
+ fn_output_signature=tf.RaggedTensorSpec(
230
+ ragged_rank=positions.ragged_rank - 1, dtype=positions.dtype
231
+ ),
232
+ )
233
+ selected_for_mask.flat_values.set_shape([None])
234
+
235
+ # Construct the mask which is a boolean RT
236
+ # Scatter 0's to positions that have been selector for deletion.
237
+ update_values = tf.zeros_like(selected_for_mask.flat_values, "int32")
238
+ update_indices = selected_for_mask.flat_values
239
+ update_indices = tf.expand_dims(update_indices, -1)
240
+ update_indices = tf.cast(update_indices, "int32")
241
+ mask_flat = tf.ones_like(inputs.flat_values, dtype="int32")
242
+ mask_flat = tf.tensor_scatter_nd_update(
243
+ mask_flat, update_indices, update_values
244
+ )
245
+ mask = tf.cast(inputs.with_flat_values(mask_flat), "bool")
246
+
247
+ inputs = tf.ragged.boolean_mask(inputs, mask)
248
+
249
+ if unbatched:
250
+ inputs = tf.squeeze(inputs, axis=0)
251
+
252
+ return inputs
253
+
254
+ def get_config(self):
255
+ config = super().get_config()
256
+ config.update(
257
+ {
258
+ "rate": self.rate,
259
+ "max_deletions": self.max_deletions,
260
+ "seed": self.seed,
261
+ "skip_list": self.skip_list,
262
+ "skip_fn": self.skip_fn,
263
+ "skip_py_fn": self.skip_py_fn,
264
+ }
265
+ )
266
+ return config
267
+
268
+ def compute_output_shape(self, inputs_shape):
269
+ inputs_shape = list(inputs_shape)
270
+ inputs_shape[-1] = None
271
+ return tuple(inputs_shape)