keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,267 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+
17
+ from keras_hub.src.api_export import keras_hub_export
18
+ from keras_hub.src.layers.preprocessing.preprocessing_layer import (
19
+ PreprocessingLayer,
20
+ )
21
+ from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
22
+ from keras_hub.src.utils.tensor_utils import is_int_dtype
23
+ from keras_hub.src.utils.tensor_utils import is_string_dtype
24
+
25
+ try:
26
+ import tensorflow as tf
27
+ except ImportError:
28
+ tf = None
29
+
30
+
31
+ @keras_hub_export("keras_hub.layers.RandomSwap")
32
+ class RandomSwap(PreprocessingLayer):
33
+ """Augments input by randomly swapping words.
34
+
35
+ This layer comes in handy when you need to generate new data using swap
36
+ augmentations as described in the paper [EDA: Easy Data Augmentation
37
+ Techniques for Boosting Performance on Text Classification Tasks]
38
+ (https://arxiv.org/pdf/1901.11196.pdf). The layer expects the inputs to be
39
+ pre-split into token level inputs. This allows control over the level of
40
+ augmentation, you can split by character for character level swaps, or by
41
+ word for word level swaps.
42
+
43
+ Input data should be passed as tensors, `tf.RaggedTensor`s, or lists. For
44
+ batched input, inputs should be a list of lists or a rank two tensor. For
45
+ unbatched inputs, each element should be a list or a rank one tensor.
46
+
47
+ Args:
48
+ rate: The probability of a given token being chosen to be swapped
49
+ with another random token.
50
+ max_swaps: The maximum number of swaps to be performed.
51
+ skip_list: A list of token values that should not be considered
52
+ candidates for deletion.
53
+ skip_fn: A function that takes as input a scalar tensor token and
54
+ returns as output a scalar tensor True/False value. A value of
55
+ True indicates that the token should not be considered a
56
+ candidate for deletion. This function must be tracable--it
57
+ should consist of tensorflow operations.
58
+ skip_py_fn: A function that takes as input a python token value and
59
+ returns as output `True` or `False`. A value of True
60
+ indicates that should not be considered a candidate for deletion.
61
+ Unlike the `skip_fn` argument, this argument need not be
62
+ tracable--it can be any python function.
63
+ seed: A seed for the random number generator.
64
+
65
+
66
+ Examples:
67
+
68
+ Word level usage.
69
+ >>> keras.utils.set_random_seed(1337)
70
+ >>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"])
71
+ >>> augmenter=keras_hub.layers.RandomSwap(rate=0.4, seed=42)
72
+ >>> augmented=augmenter(inputs)
73
+ >>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
74
+ <tf.Tensor: shape=(2,), dtype=string,
75
+ numpy=array([b'like I Hey', b'and Keras Tensorflow'], dtype=object)>
76
+
77
+ Character level usage.
78
+ >>> keras.utils.set_random_seed(1337)
79
+ >>> inputs=tf.strings.unicode_split(["Hey Dude", "Speed Up"], "UTF-8")
80
+ >>> augmenter=keras_hub.layers.RandomSwap(rate=0.4, seed=42)
81
+ >>> augmented=augmenter(inputs)
82
+ >>> tf.strings.reduce_join(augmented, axis=-1)
83
+ <tf.Tensor: shape=(2,), dtype=string,
84
+ numpy=array([b'deD yuHe', b'SUede pp'], dtype=object)>
85
+
86
+ Usage with skip_list.
87
+ >>> keras.utils.set_random_seed(1337)
88
+ >>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"])
89
+ >>> augmenter=keras_hub.layers.RandomSwap(rate=0.4,
90
+ ... skip_list=["Keras"], seed=42)
91
+ >>> augmented=augmenter(inputs)
92
+ >>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
93
+ <tf.Tensor: shape=(2,), dtype=string,
94
+ numpy=array([b'like I Hey', b'Keras and Tensorflow'], dtype=object)>
95
+
96
+ Usage with skip_fn.
97
+ >>> def skip_fn(word):
98
+ ... return tf.strings.regex_full_match(word, r"[I, a].*")
99
+ >>> keras.utils.set_random_seed(1337)
100
+ >>> inputs=tf.strings.split(["Hey I like", "Keras and Tensorflow"])
101
+ >>> augmenter=keras_hub.layers.RandomSwap(rate=0.9, max_swaps=3,
102
+ ... skip_fn=skip_fn, seed=11)
103
+ >>> augmented=augmenter(inputs)
104
+ >>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
105
+ <tf.Tensor: shape=(2,), dtype=string,
106
+ numpy=array([b'like I Hey', b'Keras and Tensorflow'], dtype=object)>
107
+
108
+ Usage with skip_py_fn.
109
+ >>> def skip_py_fn(word):
110
+ ... return len(word) < 4
111
+ >>> keras.utils.set_random_seed(1337)
112
+ >>> inputs=tf.strings.split(["He was drifting along", "With the wind"])
113
+ >>> augmenter=keras_hub.layers.RandomSwap(rate=0.8, max_swaps=2,
114
+ ... skip_py_fn=skip_py_fn, seed=15)
115
+ >>> augmented=augmenter(inputs)
116
+ >>> tf.strings.reduce_join(augmented, separator=" ", axis=-1)
117
+ <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'He was along drifting',
118
+ b'wind the With'], dtype=object)>
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ rate,
124
+ max_swaps=None,
125
+ skip_list=None,
126
+ skip_fn=None,
127
+ skip_py_fn=None,
128
+ seed=None,
129
+ name=None,
130
+ dtype="int32",
131
+ **kwargs,
132
+ ):
133
+ if not is_int_dtype(dtype) and not is_string_dtype(dtype):
134
+ raise ValueError(
135
+ "Output dtype must be an integer type or a string. "
136
+ f"Received: dtype={dtype}"
137
+ )
138
+
139
+ super().__init__(name=name, dtype=dtype, **kwargs)
140
+
141
+ self.rate = rate
142
+ self.max_swaps = max_swaps
143
+ self.seed = random.randint(1, 1e9) if seed is None else seed
144
+ self._generator = tf.random.Generator.from_seed(self.seed)
145
+ self.skip_list = skip_list
146
+ self.skip_fn = skip_fn
147
+ self.skip_py_fn = skip_py_fn
148
+ if self.max_swaps is not None and self.max_swaps < 0:
149
+ raise ValueError(
150
+ "max_swaps must be non-negative."
151
+ f"Received max_swaps={max_swaps}."
152
+ )
153
+
154
+ if [self.skip_list, self.skip_fn, self.skip_py_fn].count(None) < 2:
155
+ raise ValueError(
156
+ "Exactly one of skip_list, skip_fn, skip_py_fn must be "
157
+ "provided."
158
+ )
159
+
160
+ if self.skip_list:
161
+ self.StaticHashTable = tf.lookup.StaticHashTable(
162
+ tf.lookup.KeyValueTensorInitializer(
163
+ tf.convert_to_tensor(self.skip_list),
164
+ tf.convert_to_tensor([True] * len(self.skip_list)),
165
+ ),
166
+ default_value=False,
167
+ )
168
+
169
+ def call(self, inputs):
170
+ inputs, unbatched, _ = convert_to_ragged_batch(inputs)
171
+
172
+ skip_masks = None
173
+ if self.skip_list:
174
+ skip_masks = self.StaticHashTable.lookup(inputs.flat_values)
175
+ elif self.skip_fn:
176
+ skip_masks = tf.map_fn(
177
+ self.skip_fn, inputs.flat_values, fn_output_signature="bool"
178
+ )
179
+ elif self.skip_py_fn:
180
+
181
+ def string_fn(token):
182
+ return self.skip_py_fn(token.numpy().decode("utf-8"))
183
+
184
+ def int_fn(token):
185
+ return self.skip_py_fn(token.numpy())
186
+
187
+ py_fn = string_fn if inputs.dtype == tf.string else int_fn
188
+
189
+ skip_masks = tf.map_fn(
190
+ lambda x: tf.py_function(py_fn, [x], "bool"),
191
+ inputs.flat_values,
192
+ fn_output_signature="bool",
193
+ )
194
+
195
+ positions = tf.ragged.range(inputs.row_lengths())
196
+
197
+ if skip_masks is not None:
198
+ skip_masks = tf.logical_not(skip_masks)
199
+ skip_masks.set_shape([None])
200
+ positions = tf.ragged.boolean_mask(
201
+ positions, inputs.with_flat_values(skip_masks)
202
+ )
203
+ # Figure out how many we are going to select.
204
+ token_counts = tf.cast(positions.row_lengths(), "float32")
205
+ num_to_select = tf.random.stateless_binomial(
206
+ shape=tf.shape(token_counts),
207
+ seed=self._generator.make_seeds()[:, 0],
208
+ counts=token_counts,
209
+ probs=self.rate,
210
+ )
211
+ if self.max_swaps is not None:
212
+ num_to_select = tf.math.minimum(num_to_select, self.max_swaps)
213
+ num_to_select = tf.math.minimum(
214
+ num_to_select, tf.cast(positions.row_lengths(), "int32")
215
+ )
216
+ num_to_select = tf.cast(num_to_select, "int64")
217
+
218
+ def _swap(x):
219
+ positions, inputs, num_to_select = x
220
+ for _ in range(num_to_select):
221
+ index = tf.random.stateless_uniform(
222
+ shape=[2],
223
+ minval=0,
224
+ maxval=tf.size(positions),
225
+ dtype="int32",
226
+ seed=self._generator.make_seeds()[:, 0],
227
+ )
228
+ index1, index2 = positions[index[0]], positions[index[1]]
229
+ # swap items at the sampled indices with each other
230
+ inputs = tf.tensor_scatter_nd_update(
231
+ inputs,
232
+ [[index1], [index2]],
233
+ [inputs[index2], inputs[index1]],
234
+ )
235
+ return inputs
236
+
237
+ swapped = tf.map_fn(
238
+ _swap,
239
+ (positions, inputs, num_to_select),
240
+ fn_output_signature=tf.RaggedTensorSpec(
241
+ ragged_rank=positions.ragged_rank - 1, dtype=inputs.dtype
242
+ ),
243
+ )
244
+ swapped.flat_values.set_shape([None])
245
+
246
+ if unbatched:
247
+ swapped = tf.squeeze(swapped, axis=0)
248
+ return swapped
249
+
250
+ def get_config(self):
251
+ config = super().get_config()
252
+ config.update(
253
+ {
254
+ "rate": self.rate,
255
+ "max_swaps": self.max_swaps,
256
+ "seed": self.seed,
257
+ "skip_list": self.skip_list,
258
+ "skip_fn": self.skip_fn,
259
+ "skip_py_fn": self.skip_py_fn,
260
+ }
261
+ )
262
+ return config
263
+
264
+ def compute_output_shape(self, inputs_shape):
265
+ inputs_shape = list(inputs_shape)
266
+ inputs_shape[-1] = None
267
+ return tuple(inputs_shape)
@@ -0,0 +1,219 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.layers.preprocessing.preprocessing_layer import (
18
+ PreprocessingLayer,
19
+ )
20
+ from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
21
+
22
+ try:
23
+ import tensorflow as tf
24
+ except ImportError:
25
+ tf = None
26
+
27
+
28
+ @keras_hub_export("keras_hub.layers.StartEndPacker")
29
+ class StartEndPacker(PreprocessingLayer):
30
+ """Adds start and end tokens to a sequence and pads to a fixed length.
31
+
32
+ This layer is useful when tokenizing inputs for tasks like translation,
33
+ where each sequence should include a start and end marker. It should
34
+ be called after tokenization. The layer will first trim inputs to fit, then
35
+ add start/end tokens, and finally pad, if necessary, to `sequence_length`.
36
+
37
+ Input data should be passed as tensors, `tf.RaggedTensor`s, or lists. For
38
+ batched input, inputs should be a list of lists or a rank two tensor. For
39
+ unbatched inputs, each element should be a list or a rank one tensor.
40
+
41
+ Args:
42
+ sequence_length: int. The desired output length.
43
+ start_value: int/str/list/tuple. The ID(s) or token(s) that are to be
44
+ placed at the start of each sequence. The dtype must match the dtype
45
+ of the input tensors to the layer. If `None`, no start value will be
46
+ added.
47
+ end_value: int/str/list/tuple. The ID(s) or token(s) that are to be
48
+ placed at the end of each input segment. The dtype must match the
49
+ dtype of the input tensors to the layer. If `None`, no end value
50
+ will be added.
51
+ pad_value: int/str. The ID or token that is to be placed into the
52
+ unused positions after the last segment in the sequence. If `None`,
53
+ 0 or "" will be added depending on the dtype of the input tensor.
54
+ return_padding_mask: bool. Whether to return a boolean padding mask of
55
+ all locations that are filled in with the `pad_value`.
56
+
57
+ Call arguments:
58
+ inputs: A `tf.Tensor`, `tf.RaggedTensor`, or list of python strings.
59
+ sequence_length: Pass to override the configured `sequence_length` of
60
+ the layer.
61
+ add_start_value: Pass `False` to not append a start value for this
62
+ input.
63
+ add_end_value: Pass `False` to not append an end value for this
64
+ input.
65
+
66
+ Examples:
67
+
68
+ Unbatched input (int).
69
+ >>> inputs = [5, 6, 7]
70
+ >>> start_end_packer = keras_hub.layers.StartEndPacker(
71
+ ... sequence_length=7, start_value=1, end_value=2,
72
+ ... )
73
+ >>> outputs = start_end_packer(inputs)
74
+ >>> np.array(outputs)
75
+ array([1, 5, 6, 7, 2, 0, 0], dtype=int32)
76
+
77
+ Batched input (int).
78
+ >>> inputs = [[5, 6, 7], [8, 9, 10, 11, 12, 13, 14]]
79
+ >>> start_end_packer = keras_hub.layers.StartEndPacker(
80
+ ... sequence_length=6, start_value=1, end_value=2,
81
+ ... )
82
+ >>> outputs = start_end_packer(inputs)
83
+ >>> np.array(outputs)
84
+ array([[ 1, 5, 6, 7, 2, 0],
85
+ [ 1, 8, 9, 10, 11, 2]], dtype=int32)
86
+
87
+ Unbatched input (str).
88
+ >>> inputs = tf.constant(["this", "is", "fun"])
89
+ >>> start_end_packer = keras_hub.layers.StartEndPacker(
90
+ ... sequence_length=6, start_value="<s>", end_value="</s>",
91
+ ... pad_value="<pad>"
92
+ ... )
93
+ >>> outputs = start_end_packer(inputs)
94
+ >>> np.array(outputs).astype("U")
95
+ array(['<s>', 'this', 'is', 'fun', '</s>', '<pad>'], dtype='<U5')
96
+
97
+ Batched input (str).
98
+ >>> inputs = tf.ragged.constant([["this", "is", "fun"], ["awesome"]])
99
+ >>> start_end_packer = keras_hub.layers.StartEndPacker(
100
+ ... sequence_length=6, start_value="<s>", end_value="</s>",
101
+ ... pad_value="<pad>"
102
+ ... )
103
+ >>> outputs = start_end_packer(inputs)
104
+ >>> np.array(outputs).astype("U")
105
+ array([['<s>', 'this', 'is', 'fun', '</s>', '<pad>'],
106
+ ['<s>', 'awesome', '</s>', '<pad>', '<pad>', '<pad>']], dtype='<U7')
107
+
108
+ Multiple start tokens.
109
+ >>> inputs = tf.ragged.constant([["this", "is", "fun"], ["awesome"]])
110
+ >>> start_end_packer = keras_hub.layers.StartEndPacker(
111
+ ... sequence_length=6, start_value=["</s>", "<s>"], end_value="</s>",
112
+ ... pad_value="<pad>"
113
+ ... )
114
+ >>> outputs = start_end_packer(inputs)
115
+ >>> np.array(outputs).astype("U")
116
+ array([['</s>', '<s>', 'this', 'is', 'fun', '</s>'],
117
+ ['</s>', '<s>', 'awesome', '</s>', '<pad>', '<pad>']], dtype='<U7')
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ sequence_length,
123
+ start_value=None,
124
+ end_value=None,
125
+ pad_value=None,
126
+ return_padding_mask=False,
127
+ name=None,
128
+ **kwargs,
129
+ ):
130
+ super().__init__(name=name, **kwargs)
131
+
132
+ self.sequence_length = sequence_length
133
+
134
+ # Maintain private copies for config purposes.
135
+ self._start_value = start_value
136
+ self._end_value = end_value
137
+
138
+ def check_special_value_type(value, value_name):
139
+ if isinstance(value, (int, str)):
140
+ return [value]
141
+ if value and not isinstance(value, (list, tuple)):
142
+ raise ValueError(
143
+ f"{value_name} should be of type int/str/list/tuple."
144
+ f"Received type: `{type(value)}`."
145
+ )
146
+ return value
147
+
148
+ start_value = check_special_value_type(start_value, "start_value")
149
+ end_value = check_special_value_type(end_value, "end_value")
150
+
151
+ self.start_value = start_value
152
+ self.end_value = end_value
153
+
154
+ self.pad_value = pad_value
155
+ self.return_padding_mask = return_padding_mask
156
+
157
+ def call(
158
+ self,
159
+ inputs,
160
+ sequence_length=None,
161
+ add_start_value=True,
162
+ add_end_value=True,
163
+ ):
164
+ inputs, unbatched, _ = convert_to_ragged_batch(inputs)
165
+
166
+ x = inputs # Intermediate result.
167
+
168
+ batch_size = tf.shape(x)[0]
169
+ sequence_length = sequence_length or self.sequence_length
170
+ dtype = inputs.dtype
171
+
172
+ # Concatenate start and end tokens.
173
+ if add_start_value and self.start_value is not None:
174
+ start_value = tf.convert_to_tensor(self.start_value, dtype=dtype)
175
+ start_token_id_tensor = tf.repeat(
176
+ start_value[tf.newaxis, :], repeats=batch_size, axis=0
177
+ )
178
+ x = tf.concat([start_token_id_tensor, x], axis=-1)
179
+ if add_end_value and self.end_value is not None:
180
+ end_value = tf.convert_to_tensor(self.end_value, dtype=dtype)
181
+ end_token_id_tensor = tf.repeat(
182
+ end_value[tf.newaxis, :], repeats=batch_size, axis=0
183
+ )
184
+ # Trim to leave room for end token.
185
+ x = x[..., : sequence_length - len(self.end_value)]
186
+ x = tf.concat([x, end_token_id_tensor], axis=-1)
187
+
188
+ # Pad to desired length.
189
+ outputs = x.to_tensor(
190
+ default_value=self.pad_value,
191
+ shape=(batch_size, sequence_length),
192
+ )
193
+ outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs
194
+
195
+ if self.return_padding_mask:
196
+ mask = tf.ones_like(x, dtype="bool")
197
+ mask = mask.to_tensor(shape=(batch_size, sequence_length))
198
+ mask = tf.squeeze(mask, axis=0) if unbatched else mask
199
+ return outputs, mask
200
+
201
+ return outputs
202
+
203
+ def get_config(self):
204
+ config = super().get_config()
205
+ config.update(
206
+ {
207
+ "sequence_length": self.sequence_length,
208
+ "start_value": self._start_value,
209
+ "end_value": self._end_value,
210
+ "pad_value": self.pad_value,
211
+ "return_padding_mask": self.return_padding_mask,
212
+ }
213
+ )
214
+ return config
215
+
216
+ def compute_output_shape(self, inputs_shape):
217
+ inputs_shape = list(inputs_shape)
218
+ inputs_shape[-1] = self.sequence_length
219
+ return tuple(inputs_shape)
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.