keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,261 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
|
17
|
+
from keras_hub.src.api_export import keras_hub_export
|
18
|
+
from keras_hub.src.layers.modeling.reversible_embedding import (
|
19
|
+
ReversibleEmbedding,
|
20
|
+
)
|
21
|
+
from keras_hub.src.models.backbone import Backbone
|
22
|
+
from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm
|
23
|
+
from keras_hub.src.models.t5.t5_transformer_layer import T5TransformerLayer
|
24
|
+
|
25
|
+
|
26
|
+
@keras_hub_export("keras_hub.models.T5Backbone")
|
27
|
+
class T5Backbone(Backbone):
|
28
|
+
"""T5 encoder-decoder backbone model.
|
29
|
+
|
30
|
+
T5 is a LLM pretrained on a mix of unsupervised and supervised tasks,
|
31
|
+
where each task is converted to a sequence-to-sequence format.
|
32
|
+
T5 works well on a variety of tasks out-of-the-box by prepending
|
33
|
+
various prefixex to the input sequence, e.g., for translation:
|
34
|
+
`"translate English to German: ..."`, for summarization:
|
35
|
+
`"summarize: ..."`.
|
36
|
+
|
37
|
+
T5 was introduced in
|
38
|
+
[Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683)
|
39
|
+
|
40
|
+
The default constructor gives a fully customizable, randomly initialized T5
|
41
|
+
model with any number of layers, heads, and embedding dimensions. To load
|
42
|
+
preset architectures and weights, use the `from_preset` constructor.
|
43
|
+
|
44
|
+
Disclaimer: Pre-trained models are provided on an "as is" basis, without
|
45
|
+
warranties or conditions of any kind.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
vocabulary_size: int. The size of the token vocabulary.
|
49
|
+
num_layers: int. The number of Transformer layers.
|
50
|
+
num_heads: int. The number of attention heads for each Transformer.
|
51
|
+
The hidden size must be divisible by the number of attention heads.
|
52
|
+
hidden_dim: int. The hidden size of the Transformer layers.
|
53
|
+
intermediate_dim: int. The output dimension of the first Dense layer in
|
54
|
+
a two-layer feedforward network for each Transformer layer.
|
55
|
+
key_value_dim: int. The dimension of each head of the key/value
|
56
|
+
projections in the multi-head attention layers. Defaults to
|
57
|
+
hidden_dim / num_heads.
|
58
|
+
dropout: float. Dropout probability for the Transformer layers.
|
59
|
+
activation: activation function (or activation string name). The
|
60
|
+
activation to be used in the inner dense blocks of the
|
61
|
+
Transformer layers. Defaults to `"relu"`.
|
62
|
+
use_gated_activation: boolean. Whether to use activation gating in
|
63
|
+
the inner dense blocks of the Transformer layers.
|
64
|
+
The original T5 architecture didn't use gating, but more
|
65
|
+
recent versions do. Defaults to `True`.
|
66
|
+
layer_norm_epsilon: float. Epsilon factor to be used in the
|
67
|
+
layer normalization layers in the Transformer layers.
|
68
|
+
tie_embedding_weights: boolean. If `True`, the weights of the token
|
69
|
+
embedding and the weights projecting language model outputs from
|
70
|
+
`hidden_dim`.
|
71
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
72
|
+
for model computations and weights. Note that some computations,
|
73
|
+
such as softmax and layer normalization, will always be done at
|
74
|
+
float32 precision regardless of dtype.
|
75
|
+
"""
|
76
|
+
|
77
|
+
def __init__(
|
78
|
+
self,
|
79
|
+
vocabulary_size,
|
80
|
+
num_layers,
|
81
|
+
num_heads,
|
82
|
+
hidden_dim,
|
83
|
+
intermediate_dim,
|
84
|
+
key_value_dim=None,
|
85
|
+
dropout=0.1,
|
86
|
+
activation="relu",
|
87
|
+
use_gated_activation=True,
|
88
|
+
layer_norm_epsilon=1e-06,
|
89
|
+
tie_embedding_weights=True,
|
90
|
+
dtype=None,
|
91
|
+
**kwargs,
|
92
|
+
):
|
93
|
+
# Token embedding layer. This layer is shared by encoder and decoder.
|
94
|
+
self.token_embedding = ReversibleEmbedding(
|
95
|
+
input_dim=vocabulary_size,
|
96
|
+
output_dim=hidden_dim,
|
97
|
+
tie_weights=tie_embedding_weights,
|
98
|
+
embeddings_initializer=keras.initializers.TruncatedNormal(1.0),
|
99
|
+
dtype=dtype,
|
100
|
+
name="token_embedding",
|
101
|
+
)
|
102
|
+
self.encoder_embedding_dropout = keras.layers.Dropout(
|
103
|
+
dropout,
|
104
|
+
dtype=dtype,
|
105
|
+
name="encoder_embedding_dropout",
|
106
|
+
)
|
107
|
+
self.encoder_transformer_layers = []
|
108
|
+
for i in range(num_layers):
|
109
|
+
layer = T5TransformerLayer(
|
110
|
+
is_decoder=False,
|
111
|
+
hidden_dim=hidden_dim,
|
112
|
+
intermediate_dim=intermediate_dim,
|
113
|
+
key_value_dim=key_value_dim or hidden_dim // num_heads,
|
114
|
+
dropout=dropout,
|
115
|
+
activation=activation,
|
116
|
+
layer_norm_epsilon=layer_norm_epsilon,
|
117
|
+
num_heads=num_heads,
|
118
|
+
use_gated_activation=use_gated_activation,
|
119
|
+
use_relative_attention_bias=bool(i == 0),
|
120
|
+
dtype=dtype,
|
121
|
+
name=f"transformer_encoder_layer_{i}",
|
122
|
+
)
|
123
|
+
self.encoder_transformer_layers.append(layer)
|
124
|
+
self.encoder_layer_norm = T5LayerNorm(
|
125
|
+
epsilon=layer_norm_epsilon,
|
126
|
+
dtype=dtype,
|
127
|
+
name="encoder_output_layer_norm",
|
128
|
+
)
|
129
|
+
self.encoder_dropout = keras.layers.Dropout(
|
130
|
+
dropout,
|
131
|
+
dtype=dtype,
|
132
|
+
name="encoder_output_dropout",
|
133
|
+
)
|
134
|
+
self.decoder_embedding_dropout = keras.layers.Dropout(
|
135
|
+
dropout,
|
136
|
+
dtype=dtype,
|
137
|
+
name="decoder_embedding_dropout",
|
138
|
+
)
|
139
|
+
self.decoder_transformer_layers = []
|
140
|
+
for i in range(num_layers):
|
141
|
+
layer = T5TransformerLayer(
|
142
|
+
is_decoder=True,
|
143
|
+
hidden_dim=hidden_dim,
|
144
|
+
intermediate_dim=intermediate_dim,
|
145
|
+
key_value_dim=key_value_dim or hidden_dim // num_heads,
|
146
|
+
dropout=dropout,
|
147
|
+
activation=activation,
|
148
|
+
layer_norm_epsilon=layer_norm_epsilon,
|
149
|
+
num_heads=num_heads,
|
150
|
+
use_gated_activation=use_gated_activation,
|
151
|
+
use_relative_attention_bias=bool(i == 0),
|
152
|
+
dtype=dtype,
|
153
|
+
name=f"transformer_decoder_layer_{i}",
|
154
|
+
)
|
155
|
+
self.decoder_transformer_layers.append(layer)
|
156
|
+
self.decoder_layer_norm = T5LayerNorm(
|
157
|
+
epsilon=layer_norm_epsilon,
|
158
|
+
dtype=dtype,
|
159
|
+
name="decoder_output_layer_norm",
|
160
|
+
)
|
161
|
+
self.decoder_dropout = keras.layers.Dropout(
|
162
|
+
dropout,
|
163
|
+
dtype=dtype,
|
164
|
+
name="decoder_output_dropout",
|
165
|
+
)
|
166
|
+
|
167
|
+
# === Functional Model ===
|
168
|
+
encoder_token_id_input = keras.Input(
|
169
|
+
shape=(None,), dtype="int32", name="encoder_token_ids"
|
170
|
+
)
|
171
|
+
encoder_padding_mask_input = keras.Input(
|
172
|
+
shape=(None,), dtype="int32", name="encoder_padding_mask"
|
173
|
+
)
|
174
|
+
decoder_token_id_input = keras.Input(
|
175
|
+
shape=(None,), dtype="int32", name="decoder_token_ids"
|
176
|
+
)
|
177
|
+
decoder_padding_mask_input = keras.Input(
|
178
|
+
shape=(None,), dtype="int32", name="decoder_padding_mask"
|
179
|
+
)
|
180
|
+
# Encoder.
|
181
|
+
x = self.token_embedding(encoder_token_id_input)
|
182
|
+
x = self.encoder_embedding_dropout(x)
|
183
|
+
encoder_attention_mask = encoder_padding_mask_input[:, None, :]
|
184
|
+
position_bias = None
|
185
|
+
for transformer_layer in self.encoder_transformer_layers:
|
186
|
+
output = transformer_layer(
|
187
|
+
x,
|
188
|
+
attention_mask=encoder_attention_mask,
|
189
|
+
position_bias=position_bias,
|
190
|
+
use_causal_mask=False,
|
191
|
+
)
|
192
|
+
if isinstance(output, tuple):
|
193
|
+
x, position_bias = output
|
194
|
+
x = self.encoder_layer_norm(x)
|
195
|
+
x = self.encoder_dropout(x)
|
196
|
+
encoder_output = x
|
197
|
+
# Decoder.
|
198
|
+
x = self.token_embedding(decoder_token_id_input)
|
199
|
+
x = self.decoder_embedding_dropout(x)
|
200
|
+
decoder_attention_mask = decoder_padding_mask_input[:, None, :]
|
201
|
+
position_bias = None
|
202
|
+
for transformer_layer in self.decoder_transformer_layers:
|
203
|
+
output = transformer_layer(
|
204
|
+
x,
|
205
|
+
attention_mask=decoder_attention_mask,
|
206
|
+
position_bias=position_bias,
|
207
|
+
encoder_hidden_states=encoder_output,
|
208
|
+
encoder_attention_mask=encoder_attention_mask,
|
209
|
+
use_causal_mask=True,
|
210
|
+
)
|
211
|
+
if isinstance(output, tuple):
|
212
|
+
x, position_bias = output
|
213
|
+
x = self.decoder_layer_norm(x)
|
214
|
+
x = self.decoder_dropout(x)
|
215
|
+
decoder_output = x
|
216
|
+
super().__init__(
|
217
|
+
{
|
218
|
+
"encoder_token_ids": encoder_token_id_input,
|
219
|
+
"encoder_padding_mask": encoder_padding_mask_input,
|
220
|
+
"decoder_token_ids": decoder_token_id_input,
|
221
|
+
"decoder_padding_mask": decoder_padding_mask_input,
|
222
|
+
},
|
223
|
+
outputs={
|
224
|
+
"encoder_sequence_output": encoder_output,
|
225
|
+
"decoder_sequence_output": decoder_output,
|
226
|
+
},
|
227
|
+
dtype=dtype,
|
228
|
+
**kwargs,
|
229
|
+
)
|
230
|
+
|
231
|
+
# === Config ===
|
232
|
+
self.vocabulary_size = vocabulary_size
|
233
|
+
self.hidden_dim = hidden_dim
|
234
|
+
self.intermediate_dim = intermediate_dim
|
235
|
+
self.num_layers = num_layers
|
236
|
+
self.num_heads = num_heads
|
237
|
+
self.activation = keras.activations.get(activation)
|
238
|
+
self.key_value_dim = key_value_dim
|
239
|
+
self.dropout = dropout
|
240
|
+
self.use_gated_activation = use_gated_activation
|
241
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
242
|
+
self.tie_embedding_weights = tie_embedding_weights
|
243
|
+
|
244
|
+
def get_config(self):
|
245
|
+
config = super().get_config()
|
246
|
+
config.update(
|
247
|
+
{
|
248
|
+
"vocabulary_size": self.vocabulary_size,
|
249
|
+
"hidden_dim": self.hidden_dim,
|
250
|
+
"intermediate_dim": self.intermediate_dim,
|
251
|
+
"num_layers": self.num_layers,
|
252
|
+
"num_heads": self.num_heads,
|
253
|
+
"activation": keras.activations.serialize(self.activation),
|
254
|
+
"key_value_dim": self.key_value_dim,
|
255
|
+
"dropout": self.dropout,
|
256
|
+
"use_gated_activation": self.use_gated_activation,
|
257
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
258
|
+
"tie_embedding_weights": self.tie_embedding_weights,
|
259
|
+
}
|
260
|
+
)
|
261
|
+
return config
|
@@ -0,0 +1,35 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
|
19
|
+
class T5LayerNorm(keras.layers.Layer):
|
20
|
+
def __init__(self, epsilon=1e-6, **kwargs):
|
21
|
+
super().__init__(**kwargs)
|
22
|
+
self.epsilon = epsilon
|
23
|
+
|
24
|
+
def build(self, input_shape):
|
25
|
+
self.weight = self.add_weight(
|
26
|
+
name="weight",
|
27
|
+
shape=(input_shape[-1],),
|
28
|
+
initializer="ones",
|
29
|
+
)
|
30
|
+
self.built = True
|
31
|
+
|
32
|
+
def call(self, hidden_states):
|
33
|
+
variance = ops.mean(ops.square(hidden_states), axis=-1, keepdims=True)
|
34
|
+
hidden_states = hidden_states * ops.rsqrt(variance + self.epsilon)
|
35
|
+
return self.weight * hidden_states
|
@@ -0,0 +1,324 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
import numpy as np
|
17
|
+
from keras import ops
|
18
|
+
|
19
|
+
|
20
|
+
class T5MultiHeadAttention(keras.layers.Layer):
|
21
|
+
# This layer is adapted from Hugging Face
|
22
|
+
# Ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_tf_t5.py
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
is_decoder,
|
26
|
+
hidden_dim,
|
27
|
+
key_value_dim,
|
28
|
+
num_heads,
|
29
|
+
dropout,
|
30
|
+
use_relative_attention_bias=False,
|
31
|
+
**kwargs,
|
32
|
+
):
|
33
|
+
super().__init__(**kwargs)
|
34
|
+
self.is_decoder = is_decoder
|
35
|
+
self.hidden_dim = hidden_dim
|
36
|
+
self.key_value_dim = key_value_dim
|
37
|
+
self.num_heads = num_heads
|
38
|
+
self.use_relative_attention_bias = use_relative_attention_bias
|
39
|
+
|
40
|
+
self.inner_dim = self.num_heads * self.key_value_dim
|
41
|
+
self.relative_attention_buckets = 32
|
42
|
+
self.relative_attention_max_distance = 128
|
43
|
+
|
44
|
+
self.query_projector = keras.layers.Dense(
|
45
|
+
self.inner_dim,
|
46
|
+
use_bias=False,
|
47
|
+
kernel_initializer=keras.initializers.RandomNormal(
|
48
|
+
mean=0, stddev=(self.inner_dim * self.key_value_dim) ** -0.5
|
49
|
+
),
|
50
|
+
dtype=self.dtype_policy,
|
51
|
+
name="query_projector",
|
52
|
+
)
|
53
|
+
self.key_projector = keras.layers.Dense(
|
54
|
+
self.inner_dim,
|
55
|
+
use_bias=False,
|
56
|
+
kernel_initializer=keras.initializers.RandomNormal(
|
57
|
+
mean=0, stddev=self.inner_dim**-0.5
|
58
|
+
),
|
59
|
+
dtype=self.dtype_policy,
|
60
|
+
name="key_projector",
|
61
|
+
)
|
62
|
+
self.value_projector = keras.layers.Dense(
|
63
|
+
self.inner_dim,
|
64
|
+
use_bias=False,
|
65
|
+
kernel_initializer=keras.initializers.RandomNormal(
|
66
|
+
mean=0, stddev=self.inner_dim**-0.5
|
67
|
+
),
|
68
|
+
dtype=self.dtype_policy,
|
69
|
+
name="value_projector",
|
70
|
+
)
|
71
|
+
self.output_projector = keras.layers.Dense(
|
72
|
+
self.hidden_dim,
|
73
|
+
use_bias=False,
|
74
|
+
kernel_initializer=keras.initializers.RandomNormal(
|
75
|
+
mean=0, stddev=self.inner_dim**-0.5
|
76
|
+
),
|
77
|
+
dtype=self.dtype_policy,
|
78
|
+
name="output_projector",
|
79
|
+
)
|
80
|
+
self.dropout_layer = keras.layers.Dropout(
|
81
|
+
dropout,
|
82
|
+
dtype=self.dtype_policy,
|
83
|
+
)
|
84
|
+
|
85
|
+
if self.use_relative_attention_bias:
|
86
|
+
self.relative_attention_bias = self.add_weight(
|
87
|
+
name="embeddings",
|
88
|
+
shape=[self.relative_attention_buckets, self.num_heads],
|
89
|
+
initializer=keras.initializers.RandomNormal(
|
90
|
+
mean=0, stddev=self.inner_dim**-0.5
|
91
|
+
),
|
92
|
+
)
|
93
|
+
|
94
|
+
@staticmethod
|
95
|
+
def _relative_position_bucket(
|
96
|
+
relative_position, bidirectional=True, num_buckets=32, max_distance=128
|
97
|
+
):
|
98
|
+
"""Adapted from Mesh Tensorflow.
|
99
|
+
|
100
|
+
Translate relative position to a bucket number for relative attention.
|
101
|
+
The relative position is defined as memory_position - query_position,
|
102
|
+
i.e. the distance in tokens from the attending position to the
|
103
|
+
attended-to position. If bidirectional=False, then positive relative
|
104
|
+
positions are invalid. We use smaller buckets for
|
105
|
+
small absolute relative_position and larger buckets for larger absolute
|
106
|
+
relative_positions. All relative positions >= max_distance map to
|
107
|
+
the same bucket. All relative positions <= -max_distance map to
|
108
|
+
the same bucket. This should allow for more graceful generalization to
|
109
|
+
longer sequences than the model has been trained on.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
relative_position: an int32 Tensor
|
113
|
+
bidirectional: a boolean - whether the attention is bidirectional
|
114
|
+
num_buckets: an integer
|
115
|
+
max_distance: an integer
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
Tensor with the same shape as relative_position,
|
119
|
+
containing int32 values in the range [0, num_buckets)
|
120
|
+
"""
|
121
|
+
relative_buckets = 0
|
122
|
+
if bidirectional:
|
123
|
+
num_buckets //= 2
|
124
|
+
relative_buckets += (
|
125
|
+
ops.cast(
|
126
|
+
ops.greater(relative_position, 0),
|
127
|
+
dtype=relative_position.dtype,
|
128
|
+
)
|
129
|
+
* num_buckets
|
130
|
+
)
|
131
|
+
relative_position = ops.abs(relative_position)
|
132
|
+
else:
|
133
|
+
relative_position = -ops.minimum(relative_position, 0)
|
134
|
+
# now n is in the range [0, inf)
|
135
|
+
max_exact = num_buckets // 2
|
136
|
+
is_small = ops.less(relative_position, max_exact)
|
137
|
+
relative_position_if_large = max_exact + ops.cast(
|
138
|
+
ops.log(
|
139
|
+
ops.cast(relative_position, "float32")
|
140
|
+
/ ops.cast(max_exact, "float32")
|
141
|
+
)
|
142
|
+
/ ops.cast(ops.log(max_distance / max_exact), "float32")
|
143
|
+
* (num_buckets - max_exact),
|
144
|
+
dtype=relative_position.dtype,
|
145
|
+
)
|
146
|
+
relative_position_if_large = ops.minimum(
|
147
|
+
relative_position_if_large, num_buckets - 1
|
148
|
+
)
|
149
|
+
relative_buckets += ops.where(
|
150
|
+
is_small, relative_position, relative_position_if_large
|
151
|
+
)
|
152
|
+
return relative_buckets
|
153
|
+
|
154
|
+
def compute_bias(self, query_length, key_length):
|
155
|
+
"""Compute binned relative position bias"""
|
156
|
+
context_position = ops.arange(query_length)[:, None]
|
157
|
+
memory_position = ops.arange(key_length)[None, :]
|
158
|
+
relative_position = (
|
159
|
+
memory_position - context_position
|
160
|
+
) # shape (query_length, key_length)
|
161
|
+
relative_position_bucket = self._relative_position_bucket(
|
162
|
+
relative_position,
|
163
|
+
bidirectional=(not self.is_decoder),
|
164
|
+
num_buckets=self.relative_attention_buckets,
|
165
|
+
max_distance=self.relative_attention_max_distance,
|
166
|
+
)
|
167
|
+
values = ops.take(
|
168
|
+
self.relative_attention_bias, relative_position_bucket, axis=0
|
169
|
+
) # shape (query_length, key_length, num_heads)
|
170
|
+
values = ops.expand_dims(
|
171
|
+
ops.transpose(values, axes=(2, 0, 1)), axis=0
|
172
|
+
) # shape (1, num_heads, query_length, key_length)
|
173
|
+
return values
|
174
|
+
|
175
|
+
def call(
|
176
|
+
self,
|
177
|
+
hidden_states,
|
178
|
+
mask=None,
|
179
|
+
key_value_states=None,
|
180
|
+
position_bias=None,
|
181
|
+
past_key_value=None,
|
182
|
+
layer_head_mask=None,
|
183
|
+
query_length=None,
|
184
|
+
training=False,
|
185
|
+
):
|
186
|
+
# Input is (batch_size, query_length, dim)
|
187
|
+
# past_key_value[0] is (batch_size, num_heads, q_len - 1, dim_per_head)
|
188
|
+
batch_size, seq_length = ops.shape(hidden_states)[:2]
|
189
|
+
|
190
|
+
real_seq_length = seq_length
|
191
|
+
|
192
|
+
if past_key_value is not None:
|
193
|
+
if len(past_key_value) != 2:
|
194
|
+
raise ValueError(
|
195
|
+
f"Argument `past_key_value` should have 2 past states: "
|
196
|
+
f"keys and values. Got {len(past_key_value)} past states."
|
197
|
+
)
|
198
|
+
real_seq_length += (
|
199
|
+
ops.shape(past_key_value[0])[2]
|
200
|
+
if query_length is None
|
201
|
+
else query_length
|
202
|
+
)
|
203
|
+
|
204
|
+
key_length = (
|
205
|
+
real_seq_length
|
206
|
+
if key_value_states is None
|
207
|
+
else ops.shape(key_value_states)[1]
|
208
|
+
)
|
209
|
+
|
210
|
+
def shape(hidden_states):
|
211
|
+
return ops.transpose(
|
212
|
+
ops.reshape(
|
213
|
+
hidden_states,
|
214
|
+
(batch_size, -1, self.num_heads, self.key_value_dim),
|
215
|
+
),
|
216
|
+
axes=(0, 2, 1, 3),
|
217
|
+
)
|
218
|
+
|
219
|
+
def unshape(hidden_states):
|
220
|
+
return ops.reshape(
|
221
|
+
ops.transpose(hidden_states, axes=(0, 2, 1, 3)),
|
222
|
+
(batch_size, -1, self.inner_dim),
|
223
|
+
)
|
224
|
+
|
225
|
+
def project(
|
226
|
+
hidden_states, proj_layer, key_value_states, past_key_value
|
227
|
+
):
|
228
|
+
"""projects hidden states correctly to key/query states"""
|
229
|
+
if key_value_states is None:
|
230
|
+
# self-attention
|
231
|
+
# (batch_size, num_heads, seq_length, dim_per_head)
|
232
|
+
hidden_states = shape(proj_layer(hidden_states))
|
233
|
+
elif past_key_value is None:
|
234
|
+
# cross-attention
|
235
|
+
# (batch_size, num_heads, seq_length, dim_per_head)
|
236
|
+
hidden_states = shape(proj_layer(key_value_states))
|
237
|
+
|
238
|
+
if past_key_value is not None:
|
239
|
+
if key_value_states is None:
|
240
|
+
# self-attention
|
241
|
+
# (batch_size, num_heads, key_length, dim_per_head)
|
242
|
+
hidden_states = ops.concat(
|
243
|
+
[past_key_value, hidden_states], axis=2
|
244
|
+
)
|
245
|
+
else:
|
246
|
+
# cross-attention
|
247
|
+
hidden_states = past_key_value
|
248
|
+
return hidden_states
|
249
|
+
|
250
|
+
# get query
|
251
|
+
query_states = shape(
|
252
|
+
self.query_projector(hidden_states)
|
253
|
+
) # (batch_size, num_heads, query_length, dim_per_head)
|
254
|
+
|
255
|
+
# get key/value
|
256
|
+
key_states = project(
|
257
|
+
hidden_states,
|
258
|
+
self.key_projector,
|
259
|
+
key_value_states,
|
260
|
+
past_key_value[0] if past_key_value is not None else None,
|
261
|
+
)
|
262
|
+
value_states = project(
|
263
|
+
hidden_states,
|
264
|
+
self.value_projector,
|
265
|
+
key_value_states,
|
266
|
+
past_key_value[1] if past_key_value is not None else None,
|
267
|
+
)
|
268
|
+
|
269
|
+
scores = ops.einsum(
|
270
|
+
"bnqd,bnkd->bnqk", query_states, key_states
|
271
|
+
) # (batch_size, num_heads, query_length, key_length)
|
272
|
+
|
273
|
+
if position_bias is None:
|
274
|
+
if not self.use_relative_attention_bias:
|
275
|
+
position_bias = ops.zeros(
|
276
|
+
(1, self.num_heads, real_seq_length, key_length),
|
277
|
+
self.compute_dtype,
|
278
|
+
)
|
279
|
+
else:
|
280
|
+
position_bias = self.compute_bias(real_seq_length, key_length)
|
281
|
+
|
282
|
+
# if key and values are already calculated we want only
|
283
|
+
# the last query position bias
|
284
|
+
if past_key_value is not None:
|
285
|
+
if not self.use_relative_attention_bias:
|
286
|
+
position_bias = position_bias[:, :, -seq_length:, :]
|
287
|
+
else:
|
288
|
+
# we might have a padded past structure,
|
289
|
+
# in which case we want to fetch the position bias slice
|
290
|
+
# right after the most recently filled past index
|
291
|
+
most_recently_filled_past_index = ops.amax(
|
292
|
+
ops.where(past_key_value[0][0, 0, :, 0] != 0.0)
|
293
|
+
)
|
294
|
+
position_bias = ops.slice(
|
295
|
+
position_bias,
|
296
|
+
(0, 0, most_recently_filled_past_index + 1, 0),
|
297
|
+
(1, self.num_heads, seq_length, real_seq_length),
|
298
|
+
)
|
299
|
+
|
300
|
+
if mask is not None:
|
301
|
+
# Add a new mask axis for the head dim.
|
302
|
+
mask = mask[:, np.newaxis, :, :]
|
303
|
+
# Add a very large negative position bias for masked positions.
|
304
|
+
mask = (1.0 - ops.cast(mask, position_bias.dtype)) * -1e9
|
305
|
+
position_bias = position_bias + mask
|
306
|
+
|
307
|
+
scores += ops.cast(position_bias, scores.dtype)
|
308
|
+
weights = ops.nn.softmax(
|
309
|
+
scores, axis=-1
|
310
|
+
) # (batch_size, num_heads, query_length, key_length)
|
311
|
+
weights = self.dropout_layer(
|
312
|
+
weights, training=training
|
313
|
+
) # (batch_size, num_heads, query_length, key_length)
|
314
|
+
|
315
|
+
# Optionally mask heads
|
316
|
+
if layer_head_mask is not None:
|
317
|
+
weights = ops.reshape(layer_head_mask, (1, -1, 1, 1)) * weights
|
318
|
+
|
319
|
+
attention_output = ops.matmul(
|
320
|
+
weights, value_states
|
321
|
+
) # (batch_size, num_heads, query_length, dim_per_head)
|
322
|
+
|
323
|
+
attention_output = self.output_projector(unshape(attention_output))
|
324
|
+
return (attention_output, position_bias)
|