keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,239 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
|
20
|
+
|
21
|
+
@keras_hub_export("keras_hub.layers.MaskedLMHead")
|
22
|
+
class MaskedLMHead(keras.layers.Layer):
|
23
|
+
"""Masked Language Model (MaskedLM) head.
|
24
|
+
|
25
|
+
This layer takes two inputs:
|
26
|
+
|
27
|
+
- `inputs`: which should be a tensor of encoded tokens with shape
|
28
|
+
`(batch_size, sequence_length, hidden_dim)`.
|
29
|
+
- `mask_positions`: which should be a tensor of integer positions to
|
30
|
+
predict with shape `(batch_size, masks_per_sequence)`.
|
31
|
+
|
32
|
+
The token encodings should usually be the last output of an encoder model,
|
33
|
+
and mask positions should be the integer positions you would like to
|
34
|
+
predict for the MaskedLM task.
|
35
|
+
|
36
|
+
The layer will first gather the token encodings at the mask positions. These
|
37
|
+
gathered tokens will be passed through a dense layer the same size as
|
38
|
+
encoding dimension, then transformed to predictions the same size as the
|
39
|
+
input vocabulary. This layer will produce a single output with shape
|
40
|
+
`(batch_size, masks_per_sequence, vocabulary_size)`, which can be used to
|
41
|
+
compute an MaskedLM loss function.
|
42
|
+
|
43
|
+
This layer is often be paired with `keras_hub.layers.MaskedLMMaskGenerator`,
|
44
|
+
which will help prepare inputs for the MaskedLM task.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
vocabulary_size: The total size of the vocabulary for predictions.
|
48
|
+
token_embedding: Optional. A `keras_hub.layers.ReversibleEmbedding`
|
49
|
+
instance. If passed, the layer will be used to project from the
|
50
|
+
`hidden_dim` of the model to the output `vocabulary_size`.
|
51
|
+
intermediate_activation: The activation function of intermediate dense layer.
|
52
|
+
activation: The activation function for the outputs of the layer.
|
53
|
+
Usually either `None` (return logits), or `"softmax"`
|
54
|
+
(return probabilities).
|
55
|
+
layer_norm_epsilon: float. The epsilon value in layer
|
56
|
+
normalization components. Defaults to `1e-5`.
|
57
|
+
kernel_initializer: string or `keras.initializers` initializer.
|
58
|
+
The kernel initializer for the dense and multiheaded
|
59
|
+
attention layers. Defaults to `"glorot_uniform"`.
|
60
|
+
bias_initializer: string or `keras.initializers` initializer.
|
61
|
+
The bias initializer for the dense and multiheaded
|
62
|
+
attention layers. Defaults to `"zeros"`.
|
63
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
64
|
+
including `name`, `trainable`, `dtype` etc.
|
65
|
+
|
66
|
+
Example:
|
67
|
+
|
68
|
+
```python
|
69
|
+
batch_size = 16
|
70
|
+
vocab_size = 100
|
71
|
+
hidden_dim = 32
|
72
|
+
seq_length = 50
|
73
|
+
|
74
|
+
# Generate random inputs.
|
75
|
+
token_ids = np.random.randint(vocab_size, size=(batch_size, seq_length))
|
76
|
+
# Choose random positions as the masked inputs.
|
77
|
+
mask_positions = np.random.randint(seq_length, size=(batch_size, 5))
|
78
|
+
|
79
|
+
# Embed tokens in a `hidden_dim` feature space.
|
80
|
+
token_embedding = keras_hub.layers.ReversibleEmbedding(
|
81
|
+
vocab_size,
|
82
|
+
hidden_dim,
|
83
|
+
)
|
84
|
+
hidden_states = token_embedding(token_ids)
|
85
|
+
|
86
|
+
preds = keras_hub.layers.MaskedLMHead(
|
87
|
+
vocabulary_size=vocab_size,
|
88
|
+
token_embedding=token_embedding,
|
89
|
+
activation="softmax",
|
90
|
+
)(hidden_states, mask_positions)
|
91
|
+
```
|
92
|
+
|
93
|
+
References:
|
94
|
+
- [Press and Wolf, 2016](https://arxiv.org/abs/1608.05859)
|
95
|
+
"""
|
96
|
+
|
97
|
+
def __init__(
|
98
|
+
self,
|
99
|
+
vocabulary_size=None,
|
100
|
+
token_embedding=None,
|
101
|
+
intermediate_activation="relu",
|
102
|
+
activation=None,
|
103
|
+
layer_norm_epsilon=1e-05,
|
104
|
+
kernel_initializer="glorot_uniform",
|
105
|
+
bias_initializer="zeros",
|
106
|
+
**kwargs,
|
107
|
+
):
|
108
|
+
super().__init__(**kwargs, autocast=False)
|
109
|
+
|
110
|
+
self.vocabulary_size = vocabulary_size
|
111
|
+
self.token_embedding = token_embedding
|
112
|
+
self.intermediate_activation = keras.activations.get(
|
113
|
+
intermediate_activation
|
114
|
+
)
|
115
|
+
self.activation = keras.activations.get(activation)
|
116
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
117
|
+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
118
|
+
self.bias_initializer = keras.initializers.get(bias_initializer)
|
119
|
+
|
120
|
+
if vocabulary_size is None and token_embedding is None:
|
121
|
+
raise ValueError(
|
122
|
+
"One of `vocabulary_size` or `token_embedding` must be set. "
|
123
|
+
"Received: `vocabulary_size=None`, `token_embedding=None`"
|
124
|
+
)
|
125
|
+
|
126
|
+
if token_embedding:
|
127
|
+
if vocabulary_size and vocabulary_size != token_embedding.input_dim:
|
128
|
+
raise ValueError(
|
129
|
+
"`vocabulary_size` should match the input dimension of the "
|
130
|
+
"of `token_embedding`. Received: "
|
131
|
+
f"`vocabulary_size={vocabulary_size}`, "
|
132
|
+
f"`token_embedding.input_dim={token_embedding.input_dim}`"
|
133
|
+
)
|
134
|
+
self.vocabulary_size = token_embedding.input_dim
|
135
|
+
|
136
|
+
def build(self, inputs_shape, mask_positions_shape=None):
|
137
|
+
if self.token_embedding is not None:
|
138
|
+
feature_size = self.token_embedding.output_dim
|
139
|
+
else:
|
140
|
+
feature_size = inputs_shape[-1]
|
141
|
+
|
142
|
+
self._intermediate_dense = keras.layers.Dense(
|
143
|
+
feature_size,
|
144
|
+
activation=self.intermediate_activation,
|
145
|
+
kernel_initializer=self.kernel_initializer,
|
146
|
+
bias_initializer=self.bias_initializer,
|
147
|
+
dtype=self.dtype_policy,
|
148
|
+
name="intermediate_dense",
|
149
|
+
)
|
150
|
+
self._intermediate_layer_norm = keras.layers.LayerNormalization(
|
151
|
+
epsilon=self.layer_norm_epsilon,
|
152
|
+
dtype=self.dtype_policy,
|
153
|
+
name="intermediate_layer_norm",
|
154
|
+
)
|
155
|
+
# The gather length does not affect any of our built variables, so
|
156
|
+
# we can pass any value here.
|
157
|
+
gather_length = None
|
158
|
+
shape = (inputs_shape[0], gather_length, inputs_shape[-1])
|
159
|
+
self._intermediate_dense.build(shape)
|
160
|
+
shape = (inputs_shape[0], gather_length, feature_size)
|
161
|
+
self._intermediate_layer_norm.build(shape)
|
162
|
+
if self.token_embedding is None:
|
163
|
+
self._kernel = self.add_weight(
|
164
|
+
name="output_kernel",
|
165
|
+
shape=[feature_size, self.vocabulary_size],
|
166
|
+
initializer=self.kernel_initializer,
|
167
|
+
dtype=self.dtype,
|
168
|
+
)
|
169
|
+
self._bias = self.add_weight(
|
170
|
+
name="output_bias",
|
171
|
+
shape=[self.vocabulary_size],
|
172
|
+
initializer=self.bias_initializer,
|
173
|
+
dtype=self.dtype,
|
174
|
+
)
|
175
|
+
self.built = True
|
176
|
+
|
177
|
+
def call(self, inputs, mask_positions):
|
178
|
+
if keras.config.backend() == "tensorflow":
|
179
|
+
import tensorflow as tf
|
180
|
+
|
181
|
+
# On the tf backend, we need to work around an issue with dynamic
|
182
|
+
# shape broadcasting in take_along_axis.
|
183
|
+
x = tf.gather(inputs, mask_positions, batch_dims=1)
|
184
|
+
else:
|
185
|
+
# Gather the encoded tokens at the masked indices.
|
186
|
+
mask_positions = ops.expand_dims(mask_positions, axis=-1)
|
187
|
+
x = ops.take_along_axis(inputs, mask_positions, axis=1)
|
188
|
+
|
189
|
+
# Apply a trainable linear transformation and a layer norm.
|
190
|
+
x = self._intermediate_dense(x)
|
191
|
+
x = self._intermediate_layer_norm(x)
|
192
|
+
|
193
|
+
# Transform encodings to vocabulary_size predictions.
|
194
|
+
if self.token_embedding:
|
195
|
+
outputs = self.token_embedding(x, reverse=True)
|
196
|
+
else:
|
197
|
+
outputs = ops.matmul(x, self._kernel)
|
198
|
+
outputs = ops.cast(outputs, self.compute_dtype)
|
199
|
+
outputs = outputs + self._bias
|
200
|
+
|
201
|
+
# Apply a final activation.
|
202
|
+
if self.activation is not None:
|
203
|
+
outputs = self.activation(outputs)
|
204
|
+
|
205
|
+
return outputs
|
206
|
+
|
207
|
+
@classmethod
|
208
|
+
def from_config(cls, config):
|
209
|
+
embedding = config.get("token_embedding")
|
210
|
+
if embedding:
|
211
|
+
config["token_embedding"] = keras.layers.deserialize(embedding)
|
212
|
+
return super().from_config(config)
|
213
|
+
|
214
|
+
def get_config(self):
|
215
|
+
config = super().get_config()
|
216
|
+
embedding_config = None
|
217
|
+
if self.token_embedding:
|
218
|
+
embedding_config = keras.layers.serialize(self.token_embedding)
|
219
|
+
config.update(
|
220
|
+
{
|
221
|
+
"vocabulary_size": self.vocabulary_size,
|
222
|
+
"token_embedding": embedding_config,
|
223
|
+
"intermediate_activation": keras.activations.serialize(
|
224
|
+
self.intermediate_activation
|
225
|
+
),
|
226
|
+
"activation": keras.activations.serialize(self.activation),
|
227
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
228
|
+
"kernel_initializer": keras.initializers.serialize(
|
229
|
+
self.kernel_initializer
|
230
|
+
),
|
231
|
+
"bias_initializer": keras.initializers.serialize(
|
232
|
+
self.bias_initializer
|
233
|
+
),
|
234
|
+
}
|
235
|
+
)
|
236
|
+
return config
|
237
|
+
|
238
|
+
def compute_output_shape(self, inputs_shape, mask_positions_shape):
|
239
|
+
return mask_positions_shape + (self.vocabulary_size,)
|
@@ -0,0 +1,123 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
|
20
|
+
|
21
|
+
@keras_hub_export("keras_hub.layers.PositionEmbedding")
|
22
|
+
class PositionEmbedding(keras.layers.Layer):
|
23
|
+
"""A layer which learns a position embedding for inputs sequences.
|
24
|
+
|
25
|
+
This class assumes that in the input tensor, the last dimension corresponds
|
26
|
+
to the features, and the dimension before the last corresponds to the
|
27
|
+
sequence.
|
28
|
+
|
29
|
+
This layer does not supporting masking, but can be combined with a
|
30
|
+
`keras.layers.Embedding` for padding mask support.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
sequence_length: The maximum length of the dynamic sequence.
|
34
|
+
initializer: The initializer to use for the embedding weights. Defaults
|
35
|
+
to `"glorot_uniform"`.
|
36
|
+
seq_axis: The axis of the input tensor where we add the embeddings.
|
37
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
38
|
+
including `name`, `trainable`, `dtype` etc.
|
39
|
+
|
40
|
+
Call arguments:
|
41
|
+
inputs: The tensor inputs to compute an embedding for, with shape
|
42
|
+
`(batch_size, sequence_length, hidden_dim)`. Only the input shape
|
43
|
+
will be used, as the position embedding does not depend on the
|
44
|
+
input sequence content.
|
45
|
+
start_index: An integer or integer tensor. The starting position to
|
46
|
+
compute the position embedding from. This is useful during cached
|
47
|
+
decoding, where each position is predicted separately in a loop.
|
48
|
+
|
49
|
+
Example:
|
50
|
+
|
51
|
+
Called directly on input.
|
52
|
+
>>> layer = keras_hub.layers.PositionEmbedding(sequence_length=10)
|
53
|
+
>>> layer(np.zeros((8, 10, 16)))
|
54
|
+
|
55
|
+
Combine with a token embedding.
|
56
|
+
```python
|
57
|
+
seq_length = 50
|
58
|
+
vocab_size = 5000
|
59
|
+
embed_dim = 128
|
60
|
+
inputs = keras.Input(shape=(seq_length,))
|
61
|
+
token_embeddings = keras.layers.Embedding(
|
62
|
+
input_dim=vocab_size, output_dim=embed_dim
|
63
|
+
)(inputs)
|
64
|
+
position_embeddings = keras_hub.layers.PositionEmbedding(
|
65
|
+
sequence_length=seq_length
|
66
|
+
)(token_embeddings)
|
67
|
+
outputs = token_embeddings + position_embeddings
|
68
|
+
```
|
69
|
+
|
70
|
+
Reference:
|
71
|
+
- [Devlin et al., 2019](https://arxiv.org/abs/1810.04805)
|
72
|
+
"""
|
73
|
+
|
74
|
+
def __init__(
|
75
|
+
self,
|
76
|
+
sequence_length,
|
77
|
+
initializer="glorot_uniform",
|
78
|
+
**kwargs,
|
79
|
+
):
|
80
|
+
super().__init__(**kwargs)
|
81
|
+
if sequence_length is None:
|
82
|
+
raise ValueError(
|
83
|
+
"`sequence_length` must be an Integer, received `None`."
|
84
|
+
)
|
85
|
+
self.sequence_length = int(sequence_length)
|
86
|
+
self.initializer = keras.initializers.get(initializer)
|
87
|
+
|
88
|
+
def get_config(self):
|
89
|
+
config = super().get_config()
|
90
|
+
config.update(
|
91
|
+
{
|
92
|
+
"sequence_length": self.sequence_length,
|
93
|
+
"initializer": keras.initializers.serialize(self.initializer),
|
94
|
+
}
|
95
|
+
)
|
96
|
+
return config
|
97
|
+
|
98
|
+
def build(self, inputs_shape):
|
99
|
+
feature_size = inputs_shape[-1]
|
100
|
+
self.position_embeddings = self.add_weight(
|
101
|
+
name="embeddings",
|
102
|
+
shape=[self.sequence_length, feature_size],
|
103
|
+
initializer=self.initializer,
|
104
|
+
trainable=True,
|
105
|
+
)
|
106
|
+
self.built = True
|
107
|
+
|
108
|
+
def call(self, inputs, start_index=0):
|
109
|
+
shape = ops.shape(inputs)
|
110
|
+
feature_length = shape[-1]
|
111
|
+
sequence_length = shape[-2]
|
112
|
+
# trim to match the length of the input sequence, which might be less
|
113
|
+
# than the sequence_length of the layer.
|
114
|
+
position_embeddings = ops.convert_to_tensor(self.position_embeddings)
|
115
|
+
position_embeddings = ops.slice(
|
116
|
+
position_embeddings,
|
117
|
+
(start_index, 0),
|
118
|
+
(sequence_length, feature_length),
|
119
|
+
)
|
120
|
+
return ops.broadcast_to(position_embeddings, shape)
|
121
|
+
|
122
|
+
def compute_output_shape(self, input_shape):
|
123
|
+
return input_shape
|
@@ -0,0 +1,311 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
from packaging.version import parse
|
18
|
+
|
19
|
+
from keras_hub.src.api_export import keras_hub_export
|
20
|
+
from keras_hub.src.utils.keras_utils import assert_quantization_support
|
21
|
+
|
22
|
+
|
23
|
+
@keras_hub_export("keras_hub.layers.ReversibleEmbedding")
|
24
|
+
class ReversibleEmbedding(keras.layers.Embedding):
|
25
|
+
"""An embedding layer which can project backwards to the input dim.
|
26
|
+
|
27
|
+
This layer is an extension of `keras.layers.Embedding` for language models.
|
28
|
+
This layer can be called "in reverse" with `reverse=True`, in which case the
|
29
|
+
layer will linearly project from `output_dim` back to `input_dim`.
|
30
|
+
|
31
|
+
By default, the reverse projection will use the transpose of the
|
32
|
+
`embeddings` weights to project to `input_dim` (weights are "tied"). If
|
33
|
+
`tie_weights=False`, the model will use a separate, trainable variable for
|
34
|
+
reverse projection.
|
35
|
+
|
36
|
+
This layer has no bias terms.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
input_dim: Integer. Size of the vocabulary,
|
40
|
+
i.e. maximum integer index + 1.
|
41
|
+
output_dim: Integer. Dimension of the dense embedding.
|
42
|
+
tie_weights: Boolean, whether or not the matrix for embedding and
|
43
|
+
the matrix for the `reverse` projection should share the same
|
44
|
+
weights.
|
45
|
+
embeddings_initializer: Initializer for the `embeddings`
|
46
|
+
matrix (see `keras.initializers`).
|
47
|
+
embeddings_regularizer: Regularizer function applied to
|
48
|
+
the `embeddings` matrix (see `keras.regularizers`).
|
49
|
+
embeddings_constraint: Constraint function applied to
|
50
|
+
the `embeddings` matrix (see `keras.constraints`).
|
51
|
+
mask_zero: Boolean, whether or not the input value 0 is a special
|
52
|
+
"padding" value that should be masked out.
|
53
|
+
reverse_dtype: The dtype for the reverse projection computation.
|
54
|
+
Defaults to the `compute_dtype` of the layer.
|
55
|
+
logit_soft_cap: If `logit_soft_cap` is set and `reverse=True`, the
|
56
|
+
output logits will be scaled by
|
57
|
+
`tanh(logits / logit_soft_cap) * logit_soft_cap`. This narrows the
|
58
|
+
range of output logits and can improve training.
|
59
|
+
**kwargs: other keyword arguments passed to `keras.layers.Embedding`,
|
60
|
+
including `name`, `trainable`, `dtype` etc.
|
61
|
+
|
62
|
+
Call arguments:
|
63
|
+
inputs: The tensor inputs to the layer.
|
64
|
+
reverse: Boolean. If `True` the layer will perform a linear projection
|
65
|
+
from `output_dim` to `input_dim`, instead of a normal embedding
|
66
|
+
call. Default to `False`.
|
67
|
+
|
68
|
+
Example:
|
69
|
+
```python
|
70
|
+
batch_size = 16
|
71
|
+
vocab_size = 100
|
72
|
+
hidden_dim = 32
|
73
|
+
seq_length = 50
|
74
|
+
|
75
|
+
# Generate random inputs.
|
76
|
+
token_ids = np.random.randint(vocab_size, size=(batch_size, seq_length))
|
77
|
+
|
78
|
+
embedding = keras_hub.layers.ReversibleEmbedding(vocab_size, hidden_dim)
|
79
|
+
# Embed tokens to shape `(batch_size, seq_length, hidden_dim)`.
|
80
|
+
hidden_states = embedding(token_ids)
|
81
|
+
# Project hidden states to shape `(batch_size, seq_length, vocab_size)`.
|
82
|
+
logits = embedding(hidden_states, reverse=True)
|
83
|
+
```
|
84
|
+
|
85
|
+
References:
|
86
|
+
- [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
|
87
|
+
- [Press and Wolf, 2016](https://arxiv.org/abs/1608.05859)
|
88
|
+
"""
|
89
|
+
|
90
|
+
def __init__(
|
91
|
+
self,
|
92
|
+
input_dim,
|
93
|
+
output_dim,
|
94
|
+
tie_weights=True,
|
95
|
+
embeddings_initializer="uniform",
|
96
|
+
embeddings_regularizer=None,
|
97
|
+
embeddings_constraint=None,
|
98
|
+
mask_zero=False,
|
99
|
+
reverse_dtype=None,
|
100
|
+
logit_soft_cap=None,
|
101
|
+
**kwargs,
|
102
|
+
):
|
103
|
+
super().__init__(
|
104
|
+
input_dim,
|
105
|
+
output_dim,
|
106
|
+
embeddings_initializer=embeddings_initializer,
|
107
|
+
embeddings_regularizer=embeddings_regularizer,
|
108
|
+
embeddings_constraint=embeddings_constraint,
|
109
|
+
mask_zero=mask_zero,
|
110
|
+
**kwargs,
|
111
|
+
)
|
112
|
+
self.tie_weights = tie_weights
|
113
|
+
self.reverse_dtype = reverse_dtype
|
114
|
+
self.logit_soft_cap = logit_soft_cap
|
115
|
+
|
116
|
+
def build(self, inputs_shape=None):
|
117
|
+
super().build(inputs_shape)
|
118
|
+
if (
|
119
|
+
not self.tie_weights
|
120
|
+
and getattr(self, "quantization_mode", None) != "int8"
|
121
|
+
):
|
122
|
+
self.reverse_embeddings = self.add_weight(
|
123
|
+
name="reverse_embeddings",
|
124
|
+
shape=(self.output_dim, self.input_dim),
|
125
|
+
initializer=self.embeddings_initializer,
|
126
|
+
dtype=self.dtype,
|
127
|
+
)
|
128
|
+
|
129
|
+
def call(self, inputs, reverse=False):
|
130
|
+
if reverse:
|
131
|
+
if self.tie_weights:
|
132
|
+
kernel = ops.transpose(ops.convert_to_tensor(self.embeddings))
|
133
|
+
else:
|
134
|
+
kernel = self.reverse_embeddings
|
135
|
+
if self.reverse_dtype is not None:
|
136
|
+
inputs = ops.cast(inputs, self.reverse_dtype)
|
137
|
+
kernel = ops.cast(kernel, self.reverse_dtype)
|
138
|
+
logits = ops.matmul(inputs, kernel)
|
139
|
+
# Optionally soft-cap logits.
|
140
|
+
if self.logit_soft_cap is not None:
|
141
|
+
soft_cap = self.logit_soft_cap
|
142
|
+
logits = ops.tanh(logits / soft_cap) * soft_cap
|
143
|
+
return logits
|
144
|
+
|
145
|
+
return super().call(inputs)
|
146
|
+
|
147
|
+
def get_config(self):
|
148
|
+
config = super().get_config()
|
149
|
+
config.update(
|
150
|
+
{
|
151
|
+
"tie_weights": self.tie_weights,
|
152
|
+
"reverse_dtype": self.reverse_dtype,
|
153
|
+
"logit_soft_cap": self.logit_soft_cap,
|
154
|
+
}
|
155
|
+
)
|
156
|
+
return config
|
157
|
+
|
158
|
+
def save_own_variables(self, store):
|
159
|
+
if not self.built:
|
160
|
+
return
|
161
|
+
super().save_own_variables(store)
|
162
|
+
# Before Keras 3.2, the reverse weight is saved in the super() call.
|
163
|
+
# After Keras 3.2, the reverse weight must be saved manually.
|
164
|
+
if parse(keras.version()) < parse("3.2.0"):
|
165
|
+
return
|
166
|
+
target_variables = []
|
167
|
+
if not self.tie_weights:
|
168
|
+
# Store the reverse embedding weights as the last weights.
|
169
|
+
target_variables.append(self.reverse_embeddings)
|
170
|
+
if getattr(self, "quantization_mode", None) == "int8":
|
171
|
+
target_variables.append(self.reverse_embeddings_scale)
|
172
|
+
for i, variable in enumerate(target_variables, start=len(store)):
|
173
|
+
store[str(i)] = variable
|
174
|
+
|
175
|
+
def load_own_variables(self, store):
|
176
|
+
if not self.built:
|
177
|
+
self.build()
|
178
|
+
super().load_own_variables(store)
|
179
|
+
if not self.tie_weights:
|
180
|
+
# Last weights in the stores are the reverse embedding weights.
|
181
|
+
target_variables = [self.reverse_embeddings]
|
182
|
+
if getattr(self, "quantization_mode", None) == "int8":
|
183
|
+
target_variables.append(self.reverse_embeddings_scale)
|
184
|
+
for i, variable in enumerate(
|
185
|
+
target_variables, start=len(store) - len(target_variables)
|
186
|
+
):
|
187
|
+
variable.assign(store[str(i)])
|
188
|
+
|
189
|
+
def compute_output_spec(self, inputs, reverse=False):
|
190
|
+
output_shape = list(inputs.shape)
|
191
|
+
if reverse:
|
192
|
+
output_shape[-1] = self.input_dim
|
193
|
+
else:
|
194
|
+
output_shape += [self.output_dim]
|
195
|
+
return keras.KerasTensor(output_shape, dtype=self.compute_dtype)
|
196
|
+
|
197
|
+
# Quantization-related (int8) methods
|
198
|
+
|
199
|
+
def quantized_call(self, inputs, reverse=False):
|
200
|
+
# TODO (hongyu): This function could be removed once we add `*args` and
|
201
|
+
# `**kwargs` for `Embedding.quantized_call`
|
202
|
+
if self.quantization_mode == "int8":
|
203
|
+
return self._int8_call(inputs, reverse=reverse)
|
204
|
+
else:
|
205
|
+
self._quantization_mode_error(self.quantization_mode)
|
206
|
+
|
207
|
+
def _int8_build(
|
208
|
+
self,
|
209
|
+
embeddings_initializer="zeros",
|
210
|
+
embeddings_scale_initializer="ones",
|
211
|
+
reverse_embeddings_initializer="zeros",
|
212
|
+
reverse_embeddings_scale_initializer="ones",
|
213
|
+
):
|
214
|
+
super()._int8_build(
|
215
|
+
embeddings_initializer, embeddings_scale_initializer
|
216
|
+
)
|
217
|
+
self.inputs_quantizer = keras.quantizers.AbsMaxQuantizer(axis=-1)
|
218
|
+
if not self.tie_weights:
|
219
|
+
self.reverse_embeddings = self.add_weight(
|
220
|
+
name="reverse_embeddings",
|
221
|
+
shape=(self.output_dim, self.input_dim),
|
222
|
+
initializer=reverse_embeddings_initializer,
|
223
|
+
dtype="int8",
|
224
|
+
trainable=False,
|
225
|
+
)
|
226
|
+
self.reverse_embeddings_scale = self.add_weight(
|
227
|
+
name="reverse_embeddings_scale",
|
228
|
+
shape=(self.input_dim,),
|
229
|
+
initializer=reverse_embeddings_scale_initializer,
|
230
|
+
trainable=False,
|
231
|
+
)
|
232
|
+
|
233
|
+
def _int8_call(self, inputs, reverse=False):
|
234
|
+
if reverse:
|
235
|
+
if self.tie_weights:
|
236
|
+
kernel = ops.transpose(self._embeddings)
|
237
|
+
scale = ops.transpose(self.embeddings_scale)
|
238
|
+
else:
|
239
|
+
kernel = self.reverse_embeddings
|
240
|
+
scale = self.reverse_embeddings_scale
|
241
|
+
inputs, inputs_scale = self.inputs_quantizer(inputs)
|
242
|
+
logits = ops.matmul(inputs, kernel)
|
243
|
+
# De-scale outputs
|
244
|
+
logits = ops.cast(logits, self.compute_dtype)
|
245
|
+
logits = ops.divide(logits, ops.multiply(inputs_scale, scale))
|
246
|
+
# Optionally soft-cap logits.
|
247
|
+
if self.logit_soft_cap is not None:
|
248
|
+
soft_cap = self.logit_soft_cap
|
249
|
+
logits = ops.tanh(logits / soft_cap) * soft_cap
|
250
|
+
return logits
|
251
|
+
|
252
|
+
return super()._int8_call(inputs)
|
253
|
+
|
254
|
+
def quantize(self, mode, type_check=True):
|
255
|
+
import gc
|
256
|
+
import inspect
|
257
|
+
|
258
|
+
assert_quantization_support()
|
259
|
+
if type_check and type(self) is not ReversibleEmbedding:
|
260
|
+
raise NotImplementedError(
|
261
|
+
f"Layer {self.__class__.__name__} does not have a `quantize()` "
|
262
|
+
"method implemented."
|
263
|
+
)
|
264
|
+
self._check_quantize_args(mode, self.compute_dtype)
|
265
|
+
|
266
|
+
def abs_max_quantize(inputs, axis):
|
267
|
+
sig = inspect.signature(keras.quantizers.abs_max_quantize)
|
268
|
+
if "to_numpy" in sig.parameters:
|
269
|
+
return keras.quantizers.abs_max_quantize(
|
270
|
+
inputs, axis=axis, to_numpy=True
|
271
|
+
)
|
272
|
+
else:
|
273
|
+
# `keras<=3.4.1` doesn't support `to_numpy`
|
274
|
+
return keras.quantizers.abs_max_quantize(inputs, axis=axis)
|
275
|
+
|
276
|
+
self._tracker.unlock()
|
277
|
+
if mode == "int8":
|
278
|
+
embeddings, embeddings_scale = abs_max_quantize(
|
279
|
+
self._embeddings, axis=-1
|
280
|
+
)
|
281
|
+
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
|
282
|
+
self._untrack_variable(self._embeddings)
|
283
|
+
del self._embeddings
|
284
|
+
if not self.tie_weights:
|
285
|
+
reverse_embeddings, reverse_embeddings_scale = abs_max_quantize(
|
286
|
+
self.reverse_embeddings, axis=0
|
287
|
+
)
|
288
|
+
reverse_embeddings_scale = ops.squeeze(
|
289
|
+
reverse_embeddings_scale, axis=0
|
290
|
+
)
|
291
|
+
self._untrack_variable(self.reverse_embeddings)
|
292
|
+
del self.reverse_embeddings
|
293
|
+
else:
|
294
|
+
reverse_embeddings = None
|
295
|
+
reverse_embeddings_scale = None
|
296
|
+
self._int8_build(
|
297
|
+
lambda shape, dtype: embeddings,
|
298
|
+
lambda shape, dtype: embeddings_scale,
|
299
|
+
lambda shape, dtype: reverse_embeddings,
|
300
|
+
lambda shape, dtype: reverse_embeddings_scale,
|
301
|
+
)
|
302
|
+
else:
|
303
|
+
raise self._quantization_mode_error(mode)
|
304
|
+
self._tracker.lock()
|
305
|
+
|
306
|
+
if self.dtype_policy.quantization_mode is None:
|
307
|
+
policy = keras.dtype_policies.get(
|
308
|
+
f"{mode}_from_{self.dtype_policy.name}"
|
309
|
+
)
|
310
|
+
self.dtype_policy = policy
|
311
|
+
gc.collect()
|