keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,262 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
|
17
|
+
from keras_hub.src.api_export import keras_hub_export
|
18
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
19
|
+
|
20
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import ( # isort:skip
|
21
|
+
merge_padding_and_attention_mask,
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
@keras_hub_export("keras_hub.layers.TransformerEncoder")
|
26
|
+
class TransformerEncoder(keras.layers.Layer):
|
27
|
+
"""Transformer encoder.
|
28
|
+
|
29
|
+
This class follows the architecture of the transformer encoder layer in the
|
30
|
+
paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users
|
31
|
+
can instantiate multiple instances of this class to stack up an encoder.
|
32
|
+
|
33
|
+
This layer will correctly compute an attention mask from an implicit
|
34
|
+
Keras padding mask (for example, by passing `mask_zero=True` to a
|
35
|
+
`keras.layers.Embedding` layer). See the Masking and Padding
|
36
|
+
[guide](https://keras.io/guides/understanding_masking_and_padding/)
|
37
|
+
for more details.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
intermediate_dim: int, the hidden size of feedforward network.
|
41
|
+
num_heads: int, the number of heads in the
|
42
|
+
`keras.layers.MultiHeadAttention` layer.
|
43
|
+
dropout: float. the dropout value, shared by
|
44
|
+
`keras.layers.MultiHeadAttention` and feedforward network.
|
45
|
+
Defaults to `0.`.
|
46
|
+
activation: string or `keras.activations`. the
|
47
|
+
activation function of feedforward network.
|
48
|
+
Defaults to `"relu"`.
|
49
|
+
layer_norm_epsilon: float. The epsilon value in layer
|
50
|
+
normalization components. Defaults to `1e-5`.
|
51
|
+
kernel_initializer: string or `keras.initializers` initializer.
|
52
|
+
The kernel initializer for the dense and multiheaded
|
53
|
+
attention layers. Defaults to `"glorot_uniform"`.
|
54
|
+
bias_initializer: string or `keras.initializers` initializer.
|
55
|
+
The bias initializer for the dense and multiheaded
|
56
|
+
attention layers. Defaults to `"zeros"`.
|
57
|
+
normalize_first: bool. If True, the inputs to the
|
58
|
+
attention layer and the intermediate dense layer are normalized
|
59
|
+
(similar to GPT-2). If set to False, outputs of attention layer and
|
60
|
+
intermediate dense layer are normalized (similar to BERT).
|
61
|
+
Defaults to `False`.
|
62
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
63
|
+
including `name`, `trainable`, `dtype` etc.
|
64
|
+
|
65
|
+
Example:
|
66
|
+
|
67
|
+
```python
|
68
|
+
# Create a single transformer encoder layer.
|
69
|
+
encoder = keras_hub.layers.TransformerEncoder(
|
70
|
+
intermediate_dim=64, num_heads=8)
|
71
|
+
|
72
|
+
# Create a simple model containing the encoder.
|
73
|
+
input = keras.Input(shape=(10, 64))
|
74
|
+
output = encoder(input)
|
75
|
+
model = keras.Model(inputs=input, outputs=output)
|
76
|
+
|
77
|
+
# Call encoder on the inputs.
|
78
|
+
input_data = np.random.uniform(size=(2, 10, 64))
|
79
|
+
output = model(input_data)
|
80
|
+
```
|
81
|
+
|
82
|
+
References:
|
83
|
+
- [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
|
84
|
+
"""
|
85
|
+
|
86
|
+
def __init__(
|
87
|
+
self,
|
88
|
+
intermediate_dim,
|
89
|
+
num_heads,
|
90
|
+
dropout=0,
|
91
|
+
activation="relu",
|
92
|
+
layer_norm_epsilon=1e-05,
|
93
|
+
kernel_initializer="glorot_uniform",
|
94
|
+
bias_initializer="zeros",
|
95
|
+
normalize_first=False,
|
96
|
+
**kwargs,
|
97
|
+
):
|
98
|
+
super().__init__(**kwargs)
|
99
|
+
self.intermediate_dim = intermediate_dim
|
100
|
+
self.num_heads = num_heads
|
101
|
+
self.dropout = dropout
|
102
|
+
self.activation = keras.activations.get(activation)
|
103
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
104
|
+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
105
|
+
self.bias_initializer = keras.initializers.get(bias_initializer)
|
106
|
+
self.normalize_first = normalize_first
|
107
|
+
self.supports_masking = True
|
108
|
+
|
109
|
+
def build(self, inputs_shape):
|
110
|
+
# Infer the dimension of our hidden feature size from the build shape.
|
111
|
+
hidden_dim = inputs_shape[-1]
|
112
|
+
# Attention head size is `hidden_dim` over the number of heads.
|
113
|
+
key_dim = int(hidden_dim // self.num_heads)
|
114
|
+
if key_dim == 0:
|
115
|
+
raise ValueError(
|
116
|
+
"Attention `key_dim` computed cannot be zero. "
|
117
|
+
f"The `hidden_dim` value of {hidden_dim} has to be equal to "
|
118
|
+
f"or greater than `num_heads` value of {self.num_heads}."
|
119
|
+
)
|
120
|
+
|
121
|
+
# Self attention layers.
|
122
|
+
self._self_attention_layer = keras.layers.MultiHeadAttention(
|
123
|
+
num_heads=self.num_heads,
|
124
|
+
key_dim=key_dim,
|
125
|
+
dropout=self.dropout,
|
126
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
127
|
+
bias_initializer=clone_initializer(self.bias_initializer),
|
128
|
+
dtype=self.dtype_policy,
|
129
|
+
name="self_attention_layer",
|
130
|
+
)
|
131
|
+
if hasattr(self._self_attention_layer, "_build_from_signature"):
|
132
|
+
self._self_attention_layer._build_from_signature(
|
133
|
+
query=inputs_shape,
|
134
|
+
value=inputs_shape,
|
135
|
+
)
|
136
|
+
else:
|
137
|
+
self._self_attention_layer.build(
|
138
|
+
query_shape=inputs_shape,
|
139
|
+
value_shape=inputs_shape,
|
140
|
+
)
|
141
|
+
self._self_attention_layer_norm = keras.layers.LayerNormalization(
|
142
|
+
epsilon=self.layer_norm_epsilon,
|
143
|
+
dtype=self.dtype_policy,
|
144
|
+
name="self_attention_layer_norm",
|
145
|
+
)
|
146
|
+
self._self_attention_layer_norm.build(inputs_shape)
|
147
|
+
self._self_attention_dropout = keras.layers.Dropout(
|
148
|
+
rate=self.dropout,
|
149
|
+
dtype=self.dtype_policy,
|
150
|
+
name="self_attention_dropout",
|
151
|
+
)
|
152
|
+
|
153
|
+
# Feedforward layers.
|
154
|
+
self._feedforward_layer_norm = keras.layers.LayerNormalization(
|
155
|
+
epsilon=self.layer_norm_epsilon,
|
156
|
+
dtype=self.dtype_policy,
|
157
|
+
name="feedforward_layer_norm",
|
158
|
+
)
|
159
|
+
self._feedforward_layer_norm.build(inputs_shape)
|
160
|
+
self._feedforward_intermediate_dense = keras.layers.Dense(
|
161
|
+
self.intermediate_dim,
|
162
|
+
activation=self.activation,
|
163
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
164
|
+
bias_initializer=clone_initializer(self.bias_initializer),
|
165
|
+
dtype=self.dtype_policy,
|
166
|
+
name="feedforward_intermediate_dense",
|
167
|
+
)
|
168
|
+
self._feedforward_intermediate_dense.build(inputs_shape)
|
169
|
+
self._feedforward_output_dense = keras.layers.Dense(
|
170
|
+
hidden_dim,
|
171
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
172
|
+
bias_initializer=clone_initializer(self.bias_initializer),
|
173
|
+
dtype=self.dtype_policy,
|
174
|
+
name="feedforward_output_dense",
|
175
|
+
)
|
176
|
+
intermediate_shape = list(inputs_shape)
|
177
|
+
intermediate_shape[-1] = self.intermediate_dim
|
178
|
+
self._feedforward_output_dense.build(tuple(intermediate_shape))
|
179
|
+
self._feedforward_dropout = keras.layers.Dropout(
|
180
|
+
rate=self.dropout,
|
181
|
+
dtype=self.dtype_policy,
|
182
|
+
name="feedforward_dropout",
|
183
|
+
)
|
184
|
+
self.built = True
|
185
|
+
|
186
|
+
def call(
|
187
|
+
self, inputs, padding_mask=None, attention_mask=None, training=None
|
188
|
+
):
|
189
|
+
"""Forward pass of the TransformerEncoder.
|
190
|
+
|
191
|
+
Args:
|
192
|
+
inputs: a Tensor. The input data to TransformerEncoder, should be
|
193
|
+
of shape [batch_size, sequence_length, hidden_dim].
|
194
|
+
padding_mask: a boolean Tensor. It indicates if the token should be
|
195
|
+
masked because the token is introduced due to padding.
|
196
|
+
`padding_mask` should have shape [batch_size, sequence_length].
|
197
|
+
attention_mask: a boolean Tensor. Customized mask used to mask out
|
198
|
+
certain tokens. `attention_mask` should have shape
|
199
|
+
[batch_size, sequence_length, sequence_length].
|
200
|
+
training: a boolean indicating whether the layer should behave in
|
201
|
+
training mode or in inference mode.
|
202
|
+
|
203
|
+
Returns:
|
204
|
+
A Tensor of the same shape as the `inputs`.
|
205
|
+
"""
|
206
|
+
x = inputs # Intermediate result.
|
207
|
+
|
208
|
+
# Compute self attention mask.
|
209
|
+
self_attention_mask = merge_padding_and_attention_mask(
|
210
|
+
inputs, padding_mask, attention_mask
|
211
|
+
)
|
212
|
+
|
213
|
+
# Self attention block.
|
214
|
+
residual = x
|
215
|
+
if self.normalize_first:
|
216
|
+
x = self._self_attention_layer_norm(x)
|
217
|
+
x = self._self_attention_layer(
|
218
|
+
query=x,
|
219
|
+
value=x,
|
220
|
+
attention_mask=self_attention_mask,
|
221
|
+
training=training,
|
222
|
+
)
|
223
|
+
x = self._self_attention_dropout(x, training=training)
|
224
|
+
x = x + residual
|
225
|
+
if not self.normalize_first:
|
226
|
+
x = self._self_attention_layer_norm(x)
|
227
|
+
|
228
|
+
# Feedforward block.
|
229
|
+
residual = x
|
230
|
+
if self.normalize_first:
|
231
|
+
x = self._feedforward_layer_norm(x)
|
232
|
+
x = self._feedforward_intermediate_dense(x)
|
233
|
+
x = self._feedforward_output_dense(x)
|
234
|
+
x = self._feedforward_dropout(x, training=training)
|
235
|
+
x = x + residual
|
236
|
+
if not self.normalize_first:
|
237
|
+
x = self._feedforward_layer_norm(x)
|
238
|
+
|
239
|
+
return x
|
240
|
+
|
241
|
+
def get_config(self):
|
242
|
+
config = super().get_config()
|
243
|
+
config.update(
|
244
|
+
{
|
245
|
+
"intermediate_dim": self.intermediate_dim,
|
246
|
+
"num_heads": self.num_heads,
|
247
|
+
"dropout": self.dropout,
|
248
|
+
"activation": keras.activations.serialize(self.activation),
|
249
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
250
|
+
"kernel_initializer": keras.initializers.serialize(
|
251
|
+
self.kernel_initializer
|
252
|
+
),
|
253
|
+
"bias_initializer": keras.initializers.serialize(
|
254
|
+
self.bias_initializer
|
255
|
+
),
|
256
|
+
"normalize_first": self.normalize_first,
|
257
|
+
}
|
258
|
+
)
|
259
|
+
return config
|
260
|
+
|
261
|
+
def compute_output_shape(self, inputs_shape):
|
262
|
+
return inputs_shape
|
@@ -0,0 +1,106 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from absl import logging
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
|
19
|
+
def _check_masks_shapes(inputs, padding_mask, attention_mask):
|
20
|
+
mask = padding_mask
|
21
|
+
if hasattr(inputs, "_keras_mask") and mask is None:
|
22
|
+
mask = inputs._keras_mask
|
23
|
+
if mask is not None:
|
24
|
+
if len(mask.shape) != 2:
|
25
|
+
raise ValueError(
|
26
|
+
"`padding_mask` should have shape "
|
27
|
+
"(batch_size, target_length). "
|
28
|
+
f"Received shape `{mask.shape}`."
|
29
|
+
)
|
30
|
+
if attention_mask is not None:
|
31
|
+
if len(attention_mask.shape) != 3:
|
32
|
+
raise ValueError(
|
33
|
+
"`attention_mask` should have shape "
|
34
|
+
"(batch_size, target_length, source_length). "
|
35
|
+
f"Received shape `{mask.shape}`."
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
def compute_causal_mask(batch_size, input_length, output_length, cache_index=0):
|
40
|
+
"""Compute a causal attention mask for a transformer decoder.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
batch_size: batch size for the mask.
|
44
|
+
input_length: the length of key/value tensors in the attention layer.
|
45
|
+
output_length: the length of query tensors in the attention layer.
|
46
|
+
cache_index: the current index for cached generation. If passed, the
|
47
|
+
query sequence will be considered to start at `cache_index` rather
|
48
|
+
than zero. For example, a causal mask with `output_length=1` and
|
49
|
+
`cache_index=5` would allow the query tensor to attend to the first
|
50
|
+
five positions of the key/value tensors.
|
51
|
+
|
52
|
+
Return:
|
53
|
+
A causal attention mask with shape
|
54
|
+
`(batch_size, output_length, input_length)` that can be passed to a
|
55
|
+
attention layer.
|
56
|
+
"""
|
57
|
+
i = ops.arange(output_length, dtype="float32")
|
58
|
+
i = i + ops.cast(cache_index, "float32")
|
59
|
+
i = ops.expand_dims(i, axis=1)
|
60
|
+
j = ops.arange(input_length, dtype="float32")
|
61
|
+
mask = ops.expand_dims(i >= j, axis=0)
|
62
|
+
|
63
|
+
return ops.broadcast_to(mask, (batch_size, output_length, input_length))
|
64
|
+
|
65
|
+
|
66
|
+
def merge_padding_and_attention_mask(
|
67
|
+
inputs,
|
68
|
+
padding_mask,
|
69
|
+
attention_mask,
|
70
|
+
):
|
71
|
+
"""Merge the padding mask with a customized attention mask.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
inputs: the input sequence.
|
75
|
+
padding_mask: the 1D padding mask, of shape
|
76
|
+
[batch_size, sequence_length].
|
77
|
+
attention_mask: the 2D customized mask, of shape
|
78
|
+
[batch_size, sequence1_length, sequence2_length].
|
79
|
+
|
80
|
+
Return:
|
81
|
+
A merged 2D mask or None. If only `padding_mask` is provided, the
|
82
|
+
returned mask is padding_mask with one additional axis.
|
83
|
+
"""
|
84
|
+
_check_masks_shapes(inputs, padding_mask, attention_mask)
|
85
|
+
mask = padding_mask
|
86
|
+
if hasattr(inputs, "_keras_mask"):
|
87
|
+
if mask is None:
|
88
|
+
# If no padding mask is explicitly provided, we look for padding
|
89
|
+
# mask from the input data.
|
90
|
+
mask = inputs._keras_mask
|
91
|
+
else:
|
92
|
+
logging.warning(
|
93
|
+
"You are explicitly setting `padding_mask` while the `inputs` "
|
94
|
+
"have built-in mask, so the built-in mask is ignored."
|
95
|
+
)
|
96
|
+
if mask is not None:
|
97
|
+
# Add an axis for broadcasting, the attention mask should be 2D
|
98
|
+
# (not including the batch axis).
|
99
|
+
mask = ops.cast(ops.expand_dims(mask, axis=1), "int32")
|
100
|
+
if attention_mask is not None:
|
101
|
+
attention_mask = ops.cast(attention_mask, "int32")
|
102
|
+
if mask is None:
|
103
|
+
return attention_mask
|
104
|
+
else:
|
105
|
+
return ops.minimum(mask, attention_mask)
|
106
|
+
return mask
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
@@ -0,0 +1,220 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from keras_hub.src.api_export import keras_hub_export
|
17
|
+
from keras_hub.src.layers.preprocessing.preprocessing_layer import (
|
18
|
+
PreprocessingLayer,
|
19
|
+
)
|
20
|
+
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
|
21
|
+
|
22
|
+
try:
|
23
|
+
import tensorflow as tf
|
24
|
+
import tensorflow_text as tf_text
|
25
|
+
except ImportError:
|
26
|
+
tf = None
|
27
|
+
tf_text = None
|
28
|
+
|
29
|
+
|
30
|
+
@keras_hub_export("keras_hub.layers.MaskedLMMaskGenerator")
|
31
|
+
class MaskedLMMaskGenerator(PreprocessingLayer):
|
32
|
+
"""Layer that applies language model masking.
|
33
|
+
|
34
|
+
This layer is useful for preparing inputs for masked language modeling
|
35
|
+
(MaskedLM) tasks. It follows the masking strategy described in the
|
36
|
+
[original BERT paper](https://arxiv.org/abs/1810.04805). Given tokenized
|
37
|
+
text, it randomly selects certain number of tokens for masking. Then for
|
38
|
+
each selected token, it has a chance (configurable) to be replaced by
|
39
|
+
"mask token" or random token, or stay unchanged.
|
40
|
+
|
41
|
+
Input data should be passed as tensors, `tf.RaggedTensor`s, or lists. For
|
42
|
+
batched input, inputs should be a list of lists or a rank two tensor. For
|
43
|
+
unbatched inputs, each element should be a list or a rank one tensor.
|
44
|
+
|
45
|
+
This layer can be used with `tf.data` to generate dynamic masks on the fly
|
46
|
+
during training.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
vocabulary_size: int, the size of the vocabulary.
|
50
|
+
mask_selection_rate: float, the probability of a token is selected for
|
51
|
+
masking.
|
52
|
+
mask_token_id: int. The id of mask token.
|
53
|
+
mask_selection_length: int. Maximum number of tokens
|
54
|
+
selected for masking in each sequence. If set, the output
|
55
|
+
`mask_positions`, `mask_ids` and `mask_weights` will be padded
|
56
|
+
to dense tensors of length `mask_selection_length`, otherwise
|
57
|
+
the output will be a RaggedTensor. Defaults to `None`.
|
58
|
+
unselectable_token_ids: A list of tokens id that should not be
|
59
|
+
considered eligible for masking. By default, we assume `0`
|
60
|
+
corresponds to a padding token and ignore it. Defaults to `[0]`.
|
61
|
+
mask_token_rate: float. `mask_token_rate` must be
|
62
|
+
between 0 and 1 which indicates how often the mask_token is
|
63
|
+
substituted for tokens selected for masking. Defaults to `0.8`.
|
64
|
+
random_token_rate: float. `random_token_rate` must be
|
65
|
+
between 0 and 1 which indicates how often a random token is
|
66
|
+
substituted for tokens selected for masking.
|
67
|
+
Note: mask_token_rate + random_token_rate <= 1, and for
|
68
|
+
(1 - mask_token_rate - random_token_rate), the token will not be
|
69
|
+
changed. Defaults to `0.1`.
|
70
|
+
|
71
|
+
Returns:
|
72
|
+
A Dict with 4 keys:
|
73
|
+
token_ids: Tensor or RaggedTensor, has the same type and shape of
|
74
|
+
input. Sequence after getting masked.
|
75
|
+
mask_positions: Tensor, or RaggedTensor if `mask_selection_length`
|
76
|
+
is None. The positions of token_ids getting masked.
|
77
|
+
mask_ids: Tensor, or RaggedTensor if `mask_selection_length` is
|
78
|
+
None. The original token ids at masked positions.
|
79
|
+
mask_weights: Tensor, or RaggedTensor if `mask_selection_length` is
|
80
|
+
None. `mask_weights` has the same shape as `mask_positions` and
|
81
|
+
`mask_ids`. Each element in `mask_weights` should be 0 or 1,
|
82
|
+
1 means the corresponding position in `mask_positions` is an
|
83
|
+
actual mask, 0 means it is a pad.
|
84
|
+
|
85
|
+
Examples:
|
86
|
+
|
87
|
+
Basic usage.
|
88
|
+
```python
|
89
|
+
masker = keras_hub.layers.MaskedLMMaskGenerator(
|
90
|
+
vocabulary_size=10,
|
91
|
+
mask_selection_rate=0.2,
|
92
|
+
mask_token_id=0,
|
93
|
+
mask_selection_length=5
|
94
|
+
)
|
95
|
+
# Dense input.
|
96
|
+
masker([1, 2, 3, 4, 5])
|
97
|
+
|
98
|
+
# Ragged input.
|
99
|
+
masker([[1, 2], [1, 2, 3, 4]])
|
100
|
+
```
|
101
|
+
|
102
|
+
Masking a batch that contains special tokens.
|
103
|
+
```python
|
104
|
+
pad_id, cls_id, sep_id, mask_id = 0, 1, 2, 3
|
105
|
+
batch = [
|
106
|
+
[cls_id, 4, 5, 6, sep_id, 7, 8, sep_id, pad_id, pad_id],
|
107
|
+
[cls_id, 4, 5, sep_id, 6, 7, 8, 9, sep_id, pad_id],
|
108
|
+
]
|
109
|
+
|
110
|
+
masker = keras_hub.layers.MaskedLMMaskGenerator(
|
111
|
+
vocabulary_size = 10,
|
112
|
+
mask_selection_rate = 0.2,
|
113
|
+
mask_selection_length = 5,
|
114
|
+
mask_token_id = mask_id,
|
115
|
+
unselectable_token_ids = [
|
116
|
+
cls_id,
|
117
|
+
sep_id,
|
118
|
+
pad_id,
|
119
|
+
]
|
120
|
+
)
|
121
|
+
masker(batch)
|
122
|
+
```
|
123
|
+
"""
|
124
|
+
|
125
|
+
def __init__(
|
126
|
+
self,
|
127
|
+
vocabulary_size,
|
128
|
+
mask_selection_rate,
|
129
|
+
mask_token_id,
|
130
|
+
mask_selection_length=None,
|
131
|
+
unselectable_token_ids=[0],
|
132
|
+
mask_token_rate=0.8,
|
133
|
+
random_token_rate=0.1,
|
134
|
+
**kwargs,
|
135
|
+
):
|
136
|
+
super().__init__(**kwargs)
|
137
|
+
|
138
|
+
self.vocabulary_size = vocabulary_size
|
139
|
+
self.unselectable_token_ids = unselectable_token_ids
|
140
|
+
self.mask_selection_rate = mask_selection_rate
|
141
|
+
self.mask_selection_length = mask_selection_length
|
142
|
+
self.mask_token_rate = mask_token_rate
|
143
|
+
self.random_token_rate = random_token_rate
|
144
|
+
|
145
|
+
if mask_token_id >= vocabulary_size:
|
146
|
+
raise ValueError(
|
147
|
+
f"Mask token id should be in range [0, vocabulary_size - 1], "
|
148
|
+
f"but received mask_token_id={mask_token_id}."
|
149
|
+
)
|
150
|
+
self.mask_token_id = mask_token_id
|
151
|
+
|
152
|
+
max_selections = self.mask_selection_length
|
153
|
+
if max_selections is None:
|
154
|
+
# Set a large number to remove the `max_selections_per_batch` cap.
|
155
|
+
max_selections = 2**31 - 1
|
156
|
+
self._random_selector = tf_text.RandomItemSelector(
|
157
|
+
max_selections_per_batch=max_selections,
|
158
|
+
selection_rate=self.mask_selection_rate,
|
159
|
+
unselectable_ids=self.unselectable_token_ids,
|
160
|
+
)
|
161
|
+
self._mask_values_chooser = tf_text.MaskValuesChooser(
|
162
|
+
self.vocabulary_size,
|
163
|
+
self.mask_token_id,
|
164
|
+
mask_token_rate=self.mask_token_rate,
|
165
|
+
random_token_rate=self.random_token_rate,
|
166
|
+
)
|
167
|
+
|
168
|
+
def call(self, inputs):
|
169
|
+
inputs, unbatched, rectangular = convert_to_ragged_batch(inputs)
|
170
|
+
|
171
|
+
(
|
172
|
+
token_ids,
|
173
|
+
mask_positions,
|
174
|
+
mask_ids,
|
175
|
+
) = tf_text.mask_language_model(
|
176
|
+
inputs,
|
177
|
+
item_selector=self._random_selector,
|
178
|
+
mask_values_chooser=self._mask_values_chooser,
|
179
|
+
)
|
180
|
+
|
181
|
+
if rectangular:
|
182
|
+
# If we converted the input from dense to ragged, convert back.
|
183
|
+
token_ids = token_ids.to_tensor()
|
184
|
+
|
185
|
+
mask_weights = tf.ones_like(mask_positions, self.compute_dtype)
|
186
|
+
# If `mask_selection_length` is set, convert to dense.
|
187
|
+
if self.mask_selection_length:
|
188
|
+
target_shape = tf.cast([-1, self.mask_selection_length], "int64")
|
189
|
+
mask_positions = mask_positions.to_tensor(shape=target_shape)
|
190
|
+
mask_ids = mask_ids.to_tensor(shape=target_shape)
|
191
|
+
mask_weights = mask_weights.to_tensor(shape=target_shape)
|
192
|
+
|
193
|
+
if unbatched:
|
194
|
+
# If inputs is 1D, we format the output to be 1D as well.
|
195
|
+
token_ids = tf.squeeze(token_ids, axis=0)
|
196
|
+
mask_positions = tf.squeeze(mask_positions, axis=0)
|
197
|
+
mask_ids = tf.squeeze(mask_ids, axis=0)
|
198
|
+
mask_weights = tf.squeeze(mask_weights, axis=0)
|
199
|
+
|
200
|
+
return {
|
201
|
+
"token_ids": token_ids,
|
202
|
+
"mask_positions": mask_positions,
|
203
|
+
"mask_ids": mask_ids,
|
204
|
+
"mask_weights": mask_weights,
|
205
|
+
}
|
206
|
+
|
207
|
+
def get_config(self):
|
208
|
+
config = super().get_config()
|
209
|
+
config.update(
|
210
|
+
{
|
211
|
+
"vocabulary_size": self.vocabulary_size,
|
212
|
+
"mask_selection_rate": self.mask_selection_rate,
|
213
|
+
"mask_selection_length": self.mask_selection_length,
|
214
|
+
"unselectable_token_ids": self.unselectable_token_ids,
|
215
|
+
"mask_token_id": self.mask_token_id,
|
216
|
+
"mask_token_rate": self.mask_token_rate,
|
217
|
+
"random_token_rate": self.random_token_rate,
|
218
|
+
}
|
219
|
+
)
|
220
|
+
return config
|