keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (297) hide show
  1. keras_hub/__init__.py +52 -0
  2. keras_hub/api/__init__.py +27 -0
  3. keras_hub/api/layers/__init__.py +47 -0
  4. keras_hub/api/metrics/__init__.py +24 -0
  5. keras_hub/api/models/__init__.py +249 -0
  6. keras_hub/api/samplers/__init__.py +29 -0
  7. keras_hub/api/tokenizers/__init__.py +35 -0
  8. keras_hub/src/__init__.py +13 -0
  9. keras_hub/src/api_export.py +53 -0
  10. keras_hub/src/layers/__init__.py +13 -0
  11. keras_hub/src/layers/modeling/__init__.py +13 -0
  12. keras_hub/src/layers/modeling/alibi_bias.py +143 -0
  13. keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
  14. keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
  15. keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
  16. keras_hub/src/layers/modeling/position_embedding.py +123 -0
  17. keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
  18. keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
  19. keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
  20. keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
  21. keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
  22. keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
  23. keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
  24. keras_hub/src/layers/preprocessing/__init__.py +13 -0
  25. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
  26. keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
  27. keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
  28. keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
  29. keras_hub/src/layers/preprocessing/random_swap.py +267 -0
  30. keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
  31. keras_hub/src/metrics/__init__.py +13 -0
  32. keras_hub/src/metrics/bleu.py +394 -0
  33. keras_hub/src/metrics/edit_distance.py +197 -0
  34. keras_hub/src/metrics/perplexity.py +181 -0
  35. keras_hub/src/metrics/rouge_base.py +204 -0
  36. keras_hub/src/metrics/rouge_l.py +97 -0
  37. keras_hub/src/metrics/rouge_n.py +125 -0
  38. keras_hub/src/models/__init__.py +13 -0
  39. keras_hub/src/models/albert/__init__.py +20 -0
  40. keras_hub/src/models/albert/albert_backbone.py +267 -0
  41. keras_hub/src/models/albert/albert_classifier.py +202 -0
  42. keras_hub/src/models/albert/albert_masked_lm.py +129 -0
  43. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
  44. keras_hub/src/models/albert/albert_preprocessor.py +206 -0
  45. keras_hub/src/models/albert/albert_presets.py +70 -0
  46. keras_hub/src/models/albert/albert_tokenizer.py +119 -0
  47. keras_hub/src/models/backbone.py +311 -0
  48. keras_hub/src/models/bart/__init__.py +20 -0
  49. keras_hub/src/models/bart/bart_backbone.py +261 -0
  50. keras_hub/src/models/bart/bart_preprocessor.py +276 -0
  51. keras_hub/src/models/bart/bart_presets.py +74 -0
  52. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
  53. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
  54. keras_hub/src/models/bart/bart_tokenizer.py +124 -0
  55. keras_hub/src/models/bert/__init__.py +23 -0
  56. keras_hub/src/models/bert/bert_backbone.py +227 -0
  57. keras_hub/src/models/bert/bert_classifier.py +183 -0
  58. keras_hub/src/models/bert/bert_masked_lm.py +131 -0
  59. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
  60. keras_hub/src/models/bert/bert_preprocessor.py +184 -0
  61. keras_hub/src/models/bert/bert_presets.py +147 -0
  62. keras_hub/src/models/bert/bert_tokenizer.py +112 -0
  63. keras_hub/src/models/bloom/__init__.py +20 -0
  64. keras_hub/src/models/bloom/bloom_attention.py +186 -0
  65. keras_hub/src/models/bloom/bloom_backbone.py +173 -0
  66. keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
  67. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
  68. keras_hub/src/models/bloom/bloom_decoder.py +206 -0
  69. keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
  70. keras_hub/src/models/bloom/bloom_presets.py +121 -0
  71. keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
  72. keras_hub/src/models/causal_lm.py +383 -0
  73. keras_hub/src/models/classifier.py +109 -0
  74. keras_hub/src/models/csp_darknet/__init__.py +13 -0
  75. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
  76. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
  77. keras_hub/src/models/deberta_v3/__init__.py +24 -0
  78. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
  79. keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
  80. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
  81. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
  82. keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
  83. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
  84. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
  85. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
  86. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
  87. keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
  88. keras_hub/src/models/densenet/__init__.py +13 -0
  89. keras_hub/src/models/densenet/densenet_backbone.py +210 -0
  90. keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
  91. keras_hub/src/models/distil_bert/__init__.py +26 -0
  92. keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
  93. keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
  94. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
  95. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
  96. keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
  97. keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
  98. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
  99. keras_hub/src/models/electra/__init__.py +20 -0
  100. keras_hub/src/models/electra/electra_backbone.py +247 -0
  101. keras_hub/src/models/electra/electra_preprocessor.py +154 -0
  102. keras_hub/src/models/electra/electra_presets.py +95 -0
  103. keras_hub/src/models/electra/electra_tokenizer.py +104 -0
  104. keras_hub/src/models/f_net/__init__.py +20 -0
  105. keras_hub/src/models/f_net/f_net_backbone.py +236 -0
  106. keras_hub/src/models/f_net/f_net_classifier.py +154 -0
  107. keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
  108. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
  109. keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
  110. keras_hub/src/models/f_net/f_net_presets.py +43 -0
  111. keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
  112. keras_hub/src/models/falcon/__init__.py +20 -0
  113. keras_hub/src/models/falcon/falcon_attention.py +156 -0
  114. keras_hub/src/models/falcon/falcon_backbone.py +164 -0
  115. keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
  116. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
  117. keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
  118. keras_hub/src/models/falcon/falcon_presets.py +30 -0
  119. keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
  120. keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
  121. keras_hub/src/models/feature_pyramid_backbone.py +73 -0
  122. keras_hub/src/models/gemma/__init__.py +20 -0
  123. keras_hub/src/models/gemma/gemma_attention.py +250 -0
  124. keras_hub/src/models/gemma/gemma_backbone.py +316 -0
  125. keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
  126. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
  127. keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
  128. keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
  129. keras_hub/src/models/gemma/gemma_presets.py +248 -0
  130. keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
  131. keras_hub/src/models/gemma/rms_normalization.py +40 -0
  132. keras_hub/src/models/gpt2/__init__.py +20 -0
  133. keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
  134. keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
  135. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
  136. keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
  137. keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
  138. keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
  139. keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
  140. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
  141. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
  142. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
  143. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
  144. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
  145. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
  146. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
  147. keras_hub/src/models/image_classifier.py +90 -0
  148. keras_hub/src/models/llama/__init__.py +20 -0
  149. keras_hub/src/models/llama/llama_attention.py +225 -0
  150. keras_hub/src/models/llama/llama_backbone.py +188 -0
  151. keras_hub/src/models/llama/llama_causal_lm.py +327 -0
  152. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
  153. keras_hub/src/models/llama/llama_decoder.py +246 -0
  154. keras_hub/src/models/llama/llama_layernorm.py +48 -0
  155. keras_hub/src/models/llama/llama_preprocessor.py +189 -0
  156. keras_hub/src/models/llama/llama_presets.py +80 -0
  157. keras_hub/src/models/llama/llama_tokenizer.py +84 -0
  158. keras_hub/src/models/llama3/__init__.py +20 -0
  159. keras_hub/src/models/llama3/llama3_backbone.py +84 -0
  160. keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
  161. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
  162. keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
  163. keras_hub/src/models/llama3/llama3_presets.py +69 -0
  164. keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
  165. keras_hub/src/models/masked_lm.py +101 -0
  166. keras_hub/src/models/mistral/__init__.py +20 -0
  167. keras_hub/src/models/mistral/mistral_attention.py +238 -0
  168. keras_hub/src/models/mistral/mistral_backbone.py +203 -0
  169. keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
  170. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
  171. keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
  172. keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
  173. keras_hub/src/models/mistral/mistral_presets.py +48 -0
  174. keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
  175. keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
  176. keras_hub/src/models/mix_transformer/__init__.py +13 -0
  177. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
  178. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
  179. keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
  180. keras_hub/src/models/opt/__init__.py +20 -0
  181. keras_hub/src/models/opt/opt_backbone.py +173 -0
  182. keras_hub/src/models/opt/opt_causal_lm.py +301 -0
  183. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
  184. keras_hub/src/models/opt/opt_preprocessor.py +188 -0
  185. keras_hub/src/models/opt/opt_presets.py +72 -0
  186. keras_hub/src/models/opt/opt_tokenizer.py +116 -0
  187. keras_hub/src/models/pali_gemma/__init__.py +23 -0
  188. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
  189. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
  190. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
  191. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
  192. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
  193. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
  194. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
  195. keras_hub/src/models/phi3/__init__.py +20 -0
  196. keras_hub/src/models/phi3/phi3_attention.py +260 -0
  197. keras_hub/src/models/phi3/phi3_backbone.py +224 -0
  198. keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
  199. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
  200. keras_hub/src/models/phi3/phi3_decoder.py +260 -0
  201. keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
  202. keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
  203. keras_hub/src/models/phi3/phi3_presets.py +50 -0
  204. keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
  205. keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
  206. keras_hub/src/models/preprocessor.py +207 -0
  207. keras_hub/src/models/resnet/__init__.py +13 -0
  208. keras_hub/src/models/resnet/resnet_backbone.py +612 -0
  209. keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
  210. keras_hub/src/models/roberta/__init__.py +20 -0
  211. keras_hub/src/models/roberta/roberta_backbone.py +184 -0
  212. keras_hub/src/models/roberta/roberta_classifier.py +209 -0
  213. keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
  214. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
  215. keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
  216. keras_hub/src/models/roberta/roberta_presets.py +43 -0
  217. keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
  218. keras_hub/src/models/seq_2_seq_lm.py +54 -0
  219. keras_hub/src/models/t5/__init__.py +20 -0
  220. keras_hub/src/models/t5/t5_backbone.py +261 -0
  221. keras_hub/src/models/t5/t5_layer_norm.py +35 -0
  222. keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
  223. keras_hub/src/models/t5/t5_presets.py +95 -0
  224. keras_hub/src/models/t5/t5_tokenizer.py +100 -0
  225. keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
  226. keras_hub/src/models/task.py +419 -0
  227. keras_hub/src/models/vgg/__init__.py +13 -0
  228. keras_hub/src/models/vgg/vgg_backbone.py +158 -0
  229. keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
  230. keras_hub/src/models/vit_det/__init__.py +13 -0
  231. keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
  232. keras_hub/src/models/vit_det/vit_layers.py +565 -0
  233. keras_hub/src/models/whisper/__init__.py +20 -0
  234. keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
  235. keras_hub/src/models/whisper/whisper_backbone.py +305 -0
  236. keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
  237. keras_hub/src/models/whisper/whisper_decoder.py +141 -0
  238. keras_hub/src/models/whisper/whisper_encoder.py +106 -0
  239. keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
  240. keras_hub/src/models/whisper/whisper_presets.py +148 -0
  241. keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
  242. keras_hub/src/models/xlm_roberta/__init__.py +26 -0
  243. keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
  244. keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
  245. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
  246. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
  247. keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
  248. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
  249. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
  250. keras_hub/src/models/xlnet/__init__.py +13 -0
  251. keras_hub/src/models/xlnet/relative_attention.py +459 -0
  252. keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
  253. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
  254. keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
  255. keras_hub/src/samplers/__init__.py +13 -0
  256. keras_hub/src/samplers/beam_sampler.py +207 -0
  257. keras_hub/src/samplers/contrastive_sampler.py +231 -0
  258. keras_hub/src/samplers/greedy_sampler.py +50 -0
  259. keras_hub/src/samplers/random_sampler.py +77 -0
  260. keras_hub/src/samplers/sampler.py +237 -0
  261. keras_hub/src/samplers/serialization.py +97 -0
  262. keras_hub/src/samplers/top_k_sampler.py +92 -0
  263. keras_hub/src/samplers/top_p_sampler.py +113 -0
  264. keras_hub/src/tests/__init__.py +13 -0
  265. keras_hub/src/tests/test_case.py +608 -0
  266. keras_hub/src/tokenizers/__init__.py +13 -0
  267. keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
  268. keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
  269. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
  270. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
  271. keras_hub/src/tokenizers/tokenizer.py +235 -0
  272. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
  273. keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
  274. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
  275. keras_hub/src/utils/__init__.py +13 -0
  276. keras_hub/src/utils/keras_utils.py +130 -0
  277. keras_hub/src/utils/pipeline_model.py +293 -0
  278. keras_hub/src/utils/preset_utils.py +621 -0
  279. keras_hub/src/utils/python_utils.py +21 -0
  280. keras_hub/src/utils/tensor_utils.py +206 -0
  281. keras_hub/src/utils/timm/__init__.py +13 -0
  282. keras_hub/src/utils/timm/convert.py +37 -0
  283. keras_hub/src/utils/timm/convert_resnet.py +171 -0
  284. keras_hub/src/utils/transformers/__init__.py +13 -0
  285. keras_hub/src/utils/transformers/convert.py +101 -0
  286. keras_hub/src/utils/transformers/convert_bert.py +173 -0
  287. keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
  288. keras_hub/src/utils/transformers/convert_gemma.py +187 -0
  289. keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
  290. keras_hub/src/utils/transformers/convert_llama3.py +136 -0
  291. keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
  292. keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
  293. keras_hub/src/version_utils.py +23 -0
  294. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
  295. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
  296. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
  297. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,638 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Byte-pair encoder implementation.
