keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,227 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
|
17
|
+
from keras_hub.src.models.deberta_v3.disentangled_self_attention import (
|
18
|
+
DisentangledSelfAttention,
|
19
|
+
)
|
20
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
21
|
+
|
22
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import ( # isort:skip
|
23
|
+
merge_padding_and_attention_mask,
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
class DisentangledAttentionEncoder(keras.layers.Layer):
|
28
|
+
"""Disentangled attention encoder.
|
29
|
+
|
30
|
+
This class follows the architecture of the disentangled attention encoder
|
31
|
+
layer in the paper
|
32
|
+
["DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing"](https://arxiv.org/abs/2111.09543).
|
33
|
+
Users can instantiate multiple instances of this class to stack up a
|
34
|
+
an encoder model which has disentangled self-attention.
|
35
|
+
|
36
|
+
`DisentangledAttentionEncoder` is similar to
|
37
|
+
`keras_hub.layers.TransformerEncoder`, except for the attention layer - it
|
38
|
+
uses disentangled self-attention instead of multi-head attention.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
intermediate_dim: int, the hidden size of feedforward network.
|
42
|
+
num_heads: int, the number of heads in the attention layer.
|
43
|
+
max_position_embeddings: int. The maximum input
|
44
|
+
sequence length. Defaults to `512`.
|
45
|
+
bucket_size: int. The size of the relative position
|
46
|
+
buckets. Generally equal to `max_sequence_length // 2`.
|
47
|
+
Defaults to `256`.
|
48
|
+
dropout: float. The dropout value, shared by
|
49
|
+
the attention layer and feedforward network.
|
50
|
+
Defaults to `0.0`.
|
51
|
+
activation: string or `keras.activations`. the
|
52
|
+
activation function of feedforward network.
|
53
|
+
Defaults to `"relu"`.
|
54
|
+
layer_norm_epsilon: float. The epsilon value in layer
|
55
|
+
normalization components. Defaults to `1e-5`.
|
56
|
+
kernel_initializer: string or `keras.initializers` initializer.
|
57
|
+
The kernel initializer for the dense and disentangled
|
58
|
+
self-attention layers. Defaults to `"glorot_uniform"`.
|
59
|
+
bias_initializer: string or `keras.initializers` initializer.
|
60
|
+
The bias initializer for the dense and disentangled
|
61
|
+
self-attention layers. Defaults to `"zeros"`.
|
62
|
+
"""
|
63
|
+
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
intermediate_dim,
|
67
|
+
num_heads,
|
68
|
+
max_position_embeddings=512,
|
69
|
+
bucket_size=256,
|
70
|
+
dropout=0,
|
71
|
+
activation="relu",
|
72
|
+
layer_norm_epsilon=1e-05,
|
73
|
+
kernel_initializer="glorot_uniform",
|
74
|
+
bias_initializer="zeros",
|
75
|
+
**kwargs
|
76
|
+
):
|
77
|
+
super().__init__(**kwargs)
|
78
|
+
self.intermediate_dim = intermediate_dim
|
79
|
+
self.num_heads = num_heads
|
80
|
+
self.max_position_embeddings = max_position_embeddings
|
81
|
+
self.bucket_size = bucket_size
|
82
|
+
self.dropout = dropout
|
83
|
+
self.activation = keras.activations.get(activation)
|
84
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
85
|
+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
86
|
+
self.bias_initializer = keras.initializers.get(bias_initializer)
|
87
|
+
self._built = False
|
88
|
+
self.supports_masking = True
|
89
|
+
|
90
|
+
def build(self, inputs_shape):
|
91
|
+
# Infer the dimension of our hidden feature size from the build shape.
|
92
|
+
hidden_dim = inputs_shape[-1]
|
93
|
+
|
94
|
+
# Self attention layers.
|
95
|
+
self._self_attention_layer = DisentangledSelfAttention(
|
96
|
+
num_heads=self.num_heads,
|
97
|
+
hidden_dim=hidden_dim,
|
98
|
+
max_position_embeddings=self.max_position_embeddings,
|
99
|
+
bucket_size=self.bucket_size,
|
100
|
+
dropout=self.dropout,
|
101
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
102
|
+
bias_initializer=clone_initializer(self.bias_initializer),
|
103
|
+
dtype=self.dtype_policy,
|
104
|
+
name="self_attention_layer",
|
105
|
+
)
|
106
|
+
self._self_attention_layer.build(inputs_shape)
|
107
|
+
self._self_attention_layer_norm = keras.layers.LayerNormalization(
|
108
|
+
epsilon=self.layer_norm_epsilon,
|
109
|
+
dtype=self.dtype_policy,
|
110
|
+
name="self_attention_layer_norm",
|
111
|
+
)
|
112
|
+
self._self_attention_layer_norm.build(inputs_shape)
|
113
|
+
self._self_attention_dropout = keras.layers.Dropout(
|
114
|
+
rate=self.dropout,
|
115
|
+
dtype=self.dtype_policy,
|
116
|
+
name="self_attention_dropout",
|
117
|
+
)
|
118
|
+
|
119
|
+
# Feedforward layers.
|
120
|
+
self._feedforward_layer_norm = keras.layers.LayerNormalization(
|
121
|
+
epsilon=self.layer_norm_epsilon,
|
122
|
+
dtype=self.dtype_policy,
|
123
|
+
name="feedforward_layer_norm",
|
124
|
+
)
|
125
|
+
self._feedforward_layer_norm.build(inputs_shape)
|
126
|
+
self._feedforward_intermediate_dense = keras.layers.Dense(
|
127
|
+
self.intermediate_dim,
|
128
|
+
activation=self.activation,
|
129
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
130
|
+
bias_initializer=clone_initializer(self.bias_initializer),
|
131
|
+
dtype=self.dtype_policy,
|
132
|
+
name="feedforward_intermediate_dense",
|
133
|
+
)
|
134
|
+
self._feedforward_intermediate_dense.build(inputs_shape)
|
135
|
+
self._feedforward_output_dense = keras.layers.Dense(
|
136
|
+
hidden_dim,
|
137
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
138
|
+
bias_initializer=clone_initializer(self.bias_initializer),
|
139
|
+
dtype=self.dtype_policy,
|
140
|
+
name="feedforward_output_dense",
|
141
|
+
)
|
142
|
+
intermediate_shape = list(inputs_shape)
|
143
|
+
intermediate_shape[-1] = self.intermediate_dim
|
144
|
+
self._feedforward_output_dense.build(tuple(intermediate_shape))
|
145
|
+
self._feedforward_dropout = keras.layers.Dropout(
|
146
|
+
rate=self.dropout,
|
147
|
+
dtype=self.dtype_policy,
|
148
|
+
name="feedforward_dropout",
|
149
|
+
)
|
150
|
+
self.built = True
|
151
|
+
|
152
|
+
def call(
|
153
|
+
self,
|
154
|
+
inputs,
|
155
|
+
rel_embeddings,
|
156
|
+
padding_mask=None,
|
157
|
+
attention_mask=None,
|
158
|
+
):
|
159
|
+
"""Forward pass of `DisentangledAttentionEncoder`.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
inputs: a Tensor. The input data to `DisentangledAttentionEncoder`, should be
|
163
|
+
of shape [batch_size, sequence_length, hidden_dim].
|
164
|
+
rel_embeddings: a Tensor. The relative position embedding matrix,
|
165
|
+
should be of shape `[batch_size, 2 * bucket_size, hidden_dim]`.
|
166
|
+
padding_mask: a boolean Tensor. It indicates if the token should be
|
167
|
+
masked because the token is introduced due to padding.
|
168
|
+
`padding_mask` should have shape [batch_size, sequence_length].
|
169
|
+
False means the certain token is masked out.
|
170
|
+
attention_mask: a boolean Tensor. Customized mask used to mask out
|
171
|
+
certain tokens. `attention_mask` should have shape
|
172
|
+
[batch_size, sequence_length, sequence_length].
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
A Tensor of the same shape as the `inputs`.
|
176
|
+
"""
|
177
|
+
x = inputs
|
178
|
+
|
179
|
+
# Compute self attention mask.
|
180
|
+
self_attention_mask = merge_padding_and_attention_mask(
|
181
|
+
inputs, padding_mask, attention_mask
|
182
|
+
)
|
183
|
+
|
184
|
+
# Self attention block.
|
185
|
+
residual = x
|
186
|
+
x = self._self_attention_layer(
|
187
|
+
x,
|
188
|
+
rel_embeddings=rel_embeddings,
|
189
|
+
attention_mask=self_attention_mask,
|
190
|
+
)
|
191
|
+
x = self._self_attention_dropout(x)
|
192
|
+
x = x + residual
|
193
|
+
x = self._self_attention_layer_norm(x)
|
194
|
+
|
195
|
+
# Feedforward block.
|
196
|
+
residual = x
|
197
|
+
x = self._feedforward_intermediate_dense(x)
|
198
|
+
x = self._feedforward_output_dense(x)
|
199
|
+
x = self._feedforward_dropout(x)
|
200
|
+
x = x + residual
|
201
|
+
x = self._feedforward_layer_norm(x)
|
202
|
+
|
203
|
+
return x
|
204
|
+
|
205
|
+
def get_config(self):
|
206
|
+
config = super().get_config()
|
207
|
+
config.update(
|
208
|
+
{
|
209
|
+
"intermediate_dim": self.intermediate_dim,
|
210
|
+
"num_heads": self.num_heads,
|
211
|
+
"max_position_embeddings": self.max_position_embeddings,
|
212
|
+
"bucket_size": self.bucket_size,
|
213
|
+
"dropout": self.dropout,
|
214
|
+
"activation": keras.activations.serialize(self.activation),
|
215
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
216
|
+
"kernel_initializer": keras.initializers.serialize(
|
217
|
+
self.kernel_initializer
|
218
|
+
),
|
219
|
+
"bias_initializer": keras.initializers.serialize(
|
220
|
+
self.bias_initializer
|
221
|
+
),
|
222
|
+
}
|
223
|
+
)
|
224
|
+
return config
|
225
|
+
|
226
|
+
def compute_output_shape(self, inputs_shape):
|
227
|
+
return inputs_shape
|
@@ -0,0 +1,412 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import math
|
16
|
+
|
17
|
+
import keras
|
18
|
+
from keras import ops
|
19
|
+
|
20
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
21
|
+
|
22
|
+
|
23
|
+
class DisentangledSelfAttention(keras.layers.Layer):
|
24
|
+
"""DisentangledSelfAttention layer.
|
25
|
+
|
26
|
+
This is an implementation of disentangled self-attention as described in the
|
27
|
+
paper ["DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing"](https://arxiv.org/abs/2111.09543).
|
28
|
+
Effectively, this layer implements Multi-Head Self Attention with relative
|
29
|
+
attention, i.e., to get the final attention score, we compute the
|
30
|
+
content-to-position and position-to-content attention scores, and add these
|
31
|
+
scores to the vanilla multi-head self-attention scores.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
num_heads: int. Number of attention heads.
|
35
|
+
hidden_dim: int. Hidden dimension of the input, i.e., `hidden_states`.
|
36
|
+
max_position_embeddings: int. The maximum input
|
37
|
+
sequence length. Defaults to `512`.
|
38
|
+
bucket_size: int. The size of the relative position
|
39
|
+
buckets. Generally equal to `max_sequence_length // 2`.
|
40
|
+
Defaults to `256`.
|
41
|
+
dropout: float. Dropout probability. Defaults to `0.1`.
|
42
|
+
kernel_initializer: string or `keras.initializers` initializer.
|
43
|
+
The kernel initializer for the dense layers.
|
44
|
+
Defaults to `"glorot_uniform"`.
|
45
|
+
bias_initializer: string or `keras.initializers` initializer.
|
46
|
+
The bias initializer for the dense layers.
|
47
|
+
Defaults to `"zeros"`.
|
48
|
+
"""
|
49
|
+
|
50
|
+
def __init__(
|
51
|
+
self,
|
52
|
+
num_heads,
|
53
|
+
hidden_dim,
|
54
|
+
max_position_embeddings=512,
|
55
|
+
bucket_size=256,
|
56
|
+
dropout=0.1,
|
57
|
+
kernel_initializer="glorot_uniform",
|
58
|
+
bias_initializer="zeros",
|
59
|
+
**kwargs,
|
60
|
+
):
|
61
|
+
super().__init__(**kwargs)
|
62
|
+
|
63
|
+
# Passed args.
|
64
|
+
self.num_heads = num_heads
|
65
|
+
self.hidden_dim = hidden_dim
|
66
|
+
self.max_position_embeddings = max_position_embeddings
|
67
|
+
self.bucket_size = bucket_size
|
68
|
+
self.dropout = dropout
|
69
|
+
|
70
|
+
# Initializers.
|
71
|
+
self._kernel_initializer = keras.initializers.get(kernel_initializer)
|
72
|
+
self._bias_initializer = keras.initializers.get(bias_initializer)
|
73
|
+
|
74
|
+
# Derived args.
|
75
|
+
self.attn_head_size = hidden_dim // num_heads
|
76
|
+
|
77
|
+
# We have three types of attention - MHA, p2c and c2p.
|
78
|
+
num_type_attn = 3
|
79
|
+
self.scale_factor = 1.0 / math.sqrt(
|
80
|
+
float(num_type_attn * self.attn_head_size)
|
81
|
+
)
|
82
|
+
|
83
|
+
def build(self, inputs_shape, rel_embeddings_shape=None):
|
84
|
+
# Q, K, V linear layers.
|
85
|
+
self._query_dense = keras.layers.EinsumDense(
|
86
|
+
equation="abc,cde->abde",
|
87
|
+
output_shape=(None, self.num_heads, self.attn_head_size),
|
88
|
+
bias_axes="de",
|
89
|
+
**self._get_common_kwargs_for_sublayer(use_bias=True),
|
90
|
+
dtype=self.dtype_policy,
|
91
|
+
name="query",
|
92
|
+
)
|
93
|
+
self._query_dense.build(inputs_shape)
|
94
|
+
self._key_dense = keras.layers.EinsumDense(
|
95
|
+
equation="abc,cde->abde",
|
96
|
+
output_shape=(None, self.num_heads, self.attn_head_size),
|
97
|
+
bias_axes="de",
|
98
|
+
**self._get_common_kwargs_for_sublayer(use_bias=True),
|
99
|
+
dtype=self.dtype_policy,
|
100
|
+
name="key",
|
101
|
+
)
|
102
|
+
self._key_dense.build(inputs_shape)
|
103
|
+
self._value_dense = keras.layers.EinsumDense(
|
104
|
+
equation="abc,cde->abde",
|
105
|
+
output_shape=(None, self.num_heads, self.attn_head_size),
|
106
|
+
bias_axes="de",
|
107
|
+
**self._get_common_kwargs_for_sublayer(use_bias=True),
|
108
|
+
dtype=self.dtype_policy,
|
109
|
+
name="value",
|
110
|
+
)
|
111
|
+
self._value_dense.build(inputs_shape)
|
112
|
+
|
113
|
+
# Relative attention.
|
114
|
+
self._position_dropout_layer = keras.layers.Dropout(
|
115
|
+
self.dropout,
|
116
|
+
dtype=self.dtype_policy,
|
117
|
+
)
|
118
|
+
|
119
|
+
self._attn_dropout_layer = keras.layers.Dropout(
|
120
|
+
self.dropout,
|
121
|
+
dtype=self.dtype_policy,
|
122
|
+
name="attention_dropout",
|
123
|
+
)
|
124
|
+
self._softmax = keras.layers.Softmax(
|
125
|
+
axis=-1,
|
126
|
+
dtype="float32",
|
127
|
+
name="attention_softmax",
|
128
|
+
)
|
129
|
+
|
130
|
+
# Output.
|
131
|
+
self._output_dense = keras.layers.EinsumDense(
|
132
|
+
equation="abc,cd->abd",
|
133
|
+
output_shape=(None, self.hidden_dim),
|
134
|
+
bias_axes="d",
|
135
|
+
**self._get_common_kwargs_for_sublayer(use_bias=True),
|
136
|
+
dtype=self.dtype_policy,
|
137
|
+
name="attention_output",
|
138
|
+
)
|
139
|
+
self._output_dense.build(inputs_shape)
|
140
|
+
self.built = True
|
141
|
+
|
142
|
+
def _get_common_kwargs_for_sublayer(self, use_bias=True):
|
143
|
+
common_kwargs = {}
|
144
|
+
|
145
|
+
kernel_initializer = clone_initializer(self._kernel_initializer)
|
146
|
+
bias_initializer = clone_initializer(self._bias_initializer)
|
147
|
+
|
148
|
+
common_kwargs["kernel_initializer"] = kernel_initializer
|
149
|
+
if use_bias:
|
150
|
+
common_kwargs["bias_initializer"] = bias_initializer
|
151
|
+
|
152
|
+
return common_kwargs
|
153
|
+
|
154
|
+
def _masked_softmax(self, attention_scores, attention_mask=None):
|
155
|
+
"""Normalizes the attention scores to probabilities using softmax.
|
156
|
+
|
157
|
+
This implementation is the similar to the one present in
|
158
|
+
`keras.layers.MultiHeadAttention`.
|
159
|
+
"""
|
160
|
+
|
161
|
+
if attention_mask is not None:
|
162
|
+
mask_expansion_axis = -3
|
163
|
+
for _ in range(
|
164
|
+
len(attention_scores.shape) - len(attention_mask.shape)
|
165
|
+
):
|
166
|
+
attention_mask = ops.expand_dims(
|
167
|
+
attention_mask, axis=mask_expansion_axis
|
168
|
+
)
|
169
|
+
return self._softmax(attention_scores, attention_mask)
|
170
|
+
|
171
|
+
def _compute_attention(
|
172
|
+
self,
|
173
|
+
query,
|
174
|
+
key,
|
175
|
+
value,
|
176
|
+
rel_embeddings,
|
177
|
+
attention_mask=None,
|
178
|
+
training=None,
|
179
|
+
):
|
180
|
+
"""Computes the attention score and returns the attended outputs.
|
181
|
+
|
182
|
+
This function computes vanilla MHA score, and relative attention scores
|
183
|
+
(p2c and c2p). It then sums them up to get the final attention score,
|
184
|
+
which is used to compute the attended outputs.
|
185
|
+
"""
|
186
|
+
|
187
|
+
attention_scores = ops.einsum(
|
188
|
+
"aecd,abcd->acbe",
|
189
|
+
key,
|
190
|
+
query,
|
191
|
+
)
|
192
|
+
attention_scores = ops.multiply(attention_scores, self.scale_factor)
|
193
|
+
|
194
|
+
rel_embeddings = self._position_dropout_layer(
|
195
|
+
rel_embeddings,
|
196
|
+
training=training,
|
197
|
+
)
|
198
|
+
|
199
|
+
rel_attn_scores = self._compute_disentangled_attention(
|
200
|
+
query=query,
|
201
|
+
key=key,
|
202
|
+
rel_embeddings=rel_embeddings,
|
203
|
+
)
|
204
|
+
|
205
|
+
if rel_attn_scores is not None:
|
206
|
+
attention_scores += rel_attn_scores
|
207
|
+
|
208
|
+
attention_scores = self._masked_softmax(
|
209
|
+
attention_scores, attention_mask
|
210
|
+
)
|
211
|
+
attention_scores = self._attn_dropout_layer(
|
212
|
+
attention_scores, training=training
|
213
|
+
)
|
214
|
+
attention_output = ops.einsum(
|
215
|
+
"acbe,aecd->abcd", attention_scores, value
|
216
|
+
)
|
217
|
+
|
218
|
+
return attention_output, attention_scores
|
219
|
+
|
220
|
+
def _make_log_bucket_position(self, rel_pos):
|
221
|
+
dtype = rel_pos.dtype
|
222
|
+
sign = ops.sign(rel_pos)
|
223
|
+
mid = self.bucket_size // 2
|
224
|
+
mid = ops.cast(mid, dtype=dtype)
|
225
|
+
|
226
|
+
# If `rel_pos[i][j]` is out of bounds, assign value `mid`.
|
227
|
+
abs_pos = ops.where(
|
228
|
+
condition=(rel_pos < mid) & (rel_pos > -mid),
|
229
|
+
x1=mid - 1,
|
230
|
+
x2=ops.abs(rel_pos),
|
231
|
+
)
|
232
|
+
|
233
|
+
def _get_log_pos(abs_pos, mid):
|
234
|
+
numerator = ops.log(abs_pos / mid)
|
235
|
+
numerator = numerator * ops.cast(mid - 1, dtype=numerator.dtype)
|
236
|
+
denominator = ops.log((self.max_position_embeddings - 1) / mid)
|
237
|
+
val = ops.ceil(numerator / denominator)
|
238
|
+
val = ops.cast(val, dtype=mid.dtype)
|
239
|
+
val = val + mid
|
240
|
+
return val
|
241
|
+
|
242
|
+
log_pos = _get_log_pos(abs_pos, mid)
|
243
|
+
|
244
|
+
bucket_pos = ops.where(
|
245
|
+
condition=abs_pos <= mid,
|
246
|
+
x1=rel_pos,
|
247
|
+
x2=log_pos * sign,
|
248
|
+
)
|
249
|
+
bucket_pos = ops.cast(bucket_pos, dtype="int")
|
250
|
+
|
251
|
+
return bucket_pos
|
252
|
+
|
253
|
+
def _get_rel_pos(self, num_positions):
|
254
|
+
ids = ops.arange(num_positions)
|
255
|
+
ids = ops.cast(ids, dtype="int")
|
256
|
+
query_ids = ops.expand_dims(ids, axis=-1)
|
257
|
+
key_ids = ops.expand_dims(ids, axis=0)
|
258
|
+
key_ids = ops.repeat(key_ids, repeats=num_positions, axis=0)
|
259
|
+
|
260
|
+
rel_pos = query_ids - key_ids
|
261
|
+
rel_pos = self._make_log_bucket_position(rel_pos)
|
262
|
+
|
263
|
+
rel_pos = ops.expand_dims(ops.expand_dims(rel_pos, axis=0), axis=0)
|
264
|
+
return rel_pos
|
265
|
+
|
266
|
+
def _compute_disentangled_attention(
|
267
|
+
self,
|
268
|
+
query,
|
269
|
+
key,
|
270
|
+
rel_embeddings,
|
271
|
+
):
|
272
|
+
"""Computes relative attention scores (p2c and c2p)."""
|
273
|
+
|
274
|
+
batch_size = ops.shape(query)[0]
|
275
|
+
num_positions = ops.shape(query)[1]
|
276
|
+
|
277
|
+
rel_pos = self._get_rel_pos(num_positions)
|
278
|
+
|
279
|
+
rel_attn_span = self.bucket_size
|
280
|
+
score = 0
|
281
|
+
|
282
|
+
pos_query = self._query_dense(rel_embeddings)
|
283
|
+
pos_key = self._key_dense(rel_embeddings)
|
284
|
+
|
285
|
+
# c2p
|
286
|
+
c2p_attn_scores = ops.einsum(
|
287
|
+
"aecd,abcd->acbe",
|
288
|
+
pos_key,
|
289
|
+
query,
|
290
|
+
)
|
291
|
+
c2p_pos = ops.clip(rel_pos + rel_attn_span, 0, rel_attn_span * 2 - 1)
|
292
|
+
c2p_pos = ops.broadcast_to(
|
293
|
+
c2p_pos,
|
294
|
+
shape=(
|
295
|
+
batch_size,
|
296
|
+
self.num_heads,
|
297
|
+
num_positions,
|
298
|
+
num_positions,
|
299
|
+
),
|
300
|
+
)
|
301
|
+
|
302
|
+
if keras.config.backend() == "tensorflow":
|
303
|
+
# Work around dynamic shape bug on tensorflow backend.
|
304
|
+
import tensorflow as tf
|
305
|
+
|
306
|
+
c2p_attn_scores = tf.gather(
|
307
|
+
c2p_attn_scores,
|
308
|
+
indices=c2p_pos,
|
309
|
+
batch_dims=3,
|
310
|
+
)
|
311
|
+
else:
|
312
|
+
c2p_attn_scores = ops.take_along_axis(
|
313
|
+
c2p_attn_scores,
|
314
|
+
indices=c2p_pos,
|
315
|
+
axis=3,
|
316
|
+
)
|
317
|
+
c2p_attn_scores = ops.multiply(c2p_attn_scores, self.scale_factor)
|
318
|
+
score += c2p_attn_scores
|
319
|
+
|
320
|
+
# p2c
|
321
|
+
p2c_attn_scores = ops.einsum(
|
322
|
+
"aecd,abcd->acbe",
|
323
|
+
pos_query,
|
324
|
+
key,
|
325
|
+
)
|
326
|
+
p2c_pos = ops.clip(-rel_pos + rel_attn_span, 0, rel_attn_span * 2 - 1)
|
327
|
+
p2c_pos = ops.broadcast_to(
|
328
|
+
p2c_pos,
|
329
|
+
shape=(
|
330
|
+
batch_size,
|
331
|
+
self.num_heads,
|
332
|
+
num_positions,
|
333
|
+
num_positions,
|
334
|
+
),
|
335
|
+
)
|
336
|
+
if keras.config.backend() == "tensorflow":
|
337
|
+
# Work around dynamic shape bug on tensorflow backend.
|
338
|
+
import tensorflow as tf
|
339
|
+
|
340
|
+
p2c_attn_scores = tf.gather(
|
341
|
+
p2c_attn_scores,
|
342
|
+
indices=p2c_pos,
|
343
|
+
batch_dims=3,
|
344
|
+
)
|
345
|
+
else:
|
346
|
+
p2c_attn_scores = ops.take_along_axis(
|
347
|
+
p2c_attn_scores,
|
348
|
+
indices=p2c_pos,
|
349
|
+
axis=3,
|
350
|
+
)
|
351
|
+
p2c_attn_scores = ops.transpose(p2c_attn_scores, [0, 1, 3, 2])
|
352
|
+
p2c_attn_scores = ops.multiply(p2c_attn_scores, self.scale_factor)
|
353
|
+
score += p2c_attn_scores
|
354
|
+
|
355
|
+
return score
|
356
|
+
|
357
|
+
def call(
|
358
|
+
self,
|
359
|
+
inputs,
|
360
|
+
rel_embeddings,
|
361
|
+
attention_mask=None,
|
362
|
+
return_attention_scores=False,
|
363
|
+
training=None,
|
364
|
+
):
|
365
|
+
# `query`, `key`, `value` shape:
|
366
|
+
# `(batch_size, sequence_length, num_heads, attn_head_size)`.
|
367
|
+
query = self._query_dense(inputs)
|
368
|
+
key = self._key_dense(inputs)
|
369
|
+
value = self._value_dense(inputs)
|
370
|
+
|
371
|
+
attention_output, attention_scores = self._compute_attention(
|
372
|
+
query=query,
|
373
|
+
key=key,
|
374
|
+
value=value,
|
375
|
+
rel_embeddings=rel_embeddings,
|
376
|
+
attention_mask=attention_mask,
|
377
|
+
training=training,
|
378
|
+
)
|
379
|
+
|
380
|
+
# Reshape `attention_output` to `(batch_size, sequence_length, hidden_dim)`.
|
381
|
+
attention_output = ops.reshape(
|
382
|
+
attention_output,
|
383
|
+
[
|
384
|
+
ops.shape(attention_output)[0],
|
385
|
+
ops.shape(attention_output)[1],
|
386
|
+
self.hidden_dim,
|
387
|
+
],
|
388
|
+
)
|
389
|
+
attention_output = self._output_dense(attention_output)
|
390
|
+
|
391
|
+
if return_attention_scores:
|
392
|
+
return attention_output, attention_scores
|
393
|
+
return attention_output
|
394
|
+
|
395
|
+
def get_config(self):
|
396
|
+
config = super().get_config()
|
397
|
+
config.update(
|
398
|
+
{
|
399
|
+
"num_heads": self.num_heads,
|
400
|
+
"hidden_dim": self.hidden_dim,
|
401
|
+
"max_position_embeddings": self.max_position_embeddings,
|
402
|
+
"bucket_size": self.bucket_size,
|
403
|
+
"dropout": self.dropout,
|
404
|
+
"kernel_initializer": keras.initializers.serialize(
|
405
|
+
self._kernel_initializer
|
406
|
+
),
|
407
|
+
"bias_initializer": keras.initializers.serialize(
|
408
|
+
self._bias_initializer
|
409
|
+
),
|
410
|
+
}
|
411
|
+
)
|
412
|
+
return config
|