keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,638 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""Byte-pair encoder implementation.
|
16
|
+
|
17
|
+
This file implements the same logic as openai BPE:
|
18
|
+
https://github.com/openai/gpt-2/blob/master/src/encoder.py,
|
19
|
+
but is TF graph compatible.
|
20
|
+
"""
|
21
|
+
|
22
|
+
import json
|
23
|
+
import os
|
24
|
+
from typing import Iterable
|
25
|
+
|
26
|
+
import keras
|
27
|
+
import regex as re
|
28
|
+
|
29
|
+
from keras_hub.src.api_export import keras_hub_export
|
30
|
+
from keras_hub.src.tokenizers import tokenizer
|
31
|
+
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
|
32
|
+
from keras_hub.src.utils.tensor_utils import is_int_dtype
|
33
|
+
from keras_hub.src.utils.tensor_utils import is_string_dtype
|
34
|
+
|
35
|
+
try:
|
36
|
+
import tensorflow as tf
|
37
|
+
import tensorflow_text as tf_text
|
38
|
+
except ImportError:
|
39
|
+
tf = None
|
40
|
+
tf_text = None
|
41
|
+
|
42
|
+
VOCAB_FILENAME = "vocabulary.json"
|
43
|
+
MERGES_FILENAME = "merges.txt"
|
44
|
+
|
45
|
+
|
46
|
+
# As python and TF handles special spaces differently, we need to
|
47
|
+
# manually handle special spaces during string split.
|
48
|
+
SPECIAL_WHITESPACES = r"\x{a0}\x{2009}\x{202f}\x{3000}"
|
49
|
+
|
50
|
+
# String splitting regex pattern.
|
51
|
+
SPLIT_PATTERN_1 = (
|
52
|
+
r"'s|'t|'re|'ve|'m|'ll|'d"
|
53
|
+
+ r"|[\s{special_spaces}]+[\n\r\t\f६{special_spaces}]| ?\p{L}+|"
|
54
|
+
+ r" ?[\p{N}]+| ?[^\s\p{L}\p{N}{special_spaces}]+"
|
55
|
+
)
|
56
|
+
SPLIT_PATTERN_1 = SPLIT_PATTERN_1.replace(
|
57
|
+
"{special_spaces}", SPECIAL_WHITESPACES
|
58
|
+
)
|
59
|
+
SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$"""
|
60
|
+
|
61
|
+
|
62
|
+
def create_alts_for_unsplittable_tokens(unsplittable_tokens):
|
63
|
+
# Create alternates for all special tokens that will be not split during
|
64
|
+
# tokenization.
|
65
|
+
alts = []
|
66
|
+
prefix = "Ĵ"
|
67
|
+
# Trim out splitters.
|
68
|
+
replace_pattern = r"'|\s+|[^\p{L}\p{N}]+"
|
69
|
+
for token in unsplittable_tokens:
|
70
|
+
token = re.sub(replace_pattern, "", token)
|
71
|
+
alts.append(prefix + token)
|
72
|
+
return alts
|
73
|
+
|
74
|
+
|
75
|
+
def bytes_to_unicode():
|
76
|
+
bs = (
|
77
|
+
list(range(ord("!"), ord("~") + 1))
|
78
|
+
+ list(range(ord("¡"), ord("¬") + 1))
|
79
|
+
+ list(range(ord("®"), ord("ÿ") + 1))
|
80
|
+
)
|
81
|
+
cs = bs[:]
|
82
|
+
n = 0
|
83
|
+
# removes mapping an int to a whitespace character
|
84
|
+
for b in range(2**8):
|
85
|
+
if b not in bs:
|
86
|
+
bs.append(b)
|
87
|
+
cs.append(2**8 + n)
|
88
|
+
n += 1
|
89
|
+
cs = [chr(n) for n in cs]
|
90
|
+
bs = [n.to_bytes(1, "little") for n in bs]
|
91
|
+
return bs, cs # int to string mapping
|
92
|
+
|
93
|
+
|
94
|
+
def remove_strings_from_inputs(tensor, string_to_remove):
|
95
|
+
"""Remove certain strings from input tensor."""
|
96
|
+
non_empty_mask = tensor != string_to_remove
|
97
|
+
flatten_indexes = tf.where(non_empty_mask)
|
98
|
+
flatten_result = tf.gather_nd(tensor, flatten_indexes)
|
99
|
+
row_lengths = tf.reduce_sum(tf.cast(non_empty_mask, "int64"), axis=1)
|
100
|
+
result = tf.RaggedTensor.from_row_lengths(
|
101
|
+
values=flatten_result,
|
102
|
+
row_lengths=row_lengths,
|
103
|
+
)
|
104
|
+
return result
|
105
|
+
|
106
|
+
|
107
|
+
def split_strings_for_bpe(inputs, unsplittable_tokens=None):
|
108
|
+
# We need to recreate the exact behavior of token presplitting in the
|
109
|
+
# original gpt2 tokenizer which uses a lookahead. As re2 does not
|
110
|
+
# support lookahead match, we are using an alternative insert a special
|
111
|
+
# token "६" before leading space of non-space characters and after the
|
112
|
+
# trailing space, e.g., " keras" will be "६ keras".
|
113
|
+
inputs = tf.strings.regex_replace(
|
114
|
+
inputs, rf"( )([^\s{SPECIAL_WHITESPACES}])", r"६\1\2"
|
115
|
+
)
|
116
|
+
inputs = tf.strings.regex_replace(
|
117
|
+
inputs, rf"(\s{SPECIAL_WHITESPACES})$", r"\1६"
|
118
|
+
)
|
119
|
+
if unsplittable_tokens:
|
120
|
+
alts = create_alts_for_unsplittable_tokens(unsplittable_tokens)
|
121
|
+
for token, alt in zip(unsplittable_tokens, alts):
|
122
|
+
escaped_token = re.escape(token)
|
123
|
+
inputs = tf_text.regex_split(inputs, escaped_token, escaped_token)
|
124
|
+
inputs = tf.strings.regex_replace(inputs, escaped_token, alt)
|
125
|
+
raw_tokens = tf_text.regex_split(inputs, SPLIT_PATTERN_1, SPLIT_PATTERN_1)
|
126
|
+
# Second pass splits out the last whilespace char or "६".
|
127
|
+
raw_tokens = tf_text.regex_split(
|
128
|
+
raw_tokens, SPLIT_PATTERN_2, SPLIT_PATTERN_2
|
129
|
+
)
|
130
|
+
if unsplittable_tokens:
|
131
|
+
# Replace special tokens alternate with originals.
|
132
|
+
for token, alt in zip(unsplittable_tokens, alts):
|
133
|
+
escaped_alt = re.escape(alt)
|
134
|
+
raw_tokens = tf.strings.regex_replace(
|
135
|
+
raw_tokens, escaped_alt, token
|
136
|
+
)
|
137
|
+
while raw_tokens.shape.rank > 2:
|
138
|
+
raw_tokens = raw_tokens.merge_dims(1, 2)
|
139
|
+
return remove_strings_from_inputs(raw_tokens, "६")
|
140
|
+
|
141
|
+
|
142
|
+
class BytePairTokenizerCache(tf.Module if tf is not None else object):
|
143
|
+
"""Cache that stores the encoded result of seen tokens.
|
144
|
+
|
145
|
+
The cache key is string tensor or python strings, and the value is split
|
146
|
+
tokens joined by whitespace. For example, "dragonfly" => "dragon fly"
|
147
|
+
|
148
|
+
Example:
|
149
|
+
```
|
150
|
+
cache = BytePairTokenizerCache()
|
151
|
+
cache.insert(["butterfly", "dragonfly"], ["but ter fly", "dragon fly"])
|
152
|
+
cache.lookup(["butterfly"])
|
153
|
+
```
|
154
|
+
"""
|
155
|
+
|
156
|
+
def __init__(self):
|
157
|
+
# `tf.lookup.experimental.MutableHashTable` does not support string to
|
158
|
+
# string mapping. So we first convert to string to an integer key, and
|
159
|
+
# use the integer key to find the value.
|
160
|
+
self.factors = tf.pow(
|
161
|
+
tf.constant(256, dtype="int64"), tf.range(0, 8, dtype="int64")
|
162
|
+
)
|
163
|
+
self.id2value = tf.lookup.experimental.MutableHashTable(
|
164
|
+
"int64", tf.string, ""
|
165
|
+
)
|
166
|
+
|
167
|
+
def _get_key(self, keys):
|
168
|
+
"""Get the hash key for given inputs."""
|
169
|
+
# `tf.fingerprint` converts token to a array of uint8 of length 8, we
|
170
|
+
# need to convert it to a uint64.
|
171
|
+
return tf.squeeze(
|
172
|
+
tf.matmul(
|
173
|
+
tf.cast(tf.fingerprint(keys), dtype="int64"),
|
174
|
+
self.factors[:, tf.newaxis],
|
175
|
+
),
|
176
|
+
-1,
|
177
|
+
)
|
178
|
+
|
179
|
+
def lookup(self, keys):
|
180
|
+
"""Look up the encoded outputs of given tokens."""
|
181
|
+
ids = self._get_key(keys)
|
182
|
+
result = self.id2value.lookup(ids)
|
183
|
+
# Ensure output shape for graph mode.
|
184
|
+
result.set_shape([None])
|
185
|
+
return result
|
186
|
+
|
187
|
+
def insert(self, keys, values):
|
188
|
+
"""Insert token <=> encoded outputs pairs."""
|
189
|
+
self.id2value.insert(self._get_key(keys), values)
|
190
|
+
|
191
|
+
|
192
|
+
def create_static_hashtable(keys, values, default):
|
193
|
+
return tf.lookup.StaticHashTable(
|
194
|
+
tf.lookup.KeyValueTensorInitializer(
|
195
|
+
tf.convert_to_tensor(keys),
|
196
|
+
tf.convert_to_tensor(values),
|
197
|
+
),
|
198
|
+
default_value=default,
|
199
|
+
)
|
200
|
+
|
201
|
+
|
202
|
+
@keras_hub_export("keras_hub.tokenizers.BytePairTokenizer")
|
203
|
+
class BytePairTokenizer(tokenizer.Tokenizer):
|
204
|
+
"""Bype-pair encoding tokenizer layer.
|
205
|
+
|
206
|
+
This BPE tokenizer provides the same functionality as the official GPT-2
|
207
|
+
tokenizer. Given the same `vocabulary` which maps tokens to ids, and `merges`
|
208
|
+
which describes BPE merge rules, it should provide the same output
|
209
|
+
as OpenAI implementation (https://github.com/openai/gpt-2/blob/master/src/encoder.py).
|
210
|
+
Different from OpenAI, this implementation is graph-compatible, so you can
|
211
|
+
use it within a `tf.data` pipeline.
|
212
|
+
|
213
|
+
If input is a batch of strings (rank > 0):
|
214
|
+
By default, the layer will output a `tf.RaggedTensor` where the last
|
215
|
+
dimension of the output is ragged. If `sequence_length` is set, the layer
|
216
|
+
will output a dense `tf.Tensor` where all inputs have been padded or
|
217
|
+
truncated to `sequence_length`.
|
218
|
+
If input is a scalar string (rank == 0):
|
219
|
+
By default, the layer will output a dense `tf.Tensor` with static shape
|
220
|
+
`[None]`. If `sequence_length` is set, the output will be
|
221
|
+
a dense `tf.Tensor` of shape `[sequence_length]`.
|
222
|
+
|
223
|
+
Args:
|
224
|
+
vocabulary: string or dict, maps token to integer ids. If it is a
|
225
|
+
string, it should be the file path to a json file.
|
226
|
+
merges: string or list, contains the merge rule. If it is a string,
|
227
|
+
it should be the file path to merge rules. The merge rule file
|
228
|
+
should have one merge rule per line.
|
229
|
+
sequence_length: int. If set, the output will be
|
230
|
+
padded or truncated to the `sequence_length`. Defaults to `None`.
|
231
|
+
add_prefix_space: bool. Whether to add an
|
232
|
+
initial space to the input. This tokenizer is whitespace aware,
|
233
|
+
and will tokenize a word with a leading space differently. Adding
|
234
|
+
a prefix space to the first word will cause it to be tokenized
|
235
|
+
equivalently to all subsequent words in the sequence.
|
236
|
+
Defaults to `False`.
|
237
|
+
unsplittable_tokens: list. A list of strings that will
|
238
|
+
never be split during the word-level splitting applied before the
|
239
|
+
byte-pair encoding. This can be used to ensure special tokens map to
|
240
|
+
unique indices in the vocabulary, even if these special tokens
|
241
|
+
contain splittable characters such as punctuation. Special tokens
|
242
|
+
must still be included in `vocabulary`. Defaults to `None`.
|
243
|
+
|
244
|
+
Examples:
|
245
|
+
|
246
|
+
Tokenize
|
247
|
+
>>> vocab = {"butter": 1, "fly": 2}
|
248
|
+
>>> merge = ["b u", "t t", "e r", "bu tt", "butt er", "f l", "fl y"]
|
249
|
+
>>> tokenizer = keras_hub.tokenizers.BytePairTokenizer(vocab, merge)
|
250
|
+
>>> outputs = tokenizer("butterfly")
|
251
|
+
>>> np.array(outputs)
|
252
|
+
array([1, 2], dtype=int32)
|
253
|
+
>>> seq1, seq2 = tokenizer(["butterfly", "butter"])
|
254
|
+
>>> np.array(seq1)
|
255
|
+
array([1, 2], dtype=int32)
|
256
|
+
>>> np.array(seq2)
|
257
|
+
array([1], dtype=int32)
|
258
|
+
>>> tokenizer = keras_hub.tokenizers.BytePairTokenizer(
|
259
|
+
... vocab, merge, sequence_length=2)
|
260
|
+
>>> seq1, seq2 = tokenizer(["butterfly", "butter"])
|
261
|
+
>>> np.array(seq1)
|
262
|
+
array([1, 2], dtype=int32)
|
263
|
+
>>> np.array(seq2)
|
264
|
+
array([1, 0], dtype=int32)
|
265
|
+
|
266
|
+
Detokenize
|
267
|
+
>>> vocab = {"butter": 1, "fly": 2}
|
268
|
+
>>> merge = ["b u", "t t", "e r", "bu tt", "butt er", "f l", "fl y"]
|
269
|
+
>>> tokenizer = keras_hub.tokenizers.BytePairTokenizer(vocab, merge)
|
270
|
+
>>> tokenizer.detokenize([[1, 2]])
|
271
|
+
<tf.Tensor: shape=(1,), dtype=string, numpy=array([b'butterfly'],
|
272
|
+
dtype=object)>
|
273
|
+
"""
|
274
|
+
|
275
|
+
def __init__(
|
276
|
+
self,
|
277
|
+
vocabulary=None,
|
278
|
+
merges=None,
|
279
|
+
sequence_length=None,
|
280
|
+
add_prefix_space=False,
|
281
|
+
unsplittable_tokens=None,
|
282
|
+
dtype="int32",
|
283
|
+
**kwargs,
|
284
|
+
) -> None:
|
285
|
+
if not is_int_dtype(dtype) and not is_string_dtype(dtype):
|
286
|
+
raise ValueError(
|
287
|
+
"Output dtype must be an integer type or a string. "
|
288
|
+
f"Received: dtype={dtype}"
|
289
|
+
)
|
290
|
+
|
291
|
+
super().__init__(dtype=dtype, **kwargs)
|
292
|
+
self.sequence_length = sequence_length
|
293
|
+
self.add_prefix_space = add_prefix_space
|
294
|
+
self.unsplittable_tokens = unsplittable_tokens
|
295
|
+
self.file_assets = [VOCAB_FILENAME, MERGES_FILENAME]
|
296
|
+
|
297
|
+
# Create byte <=> unicode mapping. This is useful for handling
|
298
|
+
# whitespace tokens.
|
299
|
+
byte_list, unicode_list = bytes_to_unicode()
|
300
|
+
self.byte2unicode = create_static_hashtable(
|
301
|
+
byte_list, unicode_list, default=""
|
302
|
+
)
|
303
|
+
self.unicode2byte = create_static_hashtable(
|
304
|
+
unicode_list, byte_list, default=""
|
305
|
+
)
|
306
|
+
|
307
|
+
self.set_vocabulary_and_merges(vocabulary, merges)
|
308
|
+
|
309
|
+
def save_assets(self, dir_path):
|
310
|
+
vocab_path = os.path.join(dir_path, VOCAB_FILENAME)
|
311
|
+
merges_path = os.path.join(dir_path, MERGES_FILENAME)
|
312
|
+
with open(vocab_path, "w", encoding="utf-8") as file:
|
313
|
+
file.write(json.dumps(dict(self.vocabulary)))
|
314
|
+
with open(merges_path, "w", encoding="utf-8") as file:
|
315
|
+
for merge in self.merges:
|
316
|
+
file.write(f"{merge}\n")
|
317
|
+
|
318
|
+
def load_assets(self, dir_path):
|
319
|
+
vocab_path = os.path.join(dir_path, VOCAB_FILENAME)
|
320
|
+
merges_path = os.path.join(dir_path, MERGES_FILENAME)
|
321
|
+
self.set_vocabulary_and_merges(vocab_path, merges_path)
|
322
|
+
|
323
|
+
def set_vocabulary_and_merges(self, vocabulary, merges):
|
324
|
+
"""Set the vocabulary and merge rules from data or files."""
|
325
|
+
if vocabulary is None or merges is None:
|
326
|
+
# Clear vocab related state.
|
327
|
+
self.vocabulary = None
|
328
|
+
self.merges = None
|
329
|
+
self.cache = None
|
330
|
+
self.id_to_token_map = None
|
331
|
+
self.token_to_id_map = None
|
332
|
+
self.merge_ranks_lookup_default = None
|
333
|
+
self.merge_ranks = None
|
334
|
+
return
|
335
|
+
|
336
|
+
if isinstance(vocabulary, str):
|
337
|
+
with open(vocabulary, "r", encoding="utf-8") as f:
|
338
|
+
self.vocabulary = json.load(f)
|
339
|
+
elif isinstance(vocabulary, dict):
|
340
|
+
self.vocabulary = vocabulary.copy()
|
341
|
+
else:
|
342
|
+
raise ValueError(
|
343
|
+
"Vocabulary must be an file path or dictionary mapping string "
|
344
|
+
"token to int ids. Received: "
|
345
|
+
f"`type(vocabulary)={type(vocabulary)}`."
|
346
|
+
)
|
347
|
+
if isinstance(merges, str):
|
348
|
+
with open(merges, encoding="utf-8") as f:
|
349
|
+
self.merges = [bp.rstrip() for bp in f]
|
350
|
+
elif isinstance(merges, Iterable):
|
351
|
+
self.merges = list(merges)
|
352
|
+
else:
|
353
|
+
raise ValueError(
|
354
|
+
"Merges must be a file path or a list of merge rules. "
|
355
|
+
f"Received: `type(merges)={type(merges)}`"
|
356
|
+
)
|
357
|
+
|
358
|
+
self.cache = BytePairTokenizerCache()
|
359
|
+
if self.unsplittable_tokens:
|
360
|
+
# Put special tokens into cache, so it won't be further split and
|
361
|
+
# merged.
|
362
|
+
self.cache.insert(
|
363
|
+
self.unsplittable_tokens, self.unsplittable_tokens
|
364
|
+
)
|
365
|
+
|
366
|
+
# Create mapping between string tokens to int ids, and vice versa.
|
367
|
+
byte_pairs = [x[0] for x in self.vocabulary.items()]
|
368
|
+
byte_pair_encoding_indices = [x[1] for x in self.vocabulary.items()]
|
369
|
+
self.token_to_id_map = create_static_hashtable(
|
370
|
+
byte_pairs,
|
371
|
+
byte_pair_encoding_indices,
|
372
|
+
default=-1,
|
373
|
+
)
|
374
|
+
self.id_to_token_map = create_static_hashtable(
|
375
|
+
byte_pair_encoding_indices,
|
376
|
+
byte_pairs,
|
377
|
+
default="",
|
378
|
+
)
|
379
|
+
|
380
|
+
# Create ranking of merge rules, this is the same as order of merge
|
381
|
+
# pairs in `self.merges`.
|
382
|
+
self.merge_ranks_lookup_default = len(self.merges) + 1
|
383
|
+
self.merge_ranks = create_static_hashtable(
|
384
|
+
self.merges,
|
385
|
+
list(range(len(self.merges))),
|
386
|
+
default=self.merge_ranks_lookup_default,
|
387
|
+
)
|
388
|
+
|
389
|
+
def get_vocabulary(self):
|
390
|
+
"""Get the tokenizer vocabulary as a list of strings tokens."""
|
391
|
+
self._check_vocabulary()
|
392
|
+
return self.vocabulary.keys()
|
393
|
+
|
394
|
+
def vocabulary_size(self):
|
395
|
+
"""Get the integer size of the tokenizer vocabulary."""
|
396
|
+
self._check_vocabulary()
|
397
|
+
return len(self.vocabulary)
|
398
|
+
|
399
|
+
def id_to_token(self, id):
|
400
|
+
"""Convert an integer id to a string token."""
|
401
|
+
# This will be slow, but keep memory usage down compared to building a
|
402
|
+
# dict. Assuming the main use case is looking up a few special tokens
|
403
|
+
# early in the vocab, this should be fine.
|
404
|
+
self._check_vocabulary()
|
405
|
+
|
406
|
+
keys = self.get_vocabulary()
|
407
|
+
for token in keys:
|
408
|
+
if self.vocabulary[token] == id:
|
409
|
+
return token
|
410
|
+
raise ValueError(f"`id` is out of the vocabulary. Received: {id}")
|
411
|
+
|
412
|
+
def token_to_id(self, token):
|
413
|
+
"""Convert a string token to an integer id."""
|
414
|
+
self._check_vocabulary()
|
415
|
+
return self.vocabulary[token]
|
416
|
+
|
417
|
+
def _bpe_merge_one_step(self, words, mask):
|
418
|
+
"""Perform one step of byte-pair merge."""
|
419
|
+
# Get all word pairs.
|
420
|
+
first, second = words[:, :-1], words[:, 1:]
|
421
|
+
|
422
|
+
# Mask empty.
|
423
|
+
non_empty_mask = second.nested_row_lengths()[0] != 0
|
424
|
+
mask = mask & non_empty_mask
|
425
|
+
if not tf.reduce_any(mask):
|
426
|
+
return [words, mask]
|
427
|
+
non_empty_indices = tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask)
|
428
|
+
filterd_first = tf.ragged.boolean_mask(first, mask)
|
429
|
+
filtered_second = tf.ragged.boolean_mask(second, mask)
|
430
|
+
|
431
|
+
# Get byte pair ranking in merge rules.
|
432
|
+
pairs = tf.strings.join([filterd_first, filtered_second], separator=" ")
|
433
|
+
pair_rank = self.merge_ranks.lookup(pairs)
|
434
|
+
|
435
|
+
# Get BPE pair ranks.
|
436
|
+
min_pair_rank = tf.reduce_min(pair_rank, axis=1)
|
437
|
+
pair_found_mask = min_pair_rank != self.merge_ranks_lookup_default
|
438
|
+
|
439
|
+
# Tokens that cannot be further merged are marked as finished.
|
440
|
+
mask = tf.tensor_scatter_nd_update(
|
441
|
+
mask, tf.expand_dims(non_empty_indices, axis=1), pair_found_mask
|
442
|
+
)
|
443
|
+
if not tf.math.reduce_any(mask):
|
444
|
+
return [words, mask]
|
445
|
+
|
446
|
+
masked_pair_rank = tf.ragged.boolean_mask(pair_rank, pair_found_mask)
|
447
|
+
min_pair_rank_indices = tf.math.argmin(
|
448
|
+
masked_pair_rank.to_tensor(self.merge_ranks_lookup_default), axis=1
|
449
|
+
)
|
450
|
+
|
451
|
+
# Get words and pairs to process.
|
452
|
+
unfinished_words = tf.ragged.boolean_mask(words, mask)
|
453
|
+
|
454
|
+
pair_left = tf.gather(
|
455
|
+
unfinished_words, min_pair_rank_indices, batch_dims=1
|
456
|
+
)
|
457
|
+
pair_right = tf.gather(
|
458
|
+
unfinished_words, min_pair_rank_indices + 1, batch_dims=1
|
459
|
+
)
|
460
|
+
|
461
|
+
merged_pairs = tf.strings.join([pair_left, pair_right])
|
462
|
+
empty_strs = tf.fill(tf.shape(merged_pairs), "")
|
463
|
+
|
464
|
+
unfinished_word_indices = tf.cast(
|
465
|
+
tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask), dtype="int64"
|
466
|
+
)
|
467
|
+
merged_pair_indices = tf.concat(
|
468
|
+
[
|
469
|
+
unfinished_word_indices[:, tf.newaxis],
|
470
|
+
min_pair_rank_indices[:, tf.newaxis],
|
471
|
+
],
|
472
|
+
axis=1,
|
473
|
+
)
|
474
|
+
empty_string_indices = tf.concat(
|
475
|
+
[
|
476
|
+
unfinished_word_indices[:, tf.newaxis],
|
477
|
+
min_pair_rank_indices[:, tf.newaxis] + 1,
|
478
|
+
],
|
479
|
+
axis=1,
|
480
|
+
)
|
481
|
+
|
482
|
+
tensor_words = words.to_tensor(default_value="")
|
483
|
+
tensor_words = tf.tensor_scatter_nd_update(
|
484
|
+
tensor_words,
|
485
|
+
merged_pair_indices,
|
486
|
+
merged_pairs,
|
487
|
+
)
|
488
|
+
|
489
|
+
words = tf.tensor_scatter_nd_update(
|
490
|
+
tensor_words,
|
491
|
+
empty_string_indices,
|
492
|
+
empty_strs,
|
493
|
+
)
|
494
|
+
# Remove empty strings.
|
495
|
+
words = remove_strings_from_inputs(words, "")
|
496
|
+
return [words, mask]
|
497
|
+
|
498
|
+
def _bpe_merge(self, inputs):
|
499
|
+
"""Perform byte-pair merge for each word in the inputs."""
|
500
|
+
num_words = tf.shape(inputs)[0]
|
501
|
+
|
502
|
+
# Merge bytes.
|
503
|
+
def loop_condition(_, mask):
|
504
|
+
return tf.math.reduce_any(mask)
|
505
|
+
|
506
|
+
initial_mask = tf.fill((num_words,), True)
|
507
|
+
merged_words, _ = tf.while_loop(
|
508
|
+
loop_condition,
|
509
|
+
tf.function(self._bpe_merge_one_step),
|
510
|
+
loop_vars=[
|
511
|
+
inputs,
|
512
|
+
initial_mask,
|
513
|
+
],
|
514
|
+
shape_invariants=[
|
515
|
+
tf.TensorShape([None, None]),
|
516
|
+
tf.TensorShape([None]),
|
517
|
+
],
|
518
|
+
)
|
519
|
+
return merged_words
|
520
|
+
|
521
|
+
def _check_vocabulary(self):
|
522
|
+
if self.vocabulary is None:
|
523
|
+
raise ValueError(
|
524
|
+
"No vocabulary has been set for BytePairTokenizer. Make sure "
|
525
|
+
"to pass `vocabulary` and `merges` arguments when creating the "
|
526
|
+
"layer."
|
527
|
+
)
|
528
|
+
|
529
|
+
def tokenize(self, inputs):
|
530
|
+
self._check_vocabulary()
|
531
|
+
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
|
532
|
+
inputs = tf.convert_to_tensor(inputs)
|
533
|
+
|
534
|
+
if self.add_prefix_space:
|
535
|
+
inputs = tf.strings.join([" ", inputs])
|
536
|
+
|
537
|
+
scalar_input = inputs.shape.rank == 0
|
538
|
+
if scalar_input:
|
539
|
+
inputs = tf.expand_dims(inputs, 0)
|
540
|
+
|
541
|
+
raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens)
|
542
|
+
token_row_splits = raw_tokens.row_splits
|
543
|
+
flat_tokens = raw_tokens.flat_values
|
544
|
+
|
545
|
+
# Check cache.
|
546
|
+
cache_lookup = self.cache.lookup(flat_tokens)
|
547
|
+
cache_mask = cache_lookup == ""
|
548
|
+
|
549
|
+
has_unseen_words = tf.math.reduce_any(
|
550
|
+
(cache_lookup == "") & (flat_tokens != "")
|
551
|
+
)
|
552
|
+
|
553
|
+
def process_unseen_tokens():
|
554
|
+
unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask)
|
555
|
+
self._bpe_merge_and_update_cache(unseen_tokens)
|
556
|
+
return self.cache.lookup(flat_tokens)
|
557
|
+
|
558
|
+
# If `has_unseen_words == True`, it means not all tokens are in cache,
|
559
|
+
# we will process the unseen tokens. Otherwise return the cache lookup.
|
560
|
+
tokenized_words = tf.cond(
|
561
|
+
has_unseen_words,
|
562
|
+
process_unseen_tokens,
|
563
|
+
lambda: cache_lookup,
|
564
|
+
)
|
565
|
+
|
566
|
+
tokens = tf.strings.split(tokenized_words, sep=" ")
|
567
|
+
if self.compute_dtype != tf.string:
|
568
|
+
# Encode merged tokens.
|
569
|
+
tokens = self.token_to_id_map.lookup(tokens)
|
570
|
+
|
571
|
+
# Unflatten to match input.
|
572
|
+
tokens = tf.RaggedTensor.from_row_splits(
|
573
|
+
tokens.flat_values,
|
574
|
+
tf.gather(tokens.row_splits, token_row_splits),
|
575
|
+
)
|
576
|
+
|
577
|
+
# Convert to a dense output if `sequence_length` is set.
|
578
|
+
if self.sequence_length:
|
579
|
+
output_shape = tokens.shape.as_list()
|
580
|
+
output_shape[-1] = self.sequence_length
|
581
|
+
tokens = tokens.to_tensor(shape=output_shape)
|
582
|
+
|
583
|
+
# Convert to a dense output if input in scalar
|
584
|
+
if scalar_input:
|
585
|
+
tokens = tf.squeeze(tokens, 0)
|
586
|
+
tf.ensure_shape(tokens, shape=[self.sequence_length])
|
587
|
+
|
588
|
+
return tokens
|
589
|
+
|
590
|
+
def detokenize(self, inputs):
|
591
|
+
self._check_vocabulary()
|
592
|
+
inputs, unbatched, _ = convert_to_ragged_batch(inputs)
|
593
|
+
inputs = tf.cast(inputs, self.dtype)
|
594
|
+
unicode_text = tf.strings.reduce_join(
|
595
|
+
self.id_to_token_map.lookup(inputs), axis=-1
|
596
|
+
)
|
597
|
+
split_unicode_text = tf.strings.unicode_split(unicode_text, "UTF-8")
|
598
|
+
outputs = tf.strings.reduce_join(
|
599
|
+
self.unicode2byte.lookup(split_unicode_text), axis=-1
|
600
|
+
)
|
601
|
+
|
602
|
+
if unbatched:
|
603
|
+
outputs = tf.squeeze(outputs, 0)
|
604
|
+
return outputs
|
605
|
+
|
606
|
+
def compute_output_spec(self, input_spec):
|
607
|
+
return keras.KerasTensor(
|
608
|
+
input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype
|
609
|
+
)
|
610
|
+
|
611
|
+
def _transform_bytes(self, tokens):
|
612
|
+
"""Map token bytes to unicode using `byte2unicode`."""
|
613
|
+
split_bytes = tf.strings.bytes_split(tokens)
|
614
|
+
split_unicode = self.byte2unicode.lookup(split_bytes)
|
615
|
+
return split_unicode
|
616
|
+
|
617
|
+
def _bpe_merge_and_update_cache(self, tokens):
|
618
|
+
"""Process unseen tokens and add to cache."""
|
619
|
+
words = self._transform_bytes(tokens)
|
620
|
+
tokenized_words = self._bpe_merge(words)
|
621
|
+
|
622
|
+
# For each word, join all its token by a whitespace,
|
623
|
+
# e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
|
624
|
+
tokenized_words = tf.strings.reduce_join(
|
625
|
+
tokenized_words, axis=1, separator=" "
|
626
|
+
)
|
627
|
+
self.cache.insert(tokens, tokenized_words)
|
628
|
+
|
629
|
+
def get_config(self):
|
630
|
+
config = super().get_config()
|
631
|
+
config.update(
|
632
|
+
{
|
633
|
+
"sequence_length": self.sequence_length,
|
634
|
+
"add_prefix_space": self.add_prefix_space,
|
635
|
+
"unsplittable_tokens": self.unsplittable_tokens,
|
636
|
+
}
|
637
|
+
)
|
638
|
+
return config
|