keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,459 @@
|
|
1
|
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import math
|
16
|
+
import string
|
17
|
+
|
18
|
+
import keras
|
19
|
+
from keras import ops
|
20
|
+
|
21
|
+
_CHR_IDX = string.ascii_lowercase
|
22
|
+
|
23
|
+
|
24
|
+
def _build_proj_equation(free_dims, bound_dims, output_dims):
|
25
|
+
"""
|
26
|
+
Builds an einsum equation for projections inside multi-head attention.
|
27
|
+
"""
|
28
|
+
input_str = ""
|
29
|
+
kernel_str = ""
|
30
|
+
output_str = ""
|
31
|
+
bias_axes = ""
|
32
|
+
letter_offset = 0
|
33
|
+
for i in range(free_dims):
|
34
|
+
char = _CHR_IDX[i + letter_offset]
|
35
|
+
input_str += char
|
36
|
+
output_str += char
|
37
|
+
|
38
|
+
letter_offset += free_dims
|
39
|
+
for i in range(bound_dims):
|
40
|
+
char = _CHR_IDX[i + letter_offset]
|
41
|
+
input_str += char
|
42
|
+
kernel_str += char
|
43
|
+
|
44
|
+
letter_offset += bound_dims
|
45
|
+
for i in range(output_dims):
|
46
|
+
char = _CHR_IDX[i + letter_offset]
|
47
|
+
kernel_str += char
|
48
|
+
output_str += char
|
49
|
+
bias_axes += char
|
50
|
+
equation = "%s,%s->%s" % (input_str, kernel_str, output_str)
|
51
|
+
|
52
|
+
return equation, bias_axes, len(output_str)
|
53
|
+
|
54
|
+
|
55
|
+
def _get_output_shape(output_rank, known_last_dims):
|
56
|
+
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
|
57
|
+
|
58
|
+
|
59
|
+
def _rel_shift(x, klen=-1):
|
60
|
+
"""
|
61
|
+
Performs relative shift to form the relative attention score.
|
62
|
+
"""
|
63
|
+
|
64
|
+
x = ops.transpose(x, [2, 3, 0, 1])
|
65
|
+
x_size = ops.shape(x)
|
66
|
+
x = ops.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]])
|
67
|
+
x = ops.slice(
|
68
|
+
x, [1, 0, 0, 0], [x_size[1] - 1, x_size[0], x_size[2], x_size[3]]
|
69
|
+
)
|
70
|
+
x = ops.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]])
|
71
|
+
x = ops.slice(x, [0, 0, 0, 0], [x_size[0], klen, x_size[2], x_size[3]])
|
72
|
+
|
73
|
+
x = ops.transpose(x, [2, 3, 0, 1])
|
74
|
+
|
75
|
+
return x
|
76
|
+
|
77
|
+
|
78
|
+
class TwoStreamRelativeAttention(keras.layers.MultiHeadAttention):
|
79
|
+
"""Two-stream relative self-attention for XLNet.
|
80
|
+
|
81
|
+
In XLNet, each token has two associated vectors at each self-attention layer,
|
82
|
+
the content stream (h) and the query stream (g). The content stream is the
|
83
|
+
self-attention stream as in Transformer XL and represents the context and
|
84
|
+
content (the token itself). The query stream only has access to contextual
|
85
|
+
information and the position, but not the content.
|
86
|
+
|
87
|
+
This layer shares the same build signature as `keras.layers.MultiHeadAttention`
|
88
|
+
but has different input/output projections.
|
89
|
+
|
90
|
+
We use the notations `B`, `T`, `S`, `M`, `L`, `E`, `P`, `dim`, `num_heads`
|
91
|
+
below, where
|
92
|
+
`B` is the batch dimension, `T` is the target sequence length,
|
93
|
+
`S` in the source sequence length, `M` is the length of the state or memory,
|
94
|
+
`L` is the length of relative positional encoding, `E` is the last dimension
|
95
|
+
of query input, `P` is the number of predictions, `dim` is the dimensionality
|
96
|
+
of the encoder layers. and `num_heads` is the number of attention heads.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
content_stream: `Tensor` of shape `[B, T, dim]`.
|
100
|
+
content_attention_bias: Bias `Tensor` for content based attention of shape
|
101
|
+
`[num_heads, dim]`.
|
102
|
+
positional_attention_bias: Bias `Tensor` for position based attention of
|
103
|
+
shape `[num_heads, dim]`.
|
104
|
+
query_stream: `Tensor` of shape `[B, P, dim]`.
|
105
|
+
target_mapping: `Tensor` of shape `[B, P, S]`.
|
106
|
+
relative_position_encoding: Relative positional encoding `Tensor` of
|
107
|
+
shape `[B, L, dim]`.
|
108
|
+
segment_matrix: Optional `Tensor` representing segmentation IDs used in
|
109
|
+
XLNet of shape `[B, S, S + M]`.
|
110
|
+
segment_encoding: Optional `Tensor` representing the segmentation
|
111
|
+
encoding as used in XLNet of shape `[2, num_heads, dim]`.
|
112
|
+
segment_attention_bias: Optional trainable bias parameter added to the
|
113
|
+
query had when calculating the segment-based attention score used
|
114
|
+
in XLNet of shape `[num_heads, dim]`.
|
115
|
+
state: Optional `Tensor` of shape `[B, M, E]`.
|
116
|
+
If passed, this is also attended over as in Transformer XL.
|
117
|
+
content_attention_mask: a boolean mask of shape `[B, T, S]` that
|
118
|
+
prevents attention to certain positions for content attention
|
119
|
+
computation.
|
120
|
+
query_attention_mask: a boolean mask of shape `[B, T, S]` that
|
121
|
+
prevents attention to certain position for query attention
|
122
|
+
computation.
|
123
|
+
"""
|
124
|
+
|
125
|
+
def __init__(self, kernel_initializer="glorot_uniform", **kwargs):
|
126
|
+
super().__init__(kernel_initializer=kernel_initializer, **kwargs)
|
127
|
+
|
128
|
+
def _get_common_kwargs_for_sublayer(self):
|
129
|
+
common_kwargs = dict(
|
130
|
+
kernel_initializer=self._kernel_initializer,
|
131
|
+
bias_initializer=self._bias_initializer,
|
132
|
+
kernel_regularizer=self._kernel_regularizer,
|
133
|
+
bias_regularizer=self._bias_regularizer,
|
134
|
+
activity_regularizer=self._activity_regularizer,
|
135
|
+
kernel_constraint=self._kernel_constraint,
|
136
|
+
bias_constraint=self._bias_constraint,
|
137
|
+
)
|
138
|
+
return common_kwargs
|
139
|
+
|
140
|
+
def build(self, content_stream_shape):
|
141
|
+
self._use_bias = False
|
142
|
+
|
143
|
+
self._query_shape = content_stream_shape
|
144
|
+
self._key_shape = content_stream_shape
|
145
|
+
self._value_shape = content_stream_shape
|
146
|
+
|
147
|
+
free_dims = len(self._query_shape) - 1
|
148
|
+
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
149
|
+
free_dims, bound_dims=1, output_dims=2
|
150
|
+
)
|
151
|
+
self._query_dense = keras.layers.EinsumDense(
|
152
|
+
einsum_equation,
|
153
|
+
output_shape=_get_output_shape(
|
154
|
+
output_rank - 1, [self._num_heads, self._key_dim]
|
155
|
+
),
|
156
|
+
bias_axes=bias_axes if self._use_bias else None,
|
157
|
+
dtype=self.dtype_policy,
|
158
|
+
name="query",
|
159
|
+
**self._get_common_kwargs_for_sublayer(),
|
160
|
+
)
|
161
|
+
self._query_dense.build(self._query_shape)
|
162
|
+
|
163
|
+
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
164
|
+
len(self._key_shape) - 1, bound_dims=1, output_dims=2
|
165
|
+
)
|
166
|
+
self._key_dense = keras.layers.EinsumDense(
|
167
|
+
einsum_equation,
|
168
|
+
output_shape=_get_output_shape(
|
169
|
+
output_rank - 1, [self._num_heads, self._key_dim]
|
170
|
+
),
|
171
|
+
bias_axes=bias_axes if self._use_bias else None,
|
172
|
+
dtype=self.dtype_policy,
|
173
|
+
name="key",
|
174
|
+
**self._get_common_kwargs_for_sublayer(),
|
175
|
+
)
|
176
|
+
self._key_dense.build(self._key_shape)
|
177
|
+
|
178
|
+
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
179
|
+
len(self._value_shape) - 1, bound_dims=1, output_dims=2
|
180
|
+
)
|
181
|
+
self._value_dense = keras.layers.EinsumDense(
|
182
|
+
einsum_equation,
|
183
|
+
output_shape=_get_output_shape(
|
184
|
+
output_rank - 1, [self._num_heads, self._value_dim]
|
185
|
+
),
|
186
|
+
bias_axes=bias_axes if self._use_bias else None,
|
187
|
+
dtype=self.dtype_policy,
|
188
|
+
name="value",
|
189
|
+
**self._get_common_kwargs_for_sublayer(),
|
190
|
+
)
|
191
|
+
self._value_dense.build(self._value_shape)
|
192
|
+
|
193
|
+
free_dims = len(self._query_shape) - 1
|
194
|
+
_, _, output_rank = _build_proj_equation(
|
195
|
+
free_dims, bound_dims=2, output_dims=1
|
196
|
+
)
|
197
|
+
self._output_dense = keras.layers.EinsumDense(
|
198
|
+
"ibnd,hnd->ibh",
|
199
|
+
output_shape=_get_output_shape(
|
200
|
+
output_rank - 1, [self._query_shape[-1]]
|
201
|
+
),
|
202
|
+
bias_axes=None,
|
203
|
+
dtype=self.dtype_policy,
|
204
|
+
name="attention_output",
|
205
|
+
**self._get_common_kwargs_for_sublayer(),
|
206
|
+
)
|
207
|
+
self._output_dense.build(
|
208
|
+
self._value_dense.compute_output_shape(self._value_dim)
|
209
|
+
)
|
210
|
+
|
211
|
+
einsum_equation, _, output_rank = _build_proj_equation(
|
212
|
+
len(self._key_shape) - 1, bound_dims=1, output_dims=2
|
213
|
+
)
|
214
|
+
self._encoding_dense = keras.layers.EinsumDense(
|
215
|
+
einsum_equation,
|
216
|
+
output_shape=_get_output_shape(
|
217
|
+
output_rank - 1, [self._num_heads, self._key_dim]
|
218
|
+
),
|
219
|
+
bias_axes=None,
|
220
|
+
dtype=self.dtype_policy,
|
221
|
+
name="encoding",
|
222
|
+
**self._get_common_kwargs_for_sublayer(),
|
223
|
+
)
|
224
|
+
self._encoding_dense.build(self._key_shape)
|
225
|
+
|
226
|
+
self._build_attention(output_rank)
|
227
|
+
self.built = True
|
228
|
+
|
229
|
+
def compute_attention(
|
230
|
+
self,
|
231
|
+
query,
|
232
|
+
key,
|
233
|
+
value,
|
234
|
+
position,
|
235
|
+
content_attention_bias,
|
236
|
+
positional_attention_bias,
|
237
|
+
segment_matrix=None,
|
238
|
+
segment_encoding=None,
|
239
|
+
segment_attention_bias=None,
|
240
|
+
attention_mask=None,
|
241
|
+
):
|
242
|
+
"""Computes the attention.
|
243
|
+
|
244
|
+
This function defines the computation inside `call` with projected
|
245
|
+
multihead Q, K, V, R inputs.
|
246
|
+
|
247
|
+
We use the notations `B`, `T`, `S`, `M`, `L`, `num_heads`, `key_dim`
|
248
|
+
below, where
|
249
|
+
`B` is the batch dimension, `T` is the target sequence length,
|
250
|
+
`S` in the source sequence length, `M` is the length of the state,
|
251
|
+
`L` is the length of relative positional encoding, `num_heads` is
|
252
|
+
number of attention heads and `key_dim` is size of each attention head
|
253
|
+
for query and key.
|
254
|
+
|
255
|
+
Args:
|
256
|
+
query: Projected query `Tensor` of shape
|
257
|
+
`[B, T, num_heads, key_dim]`.
|
258
|
+
key: Projected key `Tensor` of shape
|
259
|
+
`[B, S + M, num_heads, key_dim]`.
|
260
|
+
value: Projected value `Tensor` of shape
|
261
|
+
`[B, S + M, num_heads, key_dim]`.
|
262
|
+
position: Projected position `Tensor` of shape
|
263
|
+
`[B, L, num_heads, key_dim]`.
|
264
|
+
content_attention_bias: Trainable bias parameter added to the query
|
265
|
+
head when calculating the content-based attention score.
|
266
|
+
positional_attention_bias: Trainable bias parameter added to the
|
267
|
+
query head when calculating the position-based attention score.
|
268
|
+
segment_matrix: Optional `Tensor` representing segmentation IDs
|
269
|
+
used in XLNet.
|
270
|
+
segment_encoding: Optional trainable `Tensor` representing the
|
271
|
+
segmentation encoding as used in XLNet.
|
272
|
+
segment_attention_bias: Optional trainable bias parameter added
|
273
|
+
to the query had when calculating the segment-based attention
|
274
|
+
score used in XLNet.
|
275
|
+
attention_mask: (default None) Optional mask that is added to
|
276
|
+
attention logits. If state is not None, the mask source sequence
|
277
|
+
dimension should extend M.
|
278
|
+
Returns:
|
279
|
+
attention_output: Multi-headed output of attention computation of
|
280
|
+
shape `[B, S, num_heads, key_dim]`.
|
281
|
+
"""
|
282
|
+
content_attention = ops.einsum(
|
283
|
+
self._dot_product_equation, key, query + content_attention_bias
|
284
|
+
)
|
285
|
+
positional_attention = ops.einsum(
|
286
|
+
self._dot_product_equation,
|
287
|
+
position,
|
288
|
+
query + positional_attention_bias,
|
289
|
+
)
|
290
|
+
positional_attention = _rel_shift(
|
291
|
+
positional_attention, klen=ops.shape(content_attention)[3]
|
292
|
+
)
|
293
|
+
|
294
|
+
if segment_matrix is not None:
|
295
|
+
segment_attention = ops.einsum(
|
296
|
+
"bind,snd->bnis",
|
297
|
+
query + segment_attention_bias,
|
298
|
+
segment_encoding,
|
299
|
+
)
|
300
|
+
target_shape = ops.shape(positional_attention)
|
301
|
+
segment_attention = ops.where(
|
302
|
+
ops.broadcast_to(
|
303
|
+
ops.expand_dims(segment_matrix, 1), target_shape
|
304
|
+
),
|
305
|
+
ops.broadcast_to(segment_attention[:, :, :, 1:], target_shape),
|
306
|
+
ops.broadcast_to(segment_attention[:, :, :, :1], target_shape),
|
307
|
+
)
|
308
|
+
attention_sum = (
|
309
|
+
content_attention + positional_attention + segment_attention
|
310
|
+
)
|
311
|
+
else:
|
312
|
+
attention_sum = content_attention + positional_attention
|
313
|
+
|
314
|
+
attention_scores = ops.multiply(
|
315
|
+
attention_sum, 1.0 / math.sqrt(float(self._key_dim))
|
316
|
+
)
|
317
|
+
|
318
|
+
attention_scores = self._masked_softmax(
|
319
|
+
attention_scores, attention_mask
|
320
|
+
)
|
321
|
+
|
322
|
+
attention_output = self._dropout_layer(attention_scores)
|
323
|
+
|
324
|
+
attention_output = ops.einsum(
|
325
|
+
self._combine_equation, attention_output, value
|
326
|
+
)
|
327
|
+
|
328
|
+
return attention_output
|
329
|
+
|
330
|
+
def call(
|
331
|
+
self,
|
332
|
+
content_stream,
|
333
|
+
content_attention_bias,
|
334
|
+
positional_attention_bias,
|
335
|
+
relative_position_encoding,
|
336
|
+
query_stream=None,
|
337
|
+
target_mapping=None,
|
338
|
+
segment_matrix=None,
|
339
|
+
segment_encoding=None,
|
340
|
+
segment_attention_bias=None,
|
341
|
+
state=None,
|
342
|
+
content_attention_mask=None,
|
343
|
+
query_attention_mask=None,
|
344
|
+
):
|
345
|
+
"""Compute multi-head relative attention over inputs.
|
346
|
+
|
347
|
+
We use the notations `B`, `T`, `M`, `E` below, where
|
348
|
+
`B` is the batch dimension, `T` is the target sequence length,
|
349
|
+
`M` is the length of the state or memory and `E` is the last
|
350
|
+
dimension of query input.
|
351
|
+
|
352
|
+
Args:
|
353
|
+
content_stream: The content representation, commonly referred to as h.
|
354
|
+
This serves a similar role to the standard hidden states in
|
355
|
+
Transformer-XL.
|
356
|
+
content_attention_bias: A trainable bias parameter added to the query
|
357
|
+
head when calculating the content-based attention score.
|
358
|
+
positional_attention_bias: A trainable bias parameter added to the
|
359
|
+
query head when calculating the position-based attention score.
|
360
|
+
query_stream: The query representation, commonly referred to as g.
|
361
|
+
This only has access to contextual information and position, but
|
362
|
+
not content. If not provided, then this is
|
363
|
+
MultiHeadRelativeAttention with self-attention.
|
364
|
+
relative_position_encoding: relative positional encoding for key
|
365
|
+
and value.
|
366
|
+
target_mapping: Optional `Tensor` representing the target mapping
|
367
|
+
used in partial prediction.
|
368
|
+
segment_matrix: Optional `Tensor` representing segmentation IDs
|
369
|
+
used in XLNet.
|
370
|
+
segment_encoding: Optional `Tensor` representing the segmentation
|
371
|
+
encoding as used in XLNet.
|
372
|
+
segment_attention_bias: Optional trainable bias parameter added
|
373
|
+
to the query head when calculating the segment-based attention
|
374
|
+
score.
|
375
|
+
state: (default None) optional state. If passed, this is also
|
376
|
+
attended over as in TransformerXL and XLNet.
|
377
|
+
content_attention_mask: (default None) Optional mask that is added
|
378
|
+
to content attention logits. If state is not None, the mask
|
379
|
+
source sequence dimension should extend M.
|
380
|
+
query_attention_mask: (default None) Optional mask that is added to
|
381
|
+
query attention logits. If state is not None, the mask source
|
382
|
+
sequence dimension should extend M.
|
383
|
+
|
384
|
+
Returns:
|
385
|
+
content_attention_output, query_attention_output: the results of the
|
386
|
+
computation, both of shape `[B, T, E]`.
|
387
|
+
"""
|
388
|
+
|
389
|
+
if state is not None and len(state.shape) > 1:
|
390
|
+
content_and_memory_stream = ops.concatenate(
|
391
|
+
[state, content_stream], 1
|
392
|
+
)
|
393
|
+
else:
|
394
|
+
content_and_memory_stream = content_stream
|
395
|
+
|
396
|
+
# `query` = [B, T, N, H]
|
397
|
+
query = self._query_dense(content_stream)
|
398
|
+
|
399
|
+
# `key` = [B, S + M, N, H]
|
400
|
+
key = self._key_dense(content_and_memory_stream)
|
401
|
+
|
402
|
+
# `value` = [B, S + M, N, H]
|
403
|
+
value = self._value_dense(content_and_memory_stream)
|
404
|
+
|
405
|
+
# `position` = [B, L, N, H]
|
406
|
+
position = self._encoding_dense(relative_position_encoding)
|
407
|
+
|
408
|
+
content_attention_output = self.compute_attention(
|
409
|
+
query=query,
|
410
|
+
key=key,
|
411
|
+
value=value,
|
412
|
+
position=position,
|
413
|
+
content_attention_bias=content_attention_bias,
|
414
|
+
positional_attention_bias=positional_attention_bias,
|
415
|
+
segment_matrix=segment_matrix,
|
416
|
+
segment_encoding=segment_encoding,
|
417
|
+
segment_attention_bias=segment_attention_bias,
|
418
|
+
attention_mask=content_attention_mask,
|
419
|
+
)
|
420
|
+
|
421
|
+
# `content_attention_output` = [B, S, N, H]
|
422
|
+
content_attention_output = self._output_dense(content_attention_output)
|
423
|
+
|
424
|
+
query_attention_output = None
|
425
|
+
if query_stream is not None:
|
426
|
+
query = self._query_dense(query_stream)
|
427
|
+
if target_mapping is not None:
|
428
|
+
query = ops.einsum("bmnd,bml->blnd", query, target_mapping)
|
429
|
+
query_attention_output = self.compute_attention(
|
430
|
+
query=query,
|
431
|
+
key=key,
|
432
|
+
value=value,
|
433
|
+
position=position,
|
434
|
+
content_attention_bias=content_attention_bias,
|
435
|
+
positional_attention_bias=positional_attention_bias,
|
436
|
+
segment_matrix=segment_matrix,
|
437
|
+
segment_encoding=segment_encoding,
|
438
|
+
segment_attention_bias=segment_attention_bias,
|
439
|
+
attention_mask=query_attention_mask,
|
440
|
+
)
|
441
|
+
query_attention_output = ops.einsum(
|
442
|
+
"blnd,bml->bmnd", query_attention_output, target_mapping
|
443
|
+
)
|
444
|
+
else:
|
445
|
+
query_attention_output = self.compute_attention(
|
446
|
+
query=query,
|
447
|
+
key=key,
|
448
|
+
value=value,
|
449
|
+
position=position,
|
450
|
+
content_attention_bias=content_attention_bias,
|
451
|
+
positional_attention_bias=positional_attention_bias,
|
452
|
+
segment_matrix=segment_matrix,
|
453
|
+
segment_encoding=segment_encoding,
|
454
|
+
segment_attention_bias=segment_attention_bias,
|
455
|
+
attention_mask=query_attention_mask,
|
456
|
+
)
|
457
|
+
query_attention_output = self._output_dense(query_attention_output)
|
458
|
+
|
459
|
+
return content_attention_output, query_attention_output
|
@@ -0,0 +1,222 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
|
17
|
+
from keras_hub.src.api_export import keras_hub_export
|
18
|
+
from keras_hub.src.models.backbone import Backbone
|
19
|
+
from keras_hub.src.models.xlnet.xlnet_content_and_query_embedding import (
|
20
|
+
ContentAndQueryEmbedding,
|
21
|
+
)
|
22
|
+
from keras_hub.src.models.xlnet.xlnet_encoder import XLNetAttentionMaskLayer
|
23
|
+
from keras_hub.src.models.xlnet.xlnet_encoder import XLNetEncoder
|
24
|
+
from keras_hub.src.models.xlnet.xlnet_encoder import XLNetSegmentMatrixLayer
|
25
|
+
|
26
|
+
|
27
|
+
@keras_hub_export("keras_hub.models.XLNetBackbone")
|
28
|
+
class XLNetBackbone(Backbone):
|
29
|
+
"""XLNet encoder network.
|
30
|
+
|
31
|
+
This class implements a XLNet Transformer.
|
32
|
+
|
33
|
+
The default constructor gives a fully customizable, randomly initialized
|
34
|
+
XLNet encoder with any number of layers, heads, and embedding dimensions.
|
35
|
+
To load preset architectures and weights, use the `from_preset` constructor.
|
36
|
+
|
37
|
+
Disclaimer: Pre-trained models are provided on an "as is" basis, without
|
38
|
+
warranties or conditions of any kind.
|
39
|
+
|
40
|
+
Attributes:
|
41
|
+
vocabulary_size: int. The size of the token vocabulary.
|
42
|
+
num_layers: int. The number of transformer encoder layers.
|
43
|
+
num_heads: int, the number of heads in the
|
44
|
+
`keras.layers.TwoStreamRelativeAttention` layer.
|
45
|
+
hidden_dim: int, the size hidden states.
|
46
|
+
intermediate_dim: int, the hidden size of feedforward network.
|
47
|
+
dropout: float, defaults to 0.0 the dropout value, shared by
|
48
|
+
`keras.layers.TwoStreamRelativeAttention` and feedforward network.
|
49
|
+
activation: string or `keras.activations`, defaults to "gelu". the
|
50
|
+
activation function of feedforward network.
|
51
|
+
kernel_initializer_range: int, defaults to 0.02. The kernel initializer
|
52
|
+
range for the dense and relative attention layers.
|
53
|
+
bias_initializer: string or `keras.initializers` initializer,
|
54
|
+
defaults to "zeros". The bias initializer for
|
55
|
+
the dense and multiheaded relative attention layers.
|
56
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
57
|
+
for model computations and weights. Note that some computations,
|
58
|
+
such as softmax and layer normalization, will always be done at
|
59
|
+
float32 precision regardless of dtype.
|
60
|
+
|
61
|
+
Call arguments:
|
62
|
+
token_ids: Indices of input sequence tokens in the vocabulary of shape
|
63
|
+
`[batch_size, sequence_length]`.
|
64
|
+
segment_ids: Segment token indices to indicate first and second portions
|
65
|
+
of the inputs of shape `[batch_size, sequence_length]`.
|
66
|
+
padding_mask: Mask to avoid performing attention on padding token indices
|
67
|
+
of shape `[batch_size, sequence_length]`.
|
68
|
+
|
69
|
+
Example:
|
70
|
+
```python
|
71
|
+
import numpy as np
|
72
|
+
from keras_hub.src.models import XLNetBackbone
|
73
|
+
|
74
|
+
input_data = {
|
75
|
+
"token_ids": np.array(
|
76
|
+
[460, 5272, 1758, 4905, 9, 4, 3], shape=(1, 7),
|
77
|
+
),
|
78
|
+
"segment_ids": np.array(
|
79
|
+
[0, 0, 0, 0, 0, 0, 2], shape=(1, 7),
|
80
|
+
),
|
81
|
+
"padding_mask": np.array(
|
82
|
+
[1, 1, 1, 1, 1, 1, 1], shape=(1, 7)
|
83
|
+
),
|
84
|
+
}
|
85
|
+
|
86
|
+
# Randomly initialized XLNet encoder with a custom config
|
87
|
+
model = keras_hub.models.XLNetBackbone(
|
88
|
+
vocabulary_size=32000,
|
89
|
+
num_layers=12,
|
90
|
+
num_heads=12,
|
91
|
+
hidden_dim=768,
|
92
|
+
intermediate_dim=3072,
|
93
|
+
)
|
94
|
+
output = model(input_data)
|
95
|
+
```
|
96
|
+
"""
|
97
|
+
|
98
|
+
def __init__(
|
99
|
+
self,
|
100
|
+
vocabulary_size,
|
101
|
+
num_layers,
|
102
|
+
num_heads,
|
103
|
+
hidden_dim,
|
104
|
+
intermediate_dim,
|
105
|
+
dropout=0.0,
|
106
|
+
activation="gelu",
|
107
|
+
kernel_initializer_range=0.02,
|
108
|
+
bias_initializer="zeros",
|
109
|
+
dtype=None,
|
110
|
+
**kwargs,
|
111
|
+
):
|
112
|
+
# === Layers ===
|
113
|
+
self.content_query_embedding = ContentAndQueryEmbedding(
|
114
|
+
vocabulary_size=vocabulary_size,
|
115
|
+
hidden_dim=hidden_dim,
|
116
|
+
dropout=dropout,
|
117
|
+
dtype=dtype,
|
118
|
+
name="content_query_embedding",
|
119
|
+
)
|
120
|
+
self.attn_mask_layer = XLNetAttentionMaskLayer(
|
121
|
+
hidden_dim=hidden_dim,
|
122
|
+
kernel_initializer_range=kernel_initializer_range,
|
123
|
+
dtype=dtype,
|
124
|
+
name="encoder_block_attn_mask_layer",
|
125
|
+
)
|
126
|
+
self.seg_mat_layer = XLNetSegmentMatrixLayer(
|
127
|
+
dtype=dtype,
|
128
|
+
name="encoder_block_seg_mat_layer",
|
129
|
+
)
|
130
|
+
head_dim = hidden_dim // num_heads
|
131
|
+
self.transformer_layers = []
|
132
|
+
for i in range(num_layers):
|
133
|
+
layer = XLNetEncoder(
|
134
|
+
num_heads=num_heads,
|
135
|
+
hidden_dim=hidden_dim,
|
136
|
+
head_dim=head_dim,
|
137
|
+
intermediate_dim=intermediate_dim,
|
138
|
+
dropout=dropout,
|
139
|
+
activation=activation,
|
140
|
+
layer_norm_epsilon=1e-12,
|
141
|
+
kernel_initializer_range=kernel_initializer_range,
|
142
|
+
bias_initializer=bias_initializer,
|
143
|
+
dtype=dtype,
|
144
|
+
name=f"xlnet_encoder_{i}",
|
145
|
+
)
|
146
|
+
self.transformer_layers.append(layer)
|
147
|
+
self.dropout = keras.layers.Dropout(
|
148
|
+
dropout,
|
149
|
+
dtype=dtype,
|
150
|
+
name="dropout",
|
151
|
+
)
|
152
|
+
|
153
|
+
# === Functional Model ===
|
154
|
+
token_id_input = keras.Input(
|
155
|
+
shape=(None,), dtype="int32", name="token_ids"
|
156
|
+
)
|
157
|
+
padding_mask_input = keras.Input(
|
158
|
+
shape=(None,), dtype="int32", name="padding_mask"
|
159
|
+
)
|
160
|
+
segment_id_input = keras.Input(
|
161
|
+
shape=(None,), dtype="int32", name="segment_ids"
|
162
|
+
)
|
163
|
+
# Content and Query Embedding
|
164
|
+
word_emb, pos_emb = self.content_query_embedding(token_id_input)
|
165
|
+
# Apply XLNetAttentionMaskLayer and XLNetSegmentMatrixLayer Layers
|
166
|
+
# to get the processed attention masks and segment matrix.
|
167
|
+
attn_mask_content, attn_mask_query = self.attn_mask_layer(
|
168
|
+
padding_mask_input
|
169
|
+
)
|
170
|
+
seg_mat = self.seg_mat_layer(segment_id_input)
|
171
|
+
output_content = word_emb
|
172
|
+
for transformer_layer in self.transformer_layers:
|
173
|
+
output_content, output_query = transformer_layer(
|
174
|
+
output_content=output_content,
|
175
|
+
attn_mask_content=attn_mask_content,
|
176
|
+
attn_mask_query=attn_mask_query,
|
177
|
+
pos_emb=pos_emb,
|
178
|
+
seg_mat=seg_mat,
|
179
|
+
)
|
180
|
+
output = self.dropout(output_content)
|
181
|
+
super().__init__(
|
182
|
+
inputs={
|
183
|
+
"token_ids": token_id_input,
|
184
|
+
"padding_mask": padding_mask_input,
|
185
|
+
"segment_ids": segment_id_input,
|
186
|
+
},
|
187
|
+
outputs=output,
|
188
|
+
dtype=dtype,
|
189
|
+
**kwargs,
|
190
|
+
)
|
191
|
+
|
192
|
+
# === Config ===
|
193
|
+
self.vocabulary_size = vocabulary_size
|
194
|
+
self.num_layers = num_layers
|
195
|
+
self.num_heads = num_heads
|
196
|
+
self.hidden_dim = hidden_dim
|
197
|
+
self.intermediate_dim = intermediate_dim
|
198
|
+
self.dropout = dropout
|
199
|
+
self.activation = activation
|
200
|
+
self.kernel_initializer_range = kernel_initializer_range
|
201
|
+
self.bias_initializer = bias_initializer
|
202
|
+
|
203
|
+
def get_config(self):
|
204
|
+
config = super().get_config()
|
205
|
+
config.update(
|
206
|
+
{
|
207
|
+
"vocabulary_size": self.vocabulary_size,
|
208
|
+
"num_layers": self.num_layers,
|
209
|
+
"num_heads": self.num_heads,
|
210
|
+
"hidden_dim": self.hidden_dim,
|
211
|
+
"intermediate_dim": self.intermediate_dim,
|
212
|
+
"dropout": self.dropout,
|
213
|
+
"activation": self.activation,
|
214
|
+
"kernel_initializer_range": self.kernel_initializer_range,
|
215
|
+
"bias_initializer": self.bias_initializer,
|
216
|
+
}
|
217
|
+
)
|
218
|
+
return config
|
219
|
+
|
220
|
+
@property
|
221
|
+
def token_embedding(self):
|
222
|
+
return self.get_layer("content_query_embedding").word_embed
|