keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,544 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+ from typing import Iterable
18
+
19
+ import keras
20
+
21
+ from keras_hub.src.api_export import keras_hub_export
22
+ from keras_hub.src.tokenizers import tokenizer
23
+ from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
24
+ from keras_hub.src.utils.tensor_utils import is_int_dtype
25
+ from keras_hub.src.utils.tensor_utils import is_string_dtype
26
+
27
+ try:
28
+ import tensorflow as tf
29
+ import tensorflow_text as tf_text
30
+ except ImportError:
31
+ tf = None
32
+ tf_text = None
33
+
34
+ VOCAB_FILENAME = "vocabulary.txt"
35
+
36
+ # Matches whitespace and control characters.
37
+ WHITESPACE_REGEX = r"|".join(
38
+ [
39
+ r"\s",
40
+ # Invisible control characters
41
+ r"\p{Cc}",
42
+ r"\p{Cf}",
43
+ ]
44
+ )
45
+
46
+ # Matches punctuation compatible with the original bert implementation.
47
+ PUNCTUATION_REGEX = r"|".join(
48
+ [
49
+ # Treat all non-letter/number ASCII as punctuation.
50
+ # Characters such as "^", "$", and "`" are not in the Unicode
51
+ # Punctuation class but we treat them as punctuation anyways.
52
+ r"[!-/]",
53
+ r"[:-@]",
54
+ r"[\[-`]",
55
+ r"[{-~]",
56
+ # Unicode punctuation class.
57
+ r"[\p{P}]",
58
+ ]
59
+ )
60
+
61
+ # Matches CJK characters. Obtained from
62
+ # https://github.com/google-research/bert/blob/master/tokenization.py#L251.
63
+ CJK_REGEX = r"|".join(
64
+ [
65
+ r"[\x{4E00}-\x{9FFF}]",
66
+ r"[\x{3400}-\x{4DBF}]",
67
+ r"[\x{20000}-\x{2A6DF}]",
68
+ r"[\x{2A700}-\x{2B73F}]",
69
+ r"[\x{2B740}-\x{2B81F}]",
70
+ r"[\x{2B820}-\x{2CEAF}]",
71
+ r"[\x{F900}-\x{FAFF}]",
72
+ r"[\x{2F800}-\x{2FA1F}]",
73
+ ]
74
+ )
75
+
76
+ # Matches both whitespace and punctuation.
77
+ WHITESPACE_AND_PUNCTUATION_REGEX = r"|".join(
78
+ [
79
+ WHITESPACE_REGEX,
80
+ PUNCTUATION_REGEX,
81
+ ]
82
+ )
83
+
84
+ # Matches punctuation and CJK characters.
85
+ PUNCTUATION_AND_CJK_REGEX = r"|".join(
86
+ [
87
+ PUNCTUATION_REGEX,
88
+ CJK_REGEX,
89
+ ]
90
+ )
91
+
92
+ # Matches whitespace, punctuation, and CJK characters.
93
+ WHITESPACE_PUNCTUATION_AND_CJK_REGEX = r"|".join(
94
+ [
95
+ WHITESPACE_AND_PUNCTUATION_REGEX,
96
+ CJK_REGEX,
97
+ ]
98
+ )
99
+
100
+
101
+ def get_special_tokens_pattern(special_tokens):
102
+ if special_tokens is None or len(special_tokens) == 0:
103
+ return None
104
+ return r"|".join([re.escape(token) for token in special_tokens])
105
+
106
+
107
+ def pretokenize(
108
+ text,
109
+ lowercase=False,
110
+ strip_accents=True,
111
+ split=True,
112
+ split_on_cjk=True,
113
+ special_tokens_pattern=None,
114
+ ):
115
+ """Helper function that takes in a dataset element and pretokenizes it.
116
+
117
+ Args:
118
+ text: `tf.Tensor` or `tf.RaggedTensor`. Input to be pretokenized.
119
+ lowercase: bool. If True, the input text will be
120
+ lowercased before tokenization. Defaults to `True`.
121
+ strip_accents: bool. If `True`, all accent marks will
122
+ be removed from text before tokenization. Defaults to `True`.
123
+ split: bool. If `True`, input will be split on
124
+ whitespace and punctuation marks, and all punctuation marks will be
125
+ kept as tokens. If `False`, input should be split ("pre-tokenized")
126
+ before calling the tokenizer, and passed as a dense or ragged tensor
127
+ of whole words. Defaults to `True`.
128
+ split_on_cjk: bool. If `True`, input will be split
129
+ on CJK characters, i.e., Chinese, Japanese, Korean and Vietnamese
130
+ characters (https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)).
131
+ Note that this is applicable only when `split` is `True`. Defaults
132
+ to `True`.
133
+ special_tokens_pattern: str. A regex pattern that contain the
134
+ special tokens that will never be split during the word-level
135
+ splitting applied before the word-peice encoding. This can be used
136
+ to ensure special tokens map to unique indices in the vocabulary,
137
+ even if these special tokens contain splittable characters such as
138
+ punctuation.
139
+
140
+ Returns:
141
+ A tensor containing the pre-processed and pre-tokenized `text`.
142
+ """
143
+ # Check for correct types.
144
+ if not is_string_dtype(text.dtype):
145
+ raise ValueError(
146
+ "The dataset elements in `data` must have string dtype. "
147
+ f"Received: {text.dtype}."
148
+ )
149
+ # Preprocess, lowercase, strip and split input data.
150
+ if text.shape.rank == 0:
151
+ text = tf.expand_dims(text, 0)
152
+ if split_on_cjk and split:
153
+ text = tf.strings.regex_replace(text, CJK_REGEX, r" \0 ")
154
+ if strip_accents:
155
+ # Normalize unicode to NFD, which splits out accent mark characters.
156
+ text = tf_text.normalize_utf8(text, "NFD")
157
+ # Remove the accent marks.
158
+ text = tf.strings.regex_replace(text, r"\p{Mn}", "")
159
+ if split:
160
+ if split_on_cjk:
161
+ split_pattern = WHITESPACE_PUNCTUATION_AND_CJK_REGEX
162
+ keep_split_pattern = PUNCTUATION_AND_CJK_REGEX
163
+ else:
164
+ split_pattern = WHITESPACE_AND_PUNCTUATION_REGEX
165
+ keep_split_pattern = PUNCTUATION_REGEX
166
+ if special_tokens_pattern is not None:
167
+ # the idea here is to pass the special tokens regex to the split
168
+ # function as delimiter regex pattern, so the input will be splitted
169
+ # by them, but also the function will treat each on of them as one
170
+ # entity that shouldn't be splitted even if they have other
171
+ # delimiter regex pattern inside them. then pass the special tokens
172
+ # regex also as keep delimiter regex pattern, so they will
173
+ # not be removed.
174
+ split_pattern = r"|".join(
175
+ [
176
+ special_tokens_pattern,
177
+ split_pattern,
178
+ ]
179
+ )
180
+ keep_split_pattern = r"|".join(
181
+ [special_tokens_pattern, keep_split_pattern]
182
+ )
183
+ text = tf_text.regex_split(
184
+ text,
185
+ delim_regex_pattern=split_pattern,
186
+ keep_delim_regex_pattern=keep_split_pattern,
187
+ )
188
+ if lowercase:
189
+ if special_tokens_pattern is not None:
190
+ # Do not lowercase special tokens in string space. They often
191
+ # contain capital letters, e.g. `"[CLS]"`.
192
+ mask = (
193
+ tf.strings.regex_replace(text, special_tokens_pattern, "६")
194
+ == "६"
195
+ )
196
+ text = tf.where(mask, text, tf_text.case_fold_utf8(text))
197
+ else:
198
+ text = tf_text.case_fold_utf8(text)
199
+
200
+ return text
201
+
202
+
203
+ @keras_hub_export("keras_hub.tokenizers.WordPieceTokenizer")
204
+ class WordPieceTokenizer(tokenizer.Tokenizer):
205
+ """A WordPiece tokenizer layer.
206
+
207
+ This layer provides an efficient, in graph, implementation of the WordPiece
208
+ algorithm used by BERT and other models.
209
+
210
+ To make this layer more useful out of the box, the layer will pre-tokenize
211
+ the input, which will optionally lower-case, strip accents, and split the
212
+ input on whitespace and punctuation. Each of these pre-tokenization steps is
213
+ not reversible. The `detokenize` method will join words with a space, and
214
+ will not invert `tokenize` exactly.
215
+
216
+ If a more custom pre-tokenization step is desired, the layer can be
217
+ configured to apply only the strict WordPiece algorithm by passing
218
+ `lowercase=False`, `strip_accents=False` and `split=False`. In
219
+ this case, inputs should be pre-split string tensors or ragged tensors.
220
+
221
+ Tokenizer outputs can either be padded and truncated with a
222
+ `sequence_length` argument, or left un-truncated. The exact output will
223
+ depend on the rank of the input tensors.
224
+
225
+ If input is a batch of strings (rank > 0):
226
+ By default, the layer will output a `tf.RaggedTensor` where the last
227
+ dimension of the output is ragged. If `sequence_length` is set, the layer
228
+ will output a dense `tf.Tensor` where all inputs have been padded or
229
+ truncated to `sequence_length`.
230
+
231
+ If input is a scalar string (rank == 0):
232
+ By default, the layer will output a dense `tf.Tensor` with static shape
233
+ `[None]`. If `sequence_length` is set, the output will be
234
+ a dense `tf.Tensor` of shape `[sequence_length]`.
235
+
236
+ The output dtype can be controlled via the `dtype` argument, which should
237
+ be either an integer or string type.
238
+
239
+ Args:
240
+ vocabulary: A list of strings or a string filename path. If
241
+ passing a list, each element of the list should be a single
242
+ WordPiece token string. If passing a filename, the file should be a
243
+ plain text file containing a single WordPiece token per line.
244
+ sequence_length: int. If set, the output will be converted to a dense
245
+ tensor and padded/trimmed so all outputs are of sequence_length.
246
+ lowercase: bool. If `True`, the input text will be
247
+ lowercased before tokenization. Defaults to `False`.
248
+ strip_accents: bool. If `True`, all accent marks will
249
+ be removed from text before tokenization. Defaults to `False`.
250
+ split: bool. If `True`, input will be split on
251
+ whitespace and punctuation marks, and all punctuation marks will be
252
+ kept as tokens. If `False`, input should be split ("pre-tokenized")
253
+ before calling the tokenizer, and passed as a dense or ragged tensor
254
+ of whole words. Defaults to `True`.
255
+ split_on_cjk: bool. If True, input will be split
256
+ on CJK characters, i.e., Chinese, Japanese, Korean and Vietnamese
257
+ characters (https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)).
258
+ Note that this is applicable only when `split` is True.
259
+ Defaults to `True`.
260
+ suffix_indicator: str. The characters prepended to a
261
+ WordPiece to indicate that it is a suffix to another subword.
262
+ E.g. "##ing". Defaults to `"##"`.
263
+ oov_token: str. The string value to substitute for
264
+ an unknown token. It must be included in the vocab.
265
+ Defaults to `"[UNK]"`.
266
+ special_tokens: list. A list of special tokens. when
267
+ `special_tokens_in_strings` is set to `True`, the tokenizer will map
268
+ every special token in the input strings to its id, even if these
269
+ special tokens contain characters that should be splitted before
270
+ tokenization such as punctuation. `special_tokens` must be included
271
+ in `vocabulary`.
272
+ special_tokens_in_strings: bool. A bool to indicate if the tokenizer
273
+ should expect special tokens in input strings that should be
274
+ tokenized and mapped correctly to their ids. Defaults to False.
275
+
276
+ References:
277
+ - [Schuster and Nakajima, 2012](https://research.google/pubs/pub37842/)
278
+ - [Song et al., 2020](https://arxiv.org/abs/2012.15524)
279
+
280
+ Examples:
281
+
282
+ Ragged outputs.
283
+ >>> vocab = ["[UNK]", "the", "qu", "##ick", "br", "##own", "fox", "."]
284
+ >>> inputs = "The quick brown fox."
285
+ >>> tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
286
+ ... vocabulary=vocab,
287
+ ... lowercase=True,
288
+ ... )
289
+ >>> outputs = tokenizer(inputs)
290
+ >>> np.array(outputs)
291
+ array([1, 2, 3, 4, 5, 6, 7], dtype=int32)
292
+
293
+ Dense outputs.
294
+ >>> vocab = ["[UNK]", "the", "qu", "##ick", "br", "##own", "fox", "."]
295
+ >>> inputs = ["The quick brown fox."]
296
+ >>> tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
297
+ ... vocabulary=vocab,
298
+ ... sequence_length=10,
299
+ ... lowercase=True,
300
+ ... )
301
+ >>> outputs = tokenizer(inputs)
302
+ >>> np.array(outputs)
303
+ array([[1, 2, 3, 4, 5, 6, 7, 0, 0, 0]], dtype=int32)
304
+
305
+ String output.
306
+ >>> vocab = ["[UNK]", "the", "qu", "##ick", "br", "##own", "fox", "."]
307
+ >>> inputs = "The quick brown fox."
308
+ >>> tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
309
+ ... vocabulary=vocab,
310
+ ... lowercase=True,
311
+ ... dtype="string",
312
+ ... )
313
+ >>> outputs = tokenizer(inputs)
314
+ >>> np.array(outputs).astype("U")
315
+ array(['the', 'qu', '##ick', 'br', '##own', 'fox', '.'], dtype='<U5')
316
+
317
+ Detokenization.
318
+ >>> vocab = ["[UNK]", "the", "qu", "##ick", "br", "##own", "fox", "."]
319
+ >>> inputs = "The quick brown fox."
320
+ >>> tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
321
+ ... vocabulary=vocab,
322
+ ... lowercase=True,
323
+ ... )
324
+ >>> outputs = tokenizer.detokenize(tokenizer.tokenize(inputs))
325
+ >>> np.array(outputs).astype("U")
326
+ array('the quick brown fox .', dtype='<U21')
327
+
328
+ Custom splitting.
329
+ >>> vocab = ["[UNK]", "the", "qu", "##ick", "br", "##own", "fox", "."]
330
+ >>> inputs = "The$quick$brown$fox"
331
+ >>> tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
332
+ ... vocabulary=vocab,
333
+ ... split=False,
334
+ ... lowercase=True,
335
+ ... dtype='string',
336
+ ... )
337
+ >>> split_inputs = tf.strings.split(inputs, sep="$")
338
+ >>> outputs = tokenizer(split_inputs)
339
+ >>> np.array(outputs).astype("U")
340
+ array(['the', 'qu', '##ick', 'br', '##own', 'fox'], dtype='<U5')
341
+ """
342
+
343
+ def __init__(
344
+ self,
345
+ vocabulary=None,
346
+ sequence_length=None,
347
+ lowercase=False,
348
+ strip_accents=False,
349
+ split=True,
350
+ split_on_cjk=True,
351
+ suffix_indicator="##",
352
+ oov_token="[UNK]",
353
+ special_tokens=None,
354
+ special_tokens_in_strings=False,
355
+ dtype="int32",
356
+ **kwargs,
357
+ ) -> None:
358
+ if not is_int_dtype(dtype) and not is_string_dtype(dtype):
359
+ raise ValueError(
360
+ "Output dtype must be an integer type or a string. "
361
+ f"Received: dtype={dtype}"
362
+ )
363
+
364
+ super().__init__(dtype=dtype, **kwargs)
365
+ if oov_token is None:
366
+ raise ValueError("`oov_token` cannot be None.")
367
+
368
+ self.sequence_length = sequence_length
369
+ self.lowercase = lowercase
370
+ self.strip_accents = strip_accents
371
+ self.split = split
372
+ self.split_on_cjk = split_on_cjk
373
+ self.suffix_indicator = suffix_indicator
374
+ self.oov_token = oov_token
375
+ self.special_tokens = special_tokens
376
+ self._special_tokens_pattern = None
377
+ if self.split and special_tokens_in_strings:
378
+ # the idea here is to pass the special tokens regex to the
379
+ # split function as delimiter regex pattern, so the input will
380
+ # be splitted by them, but also the function will treat each on
381
+ # of them as one entity that shouldn't be splitted even if they
382
+ # have other delimiter regex pattern inside them. then pass the
383
+ # special tokens regex also as keep delimiter regex
384
+ # pattern, so they will not be removed.
385
+ self._special_tokens_pattern = get_special_tokens_pattern(
386
+ self.special_tokens
387
+ )
388
+ self.set_vocabulary(vocabulary)
389
+ self.file_assets = [VOCAB_FILENAME]
390
+
391
+ def save_assets(self, dir_path):
392
+ path = os.path.join(dir_path, VOCAB_FILENAME)
393
+ with open(path, "w", encoding="utf-8") as file:
394
+ for token in self.vocabulary:
395
+ file.write(f"{token}\n")
396
+
397
+ def load_assets(self, dir_path):
398
+ path = os.path.join(dir_path, VOCAB_FILENAME)
399
+ self.set_vocabulary(path)
400
+
401
+ def set_vocabulary(self, vocabulary):
402
+ """Set the tokenizer vocabulary to a file or list of strings."""
403
+ if vocabulary is None:
404
+ self.vocabulary = None
405
+ self._fast_word_piece = None
406
+ return
407
+
408
+ if isinstance(vocabulary, str):
409
+ with open(vocabulary, "r", encoding="utf-8") as file:
410
+ self.vocabulary = [line.rstrip() for line in file]
411
+ elif isinstance(vocabulary, Iterable):
412
+ # Make a defensive copy.
413
+ self.vocabulary = list(vocabulary)
414
+ else:
415
+ raise ValueError(
416
+ "Vocabulary must be an file path or list of terms. "
417
+ f"Received: vocabulary={vocabulary}"
418
+ )
419
+
420
+ if self.oov_token not in self.vocabulary:
421
+ raise ValueError(
422
+ f'Cannot find `oov_token="{self.oov_token}"` in the '
423
+ "vocabulary.\n"
424
+ "You can either update the vocabulary to include "
425
+ f'`"{self.oov_token}"`, or pass a different value for '
426
+ "the `oov_token` argument when creating the tokenizer."
427
+ )
428
+
429
+ # Check for special tokens in the vocabulary
430
+ if self.special_tokens is not None:
431
+ for token in self.special_tokens:
432
+ if token not in self.vocabulary:
433
+ raise ValueError(
434
+ f"Cannot find token `'{token}'` in the provided "
435
+ f"`vocabulary`. Please provide `'{token}'` in your "
436
+ "`vocabulary` or use a pretrained `vocabulary` name."
437
+ )
438
+
439
+ self._fast_word_piece = tf_text.FastWordpieceTokenizer(
440
+ vocab=self.vocabulary,
441
+ token_out_type=self.compute_dtype,
442
+ suffix_indicator=self.suffix_indicator,
443
+ unknown_token=self.oov_token,
444
+ no_pretokenization=True,
445
+ support_detokenization=True,
446
+ )
447
+
448
+ def get_vocabulary(self):
449
+ """Get the tokenizer vocabulary as a list of strings tokens."""
450
+ self._check_vocabulary()
451
+ return self.vocabulary
452
+
453
+ def vocabulary_size(self):
454
+ """Get the integer size of the tokenizer vocabulary."""
455
+ self._check_vocabulary()
456
+ return len(self.vocabulary)
457
+
458
+ def id_to_token(self, id):
459
+ """Convert an integer id to a string token."""
460
+ self._check_vocabulary()
461
+ if id >= self.vocabulary_size() or id < 0:
462
+ raise ValueError(
463
+ f"`id` must be in range [0, {self.vocabulary_size() - 1}]. "
464
+ f"Received: {id}"
465
+ )
466
+ return self.vocabulary[id]
467
+
468
+ def token_to_id(self, token):
469
+ """Convert a string token to an integer id."""
470
+ # This will be slow, but keep memory usage down compared to building a
471
+ # . Assuming the main use case is looking up a few special tokens
472
+ # early in the vocab, this should be fine.
473
+ self._check_vocabulary()
474
+ return self.vocabulary.index(token)
475
+
476
+ def get_config(self):
477
+ config = super().get_config()
478
+ config.update(
479
+ {
480
+ "vocabulary": None, # Save vocabulary via an asset!
481
+ "sequence_length": self.sequence_length,
482
+ "lowercase": self.lowercase,
483
+ "strip_accents": self.strip_accents,
484
+ "split": self.split,
485
+ "suffix_indicator": self.suffix_indicator,
486
+ "oov_token": self.oov_token,
487
+ "special_tokens": self.special_tokens,
488
+ }
489
+ )
490
+ return config
491
+
492
+ def _check_vocabulary(self):
493
+ if self.vocabulary is None:
494
+ raise ValueError(
495
+ "No vocabulary has been set for WordPieceTokenizer. Make sure "
496
+ "to pass a `vocabulary` argument when creating the layer."
497
+ )
498
+
499
+ def tokenize(self, inputs):
500
+ self._check_vocabulary()
501
+ if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
502
+ inputs = tf.convert_to_tensor(inputs)
503
+
504
+ scalar_input = inputs.shape.rank == 0
505
+ inputs = pretokenize(
506
+ inputs,
507
+ self.lowercase,
508
+ self.strip_accents,
509
+ self.split,
510
+ self.split_on_cjk,
511
+ self._special_tokens_pattern,
512
+ )
513
+
514
+ # Apply WordPiece and coerce shape for outputs.
515
+ tokens = self._fast_word_piece.tokenize(inputs)
516
+ # By default tf.text tokenizes text with two ragged dimensions (one for
517
+ # split words and one for split subwords). We will collapse to a single
518
+ # ragged dimension which is a better out of box default.
519
+ tokens = tokens.merge_dims(-2, -1)
520
+
521
+ # Convert to a dense output if `sequence_length` is set.
522
+ if self.sequence_length:
523
+ output_shape = tokens.shape.as_list()
524
+ output_shape[-1] = self.sequence_length
525
+ tokens = tokens.to_tensor(shape=output_shape)
526
+ # Convert to a dense output if input in scalar
527
+ if scalar_input:
528
+ tokens = tf.squeeze(tokens, 0)
529
+ tf.ensure_shape(tokens, shape=[self.sequence_length])
530
+
531
+ return tokens
532
+
533
+ def detokenize(self, inputs):
534
+ self._check_vocabulary()
535
+ inputs, unbatched, _ = convert_to_ragged_batch(inputs)
536
+ outputs = self._fast_word_piece.detokenize(inputs)
537
+ if unbatched:
538
+ outputs = tf.squeeze(outputs, 0)
539
+ return outputs
540
+
541
+ def compute_output_spec(self, input_spec):
542
+ return keras.KerasTensor(
543
+ input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype
544
+ )