16
+
17
+ This file implements the same logic as openai BPE:
18
+ https://github.com/openai/gpt-2/blob/master/src/encoder.py,
19
+ but is TF graph compatible.
20
+ """
21
+
22
+ import json
23
+ import os
24
+ from typing import Iterable
25
+
26
+ import keras
27
+ import regex as re
28
+
29
+ from keras_hub.src.api_export import keras_hub_export
30
+ from keras_hub.src.tokenizers import tokenizer
31
+ from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
32
+ from keras_hub.src.utils.tensor_utils import is_int_dtype
33
+ from keras_hub.src.utils.tensor_utils import is_string_dtype
34
+
35
+ try:
36
+ import tensorflow as tf
37
+ import tensorflow_text as tf_text
38
+ except ImportError:
39
+ tf = None
40
+ tf_text = None
41
+
42
+ VOCAB_FILENAME = "vocabulary.json"
43
+ MERGES_FILENAME = "merges.txt"
44
+
45
+
46
+ # As python and TF handles special spaces differently, we need to
47
+ # manually handle special spaces during string split.
48
+ SPECIAL_WHITESPACES = r"\x{a0}\x{2009}\x{202f}\x{3000}"
49
+
50
+ # String splitting regex pattern.
51
+ SPLIT_PATTERN_1 = (
52
+ r"'s|'t|'re|'ve|'m|'ll|'d"
53
+ + r"|[\s{special_spaces}]+[\n\r\t\f६{special_spaces}]| ?\p{L}+|"
54
+ + r" ?[\p{N}]+| ?[^\s\p{L}\p{N}{special_spaces}]+"
55
+ )
56
+ SPLIT_PATTERN_1 = SPLIT_PATTERN_1.replace(
57
+ "{special_spaces}", SPECIAL_WHITESPACES
58
+ )
59
+ SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$"""
60
+
61
+
62
+ def create_alts_for_unsplittable_tokens(unsplittable_tokens):
63
+ # Create alternates for all special tokens that will be not split during
64
+ # tokenization.
65
+ alts = []
66
+ prefix = "Ĵ"
67
+ # Trim out splitters.
68
+ replace_pattern = r"'|\s+|[^\p{L}\p{N}]+"
69
+ for token in unsplittable_tokens:
70
+ token = re.sub(replace_pattern, "", token)
71
+ alts.append(prefix + token)
72
+ return alts
73
+
74
+
75
+ def bytes_to_unicode():
76
+ bs = (
77
+ list(range(ord("!"), ord("~") + 1))
78
+ + list(range(ord("¡"), ord("¬") + 1))
79
+ + list(range(ord("®"), ord("ÿ") + 1))
80
+ )
81
+ cs = bs[:]
82
+ n = 0
83
+ # removes mapping an int to a whitespace character
84
+ for b in range(2**8):
85
+ if b not in bs:
86
+ bs.append(b)
87
+ cs.append(2**8 + n)
88
+ n += 1
89
+ cs = [chr(n) for n in cs]
90
+ bs = [n.to_bytes(1, "little") for n in bs]
91
+ return bs, cs # int to string mapping
92
+
93
+
94
+ def remove_strings_from_inputs(tensor, string_to_remove):
95
+ """Remove certain strings from input tensor."""
96
+ non_empty_mask = tensor != string_to_remove
97
+ flatten_indexes = tf.where(non_empty_mask)
98
+ flatten_result = tf.gather_nd(tensor, flatten_indexes)
99
+ row_lengths = tf.reduce_sum(tf.cast(non_empty_mask, "int64"), axis=1)
100
+ result = tf.RaggedTensor.from_row_lengths(
101
+ values=flatten_result,
102
+ row_lengths=row_lengths,
103
+ )
104
+ return result
105
+
106
+
107
+ def split_strings_for_bpe(inputs, unsplittable_tokens=None):
108
+ # We need to recreate the exact behavior of token presplitting in the
109
+ # original gpt2 tokenizer which uses a lookahead. As re2 does not
110
+ # support lookahead match, we are using an alternative insert a special
111
+ # token "६" before leading space of non-space characters and after the
112
+ # trailing space, e.g., " keras" will be "६ keras".
113
+ inputs = tf.strings.regex_replace(
114
+ inputs, rf"( )([^\s{SPECIAL_WHITESPACES}])", r"६\1\2"
115
+ )
116
+ inputs = tf.strings.regex_replace(
117
+ inputs, rf"(\s{SPECIAL_WHITESPACES})$", r"\1६"
118
+ )
119
+ if unsplittable_tokens:
120
+ alts = create_alts_for_unsplittable_tokens(unsplittable_tokens)
121
+ for token, alt in zip(unsplittable_tokens, alts):
122
+ escaped_token = re.escape(token)
123
+ inputs = tf_text.regex_split(inputs, escaped_token, escaped_token)
124
+ inputs = tf.strings.regex_replace(inputs, escaped_token, alt)
125
+ raw_tokens = tf_text.regex_split(inputs, SPLIT_PATTERN_1, SPLIT_PATTERN_1)
126
+ # Second pass splits out the last whilespace char or "६".
127
+ raw_tokens = tf_text.regex_split(
128
+ raw_tokens, SPLIT_PATTERN_2, SPLIT_PATTERN_2
129
+ )
130
+ if unsplittable_tokens:
131
+ # Replace special tokens alternate with originals.
132
+ for token, alt in zip(unsplittable_tokens, alts):
133
+ escaped_alt = re.escape(alt)
134
+ raw_tokens = tf.strings.regex_replace(
135
+ raw_tokens, escaped_alt, token
136
+ )
137
+ while raw_tokens.shape.rank > 2:
138
+ raw_tokens = raw_tokens.merge_dims(1, 2)
139
+ return remove_strings_from_inputs(raw_tokens, "६")
140
+
141
+
142
+ class BytePairTokenizerCache(tf.Module if tf is not None else object):
143
+ """Cache that stores the encoded result of seen tokens.
144
+
145
+ The cache key is string tensor or python strings, and the value is split
146
+ tokens joined by whitespace. For example, "dragonfly" => "dragon fly"
147
+
148
+ Example:
149
+ ```
150
+ cache = BytePairTokenizerCache()
151
+ cache.insert(["butterfly", "dragonfly"], ["but ter fly", "dragon fly"])
152
+ cache.lookup(["butterfly"])
153
+ ```
154
+ """
155
+
156
+ def __init__(self):
157
+ # `tf.lookup.experimental.MutableHashTable` does not support string to
158
+ # string mapping. So we first convert to string to an integer key, and
159
+ # use the integer key to find the value.
160
+ self.factors = tf.pow(
161
+ tf.constant(256, dtype="int64"), tf.range(0, 8, dtype="int64")
162
+ )
163
+ self.id2value = tf.lookup.experimental.MutableHashTable(
164
+ "int64", tf.string, ""
165
+ )
166
+
167
+ def _get_key(self, keys):
168
+ """Get the hash key for given inputs."""
169
+ # `tf.fingerprint` converts token to a array of uint8 of length 8, we
170
+ # need to convert it to a uint64.
171
+ return tf.squeeze(
172
+ tf.matmul(
173
+ tf.cast(tf.fingerprint(keys), dtype="int64"),
174
+ self.factors[:, tf.newaxis],
175
+ ),
176
+ -1,
177
+ )
178
+
179
+ def lookup(self, keys):
180
+ """Look up the encoded outputs of given tokens."""
181
+ ids = self._get_key(keys)
182
+ result = self.id2value.lookup(ids)
183
+ # Ensure output shape for graph mode.
184
+ result.set_shape([None])
185
+ return result
186
+
187
+ def insert(self, keys, values):
188
+ """Insert token <=> encoded outputs pairs."""
189
+ self.id2value.insert(self._get_key(keys), values)
190
+
191
+
192
+ def create_static_hashtable(keys, values, default):
193
+ return tf.lookup.StaticHashTable(
194
+ tf.lookup.KeyValueTensorInitializer(
195
+ tf.convert_to_tensor(keys),
196
+ tf.convert_to_tensor(values),
197
+ ),
198
+ default_value=default,
199
+ )
200
+
201
+
202
+ @keras_hub_export("keras_hub.tokenizers.BytePairTokenizer")
203
+ class BytePairTokenizer(tokenizer.Tokenizer):
204
+ """Bype-pair encoding tokenizer layer.
205
+
206
+ This BPE tokenizer provides the same functionality as the official GPT-2
207
+ tokenizer. Given the same `vocabulary` which maps tokens to ids, and `merges`
208
+ which describes BPE merge rules, it should provide the same output
209
+ as OpenAI implementation (https://github.com/openai/gpt-2/blob/master/src/encoder.py).
210
+ Different from OpenAI, this implementation is graph-compatible, so you can
211
+ use it within a `tf.data` pipeline.
212
+
213
+ If input is a batch of strings (rank > 0):
214
+ By default, the layer will output a `tf.RaggedTensor` where the last
215
+ dimension of the output is ragged. If `sequence_length` is set, the layer
216
+ will output a dense `tf.Tensor` where all inputs have been padded or
217
+ truncated to `sequence_length`.
218
+ If input is a scalar string (rank == 0):
219
+ By default, the layer will output a dense `tf.Tensor` with static shape
220
+ `[None]`. If `sequence_length` is set, the output will be
221
+ a dense `tf.Tensor` of shape `[sequence_length]`.
222
+
223
+ Args:
224
+ vocabulary: string or dict, maps token to integer ids. If it is a
225
+ string, it should be the file path to a json file.
226
+ merges: string or list, contains the merge rule. If it is a string,
227
+ it should be the file path to merge rules. The merge rule file
228
+ should have one merge rule per line.
229
+ sequence_length: int. If set, the output will be
230
+ padded or truncated to the `sequence_length`. Defaults to `None`.
231
+ add_prefix_space: bool. Whether to add an
232
+ initial space to the input. This tokenizer is whitespace aware,
233
+ and will tokenize a word with a leading space differently. Adding
234
+ a prefix space to the first word will cause it to be tokenized
235
+ equivalently to all subsequent words in the sequence.
236
+ Defaults to `False`.
237
+ unsplittable_tokens: list. A list of strings that will
238
+ never be split during the word-level splitting applied before the
239
+ byte-pair encoding. This can be used to ensure special tokens map to
240
+ unique indices in the vocabulary, even if these special tokens
241
+ contain splittable characters such as punctuation. Special tokens
242
+ must still be included in `vocabulary`. Defaults to `None`.
243
+
244
+ Examples:
245
+
246
+ Tokenize
247
+ >>> vocab = {"butter": 1, "fly": 2}
248
+ >>> merge = ["b u", "t t", "e r", "bu tt", "butt er", "f l", "fl y"]
249
+ >>> tokenizer = keras_hub.tokenizers.BytePairTokenizer(vocab, merge)
250
+ >>> outputs = tokenizer("butterfly")
251
+ >>> np.array(outputs)
252
+ array([1, 2], dtype=int32)
253
+ >>> seq1, seq2 = tokenizer(["butterfly", "butter"])
254
+ >>> np.array(seq1)
255
+ array([1, 2], dtype=int32)
256
+ >>> np.array(seq2)
257
+ array([1], dtype=int32)
258
+ >>> tokenizer = keras_hub.tokenizers.BytePairTokenizer(
259
+ ... vocab, merge, sequence_length=2)
260
+ >>> seq1, seq2 = tokenizer(["butterfly", "butter"])
261
+ >>> np.array(seq1)
262
+ array([1, 2], dtype=int32)
263
+ >>> np.array(seq2)
264
+ array([1, 0], dtype=int32)
265
+
266
+ Detokenize
267
+ >>> vocab = {"butter": 1, "fly": 2}
268
+ >>> merge = ["b u", "t t", "e r", "bu tt", "butt er", "f l", "fl y"]
269
+ >>> tokenizer = keras_hub.tokenizers.BytePairTokenizer(vocab, merge)
270
+ >>> tokenizer.detokenize([[1, 2]])
271
+ <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'butterfly'],
272
+ dtype=object)>
273
+ """
274
+
275
+ def __init__(
276
+ self,
277
+ vocabulary=None,
278
+ merges=None,
279
+ sequence_length=None,
280
+ add_prefix_space=False,
281
+ unsplittable_tokens=None,
282
+ dtype="int32",
283
+ **kwargs,
284
+ ) -> None:
285
+ if not is_int_dtype(dtype) and not is_string_dtype(dtype):
286
+ raise ValueError(
287
+ "Output dtype must be an integer type or a string. "
288
+ f"Received: dtype={dtype}"
289
+ )
290
+
291
+ super().__init__(dtype=dtype, **kwargs)
292
+ self.sequence_length = sequence_length
293
+ self.add_prefix_space = add_prefix_space
294
+ self.unsplittable_tokens = unsplittable_tokens
295
+ self.file_assets = [VOCAB_FILENAME, MERGES_FILENAME]
296
+
297
+ # Create byte <=> unicode mapping. This is useful for handling
298
+ # whitespace tokens.
299
+ byte_list, unicode_list = bytes_to_unicode()
300
+ self.byte2unicode = create_static_hashtable(
301
+ byte_list, unicode_list, default=""
302
+ )
303
+ self.unicode2byte = create_static_hashtable(
304
+ unicode_list, byte_list, default=""
305
+ )
306
+
307
+ self.set_vocabulary_and_merges(vocabulary, merges)
308
+
309
+ def save_assets(self, dir_path):
310
+ vocab_path = os.path.join(dir_path, VOCAB_FILENAME)
311
+ merges_path = os.path.join(dir_path, MERGES_FILENAME)
312
+ with open(vocab_path, "w", encoding="utf-8") as file:
313
+ file.write(json.dumps(dict(self.vocabulary)))
314
+ with open(merges_path, "w", encoding="utf-8") as file:
315
+ for merge in self.merges:
316
+ file.write(f"{merge}\n")
317
+
318
+ def load_assets(self, dir_path):
319
+ vocab_path = os.path.join(dir_path, VOCAB_FILENAME)
320
+ merges_path = os.path.join(dir_path, MERGES_FILENAME)
321
+ self.set_vocabulary_and_merges(vocab_path, merges_path)
322
+
323
+ def set_vocabulary_and_merges(self, vocabulary, merges):
324
+ """Set the vocabulary and merge rules from data or files."""
325
+ if vocabulary is None or merges is None:
326
+ # Clear vocab related state.
327
+ self.vocabulary = None
328
+ self.merges = None
329
+ self.cache = None
330
+ self.id_to_token_map = None
331
+ self.token_to_id_map = None
332
+ self.merge_ranks_lookup_default = None
333
+ self.merge_ranks = None
334
+ return
335
+
336
+ if isinstance(vocabulary, str):
337
+ with open(vocabulary, "r", encoding="utf-8") as f:
338
+ self.vocabulary = json.load(f)
339
+ elif isinstance(vocabulary, dict):
340
+ self.vocabulary = vocabulary.copy()
341
+ else:
342
+ raise ValueError(
343
+ "Vocabulary must be an file path or dictionary mapping string "
344
+ "token to int ids. Received: "
345
+ f"`type(vocabulary)={type(vocabulary)}`."
346
+ )
347
+ if isinstance(merges, str):
348
+ with open(merges, encoding="utf-8") as f:
349
+ self.merges = [bp.rstrip() for bp in f]
350
+ elif isinstance(merges, Iterable):
351
+ self.merges = list(merges)
352
+ else:
353
+ raise ValueError(
354
+ "Merges must be a file path or a list of merge rules. "
355
+ f"Received: `type(merges)={type(merges)}`"
356
+ )
357
+
358
+ self.cache = BytePairTokenizerCache()
359
+ if self.unsplittable_tokens:
360
+ # Put special tokens into cache, so it won't be further split and
361
+ # merged.
362
+ self.cache.insert(
363
+ self.unsplittable_tokens, self.unsplittable_tokens
364
+ )
365
+
366
+ # Create mapping between string tokens to int ids, and vice versa.
367
+ byte_pairs = [x[0] for x in self.vocabulary.items()]
368
+ byte_pair_encoding_indices = [x[1] for x in self.vocabulary.items()]
369
+ self.token_to_id_map = create_static_hashtable(
370
+ byte_pairs,
371
+ byte_pair_encoding_indices,
372
+ default=-1,
373
+ )
374
+ self.id_to_token_map = create_static_hashtable(
375
+ byte_pair_encoding_indices,
376
+ byte_pairs,
377
+ default="",
378
+ )
379
+
380
+ # Create ranking of merge rules, this is the same as order of merge
381
+ # pairs in `self.merges`.
382
+ self.merge_ranks_lookup_default = len(self.merges) + 1
383
+ self.merge_ranks = create_static_hashtable(
384
+ self.merges,
385
+ list(range(len(self.merges))),
386
+ default=self.merge_ranks_lookup_default,
387
+ )
388
+
389
+ def get_vocabulary(self):
390
+ """Get the tokenizer vocabulary as a list of strings tokens."""
391
+ self._check_vocabulary()
392
+ return self.vocabulary.keys()
393
+
394
+ def vocabulary_size(self):
395
+ """Get the integer size of the tokenizer vocabulary."""
396
+ self._check_vocabulary()
397
+ return len(self.vocabulary)
398
+
399
+ def id_to_token(self, id):
400
+ """Convert an integer id to a string token."""
401
+ # This will be slow, but keep memory usage down compared to building a
402
+ # dict. Assuming the main use case is looking up a few special tokens
403
+ # early in the vocab, this should be fine.
404
+ self._check_vocabulary()
405
+
406
+ keys = self.get_vocabulary()
407
+ for token in keys:
408
+ if self.vocabulary[token] == id:
409
+ return token
410
+ raise ValueError(f"`id` is out of the vocabulary. Received: {id}")
411
+
412
+ def token_to_id(self, token):
413
+ """Convert a string token to an integer id."""
414
+ self._check_vocabulary()
415
+ return self.vocabulary[token]
416
+
417
+ def _bpe_merge_one_step(self, words, mask):
418
+ """Perform one step of byte-pair merge."""
419
+ # Get all word pairs.
420
+ first, second = words[:, :-1], words[:, 1:]
421
+
422
+ # Mask empty.
423
+ non_empty_mask = second.nested_row_lengths()[0] != 0
424
+ mask = mask & non_empty_mask
425
+ if not tf.reduce_any(mask):
426
+ return [words, mask]
427
+ non_empty_indices = tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask)
428
+ filterd_first = tf.ragged.boolean_mask(first, mask)
429
+ filtered_second = tf.ragged.boolean_mask(second, mask)
430
+
431
+ # Get byte pair ranking in merge rules.
432
+ pairs = tf.strings.join([filterd_first, filtered_second], separator=" ")
433
+ pair_rank = self.merge_ranks.lookup(pairs)
434
+
435
+ # Get BPE pair ranks.
436
+ min_pair_rank = tf.reduce_min(pair_rank, axis=1)
437
+ pair_found_mask = min_pair_rank != self.merge_ranks_lookup_default
438
+
439
+ # Tokens that cannot be further merged are marked as finished.
440
+ mask = tf.tensor_scatter_nd_update(
441
+ mask, tf.expand_dims(non_empty_indices, axis=1), pair_found_mask
442
+ )
443
+ if not tf.math.reduce_any(mask):
444
+ return [words, mask]
445
+
446
+ masked_pair_rank = tf.ragged.boolean_mask(pair_rank, pair_found_mask)
447
+ min_pair_rank_indices = tf.math.argmin(
448
+ masked_pair_rank.to_tensor(self.merge_ranks_lookup_default), axis=1
449
+ )
450
+
451
+ # Get words and pairs to process.
452
+ unfinished_words = tf.ragged.boolean_mask(words, mask)
453
+
454
+ pair_left = tf.gather(
455
+ unfinished_words, min_pair_rank_indices, batch_dims=1
456
+ )
457
+ pair_right = tf.gather(
458
+ unfinished_words, min_pair_rank_indices + 1, batch_dims=1
459
+ )
460
+
461
+ merged_pairs = tf.strings.join([pair_left, pair_right])
462
+ empty_strs = tf.fill(tf.shape(merged_pairs), "")
463
+
464
+ unfinished_word_indices = tf.cast(
465
+ tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask), dtype="int64"
466
+ )
467
+ merged_pair_indices = tf.concat(
468
+ [
469
+ unfinished_word_indices[:, tf.newaxis],
470
+ min_pair_rank_indices[:, tf.newaxis],
471
+ ],
472
+ axis=1,
473
+ )
474
+ empty_string_indices = tf.concat(
475
+ [
476
+ unfinished_word_indices[:, tf.newaxis],
477
+ min_pair_rank_indices[:, tf.newaxis] + 1,
478
+ ],
479
+ axis=1,
480
+ )
481
+
482
+ tensor_words = words.to_tensor(default_value="")
483
+ tensor_words = tf.tensor_scatter_nd_update(
484
+ tensor_words,
485
+ merged_pair_indices,
486
+ merged_pairs,
487
+ )
488
+
489
+ words = tf.tensor_scatter_nd_update(
490
+ tensor_words,
491
+ empty_string_indices,
492
+ empty_strs,
493
+ )
494
+ # Remove empty strings.
495
+ words = remove_strings_from_inputs(words, "")
496
+ return [words, mask]
497
+
498
+ def _bpe_merge(self, inputs):
499
+ """Perform byte-pair merge for each word in the inputs."""
500
+ num_words = tf.shape(inputs)[0]
501
+
502
+ # Merge bytes.
503
+ def loop_condition(_, mask):
504
+ return tf.math.reduce_any(mask)
505
+
506
+ initial_mask = tf.fill((num_words,), True)
507
+ merged_words, _ = tf.while_loop(
508
+ loop_condition,
509
+ tf.function(self._bpe_merge_one_step),
510
+ loop_vars=[
511
+ inputs,
512
+ initial_mask,
513
+ ],
514
+ shape_invariants=[
515
+ tf.TensorShape([None, None]),
516
+ tf.TensorShape([None]),
517
+ ],
518
+ )
519
+ return merged_words
520
+
521
+ def _check_vocabulary(self):
522
+ if self.vocabulary is None:
523
+ raise ValueError(
524
+ "No vocabulary has been set for BytePairTokenizer. Make sure "
525
+ "to pass `vocabulary` and `merges` arguments when creating the "
526
+ "layer."
527
+ )
528
+
529
+ def tokenize(self, inputs):
530
+ self._check_vocabulary()
531
+ if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
532
+ inputs = tf.convert_to_tensor(inputs)
533
+
534
+ if self.add_prefix_space:
535
+ inputs = tf.strings.join([" ", inputs])
536
+
537
+ scalar_input = inputs.shape.rank == 0
538
+ if scalar_input:
539
+ inputs = tf.expand_dims(inputs, 0)
540
+
541
+ raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens)
542
+ token_row_splits = raw_tokens.row_splits
543
+ flat_tokens = raw_tokens.flat_values
544
+
545
+ # Check cache.
546
+ cache_lookup = self.cache.lookup(flat_tokens)
547
+ cache_mask = cache_lookup == ""
548
+
549
+ has_unseen_words = tf.math.reduce_any(
550
+ (cache_lookup == "") & (flat_tokens != "")
551
+ )
552
+
553
+ def process_unseen_tokens():
554
+ unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask)
555
+ self._bpe_merge_and_update_cache(unseen_tokens)
556
+ return self.cache.lookup(flat_tokens)
557
+
558
+ # If `has_unseen_words == True`, it means not all tokens are in cache,
559
+ # we will process the unseen tokens. Otherwise return the cache lookup.
560
+ tokenized_words = tf.cond(
561
+ has_unseen_words,
562
+ process_unseen_tokens,
563
+ lambda: cache_lookup,
564
+ )
565
+
566
+ tokens = tf.strings.split(tokenized_words, sep=" ")
567
+ if self.compute_dtype != tf.string:
568
+ # Encode merged tokens.
569
+ tokens = self.token_to_id_map.lookup(tokens)
570
+
571
+ # Unflatten to match input.
572
+ tokens = tf.RaggedTensor.from_row_splits(
573
+ tokens.flat_values,
574
+ tf.gather(tokens.row_splits, token_row_splits),
575
+ )
576
+
577
+ # Convert to a dense output if `sequence_length` is set.
578
+ if self.sequence_length:
579
+ output_shape = tokens.shape.as_list()
580
+ output_shape[-1] = self.sequence_length
581
+ tokens = tokens.to_tensor(shape=output_shape)
582
+
583
+ # Convert to a dense output if input in scalar
584
+ if scalar_input:
585
+ tokens = tf.squeeze(tokens, 0)
586
+ tf.ensure_shape(tokens, shape=[self.sequence_length])
587
+
588
+ return tokens
589
+
590
+ def detokenize(self, inputs):
591
+ self._check_vocabulary()
592
+ inputs, unbatched, _ = convert_to_ragged_batch(inputs)
593
+ inputs = tf.cast(inputs, self.dtype)
594
+ unicode_text = tf.strings.reduce_join(
595
+ self.id_to_token_map.lookup(inputs), axis=-1
596
+ )
597
+ split_unicode_text = tf.strings.unicode_split(unicode_text, "UTF-8")
598
+ outputs = tf.strings.reduce_join(
599
+ self.unicode2byte.lookup(split_unicode_text), axis=-1
600
+ )
601
+
602
+ if unbatched:
603
+ outputs = tf.squeeze(outputs, 0)
604
+ return outputs
605
+
606
+ def compute_output_spec(self, input_spec):
607
+ return keras.KerasTensor(
608
+ input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype
609
+ )
610
+
611
+ def _transform_bytes(self, tokens):
612
+ """Map token bytes to unicode using `byte2unicode`."""
613
+ split_bytes = tf.strings.bytes_split(tokens)
614
+ split_unicode = self.byte2unicode.lookup(split_bytes)
615
+ return split_unicode
616
+
617
+ def _bpe_merge_and_update_cache(self, tokens):
618
+ """Process unseen tokens and add to cache."""
619
+ words = self._transform_bytes(tokens)
620
+ tokenized_words = self._bpe_merge(words)
621
+
622
+ # For each word, join all its token by a whitespace,
623
+ # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
624
+ tokenized_words = tf.strings.reduce_join(
625
+ tokenized_words, axis=1, separator=" "
626
+ )
627
+ self.cache.insert(tokens, tokenized_words)
628
+
629
+ def get_config(self):
630
+ config = super().get_config()
631
+ config.update(
632
+ {
633
+ "sequence_length": self.sequence_length,
634
+ "add_prefix_space": self.add_prefix_space,
635
+ "unsplittable_tokens": self.unsplittable_tokens,
636
+ }
637
+ )
638
+ return config