keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,311 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import os
|
16
|
+
|
17
|
+
import keras
|
18
|
+
|
19
|
+
from keras_hub.src.api_export import keras_hub_export
|
20
|
+
from keras_hub.src.utils.keras_utils import assert_quantization_support
|
21
|
+
from keras_hub.src.utils.preset_utils import CONFIG_FILE
|
22
|
+
from keras_hub.src.utils.preset_utils import MODEL_WEIGHTS_FILE
|
23
|
+
from keras_hub.src.utils.preset_utils import check_config_class
|
24
|
+
from keras_hub.src.utils.preset_utils import check_format
|
25
|
+
from keras_hub.src.utils.preset_utils import get_file
|
26
|
+
from keras_hub.src.utils.preset_utils import jax_memory_cleanup
|
27
|
+
from keras_hub.src.utils.preset_utils import list_presets
|
28
|
+
from keras_hub.src.utils.preset_utils import list_subclasses
|
29
|
+
from keras_hub.src.utils.preset_utils import load_serialized_object
|
30
|
+
from keras_hub.src.utils.preset_utils import save_metadata
|
31
|
+
from keras_hub.src.utils.preset_utils import save_serialized_object
|
32
|
+
from keras_hub.src.utils.python_utils import classproperty
|
33
|
+
from keras_hub.src.utils.timm.convert import load_timm_backbone
|
34
|
+
from keras_hub.src.utils.transformers.convert import load_transformers_backbone
|
35
|
+
|
36
|
+
|
37
|
+
@keras_hub_export("keras_hub.models.Backbone")
|
38
|
+
class Backbone(keras.Model):
|
39
|
+
"""Base class for all `Backbone` models.
|
40
|
+
|
41
|
+
A `Backbone` is the basic architecture for a given NLP model. Unlike a
|
42
|
+
`keras_hub.models.Task`, a `Backbone` is not tailored to any specific loss
|
43
|
+
function and training setup. A `Backbone` generally outputs the last hidden
|
44
|
+
states of an architecture before any output predictions.
|
45
|
+
|
46
|
+
A `Backbone` can be used in one of two ways:
|
47
|
+
|
48
|
+
1. Through a `Task` class, which will wrap and extend a `Backbone` so it
|
49
|
+
can be used with high level Keras functions like `fit()`, `predict()` or
|
50
|
+
`evaluate()`. `Task` classes are built with a particular training
|
51
|
+
objective in mind (e.g. classification or language modeling).
|
52
|
+
2. Directly, by extending underlying functional model with additional
|
53
|
+
outputs and training setup. This is the most flexible approach, and can
|
54
|
+
allow for any outputs, loss, or custom training loop.
|
55
|
+
|
56
|
+
All backbones include a `from_preset()` constructor which can be used to
|
57
|
+
load a pre-trained config and weights.
|
58
|
+
|
59
|
+
Example:
|
60
|
+
```python
|
61
|
+
# Load a BERT backbone with pre-trained weights.
|
62
|
+
backbone = keras_hub.models.Backbone.from_preset(
|
63
|
+
"bert_base_en",
|
64
|
+
)
|
65
|
+
# Load a GPT2 backbone with pre-trained weights at bfloat16 precision.
|
66
|
+
backbone = keras_hub.models.Backbone.from_preset(
|
67
|
+
"gpt2_base_en",
|
68
|
+
dtype="bfloat16",
|
69
|
+
trainable=False,
|
70
|
+
)
|
71
|
+
```
|
72
|
+
"""
|
73
|
+
|
74
|
+
def __init__(self, *args, dtype=None, **kwargs):
|
75
|
+
super().__init__(*args, **kwargs)
|
76
|
+
self._functional_layer_ids = set(
|
77
|
+
id(layer) for layer in self._flatten_layers()
|
78
|
+
)
|
79
|
+
self._initialized = True
|
80
|
+
if dtype is not None:
|
81
|
+
try:
|
82
|
+
self.dtype_policy = keras.dtype_policies.get(dtype)
|
83
|
+
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
|
84
|
+
except AttributeError:
|
85
|
+
if isinstance(dtype, keras.DTypePolicy):
|
86
|
+
dtype = dtype.name
|
87
|
+
self.dtype_policy = keras.DTypePolicy(dtype)
|
88
|
+
|
89
|
+
def __setattr__(self, name, value):
|
90
|
+
# Work around setattr issues for Keras 2 and Keras 3 torch backend.
|
91
|
+
# Since all our state is covered by functional model we can route
|
92
|
+
# around custom setattr calls.
|
93
|
+
is_property = isinstance(getattr(type(self), name, None), property)
|
94
|
+
is_unitialized = not hasattr(self, "_initialized")
|
95
|
+
simple_setattr = keras.config.backend() == "torch"
|
96
|
+
if simple_setattr and (is_property or is_unitialized):
|
97
|
+
return object.__setattr__(self, name, value)
|
98
|
+
return super().__setattr__(name, value)
|
99
|
+
|
100
|
+
@property
|
101
|
+
def token_embedding(self):
|
102
|
+
"""A `keras.layers.Embedding` instance for embedding token ids.
|
103
|
+
|
104
|
+
This layer embeds integer token ids to the hidden dim of the model.
|
105
|
+
"""
|
106
|
+
return getattr(self, "_token_embedding", None)
|
107
|
+
|
108
|
+
@token_embedding.setter
|
109
|
+
def token_embedding(self, value):
|
110
|
+
self._token_embedding = value
|
111
|
+
|
112
|
+
def quantize(self, mode, **kwargs):
|
113
|
+
assert_quantization_support()
|
114
|
+
return super().quantize(mode, **kwargs)
|
115
|
+
|
116
|
+
def get_config(self):
|
117
|
+
# Don't chain to super here. `get_config()` for functional models is
|
118
|
+
# a nested layer config and cannot be passed to Backbone constructors.
|
119
|
+
config = {
|
120
|
+
"name": self.name,
|
121
|
+
"trainable": self.trainable,
|
122
|
+
}
|
123
|
+
|
124
|
+
# Add quantization support by utilizing `DTypePolicyMap`
|
125
|
+
try:
|
126
|
+
if isinstance(
|
127
|
+
self.dtype_policy, keras.dtype_policies.DTypePolicyMap
|
128
|
+
):
|
129
|
+
config.update({"dtype": self.dtype_policy})
|
130
|
+
else:
|
131
|
+
policy_map = keras.dtype_policies.DTypePolicyMap()
|
132
|
+
for layer in self._flatten_layers():
|
133
|
+
if layer.quantization_mode is not None:
|
134
|
+
policy_map[layer.path] = layer.dtype_policy
|
135
|
+
if len(policy_map) > 0:
|
136
|
+
config.update({"dtype": policy_map})
|
137
|
+
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
|
138
|
+
except AttributeError:
|
139
|
+
pass
|
140
|
+
return config
|
141
|
+
|
142
|
+
@classmethod
|
143
|
+
def from_config(cls, config):
|
144
|
+
# The default `from_config()` for functional models will return a
|
145
|
+
# vanilla `keras.Model`. We override it to get a subclass instance back.
|
146
|
+
return cls(**config)
|
147
|
+
|
148
|
+
@classproperty
|
149
|
+
def presets(cls):
|
150
|
+
"""List built-in presets for a `Task` subclass."""
|
151
|
+
presets = list_presets(cls)
|
152
|
+
for subclass in list_subclasses(cls):
|
153
|
+
presets.update(subclass.presets)
|
154
|
+
return presets
|
155
|
+
|
156
|
+
@classmethod
|
157
|
+
def from_preset(
|
158
|
+
cls,
|
159
|
+
preset,
|
160
|
+
load_weights=True,
|
161
|
+
**kwargs,
|
162
|
+
):
|
163
|
+
"""Instantiate a `keras_hub.models.Backbone` from a model preset.
|
164
|
+
|
165
|
+
A preset is a directory of configs, weights and other file assets used
|
166
|
+
to save and load a pre-trained model. The `preset` can be passed as a
|
167
|
+
one of:
|
168
|
+
|
169
|
+
1. a built in preset identifier like `'bert_base_en'`
|
170
|
+
2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'`
|
171
|
+
3. a Hugging Face handle like `'hf://user/bert_base_en'`
|
172
|
+
4. a path to a local preset directory like `'./bert_base_en'`
|
173
|
+
|
174
|
+
This constructor can be called in one of two ways. Either from the base
|
175
|
+
class like `keras_hub.models.Backbone.from_preset()`, or from
|
176
|
+
a model class like `keras_hub.models.GemmaBackbone.from_preset()`.
|
177
|
+
If calling from the base class, the subclass of the returning object
|
178
|
+
will be inferred from the config in the preset directory.
|
179
|
+
|
180
|
+
For any `Backbone` subclass, you can run `cls.presets.keys()` to list
|
181
|
+
all built-in presets available on the class.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
preset: string. A built in preset identifier, a Kaggle Models
|
185
|
+
handle, a Hugging Face handle, or a path to a local directory.
|
186
|
+
load_weights: bool. If `True`, the weights will be loaded into the
|
187
|
+
model architecture. If `False`, the weights will be randomly
|
188
|
+
initialized.
|
189
|
+
|
190
|
+
Examples:
|
191
|
+
```python
|
192
|
+
# Load a Gemma backbone with pre-trained weights.
|
193
|
+
model = keras_hub.models.Backbone.from_preset(
|
194
|
+
"gemma_2b_en",
|
195
|
+
)
|
196
|
+
|
197
|
+
# Load a Bert backbone with a pre-trained config and random weights.
|
198
|
+
model = keras_hub.models.Backbone.from_preset(
|
199
|
+
"bert_base_en",
|
200
|
+
load_weights=False,
|
201
|
+
)
|
202
|
+
```
|
203
|
+
"""
|
204
|
+
format = check_format(preset)
|
205
|
+
|
206
|
+
if format == "transformers":
|
207
|
+
return load_transformers_backbone(cls, preset, load_weights)
|
208
|
+
elif format == "timm":
|
209
|
+
return load_timm_backbone(cls, preset, load_weights, **kwargs)
|
210
|
+
|
211
|
+
preset_cls = check_config_class(preset)
|
212
|
+
if not issubclass(preset_cls, cls):
|
213
|
+
raise ValueError(
|
214
|
+
f"Preset has type `{preset_cls.__name__}` which is not a "
|
215
|
+
f"a subclass of calling class `{cls.__name__}`. Call "
|
216
|
+
f"`from_preset` directly on `{preset_cls.__name__}` instead."
|
217
|
+
)
|
218
|
+
|
219
|
+
backbone = load_serialized_object(preset, CONFIG_FILE, **kwargs)
|
220
|
+
if load_weights:
|
221
|
+
jax_memory_cleanup(backbone)
|
222
|
+
backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
|
223
|
+
|
224
|
+
return backbone
|
225
|
+
|
226
|
+
def save_to_preset(self, preset_dir):
|
227
|
+
"""Save backbone to a preset directory.
|
228
|
+
|
229
|
+
Args:
|
230
|
+
preset_dir: The path to the local model preset directory.
|
231
|
+
"""
|
232
|
+
save_serialized_object(self, preset_dir, config_file=CONFIG_FILE)
|
233
|
+
self.save_weights(os.path.join(preset_dir, MODEL_WEIGHTS_FILE))
|
234
|
+
save_metadata(self, preset_dir)
|
235
|
+
|
236
|
+
def enable_lora(self, rank):
|
237
|
+
"""Enable Lora on the backbone.
|
238
|
+
|
239
|
+
Calling this method will freeze all weights on the backbone,
|
240
|
+
while enabling Lora on the query & value `EinsumDense` layers
|
241
|
+
of the attention layers.
|
242
|
+
"""
|
243
|
+
target_names = ["query_dense", "value_dense", "query", "value"]
|
244
|
+
self.trainable = True
|
245
|
+
self._lora_enabled_layers = []
|
246
|
+
self._lora_rank = rank
|
247
|
+
for layer in self._flatten_layers(include_self=False):
|
248
|
+
layer.trainable = False
|
249
|
+
all_layers = self._flatten_layers(include_self=False)
|
250
|
+
all_layers = [lyr for lyr in all_layers if lyr.weights]
|
251
|
+
for i, layer in enumerate(all_layers):
|
252
|
+
for name in target_names:
|
253
|
+
if layer.name == name:
|
254
|
+
if hasattr(layer, "enable_lora"):
|
255
|
+
layer.trainable = True
|
256
|
+
layer.enable_lora(rank)
|
257
|
+
self._lora_enabled_layers.append(i)
|
258
|
+
|
259
|
+
def save_lora_weights(self, filepath):
|
260
|
+
if not getattr(self, "_lora_enabled_layers", []):
|
261
|
+
raise ValueError(
|
262
|
+
"There are no lora-enabled layers in this model. "
|
263
|
+
"Make sure to call `.enable_lora(rank)` first."
|
264
|
+
)
|
265
|
+
if not str(filepath).endswith(".lora.h5"):
|
266
|
+
raise ValueError(
|
267
|
+
"The filename must end in `.lora.h5`. "
|
268
|
+
f"Received: filepath={filepath}"
|
269
|
+
)
|
270
|
+
|
271
|
+
store = keras.src.saving.saving_lib.H5IOStore(filepath, mode="w")
|
272
|
+
lora_store = store.make("lora")
|
273
|
+
lora_store["rank"] = self._lora_rank
|
274
|
+
# We cannot identify layers by name since names are non-unique,
|
275
|
+
# so we identify them by index in the topologically sorted list
|
276
|
+
# of layers that have weights.
|
277
|
+
all_layers = self._flatten_layers(include_self=False)
|
278
|
+
all_layers = [lyr for lyr in all_layers if lyr.weights]
|
279
|
+
for layer_index in self._lora_enabled_layers:
|
280
|
+
# We only lora the einsumdense layers,
|
281
|
+
# so the factored weights are always named `kernel`
|
282
|
+
layer = all_layers[layer_index]
|
283
|
+
inner_store = store.make(f"lora/{layer_index}")
|
284
|
+
inner_store["lora_kernel_a"] = layer.lora_kernel_a
|
285
|
+
inner_store["lora_kernel_b"] = layer.lora_kernel_b
|
286
|
+
store.close()
|
287
|
+
|
288
|
+
def load_lora_weights(self, filepath):
|
289
|
+
store = keras.src.saving.saving_lib.H5IOStore(filepath, mode="r")
|
290
|
+
lora_store = store.get("lora")
|
291
|
+
rank = int(lora_store["rank"][()])
|
292
|
+
|
293
|
+
if not getattr(self, "_lora_enabled_layers", []):
|
294
|
+
self.enable_lora(rank)
|
295
|
+
else:
|
296
|
+
if self._lora_rank != rank:
|
297
|
+
raise ValueError(
|
298
|
+
f"The Lora rank expected by file '{filepath}' "
|
299
|
+
f"is rank={rank}, but the model was called with "
|
300
|
+
f"`.enable_lora(rank={self._lora_rank})`. "
|
301
|
+
"Both ranks must match."
|
302
|
+
)
|
303
|
+
all_layers = self._flatten_layers(include_self=False)
|
304
|
+
all_layers = [lyr for lyr in all_layers if lyr.weights]
|
305
|
+
for layer_index in self._lora_enabled_layers:
|
306
|
+
layer = all_layers[layer_index]
|
307
|
+
lora_kernel_a = store.get(f"lora/{layer_index}")["lora_kernel_a"]
|
308
|
+
lora_kernel_b = store.get(f"lora/{layer_index}")["lora_kernel_b"]
|
309
|
+
layer.lora_kernel_a.assign(lora_kernel_a)
|
310
|
+
layer.lora_kernel_b.assign(lora_kernel_b)
|
311
|
+
store.close()
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from keras_hub.src.models.bart.bart_backbone import BartBackbone
|
16
|
+
from keras_hub.src.models.bart.bart_presets import backbone_presets
|
17
|
+
from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer
|
18
|
+
from keras_hub.src.utils.preset_utils import register_presets
|
19
|
+
|
20
|
+
register_presets(backbone_presets, (BartBackbone, BartTokenizer))
|
@@ -0,0 +1,261 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
|
17
|
+
from keras_hub.src.api_export import keras_hub_export
|
18
|
+
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
19
|
+
from keras_hub.src.layers.modeling.reversible_embedding import (
|
20
|
+
ReversibleEmbedding,
|
21
|
+
)
|
22
|
+
from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder
|
23
|
+
from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
|
24
|
+
from keras_hub.src.models.backbone import Backbone
|
25
|
+
|
26
|
+
|
27
|
+
def bart_kernel_initializer(stddev=0.02):
|
28
|
+
return keras.initializers.TruncatedNormal(stddev=stddev)
|
29
|
+
|
30
|
+
|
31
|
+
@keras_hub_export("keras_hub.models.BartBackbone")
|
32
|
+
class BartBackbone(Backbone):
|
33
|
+
"""BART encoder-decoder network.
|
34
|
+
|
35
|
+
This class implements a Transformer-based encoder-decoder model as
|
36
|
+
described in
|
37
|
+
["BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension"](https://arxiv.org/abs/1910.13461).
|
38
|
+
|
39
|
+
The default constructor gives a fully customizable, randomly initialized BART
|
40
|
+
model with any number of layers, heads, and embedding dimensions. To load
|
41
|
+
preset architectures and weights, use the `from_preset` constructor.
|
42
|
+
|
43
|
+
Disclaimer: Pre-trained models are provided on an "as is" basis, without
|
44
|
+
warranties or conditions of any kind. The underlying model is provided by a
|
45
|
+
third party and subject to a separate license, available
|
46
|
+
[here](https://github.com/facebookresearch/fairseq/).
|
47
|
+
|
48
|
+
Args:
|
49
|
+
vocabulary_size: int. The size of the token vocabulary.
|
50
|
+
num_layers: int. The number of transformer encoder layers and
|
51
|
+
transformer decoder layers.
|
52
|
+
num_heads: int. The number of attention heads for each transformer.
|
53
|
+
The hidden size must be divisible by the number of attention heads.
|
54
|
+
hidden_dim: int. The size of the transformer encoding and pooler layers.
|
55
|
+
intermediate_dim: int. The output dimension of the first Dense layer in
|
56
|
+
a two-layer feedforward network for each transformer.
|
57
|
+
dropout: float. Dropout probability for the Transformer encoder.
|
58
|
+
max_sequence_length: int. The maximum sequence length that this encoder
|
59
|
+
can consume. If None, `max_sequence_length` uses the value from
|
60
|
+
sequence length. This determines the variable shape for positional
|
61
|
+
embeddings.
|
62
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
63
|
+
for model computations and weights. Note that some computations,
|
64
|
+
such as softmax and layer normalization, will always be done at
|
65
|
+
float32 precision regardless of dtype.
|
66
|
+
|
67
|
+
Examples:
|
68
|
+
```python
|
69
|
+
input_data = {
|
70
|
+
"encoder_token_ids": np.ones(shape=(1, 12), dtype="int32"),
|
71
|
+
"encoder_padding_mask": np.array(
|
72
|
+
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]
|
73
|
+
),
|
74
|
+
"decoder_token_ids": np.ones(shape=(1, 12), dtype="int32"),
|
75
|
+
"decoder_padding_mask": np.array(
|
76
|
+
[[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]
|
77
|
+
),
|
78
|
+
}
|
79
|
+
|
80
|
+
# Pretrained BART encoder.
|
81
|
+
model = keras_hub.models.BartBackbone.from_preset("bart_base_en")
|
82
|
+
model(input_data)
|
83
|
+
|
84
|
+
# Randomly initialized BART encoder-decoder model with a custom config
|
85
|
+
model = keras_hub.models.BartBackbone(
|
86
|
+
vocabulary_size=50265,
|
87
|
+
num_layers=6,
|
88
|
+
num_heads=12,
|
89
|
+
hidden_dim=768,
|
90
|
+
intermediate_dim=3072,
|
91
|
+
max_sequence_length=12,
|
92
|
+
)
|
93
|
+
output = model(input_data)
|
94
|
+
```
|
95
|
+
"""
|
96
|
+
|
97
|
+
def __init__(
|
98
|
+
self,
|
99
|
+
vocabulary_size,
|
100
|
+
num_layers,
|
101
|
+
num_heads,
|
102
|
+
hidden_dim,
|
103
|
+
intermediate_dim,
|
104
|
+
dropout=0.1,
|
105
|
+
max_sequence_length=1024,
|
106
|
+
dtype=None,
|
107
|
+
**kwargs,
|
108
|
+
):
|
109
|
+
# === Layers ===
|
110
|
+
self.token_embedding = ReversibleEmbedding(
|
111
|
+
input_dim=vocabulary_size,
|
112
|
+
output_dim=hidden_dim,
|
113
|
+
embeddings_initializer=bart_kernel_initializer(),
|
114
|
+
dtype=dtype,
|
115
|
+
name="token_embedding",
|
116
|
+
)
|
117
|
+
self.encoder_position_embedding = PositionEmbedding(
|
118
|
+
initializer=bart_kernel_initializer(),
|
119
|
+
sequence_length=max_sequence_length,
|
120
|
+
dtype=dtype,
|
121
|
+
name="encoder_position_embedding",
|
122
|
+
)
|
123
|
+
self.encoder_embeddings_add = keras.layers.Add(
|
124
|
+
dtype=dtype,
|
125
|
+
name="encoder_embeddings_add",
|
126
|
+
)
|
127
|
+
self.encoder_embeddings_layer_norm = keras.layers.LayerNormalization(
|
128
|
+
axis=-1,
|
129
|
+
epsilon=1e-5,
|
130
|
+
dtype=dtype,
|
131
|
+
name="encoder_embeddings_layer_norm",
|
132
|
+
)
|
133
|
+
self.encoder_embeddings_dropout = keras.layers.Dropout(
|
134
|
+
dropout,
|
135
|
+
dtype=dtype,
|
136
|
+
name="encoder_embeddings_dropout",
|
137
|
+
)
|
138
|
+
self.encoder_transformer_layers = []
|
139
|
+
for i in range(num_layers):
|
140
|
+
layer = TransformerEncoder(
|
141
|
+
num_heads=num_heads,
|
142
|
+
intermediate_dim=intermediate_dim,
|
143
|
+
activation=keras.activations.gelu,
|
144
|
+
dropout=dropout,
|
145
|
+
layer_norm_epsilon=1e-5,
|
146
|
+
kernel_initializer=bart_kernel_initializer(),
|
147
|
+
dtype=dtype,
|
148
|
+
name=f"transformer_encoder_layer_{i}",
|
149
|
+
)
|
150
|
+
self.encoder_transformer_layers.append(layer)
|
151
|
+
self.decoder_position_embedding = PositionEmbedding(
|
152
|
+
initializer=bart_kernel_initializer(),
|
153
|
+
sequence_length=max_sequence_length,
|
154
|
+
dtype=dtype,
|
155
|
+
name="decoder_position_embedding",
|
156
|
+
)
|
157
|
+
self.decoder_embeddings_add = keras.layers.Add(
|
158
|
+
dtype=dtype,
|
159
|
+
name="decoder_embeddings_add",
|
160
|
+
)
|
161
|
+
self.decoder_embeddings_layer_norm = keras.layers.LayerNormalization(
|
162
|
+
axis=-1,
|
163
|
+
epsilon=1e-5,
|
164
|
+
dtype=dtype,
|
165
|
+
name="decoder_embeddings_layer_norm",
|
166
|
+
)
|
167
|
+
self.decoder_embeddings_dropout = keras.layers.Dropout(
|
168
|
+
dropout,
|
169
|
+
dtype=dtype,
|
170
|
+
name="decoder_embeddings_dropout",
|
171
|
+
)
|
172
|
+
self.decoder_transformer_layers = []
|
173
|
+
for i in range(num_layers):
|
174
|
+
layer = TransformerDecoder(
|
175
|
+
intermediate_dim=intermediate_dim,
|
176
|
+
num_heads=num_heads,
|
177
|
+
dropout=dropout,
|
178
|
+
activation=keras.activations.gelu,
|
179
|
+
layer_norm_epsilon=1e-5,
|
180
|
+
kernel_initializer=bart_kernel_initializer(),
|
181
|
+
dtype=dtype,
|
182
|
+
name=f"transformer_decoder_layer_{i}",
|
183
|
+
)
|
184
|
+
self.decoder_transformer_layers.append(layer)
|
185
|
+
|
186
|
+
# === Functional Model ===
|
187
|
+
encoder_token_id_input = keras.Input(
|
188
|
+
shape=(None,), dtype="int32", name="encoder_token_ids"
|
189
|
+
)
|
190
|
+
encoder_padding_mask_input = keras.Input(
|
191
|
+
shape=(None,), dtype="int32", name="encoder_padding_mask"
|
192
|
+
)
|
193
|
+
decoder_token_id_input = keras.Input(
|
194
|
+
shape=(None,), dtype="int32", name="decoder_token_ids"
|
195
|
+
)
|
196
|
+
decoder_padding_mask_input = keras.Input(
|
197
|
+
shape=(None,), dtype="int32", name="decoder_padding_mask"
|
198
|
+
)
|
199
|
+
# Encoder.
|
200
|
+
tokens = self.token_embedding(encoder_token_id_input)
|
201
|
+
positions = self.encoder_position_embedding(tokens)
|
202
|
+
x = self.encoder_embeddings_add((tokens, positions))
|
203
|
+
x = self.encoder_embeddings_layer_norm(x)
|
204
|
+
x = self.encoder_embeddings_dropout(x)
|
205
|
+
for transformer_layer in self.encoder_transformer_layers:
|
206
|
+
x = transformer_layer(x, padding_mask=encoder_padding_mask_input)
|
207
|
+
encoder_output = x
|
208
|
+
# Decoder.
|
209
|
+
tokens = self.token_embedding(decoder_token_id_input)
|
210
|
+
positions = self.decoder_position_embedding(tokens)
|
211
|
+
x = self.decoder_embeddings_add((tokens, positions))
|
212
|
+
x = self.decoder_embeddings_layer_norm(x)
|
213
|
+
x = self.decoder_embeddings_dropout(x)
|
214
|
+
for transformer_layer in self.decoder_transformer_layers:
|
215
|
+
x = transformer_layer(
|
216
|
+
decoder_sequence=x,
|
217
|
+
encoder_sequence=encoder_output,
|
218
|
+
decoder_padding_mask=decoder_padding_mask_input,
|
219
|
+
encoder_padding_mask=encoder_padding_mask_input,
|
220
|
+
)
|
221
|
+
decoder_output = x
|
222
|
+
# Instantiate using Functional API Model constructor
|
223
|
+
super().__init__(
|
224
|
+
inputs={
|
225
|
+
"encoder_token_ids": encoder_token_id_input,
|
226
|
+
"encoder_padding_mask": encoder_padding_mask_input,
|
227
|
+
"decoder_token_ids": decoder_token_id_input,
|
228
|
+
"decoder_padding_mask": decoder_padding_mask_input,
|
229
|
+
},
|
230
|
+
outputs={
|
231
|
+
"encoder_sequence_output": encoder_output,
|
232
|
+
"decoder_sequence_output": decoder_output,
|
233
|
+
},
|
234
|
+
dtype=dtype,
|
235
|
+
**kwargs,
|
236
|
+
)
|
237
|
+
|
238
|
+
# === Config ===
|
239
|
+
self.vocabulary_size = vocabulary_size
|
240
|
+
self.num_layers = num_layers
|
241
|
+
self.num_heads = num_heads
|
242
|
+
self.hidden_dim = hidden_dim
|
243
|
+
self.intermediate_dim = intermediate_dim
|
244
|
+
self.dropout = dropout
|
245
|
+
self.max_sequence_length = max_sequence_length
|
246
|
+
|
247
|
+
def get_config(self):
|
248
|
+
config = super().get_config()
|
249
|
+
config.update(
|
250
|
+
{
|
251
|
+
"vocabulary_size": self.vocabulary_size,
|
252
|
+
"num_layers": self.num_layers,
|
253
|
+
"num_heads": self.num_heads,
|
254
|
+
"hidden_dim": self.hidden_dim,
|
255
|
+
"intermediate_dim": self.intermediate_dim,
|
256
|
+
"dropout": self.dropout,
|
257
|
+
"max_sequence_length": self.max_sequence_length,
|
258
|
+
}
|
259
|
+
)
|
260
|
+
|
261
|
+
return config
|