keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,133 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
|
19
|
+
class ContentAndQueryEmbedding(keras.layers.Layer):
|
20
|
+
"""
|
21
|
+
Content and Query Embedding.
|
22
|
+
|
23
|
+
This class creates Content and Query Embeddings for XLNet model
|
24
|
+
which is later used in XLNet Encoder.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
vocabulary_size: int, number of tokens in the vocabulary.
|
28
|
+
hidden_dim: int, the size hidden states.
|
29
|
+
dropout: float, defaults to 0. the dropout value, shared by
|
30
|
+
`keras.layers.TwoStreamRelativeAttention` and feedforward network.
|
31
|
+
kernel_initializer_range: int, defaults to 0.02. The kernel initializer
|
32
|
+
range for the dense and relative attention layers.
|
33
|
+
name: string, defaults to None. The name of the layer.
|
34
|
+
**kwargs: other keyword arguments.
|
35
|
+
|
36
|
+
References:
|
37
|
+
- [XLNet: Generalized Autoregressive Pretraining for Language Understanding]
|
38
|
+
(https://arxiv.org/abs/1906.08237)
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self, vocabulary_size, hidden_dim, dropout, name=None, **kwargs
|
43
|
+
):
|
44
|
+
super().__init__(name=name, **kwargs)
|
45
|
+
self.vocabulary_size = vocabulary_size
|
46
|
+
self.hidden_dim = hidden_dim
|
47
|
+
self.dropout = dropout
|
48
|
+
|
49
|
+
def positional_embedding(self, pos_seq, inv_freq, bsz=None):
|
50
|
+
sinusoid_inp = ops.einsum("i,d->id", pos_seq, inv_freq)
|
51
|
+
pos_emb = ops.concatenate(
|
52
|
+
[ops.sin(sinusoid_inp), ops.cos(sinusoid_inp)], axis=-1
|
53
|
+
)
|
54
|
+
pos_emb = ops.expand_dims(pos_emb, 1)
|
55
|
+
pos_emb = (
|
56
|
+
ops.ones(
|
57
|
+
[
|
58
|
+
ops.shape(pos_emb)[0],
|
59
|
+
ops.shape(pos_emb)[1] * bsz,
|
60
|
+
ops.shape(pos_emb)[2],
|
61
|
+
],
|
62
|
+
dtype=self.compute_dtype,
|
63
|
+
)
|
64
|
+
* pos_emb
|
65
|
+
)
|
66
|
+
|
67
|
+
return pos_emb
|
68
|
+
|
69
|
+
def relative_positional_encoding(self, qlen, klen, bsz=None, clamp_len=-1):
|
70
|
+
"""create relative positional encoding."""
|
71
|
+
freq_seq = ops.arange(0, self.hidden_dim, 2.0, dtype="float32")
|
72
|
+
freq_seq = ops.cast(freq_seq, self.compute_dtype)
|
73
|
+
inv_freq = 1 / (10000 ** (freq_seq / self.hidden_dim))
|
74
|
+
|
75
|
+
beg, end = klen, -qlen
|
76
|
+
|
77
|
+
fwd_pos_seq = ops.arange(beg, end, -1.0, dtype="float32")
|
78
|
+
fwd_pos_seq = ops.cast(fwd_pos_seq, self.compute_dtype)
|
79
|
+
if clamp_len > 0:
|
80
|
+
fwd_pos_seq = ops.clip(
|
81
|
+
fwd_pos_seq, x_min=-clamp_len, x_max=clamp_len
|
82
|
+
)
|
83
|
+
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
|
84
|
+
|
85
|
+
return pos_emb
|
86
|
+
|
87
|
+
def build(self, input_shape):
|
88
|
+
self.word_embed = keras.layers.Embedding(
|
89
|
+
input_dim=self.vocabulary_size,
|
90
|
+
output_dim=self.hidden_dim,
|
91
|
+
dtype=self.dtype_policy,
|
92
|
+
name="word_embedding",
|
93
|
+
)
|
94
|
+
self.word_embed.build(input_shape)
|
95
|
+
self.dropout_layer = keras.layers.Dropout(
|
96
|
+
self.dropout,
|
97
|
+
dtype=self.dtype_policy,
|
98
|
+
)
|
99
|
+
super().build(input_shape)
|
100
|
+
|
101
|
+
def call(
|
102
|
+
self,
|
103
|
+
token_id_input,
|
104
|
+
mlen=None,
|
105
|
+
):
|
106
|
+
mlen = 0 if mlen is None else mlen
|
107
|
+
|
108
|
+
bsz, qlen = ops.shape(token_id_input)[0], ops.shape(token_id_input)[1]
|
109
|
+
klen = mlen + qlen
|
110
|
+
|
111
|
+
# Word embeddings and prepare h & g hidden states
|
112
|
+
word_emb = self.word_embed(token_id_input)
|
113
|
+
word_emb = self.dropout_layer(word_emb)
|
114
|
+
|
115
|
+
# Positional encoding
|
116
|
+
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
|
117
|
+
pos_emb = self.dropout_layer(pos_emb)
|
118
|
+
pos_emb = ops.reshape(
|
119
|
+
pos_emb,
|
120
|
+
[
|
121
|
+
ops.shape(pos_emb)[1],
|
122
|
+
ops.shape(pos_emb)[0],
|
123
|
+
ops.shape(pos_emb)[2],
|
124
|
+
],
|
125
|
+
)
|
126
|
+
|
127
|
+
return word_emb, pos_emb
|
128
|
+
|
129
|
+
def compute_output_shape(self, token_id_input_shape):
|
130
|
+
return [
|
131
|
+
token_id_input_shape + (self.hidden_dim,),
|
132
|
+
(token_id_input_shape[0], 1, self.hidden_dim),
|
133
|
+
]
|
@@ -0,0 +1,378 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.models.xlnet.relative_attention import (
|
19
|
+
TwoStreamRelativeAttention,
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
def xlnet_kernel_initializer(stddev=0.02):
|
24
|
+
return keras.initializers.TruncatedNormal(stddev=stddev)
|
25
|
+
|
26
|
+
|
27
|
+
class XLNetEncoder(keras.layers.Layer):
|
28
|
+
"""
|
29
|
+
XLNet Encoder.
|
30
|
+
|
31
|
+
This class follows the architecture of the transformer encoder layer in the
|
32
|
+
paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users
|
33
|
+
can instantiate multiple instances of this class to stack up an encoder.
|
34
|
+
|
35
|
+
Contrary to the single hidden state used in the paper mentioned above, this
|
36
|
+
Encoder uses two hidden states, Content State and Query State. Thus calculates
|
37
|
+
Two Stream Relative Attention using both of the hidden states. To know more
|
38
|
+
please check the reference.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
num_heads: int, the number of heads in the
|
42
|
+
`keras.layers.TwoStreamRelativeAttention` layer.
|
43
|
+
hidden_dim: int, the size hidden states.
|
44
|
+
head_dim: int, the size of each attention head.
|
45
|
+
intermediate_dim: int, the hidden size of feedforward network.
|
46
|
+
dropout: float, defaults to 0.0 the dropout value, shared by
|
47
|
+
`keras.layers.TwoStreamRelativeAttention` and feedforward network.
|
48
|
+
activation: string or `keras.activations`, defaults to "gelu". the
|
49
|
+
activation function of feedforward network.
|
50
|
+
layer_norm_epsilon: float, defaults to 1e-12. The epsilon value in layer
|
51
|
+
normalization components.
|
52
|
+
kernel_initializer_range: int, defaults to 0.02. The kernel initializer
|
53
|
+
range for the dense and relative attention layers.
|
54
|
+
bias_initializer: string or `keras.initializers` initializer,
|
55
|
+
defaults to "zeros". The bias initializer for
|
56
|
+
the dense and multiheaded relative attention layers.
|
57
|
+
name: string, defaults to None. The name of the layer.
|
58
|
+
**kwargs: other keyword arguments.
|
59
|
+
|
60
|
+
References:
|
61
|
+
- [XLNet: Generalized Autoregressive Pretraining for Language Understanding]
|
62
|
+
(https://arxiv.org/abs/1906.08237)
|
63
|
+
"""
|
64
|
+
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
num_heads,
|
68
|
+
hidden_dim,
|
69
|
+
head_dim,
|
70
|
+
intermediate_dim,
|
71
|
+
dropout=0.0,
|
72
|
+
activation="gelu",
|
73
|
+
layer_norm_epsilon=1e-12,
|
74
|
+
kernel_initializer_range=0.02,
|
75
|
+
bias_initializer="zeros",
|
76
|
+
name=None,
|
77
|
+
**kwargs
|
78
|
+
):
|
79
|
+
super().__init__(name=name, **kwargs)
|
80
|
+
self.num_heads = num_heads
|
81
|
+
self.hidden_dim = hidden_dim
|
82
|
+
self.head_dim = head_dim
|
83
|
+
self.intermediate_dim = intermediate_dim
|
84
|
+
self.dropout = dropout
|
85
|
+
self.activation = activation
|
86
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
87
|
+
self.kernel_initializer_range = kernel_initializer_range
|
88
|
+
self.bias_initializer = keras.initializers.get(bias_initializer)
|
89
|
+
self.kernel_initializer = xlnet_kernel_initializer(
|
90
|
+
self.kernel_initializer_range
|
91
|
+
)
|
92
|
+
|
93
|
+
def build(self, input_shape):
|
94
|
+
# Attention Part
|
95
|
+
self.relative_attention = TwoStreamRelativeAttention(
|
96
|
+
num_heads=self.num_heads,
|
97
|
+
key_dim=self.head_dim,
|
98
|
+
kernel_initializer=self.kernel_initializer,
|
99
|
+
bias_initializer=self.bias_initializer,
|
100
|
+
dtype=self.dtype_policy,
|
101
|
+
name="rel_attn",
|
102
|
+
)
|
103
|
+
self.relative_attention.build(input_shape)
|
104
|
+
|
105
|
+
self.layer_norm = keras.layers.LayerNormalization(
|
106
|
+
epsilon=self.layer_norm_epsilon,
|
107
|
+
dtype=self.dtype_policy,
|
108
|
+
name="layer_norm_rel_attn",
|
109
|
+
)
|
110
|
+
self.layer_norm.build(input_shape)
|
111
|
+
|
112
|
+
self.dropout_attn = keras.layers.Dropout(
|
113
|
+
self.dropout,
|
114
|
+
dtype=self.dtype_policy,
|
115
|
+
)
|
116
|
+
|
117
|
+
# Feed-Forward Part
|
118
|
+
self.layer_norm_ff = keras.layers.LayerNormalization(
|
119
|
+
epsilon=self.layer_norm_epsilon,
|
120
|
+
dtype=self.dtype_policy,
|
121
|
+
name="layer_norm_ff",
|
122
|
+
)
|
123
|
+
self.layer_norm_ff.build(input_shape)
|
124
|
+
|
125
|
+
self.feedforward_intermediate_dense = keras.layers.Dense(
|
126
|
+
self.intermediate_dim,
|
127
|
+
kernel_initializer=self.kernel_initializer,
|
128
|
+
dtype=self.dtype_policy,
|
129
|
+
name="feedforward_intermediate_dense",
|
130
|
+
)
|
131
|
+
self.feedforward_intermediate_dense.build(input_shape)
|
132
|
+
|
133
|
+
self.feedforward_output_dense = keras.layers.Dense(
|
134
|
+
self.hidden_dim,
|
135
|
+
kernel_initializer=self.kernel_initializer,
|
136
|
+
dtype=self.dtype_policy,
|
137
|
+
name="feedforward_output_dense",
|
138
|
+
)
|
139
|
+
self.feedforward_output_dense.build(
|
140
|
+
self.feedforward_intermediate_dense.compute_output_shape(
|
141
|
+
input_shape
|
142
|
+
)
|
143
|
+
)
|
144
|
+
|
145
|
+
self.dropout_ff = keras.layers.Dropout(
|
146
|
+
self.dropout,
|
147
|
+
dtype=self.dtype_policy,
|
148
|
+
)
|
149
|
+
|
150
|
+
self.activation_function_ff = keras.activations.get(self.activation)
|
151
|
+
|
152
|
+
self.content_attention_bias = self.add_weight(
|
153
|
+
shape=(self.num_heads, self.head_dim),
|
154
|
+
initializer=self.bias_initializer,
|
155
|
+
trainable=True,
|
156
|
+
name="content_attention_bias",
|
157
|
+
)
|
158
|
+
|
159
|
+
self.positional_attention_bias = self.add_weight(
|
160
|
+
shape=(self.num_heads, self.head_dim),
|
161
|
+
initializer=self.bias_initializer,
|
162
|
+
trainable=True,
|
163
|
+
name="positional_attention_bias",
|
164
|
+
)
|
165
|
+
|
166
|
+
self.segment_attention_bias = self.add_weight(
|
167
|
+
shape=(self.num_heads, self.head_dim),
|
168
|
+
initializer=self.bias_initializer,
|
169
|
+
trainable=True,
|
170
|
+
name="segment_attention_bias",
|
171
|
+
)
|
172
|
+
|
173
|
+
self.segment_encoding = self.add_weight(
|
174
|
+
shape=(2, self.num_heads, self.head_dim),
|
175
|
+
initializer=self.kernel_initializer,
|
176
|
+
trainable=True,
|
177
|
+
name="segment_encoding",
|
178
|
+
)
|
179
|
+
|
180
|
+
super().build(input_shape)
|
181
|
+
|
182
|
+
def call(
|
183
|
+
self,
|
184
|
+
output_content,
|
185
|
+
attn_mask_content,
|
186
|
+
attn_mask_query,
|
187
|
+
pos_emb,
|
188
|
+
seg_mat,
|
189
|
+
output_query=None,
|
190
|
+
mems=None,
|
191
|
+
target_mapping=None,
|
192
|
+
):
|
193
|
+
# rel_attn
|
194
|
+
attn_out_content, attn_out_query = self.relative_attention(
|
195
|
+
content_stream=output_content,
|
196
|
+
query_stream=output_query,
|
197
|
+
content_attention_mask=attn_mask_content,
|
198
|
+
query_attention_mask=attn_mask_query,
|
199
|
+
relative_position_encoding=pos_emb,
|
200
|
+
content_attention_bias=self.content_attention_bias,
|
201
|
+
positional_attention_bias=self.positional_attention_bias,
|
202
|
+
segment_attention_bias=self.segment_attention_bias,
|
203
|
+
segment_matrix=seg_mat,
|
204
|
+
segment_encoding=self.segment_encoding,
|
205
|
+
target_mapping=target_mapping,
|
206
|
+
state=mems,
|
207
|
+
)
|
208
|
+
|
209
|
+
attn_out_content = self.dropout_attn(attn_out_content)
|
210
|
+
attn_out_content = attn_out_content + output_content
|
211
|
+
attn_out_content = self.layer_norm(attn_out_content)
|
212
|
+
|
213
|
+
if attn_out_query is not None:
|
214
|
+
attn_out_query = self.dropout_attn(attn_out_query)
|
215
|
+
attn_out_query = attn_out_query + output_query
|
216
|
+
attn_out_query = self.layer_norm(attn_out_query)
|
217
|
+
|
218
|
+
# feed-forward
|
219
|
+
ff_out_content = attn_out_content
|
220
|
+
ff_out_content = self.feedforward_intermediate_dense(ff_out_content)
|
221
|
+
ff_out_content = self.activation_function_ff(ff_out_content)
|
222
|
+
ff_out_content = self.dropout_ff(ff_out_content)
|
223
|
+
ff_out_content = self.feedforward_output_dense(ff_out_content)
|
224
|
+
ff_out_content = self.dropout_ff(ff_out_content)
|
225
|
+
ff_out_content = self.layer_norm_ff(ff_out_content + attn_out_content)
|
226
|
+
|
227
|
+
if attn_out_query is not None:
|
228
|
+
ff_out_query = attn_out_query
|
229
|
+
ff_out_query = self.feedforward_intermediate_dense(ff_out_query)
|
230
|
+
ff_out_query = self.activation_function_ff(ff_out_query)
|
231
|
+
ff_out_query = self.dropout_ff(ff_out_query)
|
232
|
+
ff_out_query = self.feedforward_output_dense(ff_out_query)
|
233
|
+
ff_out_query = self.dropout_ff(ff_out_query)
|
234
|
+
ff_out_query = self.layer_norm_ff(ff_out_query + attn_out_query)
|
235
|
+
|
236
|
+
return ff_out_content, ff_out_query
|
237
|
+
|
238
|
+
return ff_out_content, None
|
239
|
+
|
240
|
+
def compute_output_shape(
|
241
|
+
self,
|
242
|
+
output_content_shape,
|
243
|
+
pos_emb_shape,
|
244
|
+
attn_mask_content_shape,
|
245
|
+
attn_mask_query_shape,
|
246
|
+
seg_mat_shape,
|
247
|
+
output_query_shape=None,
|
248
|
+
):
|
249
|
+
return [output_content_shape, output_content_shape]
|
250
|
+
|
251
|
+
|
252
|
+
class XLNetAttentionMaskLayer(keras.layers.Layer):
|
253
|
+
"""
|
254
|
+
Attention Mask Layer for XLNet Encoder Block.
|
255
|
+
|
256
|
+
This layer processes attention masks for both content state and query state
|
257
|
+
during the forward pass.
|
258
|
+
|
259
|
+
Args:
|
260
|
+
hidden_dim: int, the size hidden states.
|
261
|
+
kernel_initializer_range: int, defaults to 0.02. The kernel initializer
|
262
|
+
range for the dense and relative attention layers.
|
263
|
+
**kwargs: other keyword arguments.
|
264
|
+
"""
|
265
|
+
|
266
|
+
def __init__(self, hidden_dim, kernel_initializer_range, **kwargs):
|
267
|
+
super().__init__(**kwargs)
|
268
|
+
self.hidden_dim = hidden_dim
|
269
|
+
self.kernel_initializer_range = kernel_initializer_range
|
270
|
+
self.kernel_initializer = xlnet_kernel_initializer(
|
271
|
+
self.kernel_initializer_range
|
272
|
+
)
|
273
|
+
|
274
|
+
def build(self, inputs_shape):
|
275
|
+
self.mask_emb = self.add_weight(
|
276
|
+
shape=(1, 1, self.hidden_dim),
|
277
|
+
initializer=self.kernel_initializer,
|
278
|
+
trainable=True,
|
279
|
+
name="mask_emb",
|
280
|
+
)
|
281
|
+
self.built = True
|
282
|
+
|
283
|
+
def call(self, inputs, mlen=None):
|
284
|
+
bsz, qlen = ops.shape(inputs)[0], ops.shape(inputs)[1]
|
285
|
+
mlen = 0 if mlen is None else mlen
|
286
|
+
|
287
|
+
inputs = 1 - inputs
|
288
|
+
inputs = ops.reshape(
|
289
|
+
inputs,
|
290
|
+
[ops.shape(inputs)[1], ops.shape(inputs)[0]],
|
291
|
+
)
|
292
|
+
|
293
|
+
data_mask = ops.expand_dims(inputs, 0)
|
294
|
+
|
295
|
+
if mlen > 0:
|
296
|
+
mems_mask = ops.zeros([ops.shape(data_mask)[0], mlen, bsz])
|
297
|
+
data_mask = ops.concatenate(
|
298
|
+
[ops.cast(mems_mask, dtype="int32"), data_mask], axis=1
|
299
|
+
)
|
300
|
+
attn_mask_query = ops.expand_dims(data_mask, -1)
|
301
|
+
|
302
|
+
attn_mask_query = ops.cast(
|
303
|
+
attn_mask_query > 0, dtype=attn_mask_query.dtype
|
304
|
+
)
|
305
|
+
|
306
|
+
# Since ops.eye doesn't support tensorflow Tensor as input.
|
307
|
+
# we need to create custom function here.
|
308
|
+
n = ops.expand_dims(ops.arange(qlen), -1)
|
309
|
+
m = ops.arange(qlen)
|
310
|
+
attn_mask_content = -ops.cast(
|
311
|
+
ops.where(n == m, 1, 0), attn_mask_query.dtype
|
312
|
+
)
|
313
|
+
|
314
|
+
if mlen > 0:
|
315
|
+
attn_mask_content = ops.concatenate(
|
316
|
+
[
|
317
|
+
ops.zeros([qlen, mlen], dtype=attn_mask_content.dtype),
|
318
|
+
attn_mask_content,
|
319
|
+
],
|
320
|
+
axis=-1,
|
321
|
+
)
|
322
|
+
|
323
|
+
attn_mask_content = ops.cast(
|
324
|
+
(
|
325
|
+
attn_mask_query
|
326
|
+
+ ops.expand_dims(ops.expand_dims(attn_mask_content, -1), -1)
|
327
|
+
)
|
328
|
+
> 0,
|
329
|
+
dtype=attn_mask_content.dtype,
|
330
|
+
)
|
331
|
+
|
332
|
+
# to make sure inputs suitable for TwoStreamRelativeAttention
|
333
|
+
attn_mask_content = 1.0 - ops.cast(
|
334
|
+
ops.transpose(ops.squeeze(attn_mask_content, -1), [2, 0, 1]),
|
335
|
+
"float32",
|
336
|
+
)
|
337
|
+
attn_mask_query = 1.0 - ops.cast(
|
338
|
+
ops.transpose(ops.squeeze(attn_mask_query, -1), [2, 0, 1]),
|
339
|
+
"float32",
|
340
|
+
)
|
341
|
+
|
342
|
+
return attn_mask_content, attn_mask_query
|
343
|
+
|
344
|
+
def compute_output_shape(self, padding_mask_shape):
|
345
|
+
return [padding_mask_shape, padding_mask_shape]
|
346
|
+
|
347
|
+
|
348
|
+
class XLNetSegmentMatrixLayer(keras.layers.Layer):
|
349
|
+
"""
|
350
|
+
This layer creates Segment Matrix for XLNet Encoder.
|
351
|
+
"""
|
352
|
+
|
353
|
+
def call(self, segment_ids, mlen=None):
|
354
|
+
bsz = ops.shape(segment_ids)[0]
|
355
|
+
mlen = 0 if mlen is None else mlen
|
356
|
+
|
357
|
+
# Prepare seg_mat
|
358
|
+
segment_ids = ops.transpose(segment_ids, [1, 0])
|
359
|
+
|
360
|
+
if mlen > 0:
|
361
|
+
mem_pad = ops.zeros([mlen, bsz], dtype=segment_ids.dtype)
|
362
|
+
cat_ids = ops.concatenate([mem_pad, segment_ids], 0)
|
363
|
+
else:
|
364
|
+
cat_ids = segment_ids
|
365
|
+
|
366
|
+
# `1` indicates not in the same segment [qlen x klen x bsz]
|
367
|
+
seg_mat = ops.cast(
|
368
|
+
ops.logical_not(ops.equal(segment_ids[:, None], cat_ids[None, :])),
|
369
|
+
dtype=segment_ids.dtype,
|
370
|
+
)
|
371
|
+
|
372
|
+
# to make sure inputs suitable for TwoStreamRelativeAttention
|
373
|
+
seg_mat = ops.cast(ops.transpose(seg_mat, [2, 0, 1]), dtype="bool")
|
374
|
+
|
375
|
+
return seg_mat
|
376
|
+
|
377
|
+
def compute_output_shape(self, segment_ids_shape):
|
378
|
+
return segment_ids_shape
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|