keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,544 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import os
|
16
|
+
import re
|
17
|
+
from typing import Iterable
|
18
|
+
|
19
|
+
import keras
|
20
|
+
|
21
|
+
from keras_hub.src.api_export import keras_hub_export
|
22
|
+
from keras_hub.src.tokenizers import tokenizer
|
23
|
+
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
|
24
|
+
from keras_hub.src.utils.tensor_utils import is_int_dtype
|
25
|
+
from keras_hub.src.utils.tensor_utils import is_string_dtype
|
26
|
+
|
27
|
+
try:
|
28
|
+
import tensorflow as tf
|
29
|
+
import tensorflow_text as tf_text
|
30
|
+
except ImportError:
|
31
|
+
tf = None
|
32
|
+
tf_text = None
|
33
|
+
|
34
|
+
VOCAB_FILENAME = "vocabulary.txt"
|
35
|
+
|
36
|
+
# Matches whitespace and control characters.
|
37
|
+
WHITESPACE_REGEX = r"|".join(
|
38
|
+
[
|
39
|
+
r"\s",
|
40
|
+
# Invisible control characters
|
41
|
+
r"\p{Cc}",
|
42
|
+
r"\p{Cf}",
|
43
|
+
]
|
44
|
+
)
|
45
|
+
|
46
|
+
# Matches punctuation compatible with the original bert implementation.
|
47
|
+
PUNCTUATION_REGEX = r"|".join(
|
48
|
+
[
|
49
|
+
# Treat all non-letter/number ASCII as punctuation.
|
50
|
+
# Characters such as "^", "$", and "`" are not in the Unicode
|
51
|
+
# Punctuation class but we treat them as punctuation anyways.
|
52
|
+
r"[!-/]",
|
53
|
+
r"[:-@]",
|
54
|
+
r"[\[-`]",
|
55
|
+
r"[{-~]",
|
56
|
+
# Unicode punctuation class.
|
57
|
+
r"[\p{P}]",
|
58
|
+
]
|
59
|
+
)
|
60
|
+
|
61
|
+
# Matches CJK characters. Obtained from
|
62
|
+
# https://github.com/google-research/bert/blob/master/tokenization.py#L251.
|
63
|
+
CJK_REGEX = r"|".join(
|
64
|
+
[
|
65
|
+
r"[\x{4E00}-\x{9FFF}]",
|
66
|
+
r"[\x{3400}-\x{4DBF}]",
|
67
|
+
r"[\x{20000}-\x{2A6DF}]",
|
68
|
+
r"[\x{2A700}-\x{2B73F}]",
|
69
|
+
r"[\x{2B740}-\x{2B81F}]",
|
70
|
+
r"[\x{2B820}-\x{2CEAF}]",
|
71
|
+
r"[\x{F900}-\x{FAFF}]",
|
72
|
+
r"[\x{2F800}-\x{2FA1F}]",
|
73
|
+
]
|
74
|
+
)
|
75
|
+
|
76
|
+
# Matches both whitespace and punctuation.
|
77
|
+
WHITESPACE_AND_PUNCTUATION_REGEX = r"|".join(
|
78
|
+
[
|
79
|
+
WHITESPACE_REGEX,
|
80
|
+
PUNCTUATION_REGEX,
|
81
|
+
]
|
82
|
+
)
|
83
|
+
|
84
|
+
# Matches punctuation and CJK characters.
|
85
|
+
PUNCTUATION_AND_CJK_REGEX = r"|".join(
|
86
|
+
[
|
87
|
+
PUNCTUATION_REGEX,
|
88
|
+
CJK_REGEX,
|
89
|
+
]
|
90
|
+
)
|
91
|
+
|
92
|
+
# Matches whitespace, punctuation, and CJK characters.
|
93
|
+
WHITESPACE_PUNCTUATION_AND_CJK_REGEX = r"|".join(
|
94
|
+
[
|
95
|
+
WHITESPACE_AND_PUNCTUATION_REGEX,
|
96
|
+
CJK_REGEX,
|
97
|
+
]
|
98
|
+
)
|
99
|
+
|
100
|
+
|
101
|
+
def get_special_tokens_pattern(special_tokens):
|
102
|
+
if special_tokens is None or len(special_tokens) == 0:
|
103
|
+
return None
|
104
|
+
return r"|".join([re.escape(token) for token in special_tokens])
|
105
|
+
|
106
|
+
|
107
|
+
def pretokenize(
|
108
|
+
text,
|
109
|
+
lowercase=False,
|
110
|
+
strip_accents=True,
|
111
|
+
split=True,
|
112
|
+
split_on_cjk=True,
|
113
|
+
special_tokens_pattern=None,
|
114
|
+
):
|
115
|
+
"""Helper function that takes in a dataset element and pretokenizes it.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
text: `tf.Tensor` or `tf.RaggedTensor`. Input to be pretokenized.
|
119
|
+
lowercase: bool. If True, the input text will be
|
120
|
+
lowercased before tokenization. Defaults to `True`.
|
121
|
+
strip_accents: bool. If `True`, all accent marks will
|
122
|
+
be removed from text before tokenization. Defaults to `True`.
|
123
|
+
split: bool. If `True`, input will be split on
|
124
|
+
whitespace and punctuation marks, and all punctuation marks will be
|
125
|
+
kept as tokens. If `False`, input should be split ("pre-tokenized")
|
126
|
+
before calling the tokenizer, and passed as a dense or ragged tensor
|
127
|
+
of whole words. Defaults to `True`.
|
128
|
+
split_on_cjk: bool. If `True`, input will be split
|
129
|
+
on CJK characters, i.e., Chinese, Japanese, Korean and Vietnamese
|
130
|
+
characters (https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)).
|
131
|
+
Note that this is applicable only when `split` is `True`. Defaults
|
132
|
+
to `True`.
|
133
|
+
special_tokens_pattern: str. A regex pattern that contain the
|
134
|
+
special tokens that will never be split during the word-level
|
135
|
+
splitting applied before the word-peice encoding. This can be used
|
136
|
+
to ensure special tokens map to unique indices in the vocabulary,
|
137
|
+
even if these special tokens contain splittable characters such as
|
138
|
+
punctuation.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
A tensor containing the pre-processed and pre-tokenized `text`.
|
142
|
+
"""
|
143
|
+
# Check for correct types.
|
144
|
+
if not is_string_dtype(text.dtype):
|
145
|
+
raise ValueError(
|
146
|
+
"The dataset elements in `data` must have string dtype. "
|
147
|
+
f"Received: {text.dtype}."
|
148
|
+
)
|
149
|
+
# Preprocess, lowercase, strip and split input data.
|
150
|
+
if text.shape.rank == 0:
|
151
|
+
text = tf.expand_dims(text, 0)
|
152
|
+
if split_on_cjk and split:
|
153
|
+
text = tf.strings.regex_replace(text, CJK_REGEX, r" \0 ")
|
154
|
+
if strip_accents:
|
155
|
+
# Normalize unicode to NFD, which splits out accent mark characters.
|
156
|
+
text = tf_text.normalize_utf8(text, "NFD")
|
157
|
+
# Remove the accent marks.
|
158
|
+
text = tf.strings.regex_replace(text, r"\p{Mn}", "")
|
159
|
+
if split:
|
160
|
+
if split_on_cjk:
|
161
|
+
split_pattern = WHITESPACE_PUNCTUATION_AND_CJK_REGEX
|
162
|
+
keep_split_pattern = PUNCTUATION_AND_CJK_REGEX
|
163
|
+
else:
|
164
|
+
split_pattern = WHITESPACE_AND_PUNCTUATION_REGEX
|
165
|
+
keep_split_pattern = PUNCTUATION_REGEX
|
166
|
+
if special_tokens_pattern is not None:
|
167
|
+
# the idea here is to pass the special tokens regex to the split
|
168
|
+
# function as delimiter regex pattern, so the input will be splitted
|
169
|
+
# by them, but also the function will treat each on of them as one
|
170
|
+
# entity that shouldn't be splitted even if they have other
|
171
|
+
# delimiter regex pattern inside them. then pass the special tokens
|
172
|
+
# regex also as keep delimiter regex pattern, so they will
|
173
|
+
# not be removed.
|
174
|
+
split_pattern = r"|".join(
|
175
|
+
[
|
176
|
+
special_tokens_pattern,
|
177
|
+
split_pattern,
|
178
|
+
]
|
179
|
+
)
|
180
|
+
keep_split_pattern = r"|".join(
|
181
|
+
[special_tokens_pattern, keep_split_pattern]
|
182
|
+
)
|
183
|
+
text = tf_text.regex_split(
|
184
|
+
text,
|
185
|
+
delim_regex_pattern=split_pattern,
|
186
|
+
keep_delim_regex_pattern=keep_split_pattern,
|
187
|
+
)
|
188
|
+
if lowercase:
|
189
|
+
if special_tokens_pattern is not None:
|
190
|
+
# Do not lowercase special tokens in string space. They often
|
191
|
+
# contain capital letters, e.g. `"[CLS]"`.
|
192
|
+
mask = (
|
193
|
+
tf.strings.regex_replace(text, special_tokens_pattern, "६")
|
194
|
+
== "६"
|
195
|
+
)
|
196
|
+
text = tf.where(mask, text, tf_text.case_fold_utf8(text))
|
197
|
+
else:
|
198
|
+
text = tf_text.case_fold_utf8(text)
|
199
|
+
|
200
|
+
return text
|
201
|
+
|
202
|
+
|
203
|
+
@keras_hub_export("keras_hub.tokenizers.WordPieceTokenizer")
|
204
|
+
class WordPieceTokenizer(tokenizer.Tokenizer):
|
205
|
+
"""A WordPiece tokenizer layer.
|
206
|
+
|
207
|
+
This layer provides an efficient, in graph, implementation of the WordPiece
|
208
|
+
algorithm used by BERT and other models.
|
209
|
+
|
210
|
+
To make this layer more useful out of the box, the layer will pre-tokenize
|
211
|
+
the input, which will optionally lower-case, strip accents, and split the
|
212
|
+
input on whitespace and punctuation. Each of these pre-tokenization steps is
|
213
|
+
not reversible. The `detokenize` method will join words with a space, and
|
214
|
+
will not invert `tokenize` exactly.
|
215
|
+
|
216
|
+
If a more custom pre-tokenization step is desired, the layer can be
|
217
|
+
configured to apply only the strict WordPiece algorithm by passing
|
218
|
+
`lowercase=False`, `strip_accents=False` and `split=False`. In
|
219
|
+
this case, inputs should be pre-split string tensors or ragged tensors.
|
220
|
+
|
221
|
+
Tokenizer outputs can either be padded and truncated with a
|
222
|
+
`sequence_length` argument, or left un-truncated. The exact output will
|
223
|
+
depend on the rank of the input tensors.
|
224
|
+
|
225
|
+
If input is a batch of strings (rank > 0):
|
226
|
+
By default, the layer will output a `tf.RaggedTensor` where the last
|
227
|
+
dimension of the output is ragged. If `sequence_length` is set, the layer
|
228
|
+
will output a dense `tf.Tensor` where all inputs have been padded or
|
229
|
+
truncated to `sequence_length`.
|
230
|
+
|
231
|
+
If input is a scalar string (rank == 0):
|
232
|
+
By default, the layer will output a dense `tf.Tensor` with static shape
|
233
|
+
`[None]`. If `sequence_length` is set, the output will be
|
234
|
+
a dense `tf.Tensor` of shape `[sequence_length]`.
|
235
|
+
|
236
|
+
The output dtype can be controlled via the `dtype` argument, which should
|
237
|
+
be either an integer or string type.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
vocabulary: A list of strings or a string filename path. If
|
241
|
+
passing a list, each element of the list should be a single
|
242
|
+
WordPiece token string. If passing a filename, the file should be a
|
243
|
+
plain text file containing a single WordPiece token per line.
|
244
|
+
sequence_length: int. If set, the output will be converted to a dense
|
245
|
+
tensor and padded/trimmed so all outputs are of sequence_length.
|
246
|
+
lowercase: bool. If `True`, the input text will be
|
247
|
+
lowercased before tokenization. Defaults to `False`.
|
248
|
+
strip_accents: bool. If `True`, all accent marks will
|
249
|
+
be removed from text before tokenization. Defaults to `False`.
|
250
|
+
split: bool. If `True`, input will be split on
|
251
|
+
whitespace and punctuation marks, and all punctuation marks will be
|
252
|
+
kept as tokens. If `False`, input should be split ("pre-tokenized")
|
253
|
+
before calling the tokenizer, and passed as a dense or ragged tensor
|
254
|
+
of whole words. Defaults to `True`.
|
255
|
+
split_on_cjk: bool. If True, input will be split
|
256
|
+
on CJK characters, i.e., Chinese, Japanese, Korean and Vietnamese
|
257
|
+
characters (https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)).
|
258
|
+
Note that this is applicable only when `split` is True.
|
259
|
+
Defaults to `True`.
|
260
|
+
suffix_indicator: str. The characters prepended to a
|
261
|
+
WordPiece to indicate that it is a suffix to another subword.
|
262
|
+
E.g. "##ing". Defaults to `"##"`.
|
263
|
+
oov_token: str. The string value to substitute for
|
264
|
+
an unknown token. It must be included in the vocab.
|
265
|
+
Defaults to `"[UNK]"`.
|
266
|
+
special_tokens: list. A list of special tokens. when
|
267
|
+
`special_tokens_in_strings` is set to `True`, the tokenizer will map
|
268
|
+
every special token in the input strings to its id, even if these
|
269
|
+
special tokens contain characters that should be splitted before
|
270
|
+
tokenization such as punctuation. `special_tokens` must be included
|
271
|
+
in `vocabulary`.
|
272
|
+
special_tokens_in_strings: bool. A bool to indicate if the tokenizer
|
273
|
+
should expect special tokens in input strings that should be
|
274
|
+
tokenized and mapped correctly to their ids. Defaults to False.
|
275
|
+
|
276
|
+
References:
|
277
|
+
- [Schuster and Nakajima, 2012](https://research.google/pubs/pub37842/)
|
278
|
+
- [Song et al., 2020](https://arxiv.org/abs/2012.15524)
|
279
|
+
|
280
|
+
Examples:
|
281
|
+
|
282
|
+
Ragged outputs.
|
283
|
+
>>> vocab = ["[UNK]", "the", "qu", "##ick", "br", "##own", "fox", "."]
|
284
|
+
>>> inputs = "The quick brown fox."
|
285
|
+
>>> tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
|
286
|
+
... vocabulary=vocab,
|
287
|
+
... lowercase=True,
|
288
|
+
... )
|
289
|
+
>>> outputs = tokenizer(inputs)
|
290
|
+
>>> np.array(outputs)
|
291
|
+
array([1, 2, 3, 4, 5, 6, 7], dtype=int32)
|
292
|
+
|
293
|
+
Dense outputs.
|
294
|
+
>>> vocab = ["[UNK]", "the", "qu", "##ick", "br", "##own", "fox", "."]
|
295
|
+
>>> inputs = ["The quick brown fox."]
|
296
|
+
>>> tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
|
297
|
+
... vocabulary=vocab,
|
298
|
+
... sequence_length=10,
|
299
|
+
... lowercase=True,
|
300
|
+
... )
|
301
|
+
>>> outputs = tokenizer(inputs)
|
302
|
+
>>> np.array(outputs)
|
303
|
+
array([[1, 2, 3, 4, 5, 6, 7, 0, 0, 0]], dtype=int32)
|
304
|
+
|
305
|
+
String output.
|
306
|
+
>>> vocab = ["[UNK]", "the", "qu", "##ick", "br", "##own", "fox", "."]
|
307
|
+
>>> inputs = "The quick brown fox."
|
308
|
+
>>> tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
|
309
|
+
... vocabulary=vocab,
|
310
|
+
... lowercase=True,
|
311
|
+
... dtype="string",
|
312
|
+
... )
|
313
|
+
>>> outputs = tokenizer(inputs)
|
314
|
+
>>> np.array(outputs).astype("U")
|
315
|
+
array(['the', 'qu', '##ick', 'br', '##own', 'fox', '.'], dtype='<U5')
|
316
|
+
|
317
|
+
Detokenization.
|
318
|
+
>>> vocab = ["[UNK]", "the", "qu", "##ick", "br", "##own", "fox", "."]
|
319
|
+
>>> inputs = "The quick brown fox."
|
320
|
+
>>> tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
|
321
|
+
... vocabulary=vocab,
|
322
|
+
... lowercase=True,
|
323
|
+
... )
|
324
|
+
>>> outputs = tokenizer.detokenize(tokenizer.tokenize(inputs))
|
325
|
+
>>> np.array(outputs).astype("U")
|
326
|
+
array('the quick brown fox .', dtype='<U21')
|
327
|
+
|
328
|
+
Custom splitting.
|
329
|
+
>>> vocab = ["[UNK]", "the", "qu", "##ick", "br", "##own", "fox", "."]
|
330
|
+
>>> inputs = "The$quick$brown$fox"
|
331
|
+
>>> tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
|
332
|
+
... vocabulary=vocab,
|
333
|
+
... split=False,
|
334
|
+
... lowercase=True,
|
335
|
+
... dtype='string',
|
336
|
+
... )
|
337
|
+
>>> split_inputs = tf.strings.split(inputs, sep="$")
|
338
|
+
>>> outputs = tokenizer(split_inputs)
|
339
|
+
>>> np.array(outputs).astype("U")
|
340
|
+
array(['the', 'qu', '##ick', 'br', '##own', 'fox'], dtype='<U5')
|
341
|
+
"""
|
342
|
+
|
343
|
+
def __init__(
|
344
|
+
self,
|
345
|
+
vocabulary=None,
|
346
|
+
sequence_length=None,
|
347
|
+
lowercase=False,
|
348
|
+
strip_accents=False,
|
349
|
+
split=True,
|
350
|
+
split_on_cjk=True,
|
351
|
+
suffix_indicator="##",
|
352
|
+
oov_token="[UNK]",
|
353
|
+
special_tokens=None,
|
354
|
+
special_tokens_in_strings=False,
|
355
|
+
dtype="int32",
|
356
|
+
**kwargs,
|
357
|
+
) -> None:
|
358
|
+
if not is_int_dtype(dtype) and not is_string_dtype(dtype):
|
359
|
+
raise ValueError(
|
360
|
+
"Output dtype must be an integer type or a string. "
|
361
|
+
f"Received: dtype={dtype}"
|
362
|
+
)
|
363
|
+
|
364
|
+
super().__init__(dtype=dtype, **kwargs)
|
365
|
+
if oov_token is None:
|
366
|
+
raise ValueError("`oov_token` cannot be None.")
|
367
|
+
|
368
|
+
self.sequence_length = sequence_length
|
369
|
+
self.lowercase = lowercase
|
370
|
+
self.strip_accents = strip_accents
|
371
|
+
self.split = split
|
372
|
+
self.split_on_cjk = split_on_cjk
|
373
|
+
self.suffix_indicator = suffix_indicator
|
374
|
+
self.oov_token = oov_token
|
375
|
+
self.special_tokens = special_tokens
|
376
|
+
self._special_tokens_pattern = None
|
377
|
+
if self.split and special_tokens_in_strings:
|
378
|
+
# the idea here is to pass the special tokens regex to the
|
379
|
+
# split function as delimiter regex pattern, so the input will
|
380
|
+
# be splitted by them, but also the function will treat each on
|
381
|
+
# of them as one entity that shouldn't be splitted even if they
|
382
|
+
# have other delimiter regex pattern inside them. then pass the
|
383
|
+
# special tokens regex also as keep delimiter regex
|
384
|
+
# pattern, so they will not be removed.
|
385
|
+
self._special_tokens_pattern = get_special_tokens_pattern(
|
386
|
+
self.special_tokens
|
387
|
+
)
|
388
|
+
self.set_vocabulary(vocabulary)
|
389
|
+
self.file_assets = [VOCAB_FILENAME]
|
390
|
+
|
391
|
+
def save_assets(self, dir_path):
|
392
|
+
path = os.path.join(dir_path, VOCAB_FILENAME)
|
393
|
+
with open(path, "w", encoding="utf-8") as file:
|
394
|
+
for token in self.vocabulary:
|
395
|
+
file.write(f"{token}\n")
|
396
|
+
|
397
|
+
def load_assets(self, dir_path):
|
398
|
+
path = os.path.join(dir_path, VOCAB_FILENAME)
|
399
|
+
self.set_vocabulary(path)
|
400
|
+
|
401
|
+
def set_vocabulary(self, vocabulary):
|
402
|
+
"""Set the tokenizer vocabulary to a file or list of strings."""
|
403
|
+
if vocabulary is None:
|
404
|
+
self.vocabulary = None
|
405
|
+
self._fast_word_piece = None
|
406
|
+
return
|
407
|
+
|
408
|
+
if isinstance(vocabulary, str):
|
409
|
+
with open(vocabulary, "r", encoding="utf-8") as file:
|
410
|
+
self.vocabulary = [line.rstrip() for line in file]
|
411
|
+
elif isinstance(vocabulary, Iterable):
|
412
|
+
# Make a defensive copy.
|
413
|
+
self.vocabulary = list(vocabulary)
|
414
|
+
else:
|
415
|
+
raise ValueError(
|
416
|
+
"Vocabulary must be an file path or list of terms. "
|
417
|
+
f"Received: vocabulary={vocabulary}"
|
418
|
+
)
|
419
|
+
|
420
|
+
if self.oov_token not in self.vocabulary:
|
421
|
+
raise ValueError(
|
422
|
+
f'Cannot find `oov_token="{self.oov_token}"` in the '
|
423
|
+
"vocabulary.\n"
|
424
|
+
"You can either update the vocabulary to include "
|
425
|
+
f'`"{self.oov_token}"`, or pass a different value for '
|
426
|
+
"the `oov_token` argument when creating the tokenizer."
|
427
|
+
)
|
428
|
+
|
429
|
+
# Check for special tokens in the vocabulary
|
430
|
+
if self.special_tokens is not None:
|
431
|
+
for token in self.special_tokens:
|
432
|
+
if token not in self.vocabulary:
|
433
|
+
raise ValueError(
|
434
|
+
f"Cannot find token `'{token}'` in the provided "
|
435
|
+
f"`vocabulary`. Please provide `'{token}'` in your "
|
436
|
+
"`vocabulary` or use a pretrained `vocabulary` name."
|
437
|
+
)
|
438
|
+
|
439
|
+
self._fast_word_piece = tf_text.FastWordpieceTokenizer(
|
440
|
+
vocab=self.vocabulary,
|
441
|
+
token_out_type=self.compute_dtype,
|
442
|
+
suffix_indicator=self.suffix_indicator,
|
443
|
+
unknown_token=self.oov_token,
|
444
|
+
no_pretokenization=True,
|
445
|
+
support_detokenization=True,
|
446
|
+
)
|
447
|
+
|
448
|
+
def get_vocabulary(self):
|
449
|
+
"""Get the tokenizer vocabulary as a list of strings tokens."""
|
450
|
+
self._check_vocabulary()
|
451
|
+
return self.vocabulary
|
452
|
+
|
453
|
+
def vocabulary_size(self):
|
454
|
+
"""Get the integer size of the tokenizer vocabulary."""
|
455
|
+
self._check_vocabulary()
|
456
|
+
return len(self.vocabulary)
|
457
|
+
|
458
|
+
def id_to_token(self, id):
|
459
|
+
"""Convert an integer id to a string token."""
|
460
|
+
self._check_vocabulary()
|
461
|
+
if id >= self.vocabulary_size() or id < 0:
|
462
|
+
raise ValueError(
|
463
|
+
f"`id` must be in range [0, {self.vocabulary_size() - 1}]. "
|
464
|
+
f"Received: {id}"
|
465
|
+
)
|
466
|
+
return self.vocabulary[id]
|
467
|
+
|
468
|
+
def token_to_id(self, token):
|
469
|
+
"""Convert a string token to an integer id."""
|
470
|
+
# This will be slow, but keep memory usage down compared to building a
|
471
|
+
# . Assuming the main use case is looking up a few special tokens
|
472
|
+
# early in the vocab, this should be fine.
|
473
|
+
self._check_vocabulary()
|
474
|
+
return self.vocabulary.index(token)
|
475
|
+
|
476
|
+
def get_config(self):
|
477
|
+
config = super().get_config()
|
478
|
+
config.update(
|
479
|
+
{
|
480
|
+
"vocabulary": None, # Save vocabulary via an asset!
|
481
|
+
"sequence_length": self.sequence_length,
|
482
|
+
"lowercase": self.lowercase,
|
483
|
+
"strip_accents": self.strip_accents,
|
484
|
+
"split": self.split,
|
485
|
+
"suffix_indicator": self.suffix_indicator,
|
486
|
+
"oov_token": self.oov_token,
|
487
|
+
"special_tokens": self.special_tokens,
|
488
|
+
}
|
489
|
+
)
|
490
|
+
return config
|
491
|
+
|
492
|
+
def _check_vocabulary(self):
|
493
|
+
if self.vocabulary is None:
|
494
|
+
raise ValueError(
|
495
|
+
"No vocabulary has been set for WordPieceTokenizer. Make sure "
|
496
|
+
"to pass a `vocabulary` argument when creating the layer."
|
497
|
+
)
|
498
|
+
|
499
|
+
def tokenize(self, inputs):
|
500
|
+
self._check_vocabulary()
|
501
|
+
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
|
502
|
+
inputs = tf.convert_to_tensor(inputs)
|
503
|
+
|
504
|
+
scalar_input = inputs.shape.rank == 0
|
505
|
+
inputs = pretokenize(
|
506
|
+
inputs,
|
507
|
+
self.lowercase,
|
508
|
+
self.strip_accents,
|
509
|
+
self.split,
|
510
|
+
self.split_on_cjk,
|
511
|
+
self._special_tokens_pattern,
|
512
|
+
)
|
513
|
+
|
514
|
+
# Apply WordPiece and coerce shape for outputs.
|
515
|
+
tokens = self._fast_word_piece.tokenize(inputs)
|
516
|
+
# By default tf.text tokenizes text with two ragged dimensions (one for
|
517
|
+
# split words and one for split subwords). We will collapse to a single
|
518
|
+
# ragged dimension which is a better out of box default.
|
519
|
+
tokens = tokens.merge_dims(-2, -1)
|
520
|
+
|
521
|
+
# Convert to a dense output if `sequence_length` is set.
|
522
|
+
if self.sequence_length:
|
523
|
+
output_shape = tokens.shape.as_list()
|
524
|
+
output_shape[-1] = self.sequence_length
|
525
|
+
tokens = tokens.to_tensor(shape=output_shape)
|
526
|
+
# Convert to a dense output if input in scalar
|
527
|
+
if scalar_input:
|
528
|
+
tokens = tf.squeeze(tokens, 0)
|
529
|
+
tf.ensure_shape(tokens, shape=[self.sequence_length])
|
530
|
+
|
531
|
+
return tokens
|
532
|
+
|
533
|
+
def detokenize(self, inputs):
|
534
|
+
self._check_vocabulary()
|
535
|
+
inputs, unbatched, _ = convert_to_ragged_batch(inputs)
|
536
|
+
outputs = self._fast_word_piece.detokenize(inputs)
|
537
|
+
if unbatched:
|
538
|
+
outputs = tf.squeeze(outputs, 0)
|
539
|
+
return outputs
|
540
|
+
|
541
|
+
def compute_output_spec(self, input_spec):
|
542
|
+
return keras.KerasTensor(
|
543
|
+
input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype
|
544
|
+
)
|