keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,496 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
from keras_hub.src.layers.modeling.cached_multi_head_attention import (
|
20
|
+
CachedMultiHeadAttention,
|
21
|
+
)
|
22
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
23
|
+
|
24
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import ( # isort:skip
|
25
|
+
compute_causal_mask,
|
26
|
+
merge_padding_and_attention_mask,
|
27
|
+
)
|
28
|
+
|
29
|
+
|
30
|
+
@keras_hub_export("keras_hub.layers.TransformerDecoder")
|
31
|
+
class TransformerDecoder(keras.layers.Layer):
|
32
|
+
"""Transformer decoder.
|
33
|
+
|
34
|
+
This class follows the architecture of the transformer decoder layer in the
|
35
|
+
paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users
|
36
|
+
can instantiate multiple instances of this class to stack up a decoder.
|
37
|
+
|
38
|
+
By default, this layer will apply a causal mask to the decoder attention
|
39
|
+
layer. You can also pass padding or attention masks directly to the layer
|
40
|
+
during call, e.g. with `decoder_padding_mask` or `decoder_attention_mask`.
|
41
|
+
|
42
|
+
This layer can be called with either one or two inputs. The number of inputs
|
43
|
+
must be consistent across all calls. The options are as follows:
|
44
|
+
`layer(decoder_sequence)`: no cross-attention will be built into the
|
45
|
+
decoder block. This is useful when building a "decoder-only"
|
46
|
+
transformer such as GPT-2.
|
47
|
+
`layer(decoder_sequence, encoder_sequence)`: cross-attention will be
|
48
|
+
built into the decoder block. This is useful when building an
|
49
|
+
"encoder-decoder" transformer, such as the original transformer
|
50
|
+
model described in Attention is All You Need.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
intermediate_dim: int, the hidden size of feedforward network.
|
54
|
+
num_heads: int, the number of heads in MultiHeadAttention.
|
55
|
+
dropout: float. the dropout value, shared by
|
56
|
+
MultiHeadAttention and feedforward network. Defaults to `0.`.
|
57
|
+
activation: string or `keras.activations`. the
|
58
|
+
activation function of feedforward network.
|
59
|
+
Defaults to `"relu"`.
|
60
|
+
layer_norm_epsilon: float. The eps value in layer
|
61
|
+
normalization components. Defaults to `1e-5`.
|
62
|
+
kernel_initializer: string or `keras.initializers` initializer.
|
63
|
+
The kernel initializer for the dense and multiheaded
|
64
|
+
attention layers. Defaults to `"glorot_uniform"`.
|
65
|
+
bias_initializer: string or `keras.initializers` initializer.
|
66
|
+
The bias initializer for the dense and multiheaded
|
67
|
+
attention layers. Defaults to `"zeros"`.
|
68
|
+
normalize_first: bool. If True, the inputs to the
|
69
|
+
attention layer(s) and the intermediate dense layer are normalized
|
70
|
+
(similar to GPT-2). If set to False, outputs of attention layer and
|
71
|
+
intermediate dense layer are normalized (similar to BERT).
|
72
|
+
Defaults to `False`.
|
73
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
74
|
+
including `name`, `trainable`, `dtype` etc.
|
75
|
+
|
76
|
+
Example:
|
77
|
+
```python
|
78
|
+
# Create a single transformer decoder layer.
|
79
|
+
decoder = keras_hub.layers.TransformerDecoder(
|
80
|
+
intermediate_dim=64, num_heads=8)
|
81
|
+
|
82
|
+
# Create a simple model containing the decoder.
|
83
|
+
decoder_input = keras.Input(shape=(10, 64))
|
84
|
+
encoder_input = keras.Input(shape=(10, 64))
|
85
|
+
output = decoder(decoder_input, encoder_input)
|
86
|
+
model = keras.Model(
|
87
|
+
inputs=(decoder_input, encoder_input),
|
88
|
+
outputs=output,
|
89
|
+
)
|
90
|
+
|
91
|
+
# Call decoder on the inputs.
|
92
|
+
decoder_input_data = np.random.uniform(size=(2, 10, 64))
|
93
|
+
encoder_input_data = np.random.uniform(size=(2, 10, 64))
|
94
|
+
decoder_output = model((decoder_input_data, encoder_input_data))
|
95
|
+
```
|
96
|
+
|
97
|
+
References:
|
98
|
+
- [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
|
99
|
+
|
100
|
+
"""
|
101
|
+
|
102
|
+
def __init__(
|
103
|
+
self,
|
104
|
+
intermediate_dim,
|
105
|
+
num_heads,
|
106
|
+
dropout=0,
|
107
|
+
activation="relu",
|
108
|
+
layer_norm_epsilon=1e-05,
|
109
|
+
kernel_initializer="glorot_uniform",
|
110
|
+
bias_initializer="zeros",
|
111
|
+
normalize_first=False,
|
112
|
+
**kwargs,
|
113
|
+
):
|
114
|
+
# Work around for model saving, we need to ensure our model is built
|
115
|
+
# immediately after restoring from config.
|
116
|
+
decoder_sequence_shape = kwargs.pop("decoder_sequence_shape", None)
|
117
|
+
encoder_sequence_shape = kwargs.pop("encoder_sequence_shape", None)
|
118
|
+
|
119
|
+
super().__init__(**kwargs)
|
120
|
+
self.intermediate_dim = intermediate_dim
|
121
|
+
self.num_heads = num_heads
|
122
|
+
self.dropout = dropout
|
123
|
+
self.activation = keras.activations.get(activation)
|
124
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
125
|
+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
126
|
+
self.bias_initializer = keras.initializers.get(bias_initializer)
|
127
|
+
self.normalize_first = normalize_first
|
128
|
+
self.supports_masking = True
|
129
|
+
self._decoder_sequence_shape = None
|
130
|
+
self._encoder_sequence_shape = None
|
131
|
+
|
132
|
+
if decoder_sequence_shape:
|
133
|
+
self.build(decoder_sequence_shape, encoder_sequence_shape)
|
134
|
+
|
135
|
+
def build(
|
136
|
+
self,
|
137
|
+
decoder_sequence_shape,
|
138
|
+
encoder_sequence_shape=None,
|
139
|
+
):
|
140
|
+
self._decoder_sequence_shape = decoder_sequence_shape
|
141
|
+
self._encoder_sequence_shape = encoder_sequence_shape
|
142
|
+
# Infer the dimension of our hidden feature size from the build shape.
|
143
|
+
hidden_dim = decoder_sequence_shape[-1]
|
144
|
+
# Attention head size is `hidden_dim` over the number of heads.
|
145
|
+
head_dim = int(hidden_dim // self.num_heads)
|
146
|
+
if head_dim == 0:
|
147
|
+
raise ValueError(
|
148
|
+
"Attention `head_dim` computed cannot be zero. "
|
149
|
+
f"The `hidden_dim` value of {hidden_dim} has to be equal to "
|
150
|
+
f"or greater than `num_heads` value of {self.num_heads}."
|
151
|
+
)
|
152
|
+
|
153
|
+
# Self attention layers.
|
154
|
+
self._self_attention_layer = CachedMultiHeadAttention(
|
155
|
+
num_heads=self.num_heads,
|
156
|
+
key_dim=head_dim,
|
157
|
+
dropout=self.dropout,
|
158
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
159
|
+
bias_initializer=clone_initializer(self.bias_initializer),
|
160
|
+
dtype=self.dtype_policy,
|
161
|
+
name="self_attention",
|
162
|
+
)
|
163
|
+
if hasattr(self._self_attention_layer, "_build_from_signature"):
|
164
|
+
self._self_attention_layer._build_from_signature(
|
165
|
+
query=decoder_sequence_shape,
|
166
|
+
value=decoder_sequence_shape,
|
167
|
+
)
|
168
|
+
else:
|
169
|
+
self._self_attention_layer.build(
|
170
|
+
query_shape=decoder_sequence_shape,
|
171
|
+
value_shape=decoder_sequence_shape,
|
172
|
+
)
|
173
|
+
self._self_attention_layer_norm = keras.layers.LayerNormalization(
|
174
|
+
epsilon=self.layer_norm_epsilon,
|
175
|
+
dtype=self.dtype_policy,
|
176
|
+
name="self_attention_layer_norm",
|
177
|
+
)
|
178
|
+
self._self_attention_layer_norm.build(decoder_sequence_shape)
|
179
|
+
self._self_attention_dropout = keras.layers.Dropout(
|
180
|
+
rate=self.dropout,
|
181
|
+
dtype=self.dtype_policy,
|
182
|
+
name="self_attention_dropout",
|
183
|
+
)
|
184
|
+
|
185
|
+
# Cross attention layers are optional.
|
186
|
+
self._cross_attention_layer = None
|
187
|
+
if encoder_sequence_shape:
|
188
|
+
self._cross_attention_layer = CachedMultiHeadAttention(
|
189
|
+
num_heads=self.num_heads,
|
190
|
+
key_dim=head_dim,
|
191
|
+
value_dim=head_dim,
|
192
|
+
dropout=self.dropout,
|
193
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
194
|
+
bias_initializer=clone_initializer(self.bias_initializer),
|
195
|
+
dtype=self.dtype_policy,
|
196
|
+
name="cross_attention",
|
197
|
+
)
|
198
|
+
if hasattr(self._cross_attention_layer, "_build_from_signature"):
|
199
|
+
self._cross_attention_layer._build_from_signature(
|
200
|
+
query=decoder_sequence_shape,
|
201
|
+
value=encoder_sequence_shape,
|
202
|
+
)
|
203
|
+
else:
|
204
|
+
self._cross_attention_layer.build(
|
205
|
+
query_shape=decoder_sequence_shape,
|
206
|
+
value_shape=encoder_sequence_shape,
|
207
|
+
)
|
208
|
+
self._cross_attention_layer_norm = keras.layers.LayerNormalization(
|
209
|
+
epsilon=self.layer_norm_epsilon,
|
210
|
+
dtype=self.dtype_policy,
|
211
|
+
name="cross_attention_layer_norm",
|
212
|
+
)
|
213
|
+
self._cross_attention_layer_norm.build(decoder_sequence_shape)
|
214
|
+
self._cross_attention_dropout = keras.layers.Dropout(
|
215
|
+
rate=self.dropout,
|
216
|
+
dtype=self.dtype_policy,
|
217
|
+
name="cross_attention_dropout",
|
218
|
+
)
|
219
|
+
|
220
|
+
# Feedforward layers.
|
221
|
+
self._feedforward_intermediate_dense = keras.layers.Dense(
|
222
|
+
self.intermediate_dim,
|
223
|
+
activation=self.activation,
|
224
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
225
|
+
bias_initializer=clone_initializer(self.bias_initializer),
|
226
|
+
dtype=self.dtype_policy,
|
227
|
+
name="feedforward_intermediate_dense",
|
228
|
+
)
|
229
|
+
self._feedforward_intermediate_dense.build(decoder_sequence_shape)
|
230
|
+
self._feedforward_output_dense = keras.layers.Dense(
|
231
|
+
hidden_dim,
|
232
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
233
|
+
bias_initializer=clone_initializer(self.bias_initializer),
|
234
|
+
dtype=self.dtype_policy,
|
235
|
+
name="feedforward_output_dense",
|
236
|
+
)
|
237
|
+
intermediate_shape = list(decoder_sequence_shape)
|
238
|
+
intermediate_shape[-1] = self.intermediate_dim
|
239
|
+
self._feedforward_output_dense.build(tuple(intermediate_shape))
|
240
|
+
self._feedforward_layer_norm = keras.layers.LayerNormalization(
|
241
|
+
epsilon=self.layer_norm_epsilon,
|
242
|
+
dtype=self.dtype_policy,
|
243
|
+
name="feedforward_layer_norm",
|
244
|
+
)
|
245
|
+
self._feedforward_layer_norm.build(decoder_sequence_shape)
|
246
|
+
self._feedforward_dropout = keras.layers.Dropout(
|
247
|
+
rate=self.dropout,
|
248
|
+
dtype=self.dtype_policy,
|
249
|
+
name="feedforward_dropout",
|
250
|
+
)
|
251
|
+
# Create layers based on input shape.
|
252
|
+
self.built = True
|
253
|
+
|
254
|
+
def call(
|
255
|
+
self,
|
256
|
+
decoder_sequence,
|
257
|
+
encoder_sequence=None,
|
258
|
+
decoder_padding_mask=None,
|
259
|
+
decoder_attention_mask=None,
|
260
|
+
encoder_padding_mask=None,
|
261
|
+
encoder_attention_mask=None,
|
262
|
+
self_attention_cache=None,
|
263
|
+
self_attention_cache_update_index=None,
|
264
|
+
cross_attention_cache=None,
|
265
|
+
cross_attention_cache_update_index=None,
|
266
|
+
use_causal_mask=True,
|
267
|
+
training=None,
|
268
|
+
):
|
269
|
+
"""Forward pass of the TransformerDecoder.
|
270
|
+
|
271
|
+
Args:
|
272
|
+
decoder_sequence: a Tensor. The decoder input sequence.
|
273
|
+
encoder_sequence: a Tensor. The encoder input sequence. For decoder
|
274
|
+
only models (like GPT2), this should be left `None`. Once the
|
275
|
+
model is called once without an encoder_sequence, you cannot
|
276
|
+
call it again with encoder_sequence.
|
277
|
+
decoder_padding_mask: a boolean Tensor, the padding mask of decoder
|
278
|
+
sequence, must be of shape
|
279
|
+
`[batch_size, decoder_sequence_length]`.
|
280
|
+
decoder_attention_mask: a boolean Tensor. Customized decoder
|
281
|
+
sequence mask, must be of shape
|
282
|
+
`[batch_size, decoder_sequence_length, decoder_sequence_length]`.
|
283
|
+
encoder_padding_mask: a boolean Tensor, the padding mask of encoder
|
284
|
+
sequence, must be of shape
|
285
|
+
`[batch_size, encoder_sequence_length]`.
|
286
|
+
encoder_attention_mask: a boolean Tensor. Customized encoder
|
287
|
+
sequence mask, must be of shape
|
288
|
+
`[batch_size, encoder_sequence_length, encoder_sequence_length]`.
|
289
|
+
self_attention_cache: a dense float Tensor. The cache of key/values
|
290
|
+
pairs in the self-attention layer. Has shape
|
291
|
+
`[batch_size, 2, max_seq_len, num_heads, key_dims]`.
|
292
|
+
self_attention_cache_update_index: an int or int Tensor, the index
|
293
|
+
at which to update the `self_attention_cache`. Usually, this is
|
294
|
+
the index of the current token being processed during decoding.
|
295
|
+
cross_attention_cache: a dense float Tensor. The cache of
|
296
|
+
key/value pairs in the cross-attention layer. Has shape
|
297
|
+
`[batch_size, 2, S, num_heads, key_dims]`.
|
298
|
+
cross_attention_cache_update_index: an int or int Tensor, the index
|
299
|
+
at which to update the `cross_attention_cache`. Usually, this is
|
300
|
+
either `0` (compute the entire `cross_attention_cache`), or
|
301
|
+
`None` (reuse a previously computed `cross_attention_cache`).
|
302
|
+
use_causal_mask: bool, defaults to `True`. If true, a causal mask
|
303
|
+
(masking out future input) is applied `on the decoder sequence.
|
304
|
+
training: a boolean indicating whether the layer should behave in
|
305
|
+
training mode or in inference mode.
|
306
|
+
|
307
|
+
Returns:
|
308
|
+
One of three things, depending on call arguments:
|
309
|
+
- `outputs`, if `self_attention_cache` is `None.
|
310
|
+
- `(outputs, self_attention_cache)`, if `self_attention_cache` is
|
311
|
+
set and the layer has no cross-attention.
|
312
|
+
- `(outputs, self_attention_cache, cross_attention_cache)`, if
|
313
|
+
`self_attention_cache` and `cross_attention_cache` are set and
|
314
|
+
the layer has cross-attention.
|
315
|
+
"""
|
316
|
+
|
317
|
+
has_encoder_sequence = encoder_sequence is not None
|
318
|
+
|
319
|
+
has_cross_attention = self._cross_attention_layer is not None
|
320
|
+
if not has_cross_attention and has_encoder_sequence:
|
321
|
+
raise ValueError(
|
322
|
+
"The number of call arguments to "
|
323
|
+
"`keras_hub.layers.TransformerDecoder` should not change. "
|
324
|
+
"Use `layer(decoder_sequence, encoder_sequence)` to "
|
325
|
+
"build a layer with cross attention, or "
|
326
|
+
"`layer(decoder_sequence)` to build a layer without. "
|
327
|
+
"This layer has been built without cross attention, but "
|
328
|
+
"you are trying to call it with encoder_sequence."
|
329
|
+
)
|
330
|
+
elif has_cross_attention and not has_encoder_sequence:
|
331
|
+
raise ValueError(
|
332
|
+
"The number of call arguments to "
|
333
|
+
"`keras_hub.layers.TransformerDecoder` should not change. "
|
334
|
+
"Use `layer(decoder_sequence, encoder_sequence)` to "
|
335
|
+
"build a layer with cross attention, or "
|
336
|
+
"`layer(decoder_sequence)` to build a layer without. "
|
337
|
+
"This layer has been built with cross attention, but "
|
338
|
+
"you did not provide encoder_sequence."
|
339
|
+
)
|
340
|
+
|
341
|
+
has_self_attention_cache = self_attention_cache is not None
|
342
|
+
has_cross_attention_cache = cross_attention_cache is not None
|
343
|
+
if has_cross_attention and (
|
344
|
+
has_self_attention_cache != has_cross_attention_cache
|
345
|
+
):
|
346
|
+
raise ValueError(
|
347
|
+
"When calling `keras_hub.layers.TransformerDecoder` with "
|
348
|
+
"cross-attention (with both `encoder_sequence` and "
|
349
|
+
"`decoder_sequence`), `self_attention_cache` and "
|
350
|
+
"`cross_attention_cache` should both be set or both be `None`. "
|
351
|
+
"One cannot be `None` while the other is not. Received: "
|
352
|
+
f"self_attention_cache={self_attention_cache}, "
|
353
|
+
f"cross_attention_cache={cross_attention_cache}."
|
354
|
+
)
|
355
|
+
|
356
|
+
self_attention_mask = self._compute_self_attention_mask(
|
357
|
+
decoder_sequence=decoder_sequence,
|
358
|
+
decoder_padding_mask=decoder_padding_mask,
|
359
|
+
decoder_attention_mask=decoder_attention_mask,
|
360
|
+
use_causal_mask=use_causal_mask,
|
361
|
+
self_attention_cache=self_attention_cache,
|
362
|
+
self_attention_cache_update_index=self_attention_cache_update_index,
|
363
|
+
)
|
364
|
+
|
365
|
+
x = decoder_sequence # Intermediate result.
|
366
|
+
|
367
|
+
# Self attention block.
|
368
|
+
residual = x
|
369
|
+
if self.normalize_first:
|
370
|
+
x = self._self_attention_layer_norm(x)
|
371
|
+
attention_output = self._self_attention_layer(
|
372
|
+
query=x,
|
373
|
+
value=x,
|
374
|
+
attention_mask=self_attention_mask,
|
375
|
+
cache=self_attention_cache,
|
376
|
+
cache_update_index=self_attention_cache_update_index,
|
377
|
+
training=training,
|
378
|
+
)
|
379
|
+
if self_attention_cache is None:
|
380
|
+
x = attention_output
|
381
|
+
else:
|
382
|
+
x, self_attention_cache = attention_output
|
383
|
+
x = self._self_attention_dropout(x, training=training)
|
384
|
+
x = x + residual
|
385
|
+
if not self.normalize_first:
|
386
|
+
x = self._self_attention_layer_norm(x)
|
387
|
+
|
388
|
+
# Cross attention is optional.
|
389
|
+
if has_cross_attention:
|
390
|
+
# Compute cross attention mask.
|
391
|
+
cross_attention_mask = merge_padding_and_attention_mask(
|
392
|
+
encoder_sequence, encoder_padding_mask, encoder_attention_mask
|
393
|
+
)
|
394
|
+
|
395
|
+
# Cross attention block.
|
396
|
+
residual = x
|
397
|
+
if self.normalize_first:
|
398
|
+
x = self._cross_attention_layer_norm(x)
|
399
|
+
attention_output = self._cross_attention_layer(
|
400
|
+
query=x,
|
401
|
+
value=encoder_sequence,
|
402
|
+
attention_mask=cross_attention_mask,
|
403
|
+
cache=cross_attention_cache,
|
404
|
+
cache_update_index=cross_attention_cache_update_index,
|
405
|
+
training=training,
|
406
|
+
)
|
407
|
+
if cross_attention_cache is None:
|
408
|
+
x = attention_output
|
409
|
+
else:
|
410
|
+
x, cross_attention_cache = attention_output
|
411
|
+
x = self._cross_attention_dropout(x, training=training)
|
412
|
+
x = x + residual
|
413
|
+
if not self.normalize_first:
|
414
|
+
x = self._cross_attention_layer_norm(x)
|
415
|
+
|
416
|
+
# Feedforward block.
|
417
|
+
residual = x
|
418
|
+
if self.normalize_first:
|
419
|
+
x = self._feedforward_layer_norm(x)
|
420
|
+
x = self._feedforward_intermediate_dense(x)
|
421
|
+
x = self._feedforward_output_dense(x)
|
422
|
+
x = self._feedforward_dropout(x, training=training)
|
423
|
+
x = x + residual
|
424
|
+
if not self.normalize_first:
|
425
|
+
x = self._feedforward_layer_norm(x)
|
426
|
+
|
427
|
+
if self_attention_cache is not None:
|
428
|
+
if has_cross_attention:
|
429
|
+
return (x, self_attention_cache, cross_attention_cache)
|
430
|
+
else:
|
431
|
+
return (x, self_attention_cache)
|
432
|
+
else:
|
433
|
+
return x
|
434
|
+
|
435
|
+
def _compute_self_attention_mask(
|
436
|
+
self,
|
437
|
+
decoder_sequence,
|
438
|
+
decoder_padding_mask,
|
439
|
+
decoder_attention_mask,
|
440
|
+
use_causal_mask,
|
441
|
+
self_attention_cache,
|
442
|
+
self_attention_cache_update_index,
|
443
|
+
):
|
444
|
+
decoder_mask = merge_padding_and_attention_mask(
|
445
|
+
decoder_sequence, decoder_padding_mask, decoder_attention_mask
|
446
|
+
)
|
447
|
+
if use_causal_mask:
|
448
|
+
batch_size = ops.shape(decoder_sequence)[0]
|
449
|
+
input_length = output_length = ops.shape(decoder_sequence)[1]
|
450
|
+
# We need to handle a rectangular causal mask when doing cached
|
451
|
+
# decoding. For generative inference, `decoder_sequence` will
|
452
|
+
# generally be length 1, and `cache` will be the full generation length.
|
453
|
+
if self_attention_cache is not None:
|
454
|
+
input_length = ops.shape(self_attention_cache)[2]
|
455
|
+
|
456
|
+
causal_mask = compute_causal_mask(
|
457
|
+
batch_size,
|
458
|
+
input_length,
|
459
|
+
output_length,
|
460
|
+
(
|
461
|
+
0
|
462
|
+
if self_attention_cache_update_index is None
|
463
|
+
else self_attention_cache_update_index
|
464
|
+
),
|
465
|
+
)
|
466
|
+
return (
|
467
|
+
ops.minimum(decoder_mask, causal_mask)
|
468
|
+
if decoder_mask is not None
|
469
|
+
else causal_mask
|
470
|
+
)
|
471
|
+
return decoder_mask
|
472
|
+
|
473
|
+
def get_config(self):
|
474
|
+
config = super().get_config()
|
475
|
+
config.update(
|
476
|
+
{
|
477
|
+
"intermediate_dim": self.intermediate_dim,
|
478
|
+
"num_heads": self.num_heads,
|
479
|
+
"dropout": self.dropout,
|
480
|
+
"activation": keras.activations.serialize(self.activation),
|
481
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
482
|
+
"kernel_initializer": keras.initializers.serialize(
|
483
|
+
self.kernel_initializer
|
484
|
+
),
|
485
|
+
"bias_initializer": keras.initializers.serialize(
|
486
|
+
self.bias_initializer
|
487
|
+
),
|
488
|
+
"normalize_first": self.normalize_first,
|
489
|
+
"decoder_sequence_shape": self._decoder_sequence_shape,
|
490
|
+
"encoder_sequence_shape": self._encoder_sequence_shape,
|
491
|
+
}
|
492
|
+
)
|
493
|
+
return config
|
494
|
+
|
495
|
+
def compute_output_shape(self, decoder_sequence_shape):
|
496
|
+
return decoder_sequence_shape
|