keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,299 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import numpy as np
|
16
|
+
|
17
|
+
try:
|
18
|
+
import tensorflow as tf
|
19
|
+
except ImportError:
|
20
|
+
raise ImportError(
|
21
|
+
"To use `keras_hub`, please install Tensorflow: `pip install tensorflow`. "
|
22
|
+
"The TensorFlow package is required for data preprocessing with any backend."
|
23
|
+
)
|
24
|
+
|
25
|
+
from keras_hub.src.api_export import keras_hub_export
|
26
|
+
from keras_hub.src.tokenizers import tokenizer
|
27
|
+
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
|
28
|
+
from keras_hub.src.utils.tensor_utils import is_int_dtype
|
29
|
+
|
30
|
+
try:
|
31
|
+
import tensorflow_text as tf_text
|
32
|
+
except ImportError:
|
33
|
+
tf_text = None
|
34
|
+
|
35
|
+
|
36
|
+
@keras_hub_export("keras_hub.tokenizers.ByteTokenizer")
|
37
|
+
class ByteTokenizer(tokenizer.Tokenizer):
|
38
|
+
"""Raw byte tokenizer.
|
39
|
+
|
40
|
+
This tokenizer is a vocabulary-free tokenizer which will tokenize text as
|
41
|
+
as raw bytes from [0, 256).
|
42
|
+
|
43
|
+
Tokenizer outputs can either be padded and truncated with a
|
44
|
+
`sequence_length` argument, or left un-truncated. The exact output will
|
45
|
+
depend on the rank of the input tensors.
|
46
|
+
|
47
|
+
If input is a batch of strings:
|
48
|
+
By default, the layer will output a `tf.RaggedTensor` where the last
|
49
|
+
dimension of the output is ragged. If `sequence_length` is set, the layer
|
50
|
+
will output a dense `tf.Tensor` where all inputs have been padded or
|
51
|
+
truncated to `sequence_length`.
|
52
|
+
|
53
|
+
If input is a scalar string:
|
54
|
+
There are two cases here. If `sequence_length` is set, the output will be
|
55
|
+
a dense `tf.Tensor` of shape `[sequence_length]`. Otherwise, the output will
|
56
|
+
be a dense `tf.Tensor` of shape `[None]`.
|
57
|
+
|
58
|
+
The output dtype can be controlled via the
|
59
|
+
`dtype` argument, which should be an integer type
|
60
|
+
("int16", "int32", etc.).
|
61
|
+
|
62
|
+
Args:
|
63
|
+
lowercase: boolean. If True, the input text will be converted to
|
64
|
+
lowercase before tokenization.
|
65
|
+
sequence_length: int. If set, the output will be converted to a dense
|
66
|
+
tensor and padded/trimmed so all outputs are of sequence_length.
|
67
|
+
normalization_form: string. One of the following values: (None, "NFC",
|
68
|
+
"NFKC", "NFD", "NFKD"). If set, every UTF-8 string in the input
|
69
|
+
tensor text will be normalized to the given form before tokenizing.
|
70
|
+
errors: One of ('replace', 'remove', 'strict'). Specifies the
|
71
|
+
`detokenize()` behavior when an invalid tokenizer is encountered.
|
72
|
+
The value of `'strict'` will cause the operation to produce a
|
73
|
+
`InvalidArgument` error on any invalid input formatting. A value of
|
74
|
+
`'replace'` will cause the tokenizer to replace any invalid
|
75
|
+
formatting in the input with the `replacement_char` codepoint.
|
76
|
+
A value of `'ignore'` will cause the tokenizer to skip any invalid
|
77
|
+
formatting in the input and produce no corresponding output
|
78
|
+
character.
|
79
|
+
replacement_char: int. The replacement character to
|
80
|
+
use when an invalid byte sequence is encountered and when `errors`
|
81
|
+
is set to "replace" (same behaviour as
|
82
|
+
https://www.tensorflow.org/api_docs/python/tf/strings/unicode_transcode).
|
83
|
+
(U+FFFD) is `65533`. Defaults to `65533`.
|
84
|
+
|
85
|
+
Examples:
|
86
|
+
|
87
|
+
Basic usage.
|
88
|
+
>>> tokenizer = keras_hub.tokenizers.ByteTokenizer()
|
89
|
+
>>> outputs = tokenizer("hello")
|
90
|
+
>>> np.array(outputs)
|
91
|
+
array([104, 101, 108, 108, 111], dtype=int32)
|
92
|
+
|
93
|
+
Ragged outputs.
|
94
|
+
>>> inputs = ["hello", "hi"]
|
95
|
+
>>> tokenizer = keras_hub.tokenizers.ByteTokenizer()
|
96
|
+
>>> seq1, seq2 = tokenizer(inputs)
|
97
|
+
>>> np.array(seq1)
|
98
|
+
array([104, 101, 108, 108, 111], dtype=int32)
|
99
|
+
>>> np.array(seq2)
|
100
|
+
array([104, 105], dtype=int32)
|
101
|
+
|
102
|
+
Dense outputs.
|
103
|
+
>>> inputs = ["hello", "hi"]
|
104
|
+
>>> tokenizer = keras_hub.tokenizers.ByteTokenizer(sequence_length=8)
|
105
|
+
>>> seq1, seq2 = tokenizer(inputs)
|
106
|
+
>>> np.array(seq1)
|
107
|
+
array([104, 101, 108, 108, 111, 0, 0, 0], dtype=int32)
|
108
|
+
>>> np.array(seq2)
|
109
|
+
array([104, 105, 0, 0, 0, 0, 0, 0], dtype=int32)
|
110
|
+
|
111
|
+
Tokenize, then batch for ragged outputs.
|
112
|
+
>>> tokenizer = keras_hub.tokenizers.ByteTokenizer()
|
113
|
+
>>> ds = tf.data.Dataset.from_tensor_slices(["hello", "fun"])
|
114
|
+
>>> ds = ds.map(tokenizer)
|
115
|
+
>>> ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(2))
|
116
|
+
>>> ds.take(1).get_single_element()
|
117
|
+
<tf.RaggedTensor [[104, 101, 108, 108, 111], [102, 117, 110]]>
|
118
|
+
|
119
|
+
Batch, then tokenize for ragged outputs.
|
120
|
+
>>> tokenizer = keras_hub.tokenizers.ByteTokenizer()
|
121
|
+
>>> ds = tf.data.Dataset.from_tensor_slices(["hello", "fun"])
|
122
|
+
>>> ds = ds.batch(2).map(tokenizer)
|
123
|
+
>>> ds.take(1).get_single_element()
|
124
|
+
<tf.RaggedTensor [[104, 101, 108, 108, 111], [102, 117, 110]]>
|
125
|
+
|
126
|
+
Tokenize, then batch for dense outputs (`sequence_length` provided).
|
127
|
+
>>> tokenizer = keras_hub.tokenizers.ByteTokenizer(sequence_length=5)
|
128
|
+
>>> ds = tf.data.Dataset.from_tensor_slices(["hello", "fun"])
|
129
|
+
>>> ds = ds.map(tokenizer)
|
130
|
+
>>> ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(2))
|
131
|
+
>>> ds.take(1).get_single_element()
|
132
|
+
<tf.Tensor: shape=(2, 5), dtype=int32, numpy=
|
133
|
+
array([[104, 101, 108, 108, 111],
|
134
|
+
[102, 117, 110, 0, 0]], dtype=int32)>
|
135
|
+
|
136
|
+
Batch, then tokenize for dense outputs. (`sequence_length` provided).
|
137
|
+
>>> tokenizer = keras_hub.tokenizers.ByteTokenizer(sequence_length=5)
|
138
|
+
>>> ds = tf.data.Dataset.from_tensor_slices(["hello", "fun"])
|
139
|
+
>>> ds = ds.batch(2).map(tokenizer)
|
140
|
+
>>> ds.take(1).get_single_element()
|
141
|
+
<tf.Tensor: shape=(2, 5), dtype=int32, numpy=
|
142
|
+
array([[104, 101, 108, 108, 111],
|
143
|
+
[102, 117, 110, 0, 0]], dtype=int32)>
|
144
|
+
|
145
|
+
Detokenization.
|
146
|
+
>>> inputs = [104, 101, 108, 108, 111]
|
147
|
+
>>> tokenizer = keras_hub.tokenizers.ByteTokenizer()
|
148
|
+
>>> outputs = tokenizer.detokenize(inputs)
|
149
|
+
>>> np.array(outputs).astype("U")
|
150
|
+
array('hello', dtype='<U5')
|
151
|
+
|
152
|
+
Detokenization with invalid bytes.
|
153
|
+
>>> # The 255 below is invalid utf-8.
|
154
|
+
>>> inputs = [104, 101, 255, 108, 108, 111]
|
155
|
+
>>> tokenizer = keras_hub.tokenizers.ByteTokenizer(
|
156
|
+
... errors="replace", replacement_char=88)
|
157
|
+
>>> outputs = tokenizer.detokenize(inputs)
|
158
|
+
>>> np.array(outputs).astype("U")
|
159
|
+
array('heXllo', dtype='<U6')
|
160
|
+
"""
|
161
|
+
|
162
|
+
def __init__(
|
163
|
+
self,
|
164
|
+
lowercase=True,
|
165
|
+
sequence_length=None,
|
166
|
+
normalization_form=None,
|
167
|
+
errors="replace",
|
168
|
+
replacement_char=65533,
|
169
|
+
dtype="int32",
|
170
|
+
**kwargs,
|
171
|
+
):
|
172
|
+
if not is_int_dtype(dtype):
|
173
|
+
raise ValueError(
|
174
|
+
"Output dtype must be an integer type. "
|
175
|
+
f"Received: dtype={dtype}"
|
176
|
+
)
|
177
|
+
|
178
|
+
# Check normalization_form.
|
179
|
+
if normalization_form not in (None, "NFC", "NFKC", "NFD", "NFKD"):
|
180
|
+
raise ValueError(
|
181
|
+
'`normalization_form` must be one of None, "NFC", "NFKC", '
|
182
|
+
'"NFD", "NFKD". Received: normalization_form='
|
183
|
+
f"{normalization_form}"
|
184
|
+
)
|
185
|
+
|
186
|
+
# Check errors.
|
187
|
+
if errors not in ("strict", "replace", "ignore"):
|
188
|
+
raise ValueError(
|
189
|
+
'`errors` must be one of "strict", "replace", "ignore" '
|
190
|
+
f"Received: errors={errors}"
|
191
|
+
)
|
192
|
+
|
193
|
+
super().__init__(dtype=dtype, **kwargs)
|
194
|
+
|
195
|
+
self.lowercase = lowercase
|
196
|
+
self.sequence_length = sequence_length
|
197
|
+
self.normalization_form = normalization_form
|
198
|
+
self.errors = errors
|
199
|
+
self.replacement_char = replacement_char
|
200
|
+
|
201
|
+
self._char_lst = tf.constant(
|
202
|
+
[i.tobytes() for i in np.arange(256, dtype=np.uint8)]
|
203
|
+
)
|
204
|
+
|
205
|
+
def vocabulary_size(self):
|
206
|
+
"""Get the integer size of the tokenizer vocabulary."""
|
207
|
+
return 256
|
208
|
+
|
209
|
+
def get_vocabulary(self):
|
210
|
+
vocab = {}
|
211
|
+
for i in range(self.vocabulary_size()):
|
212
|
+
vocab[chr(i)] = i
|
213
|
+
return vocab
|
214
|
+
|
215
|
+
def tokenize(self, inputs):
|
216
|
+
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
|
217
|
+
inputs = tf.convert_to_tensor(inputs)
|
218
|
+
|
219
|
+
scalar_input = inputs.shape.rank == 0
|
220
|
+
if scalar_input:
|
221
|
+
inputs = tf.expand_dims(inputs, 0)
|
222
|
+
|
223
|
+
# Optional: Lowercase the input.
|
224
|
+
if self.lowercase:
|
225
|
+
inputs = tf_text.case_fold_utf8(inputs)
|
226
|
+
|
227
|
+
# Optional: Normalize unicode.
|
228
|
+
if self.normalization_form is not None:
|
229
|
+
inputs = tf_text.normalize_utf8(inputs, self.normalization_form)
|
230
|
+
|
231
|
+
# Tokenize input strings.
|
232
|
+
tokens = tf.strings.bytes_split(inputs)
|
233
|
+
tokens = tf.squeeze(
|
234
|
+
tf.ragged.map_flat_values(tf.io.decode_raw, tokens, tf.uint8), -1
|
235
|
+
)
|
236
|
+
tokens = tf.cast(tokens, self.compute_dtype)
|
237
|
+
|
238
|
+
# Convert to a dense output if `sequence_length` is set.
|
239
|
+
if self.sequence_length:
|
240
|
+
output_shape = tokens.shape.as_list()
|
241
|
+
output_shape[-1] = self.sequence_length
|
242
|
+
tokens = tokens.to_tensor(shape=output_shape)
|
243
|
+
|
244
|
+
if scalar_input:
|
245
|
+
tokens = tf.squeeze(tokens, 0)
|
246
|
+
return tokens
|
247
|
+
|
248
|
+
def detokenize(self, inputs):
|
249
|
+
inputs, unbatched, _ = convert_to_ragged_batch(inputs)
|
250
|
+
# Remove trailing padding tokens, so that trailing "\x00" bytes don't
|
251
|
+
# show up in the detokenized output.
|
252
|
+
inputs = tf.ragged.boolean_mask(inputs, tf.not_equal(inputs, 0))
|
253
|
+
|
254
|
+
outputs = tf.strings.reduce_join(
|
255
|
+
tf.gather(self._char_lst, inputs), axis=-1
|
256
|
+
)
|
257
|
+
|
258
|
+
# Handle errors if an invalid byte sequence is encountered.
|
259
|
+
outputs = tf.strings.unicode_transcode(
|
260
|
+
outputs,
|
261
|
+
"UTF-8",
|
262
|
+
"UTF-8",
|
263
|
+
errors=self.errors,
|
264
|
+
replacement_char=self.replacement_char,
|
265
|
+
)
|
266
|
+
if unbatched:
|
267
|
+
outputs = tf.squeeze(outputs, 0)
|
268
|
+
return outputs
|
269
|
+
|
270
|
+
def id_to_token(self, id):
|
271
|
+
"""Convert an integer id to a string token."""
|
272
|
+
if id >= self.vocabulary_size() or id < 0:
|
273
|
+
raise ValueError(
|
274
|
+
f"`id` must be in range [0, {self.vocabulary_size() - 1}]. "
|
275
|
+
f"Received: {id}"
|
276
|
+
)
|
277
|
+
return chr(id)
|
278
|
+
|
279
|
+
def token_to_id(self, token):
|
280
|
+
"""Convert a string token to an integer id."""
|
281
|
+
id = ord(token)
|
282
|
+
if id >= self.vocabulary_size():
|
283
|
+
raise ValueError(
|
284
|
+
f"Token {token} is not supported by `ByteTokenizer`."
|
285
|
+
)
|
286
|
+
return id
|
287
|
+
|
288
|
+
def get_config(self):
|
289
|
+
config = super().get_config()
|
290
|
+
config.update(
|
291
|
+
{
|
292
|
+
"lowercase": self.lowercase,
|
293
|
+
"sequence_length": self.sequence_length,
|
294
|
+
"normalization_form": self.normalization_form,
|
295
|
+
"errors": self.errors,
|
296
|
+
"replacement_char": self.replacement_char,
|
297
|
+
}
|
298
|
+
)
|
299
|
+
return config
|
@@ -0,0 +1,267 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import base64
|
16
|
+
import binascii
|
17
|
+
import os
|
18
|
+
|
19
|
+
import keras
|
20
|
+
|
21
|
+
try:
|
22
|
+
import tensorflow as tf
|
23
|
+
except ImportError:
|
24
|
+
raise ImportError(
|
25
|
+
"To use `keras_hub`, please install Tensorflow: `pip install tensorflow`. "
|
26
|
+
"The TensorFlow package is required for data preprocessing with any backend."
|
27
|
+
)
|
28
|
+
|
29
|
+
from keras_hub.src.api_export import keras_hub_export
|
30
|
+
from keras_hub.src.tokenizers import tokenizer
|
31
|
+
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
|
32
|
+
from keras_hub.src.utils.tensor_utils import is_int_dtype
|
33
|
+
from keras_hub.src.utils.tensor_utils import is_string_dtype
|
34
|
+
from keras_hub.src.utils.tensor_utils import tensor_to_list
|
35
|
+
|
36
|
+
try:
|
37
|
+
import tensorflow_text as tf_text
|
38
|
+
except ImportError:
|
39
|
+
tf_text = None
|
40
|
+
|
41
|
+
|
42
|
+
VOCAB_FILENAME = "vocabulary.spm"
|
43
|
+
|
44
|
+
|
45
|
+
@keras_hub_export("keras_hub.tokenizers.SentencePieceTokenizer")
|
46
|
+
class SentencePieceTokenizer(tokenizer.Tokenizer):
|
47
|
+
"""A SentencePiece tokenizer layer.
|
48
|
+
|
49
|
+
This layer provides an implementation of SentencePiece tokenization
|
50
|
+
as described in the [SentencePiece paper](https://arxiv.org/abs/1808.06226)
|
51
|
+
and the [SentencePiece package](https://pypi.org/project/sentencepiece/).
|
52
|
+
The tokenization will run entirely within the Tensorflow graph, and can
|
53
|
+
be saved inside a `keras.Model`.
|
54
|
+
|
55
|
+
By default, the layer will output a `tf.RaggedTensor` where the last
|
56
|
+
dimension of the output is ragged after whitespace splitting and sub-word
|
57
|
+
tokenizing. If `sequence_length` is set, the layer will output a dense
|
58
|
+
`tf.Tensor` where all inputs have been padded or truncated to
|
59
|
+
`sequence_length`. The output dtype can be controlled via the `dtype`
|
60
|
+
argument, which should be either an integer or string type.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
proto: Either a `string` path to a SentencePiece proto file, or a
|
64
|
+
`bytes` object with a serialized SentencePiece proto. See the
|
65
|
+
[SentencePiece repository](https://github.com/google/sentencepiece)
|
66
|
+
for more details on the format.
|
67
|
+
sequence_length: If set, the output will be converted to a dense
|
68
|
+
tensor and padded/trimmed so all outputs are of `sequence_length`.
|
69
|
+
|
70
|
+
References:
|
71
|
+
- [Kudo and Richardson, 2018](https://arxiv.org/abs/1808.06226)
|
72
|
+
|
73
|
+
Examples:
|
74
|
+
|
75
|
+
From bytes.
|
76
|
+
```python
|
77
|
+
def train_sentence_piece_bytes(ds, size):
|
78
|
+
bytes_io = io.BytesIO()
|
79
|
+
sentencepiece.SentencePieceTrainer.train(
|
80
|
+
sentence_iterator=ds.as_numpy_iterator(),
|
81
|
+
model_writer=bytes_io,
|
82
|
+
vocab_size=size,
|
83
|
+
)
|
84
|
+
return bytes_io.getvalue()
|
85
|
+
|
86
|
+
# Train a sentencepiece proto.
|
87
|
+
ds = tf.data.Dataset.from_tensor_slices(["the quick brown fox."])
|
88
|
+
proto = train_sentence_piece_bytes(ds, 20)
|
89
|
+
# Tokenize inputs.
|
90
|
+
tokenizer = keras_hub.tokenizers.SentencePieceTokenizer(proto=proto)
|
91
|
+
ds = ds.map(tokenizer)
|
92
|
+
```
|
93
|
+
|
94
|
+
From a file.
|
95
|
+
```python
|
96
|
+
def train_sentence_piece_file(ds, path, size):
|
97
|
+
with open(path, "wb") as model_file:
|
98
|
+
sentencepiece.SentencePieceTrainer.train(
|
99
|
+
sentence_iterator=ds.as_numpy_iterator(),
|
100
|
+
model_writer=model_file,
|
101
|
+
vocab_size=size,
|
102
|
+
)
|
103
|
+
|
104
|
+
# Train a sentencepiece proto.
|
105
|
+
ds = tf.data.Dataset.from_tensor_slices(["the quick brown fox."])
|
106
|
+
proto = train_sentence_piece_file(ds, "model.spm", 20)
|
107
|
+
# Tokenize inputs.
|
108
|
+
tokenizer = keras_hub.tokenizers.SentencePieceTokenizer(proto="model.spm")
|
109
|
+
ds = ds.map(tokenizer)
|
110
|
+
```
|
111
|
+
"""
|
112
|
+
|
113
|
+
def __init__(
|
114
|
+
self,
|
115
|
+
proto=None,
|
116
|
+
sequence_length=None,
|
117
|
+
dtype="int32",
|
118
|
+
**kwargs,
|
119
|
+
) -> None:
|
120
|
+
if not is_int_dtype(dtype) and not is_string_dtype(dtype):
|
121
|
+
raise ValueError(
|
122
|
+
"Output dtype must be an integer type or a string. "
|
123
|
+
f"Received: dtype={dtype}"
|
124
|
+
)
|
125
|
+
|
126
|
+
super().__init__(dtype=dtype, **kwargs)
|
127
|
+
|
128
|
+
self.proto = None
|
129
|
+
self.sequence_length = sequence_length
|
130
|
+
self.set_proto(proto)
|
131
|
+
self.file_assets = [VOCAB_FILENAME]
|
132
|
+
|
133
|
+
def save_assets(self, dir_path):
|
134
|
+
path = os.path.join(dir_path, VOCAB_FILENAME)
|
135
|
+
with open(path, "wb") as file:
|
136
|
+
file.write(self.proto)
|
137
|
+
|
138
|
+
def load_assets(self, dir_path):
|
139
|
+
path = os.path.join(dir_path, VOCAB_FILENAME)
|
140
|
+
self.set_proto(path)
|
141
|
+
|
142
|
+
def set_proto(self, proto):
|
143
|
+
if proto is None:
|
144
|
+
self.proto = None
|
145
|
+
self._sentence_piece = None
|
146
|
+
return
|
147
|
+
|
148
|
+
if isinstance(proto, str):
|
149
|
+
# A string could be either a filepath, or a base64 encoded byte
|
150
|
+
# array (which we need for serialization). We will heuristically
|
151
|
+
# try to distinguish, by checking if a string is both longer and
|
152
|
+
# than 2048 characters and valid base64 characters.
|
153
|
+
is_base64 = False
|
154
|
+
if len(proto) > 2048:
|
155
|
+
try:
|
156
|
+
proto_bytes = base64.b64decode(proto, validate=True)
|
157
|
+
is_base64 = True
|
158
|
+
except binascii.Error:
|
159
|
+
pass
|
160
|
+
if not is_base64:
|
161
|
+
proto_bytes = open(proto, "rb").read()
|
162
|
+
elif isinstance(proto, bytes):
|
163
|
+
proto_bytes = proto
|
164
|
+
else:
|
165
|
+
raise ValueError(
|
166
|
+
"SentencePiece `proto` argument should be either a `string` "
|
167
|
+
f"filepath or a `bytes` sequence. "
|
168
|
+
f"Received unknown type: {type(proto)}"
|
169
|
+
)
|
170
|
+
|
171
|
+
self._sentence_piece = tf_text.SentencepieceTokenizer(
|
172
|
+
model=proto_bytes,
|
173
|
+
out_type=self.compute_dtype,
|
174
|
+
)
|
175
|
+
# Keras cannot serialize a bytestring, so we base64 encode the model
|
176
|
+
# byte array as a string for saving.
|
177
|
+
self.proto = proto_bytes
|
178
|
+
|
179
|
+
def vocabulary_size(self):
|
180
|
+
"""Get the integer size of the tokenizer vocabulary."""
|
181
|
+
self._check_vocabulary()
|
182
|
+
return int(self._sentence_piece.vocab_size().numpy())
|
183
|
+
|
184
|
+
def get_vocabulary(self):
|
185
|
+
"""Get the tokenizer vocabulary."""
|
186
|
+
self._check_vocabulary()
|
187
|
+
return tensor_to_list(
|
188
|
+
self._sentence_piece.id_to_string(
|
189
|
+
tf.range(int(self._sentence_piece.vocab_size().numpy()))
|
190
|
+
)
|
191
|
+
)
|
192
|
+
|
193
|
+
def id_to_token(self, id):
|
194
|
+
"""Convert an integer id to a string token."""
|
195
|
+
self._check_vocabulary()
|
196
|
+
if id >= self.vocabulary_size() or id < 0:
|
197
|
+
raise ValueError(
|
198
|
+
f"`id` must be in range [0, {self.vocabulary_size() - 1}]. "
|
199
|
+
f"Received: {id}"
|
200
|
+
)
|
201
|
+
return tensor_to_list(self._sentence_piece.id_to_string(id))
|
202
|
+
|
203
|
+
def token_to_id(self, token):
|
204
|
+
"""Convert a string token to an integer id."""
|
205
|
+
self._check_vocabulary()
|
206
|
+
return int(self._sentence_piece.string_to_id(token).numpy())
|
207
|
+
|
208
|
+
def get_config(self):
|
209
|
+
config = super().get_config()
|
210
|
+
config.update(
|
211
|
+
{
|
212
|
+
"proto": None, # Save vocabulary via an asset!
|
213
|
+
"sequence_length": self.sequence_length,
|
214
|
+
}
|
215
|
+
)
|
216
|
+
return config
|
217
|
+
|
218
|
+
def _check_vocabulary(self):
|
219
|
+
if self.proto is None:
|
220
|
+
raise ValueError(
|
221
|
+
"No vocabulary has been set for SentencePieceTokenizer. Make "
|
222
|
+
"sure to pass a `proto` argument when creating the layer."
|
223
|
+
)
|
224
|
+
|
225
|
+
def tokenize(self, inputs):
|
226
|
+
self._check_vocabulary()
|
227
|
+
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
|
228
|
+
inputs = tf.convert_to_tensor(inputs)
|
229
|
+
scalar_input = inputs.shape.rank == 0
|
230
|
+
if scalar_input:
|
231
|
+
inputs = tf.expand_dims(inputs, 0)
|
232
|
+
|
233
|
+
if self._sentence_piece is None:
|
234
|
+
raise ValueError(
|
235
|
+
"No vocabulary has been set for SentencePieceTokenizer. Make "
|
236
|
+
"sure to pass a `vocabulary` argument when creating the layer."
|
237
|
+
)
|
238
|
+
|
239
|
+
tokens = self._sentence_piece.tokenize(inputs)
|
240
|
+
|
241
|
+
# Convert to a dense output if `sequence_length` is set.
|
242
|
+
if self.sequence_length:
|
243
|
+
output_shape = tokens.shape.as_list()
|
244
|
+
output_shape[-1] = self.sequence_length
|
245
|
+
tokens = tokens.to_tensor(shape=output_shape)
|
246
|
+
|
247
|
+
# Convert to a dense output if input was a scalar.
|
248
|
+
if scalar_input:
|
249
|
+
tokens = tf.squeeze(tokens, 0)
|
250
|
+
tf.ensure_shape(tokens, shape=[self.sequence_length])
|
251
|
+
|
252
|
+
return tokens
|
253
|
+
|
254
|
+
def detokenize(self, inputs):
|
255
|
+
self._check_vocabulary()
|
256
|
+
inputs, unbatched, _ = convert_to_ragged_batch(inputs)
|
257
|
+
# tf-text sentencepiece does not handle int64.
|
258
|
+
inputs = tf.cast(inputs, "int32")
|
259
|
+
outputs = self._sentence_piece.detokenize(inputs)
|
260
|
+
if unbatched:
|
261
|
+
outputs = tf.squeeze(outputs, 0)
|
262
|
+
return outputs
|
263
|
+
|
264
|
+
def compute_output_spec(self, input_spec):
|
265
|
+
return keras.KerasTensor(
|
266
|
+
input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype
|
267
|
+
)
|