keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,186 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import numpy as np
|
15
|
+
|
16
|
+
from keras_hub.src.utils.preset_utils import HF_CONFIG_FILE
|
17
|
+
from keras_hub.src.utils.preset_utils import get_file
|
18
|
+
from keras_hub.src.utils.preset_utils import jax_memory_cleanup
|
19
|
+
from keras_hub.src.utils.preset_utils import load_config
|
20
|
+
from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
|
21
|
+
|
22
|
+
|
23
|
+
def convert_backbone_config(transformers_config):
|
24
|
+
return {
|
25
|
+
"vocabulary_size": transformers_config["vocab_size"],
|
26
|
+
"num_layers": transformers_config["n_layer"],
|
27
|
+
"num_heads": transformers_config["n_head"],
|
28
|
+
"hidden_dim": transformers_config["n_embd"],
|
29
|
+
"intermediate_dim": transformers_config["n_embd"] * 4,
|
30
|
+
"dropout": transformers_config["resid_pdrop"],
|
31
|
+
"max_sequence_length": transformers_config["n_positions"],
|
32
|
+
}
|
33
|
+
|
34
|
+
|
35
|
+
def convert_weights(backbone, loader, transformers_config):
|
36
|
+
# Embeddings
|
37
|
+
loader.port_weight(
|
38
|
+
keras_variable=backbone.token_embedding.embeddings,
|
39
|
+
hf_weight_key="wte.weight",
|
40
|
+
)
|
41
|
+
loader.port_weight(
|
42
|
+
keras_variable=backbone.position_embedding.position_embeddings,
|
43
|
+
hf_weight_key="wpe.weight",
|
44
|
+
)
|
45
|
+
|
46
|
+
# Attention blocks
|
47
|
+
for index in range(backbone.num_layers):
|
48
|
+
decoder_layer = backbone.transformer_layers[index]
|
49
|
+
|
50
|
+
# Norm layers
|
51
|
+
loader.port_weight(
|
52
|
+
keras_variable=decoder_layer._self_attention_layer_norm.gamma,
|
53
|
+
hf_weight_key=f"h.{index}.ln_1.weight",
|
54
|
+
)
|
55
|
+
loader.port_weight(
|
56
|
+
keras_variable=decoder_layer._self_attention_layer_norm.beta,
|
57
|
+
hf_weight_key=f"h.{index}.ln_1.bias",
|
58
|
+
)
|
59
|
+
loader.port_weight(
|
60
|
+
keras_variable=decoder_layer._feedforward_layer_norm.gamma,
|
61
|
+
hf_weight_key=f"h.{index}.ln_2.weight",
|
62
|
+
)
|
63
|
+
loader.port_weight(
|
64
|
+
keras_variable=decoder_layer._feedforward_layer_norm.beta,
|
65
|
+
hf_weight_key=f"h.{index}.ln_2.bias",
|
66
|
+
)
|
67
|
+
|
68
|
+
# Attention layers
|
69
|
+
n_embd = transformers_config["n_embd"]
|
70
|
+
|
71
|
+
# Query
|
72
|
+
loader.port_weight(
|
73
|
+
keras_variable=decoder_layer._self_attention_layer.query_dense.kernel,
|
74
|
+
hf_weight_key=f"h.{index}.attn.c_attn.weight",
|
75
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
76
|
+
hf_tensor[:, :n_embd], keras_shape
|
77
|
+
),
|
78
|
+
)
|
79
|
+
loader.port_weight(
|
80
|
+
keras_variable=decoder_layer._self_attention_layer.query_dense.bias,
|
81
|
+
hf_weight_key=f"h.{index}.attn.c_attn.bias",
|
82
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
83
|
+
hf_tensor[:n_embd], keras_shape
|
84
|
+
),
|
85
|
+
)
|
86
|
+
|
87
|
+
# Key
|
88
|
+
loader.port_weight(
|
89
|
+
keras_variable=decoder_layer._self_attention_layer.key_dense.kernel,
|
90
|
+
hf_weight_key=f"h.{index}.attn.c_attn.weight",
|
91
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
92
|
+
hf_tensor[:, n_embd : 2 * n_embd], keras_shape
|
93
|
+
),
|
94
|
+
)
|
95
|
+
loader.port_weight(
|
96
|
+
keras_variable=decoder_layer._self_attention_layer.key_dense.bias,
|
97
|
+
hf_weight_key=f"h.{index}.attn.c_attn.bias",
|
98
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
99
|
+
hf_tensor[n_embd : 2 * n_embd], keras_shape
|
100
|
+
),
|
101
|
+
)
|
102
|
+
|
103
|
+
# Value
|
104
|
+
loader.port_weight(
|
105
|
+
keras_variable=decoder_layer._self_attention_layer.value_dense.kernel,
|
106
|
+
hf_weight_key=f"h.{index}.attn.c_attn.weight",
|
107
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
108
|
+
hf_tensor[:, 2 * n_embd :], keras_shape
|
109
|
+
),
|
110
|
+
)
|
111
|
+
loader.port_weight(
|
112
|
+
keras_variable=decoder_layer._self_attention_layer.value_dense.bias,
|
113
|
+
hf_weight_key=f"h.{index}.attn.c_attn.bias",
|
114
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
115
|
+
hf_tensor[2 * n_embd :], keras_shape
|
116
|
+
),
|
117
|
+
)
|
118
|
+
|
119
|
+
# Output
|
120
|
+
loader.port_weight(
|
121
|
+
keras_variable=decoder_layer._self_attention_layer.output_dense.kernel,
|
122
|
+
hf_weight_key=f"h.{index}.attn.c_proj.weight",
|
123
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
124
|
+
hf_tensor, keras_shape
|
125
|
+
),
|
126
|
+
)
|
127
|
+
loader.port_weight(
|
128
|
+
keras_variable=decoder_layer._self_attention_layer.output_dense.bias,
|
129
|
+
hf_weight_key=f"h.{index}.attn.c_proj.bias",
|
130
|
+
)
|
131
|
+
|
132
|
+
# MLP layers
|
133
|
+
loader.port_weight(
|
134
|
+
keras_variable=decoder_layer._feedforward_intermediate_dense.kernel,
|
135
|
+
hf_weight_key=f"h.{index}.mlp.c_fc.weight",
|
136
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
137
|
+
hf_tensor, keras_shape
|
138
|
+
),
|
139
|
+
)
|
140
|
+
loader.port_weight(
|
141
|
+
keras_variable=decoder_layer._feedforward_intermediate_dense.bias,
|
142
|
+
hf_weight_key=f"h.{index}.mlp.c_fc.bias",
|
143
|
+
)
|
144
|
+
loader.port_weight(
|
145
|
+
keras_variable=decoder_layer._feedforward_output_dense.kernel,
|
146
|
+
hf_weight_key=f"h.{index}.mlp.c_proj.weight",
|
147
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
148
|
+
hf_tensor, keras_shape
|
149
|
+
),
|
150
|
+
)
|
151
|
+
loader.port_weight(
|
152
|
+
keras_variable=decoder_layer._feedforward_output_dense.bias,
|
153
|
+
hf_weight_key=f"h.{index}.mlp.c_proj.bias",
|
154
|
+
)
|
155
|
+
|
156
|
+
# Normalization
|
157
|
+
loader.port_weight(
|
158
|
+
keras_variable=backbone.layer_norm.gamma,
|
159
|
+
hf_weight_key="ln_f.weight",
|
160
|
+
)
|
161
|
+
loader.port_weight(
|
162
|
+
keras_variable=backbone.layer_norm.beta,
|
163
|
+
hf_weight_key="ln_f.bias",
|
164
|
+
)
|
165
|
+
|
166
|
+
return backbone
|
167
|
+
|
168
|
+
|
169
|
+
def load_gpt2_backbone(cls, preset, load_weights):
|
170
|
+
transformers_config = load_config(preset, HF_CONFIG_FILE)
|
171
|
+
keras_config = convert_backbone_config(transformers_config)
|
172
|
+
backbone = cls(**keras_config)
|
173
|
+
if load_weights:
|
174
|
+
jax_memory_cleanup(backbone)
|
175
|
+
with SafetensorLoader(preset) as loader:
|
176
|
+
convert_weights(backbone, loader, transformers_config)
|
177
|
+
return backbone
|
178
|
+
|
179
|
+
|
180
|
+
def load_gpt2_tokenizer(cls, preset):
|
181
|
+
vocab_file = get_file(preset, "vocab.json")
|
182
|
+
merges_file = get_file(preset, "merges.txt")
|
183
|
+
return cls(
|
184
|
+
vocabulary=vocab_file,
|
185
|
+
merges=merges_file,
|
186
|
+
)
|
@@ -0,0 +1,136 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import numpy as np
|
15
|
+
|
16
|
+
from keras_hub.src.utils.preset_utils import HF_CONFIG_FILE
|
17
|
+
from keras_hub.src.utils.preset_utils import jax_memory_cleanup
|
18
|
+
from keras_hub.src.utils.preset_utils import load_config
|
19
|
+
from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
|
20
|
+
|
21
|
+
|
22
|
+
def convert_backbone_config(transformers_config):
|
23
|
+
return {
|
24
|
+
"vocabulary_size": transformers_config["vocab_size"],
|
25
|
+
"num_layers": transformers_config["num_hidden_layers"],
|
26
|
+
"num_query_heads": transformers_config["num_attention_heads"],
|
27
|
+
"hidden_dim": transformers_config["hidden_size"],
|
28
|
+
"intermediate_dim": transformers_config["intermediate_size"],
|
29
|
+
"num_key_value_heads": transformers_config["num_key_value_heads"],
|
30
|
+
}
|
31
|
+
|
32
|
+
|
33
|
+
def convert_weights(backbone, loader, transformers_config):
|
34
|
+
loader.port_weight(
|
35
|
+
keras_variable=backbone.get_layer("token_embedding").embeddings,
|
36
|
+
hf_weight_key="model.embed_tokens.weight",
|
37
|
+
)
|
38
|
+
loader.port_weight(
|
39
|
+
keras_variable=backbone.get_layer("token_embedding").reverse_embeddings,
|
40
|
+
hf_weight_key="lm_head.weight",
|
41
|
+
# rearrange_pattern="b a -> a b",
|
42
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
43
|
+
)
|
44
|
+
|
45
|
+
def transpose_and_reshape(x, shape):
|
46
|
+
return np.reshape(np.transpose(x), shape)
|
47
|
+
|
48
|
+
# Attention blocks
|
49
|
+
for i in range(backbone.num_layers):
|
50
|
+
decoder_layer = backbone.get_layer(f"transformer_layer_{i}")
|
51
|
+
# Norm layers
|
52
|
+
loader.port_weight(
|
53
|
+
keras_variable=decoder_layer._self_attention_layernorm.scale,
|
54
|
+
hf_weight_key=f"model.layers.{i}.input_layernorm.weight",
|
55
|
+
)
|
56
|
+
loader.port_weight(
|
57
|
+
keras_variable=decoder_layer._feedforward_layernorm.scale,
|
58
|
+
hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight",
|
59
|
+
)
|
60
|
+
|
61
|
+
# Attention layers
|
62
|
+
loader.port_weight(
|
63
|
+
keras_variable=decoder_layer._self_attention_layer._query_dense.kernel,
|
64
|
+
hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight",
|
65
|
+
hook_fn=transpose_and_reshape,
|
66
|
+
)
|
67
|
+
loader.port_weight(
|
68
|
+
keras_variable=decoder_layer._self_attention_layer._key_dense.kernel,
|
69
|
+
hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight",
|
70
|
+
hook_fn=transpose_and_reshape,
|
71
|
+
)
|
72
|
+
loader.port_weight(
|
73
|
+
keras_variable=decoder_layer._self_attention_layer._value_dense.kernel,
|
74
|
+
hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight",
|
75
|
+
hook_fn=transpose_and_reshape,
|
76
|
+
)
|
77
|
+
loader.port_weight(
|
78
|
+
keras_variable=decoder_layer._self_attention_layer._output_dense.kernel,
|
79
|
+
hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight",
|
80
|
+
# rearrange_patterns="c (a b) -> a b c",
|
81
|
+
# rearrange_dims={"a": backbone.num_query_heads},
|
82
|
+
hook_fn=transpose_and_reshape,
|
83
|
+
)
|
84
|
+
|
85
|
+
# MLP layers
|
86
|
+
loader.port_weight(
|
87
|
+
keras_variable=decoder_layer._feedforward_gate_dense.kernel,
|
88
|
+
hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight",
|
89
|
+
# rearrange_patterns="b a -> a b",
|
90
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
91
|
+
)
|
92
|
+
loader.port_weight(
|
93
|
+
keras_variable=decoder_layer._feedforward_intermediate_dense.kernel,
|
94
|
+
hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight",
|
95
|
+
# rearrange_patterns="b a -> a b",
|
96
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
97
|
+
)
|
98
|
+
loader.port_weight(
|
99
|
+
keras_variable=decoder_layer._feedforward_output_dense.kernel,
|
100
|
+
hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight",
|
101
|
+
# rearrange_patterns="b a -> a b",
|
102
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
103
|
+
)
|
104
|
+
|
105
|
+
# Final normalization layer
|
106
|
+
loader.port_weight(
|
107
|
+
keras_variable=backbone.get_layer("sequence_output_layernorm").scale,
|
108
|
+
hf_weight_key="model.norm.weight",
|
109
|
+
)
|
110
|
+
|
111
|
+
return backbone
|
112
|
+
|
113
|
+
|
114
|
+
def load_llama3_backbone(cls, preset, load_weights):
|
115
|
+
transformers_config = load_config(preset, HF_CONFIG_FILE)
|
116
|
+
keras_config = convert_backbone_config(transformers_config)
|
117
|
+
backbone = cls(**keras_config)
|
118
|
+
if load_weights:
|
119
|
+
jax_memory_cleanup(backbone)
|
120
|
+
with SafetensorLoader(preset) as loader:
|
121
|
+
convert_weights(backbone, loader, transformers_config)
|
122
|
+
return backbone
|
123
|
+
|
124
|
+
|
125
|
+
def load_llama3_tokenizer(cls, preset):
|
126
|
+
tokenizer_config = load_config(preset, "tokenizer.json")
|
127
|
+
vocab = tokenizer_config["model"]["vocab"]
|
128
|
+
merges = tokenizer_config["model"]["merges"]
|
129
|
+
|
130
|
+
bot = tokenizer_config["added_tokens"][0] # begin of text
|
131
|
+
eot = tokenizer_config["added_tokens"][1] # end of text
|
132
|
+
|
133
|
+
vocab[bot["content"]] = bot["id"]
|
134
|
+
vocab[eot["content"]] = eot["id"]
|
135
|
+
|
136
|
+
return cls(vocabulary=vocab, merges=merges)
|
@@ -0,0 +1,303 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import numpy as np
|
15
|
+
|
16
|
+
from keras_hub.src.utils.preset_utils import HF_CONFIG_FILE
|
17
|
+
from keras_hub.src.utils.preset_utils import get_file
|
18
|
+
from keras_hub.src.utils.preset_utils import jax_memory_cleanup
|
19
|
+
from keras_hub.src.utils.preset_utils import load_config
|
20
|
+
from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
|
21
|
+
|
22
|
+
|
23
|
+
def convert_backbone_config(transformers_config):
|
24
|
+
text_config = transformers_config["text_config"]
|
25
|
+
vision_config = transformers_config["vision_config"]
|
26
|
+
return {
|
27
|
+
"vocabulary_size": transformers_config["image_token_index"],
|
28
|
+
"image_size": (
|
29
|
+
vision_config["image_size"]
|
30
|
+
if "image_size" in vision_config.keys()
|
31
|
+
else 224
|
32
|
+
),
|
33
|
+
"num_layers": text_config["num_hidden_layers"],
|
34
|
+
"num_query_heads": text_config["num_attention_heads"],
|
35
|
+
"num_key_value_heads": text_config["num_key_value_heads"],
|
36
|
+
"hidden_dim": text_config["hidden_size"],
|
37
|
+
"intermediate_dim": text_config["intermediate_size"] * 2,
|
38
|
+
"head_dim": text_config["num_image_tokens"],
|
39
|
+
"vit_patch_size": vision_config["patch_size"],
|
40
|
+
"vit_num_heads": vision_config["num_attention_heads"],
|
41
|
+
"vit_hidden_dim": vision_config["hidden_size"],
|
42
|
+
"vit_num_layers": vision_config["num_hidden_layers"],
|
43
|
+
"vit_intermediate_dim": vision_config["intermediate_size"],
|
44
|
+
}
|
45
|
+
|
46
|
+
|
47
|
+
def convert_weights(backbone, loader, transformers_config):
|
48
|
+
############################################################################
|
49
|
+
# Image Tower
|
50
|
+
############################################################################
|
51
|
+
image_encoder = backbone.vit_encoder.get_layer("image_encoder")
|
52
|
+
|
53
|
+
# Embedding
|
54
|
+
loader.port_weight(
|
55
|
+
keras_variable=image_encoder.vision_embeddings.patch_embedding.bias,
|
56
|
+
hf_weight_key="vision_tower.vision_model.embeddings.patch_embedding.bias",
|
57
|
+
)
|
58
|
+
|
59
|
+
loader.port_weight(
|
60
|
+
keras_variable=image_encoder.vision_embeddings.patch_embedding.kernel,
|
61
|
+
hf_weight_key="vision_tower.vision_model.embeddings.patch_embedding.weight",
|
62
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(2, 3, 1, 0)),
|
63
|
+
)
|
64
|
+
|
65
|
+
# Positional Embedding
|
66
|
+
loader.port_weight(
|
67
|
+
keras_variable=image_encoder.vision_embeddings.position_embedding.embeddings,
|
68
|
+
hf_weight_key="vision_tower.vision_model.embeddings.position_embedding.weight",
|
69
|
+
)
|
70
|
+
|
71
|
+
# Normalization
|
72
|
+
loader.port_weight(
|
73
|
+
keras_variable=image_encoder.encoder_layer_norm.gamma,
|
74
|
+
hf_weight_key="vision_tower.vision_model.post_layernorm.weight",
|
75
|
+
)
|
76
|
+
|
77
|
+
loader.port_weight(
|
78
|
+
keras_variable=image_encoder.encoder_layer_norm.beta,
|
79
|
+
hf_weight_key="vision_tower.vision_model.post_layernorm.bias",
|
80
|
+
)
|
81
|
+
|
82
|
+
# ResBlocks
|
83
|
+
for index in range(image_encoder.num_layers):
|
84
|
+
block = image_encoder.resblocks[index]
|
85
|
+
|
86
|
+
loader.port_weight(
|
87
|
+
keras_variable=block.layer_norm_1.beta,
|
88
|
+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.layer_norm1.bias",
|
89
|
+
)
|
90
|
+
|
91
|
+
loader.port_weight(
|
92
|
+
keras_variable=block.layer_norm_1.gamma,
|
93
|
+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.layer_norm1.weight",
|
94
|
+
)
|
95
|
+
|
96
|
+
loader.port_weight(
|
97
|
+
keras_variable=block.layer_norm_2.beta,
|
98
|
+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.layer_norm2.bias",
|
99
|
+
)
|
100
|
+
|
101
|
+
loader.port_weight(
|
102
|
+
keras_variable=block.layer_norm_2.gamma,
|
103
|
+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.layer_norm2.weight",
|
104
|
+
)
|
105
|
+
|
106
|
+
loader.port_weight(
|
107
|
+
keras_variable=block.mlp_dense_1.kernel,
|
108
|
+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.mlp.fc1.weight",
|
109
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
110
|
+
)
|
111
|
+
|
112
|
+
loader.port_weight(
|
113
|
+
keras_variable=block.mlp_dense_1.bias,
|
114
|
+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.mlp.fc1.bias",
|
115
|
+
)
|
116
|
+
|
117
|
+
loader.port_weight(
|
118
|
+
keras_variable=block.mlp_dense_2.kernel,
|
119
|
+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.mlp.fc2.weight",
|
120
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
121
|
+
)
|
122
|
+
|
123
|
+
loader.port_weight(
|
124
|
+
keras_variable=block.mlp_dense_2.bias,
|
125
|
+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.mlp.fc2.bias",
|
126
|
+
)
|
127
|
+
|
128
|
+
loader.port_weight(
|
129
|
+
keras_variable=block.attn.key_proj.bias,
|
130
|
+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.k_proj.bias",
|
131
|
+
)
|
132
|
+
|
133
|
+
loader.port_weight(
|
134
|
+
keras_variable=block.attn.key_proj.kernel,
|
135
|
+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.k_proj.weight",
|
136
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
137
|
+
)
|
138
|
+
|
139
|
+
loader.port_weight(
|
140
|
+
keras_variable=block.attn.out_proj.bias,
|
141
|
+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.out_proj.bias",
|
142
|
+
)
|
143
|
+
|
144
|
+
loader.port_weight(
|
145
|
+
keras_variable=block.attn.out_proj.kernel,
|
146
|
+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.out_proj.weight",
|
147
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
148
|
+
)
|
149
|
+
|
150
|
+
loader.port_weight(
|
151
|
+
keras_variable=block.attn.query_proj.bias,
|
152
|
+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.q_proj.bias",
|
153
|
+
)
|
154
|
+
|
155
|
+
loader.port_weight(
|
156
|
+
keras_variable=block.attn.query_proj.kernel,
|
157
|
+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.q_proj.weight",
|
158
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
159
|
+
)
|
160
|
+
|
161
|
+
loader.port_weight(
|
162
|
+
keras_variable=block.attn.value_proj.bias,
|
163
|
+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.v_proj.bias",
|
164
|
+
)
|
165
|
+
|
166
|
+
loader.port_weight(
|
167
|
+
keras_variable=block.attn.value_proj.kernel,
|
168
|
+
hf_weight_key=f"vision_tower.vision_model.encoder.layers.{index}.self_attn.v_proj.weight",
|
169
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
170
|
+
)
|
171
|
+
|
172
|
+
# Multi Modal Projection
|
173
|
+
loader.port_weight(
|
174
|
+
keras_variable=backbone.vit_encoder.get_layer(
|
175
|
+
"image_classifier"
|
176
|
+
).kernel,
|
177
|
+
hf_weight_key="multi_modal_projector.linear.weight",
|
178
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
179
|
+
)
|
180
|
+
|
181
|
+
loader.port_weight(
|
182
|
+
keras_variable=backbone.vit_encoder.get_layer("image_classifier").bias,
|
183
|
+
hf_weight_key="multi_modal_projector.linear.bias",
|
184
|
+
)
|
185
|
+
|
186
|
+
############################################################################
|
187
|
+
# Language Tower
|
188
|
+
############################################################################
|
189
|
+
for index in range(backbone.num_layers):
|
190
|
+
decoder_layer = backbone.transformer_layers[index]
|
191
|
+
|
192
|
+
# Norm layers
|
193
|
+
loader.port_weight(
|
194
|
+
keras_variable=decoder_layer.pre_attention_norm.scale,
|
195
|
+
hf_weight_key=f"language_model.model.layers.{index}.input_layernorm.weight",
|
196
|
+
)
|
197
|
+
loader.port_weight(
|
198
|
+
keras_variable=decoder_layer.pre_ffw_norm.scale,
|
199
|
+
hf_weight_key=f"language_model.model.layers.{index}.post_attention_layernorm.weight",
|
200
|
+
)
|
201
|
+
|
202
|
+
# Attention layers
|
203
|
+
loader.port_weight(
|
204
|
+
keras_variable=decoder_layer.attention.query_dense.kernel,
|
205
|
+
hf_weight_key=f"language_model.model.layers.{index}.self_attn.q_proj.weight",
|
206
|
+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
|
207
|
+
np.reshape(
|
208
|
+
hf_tensor,
|
209
|
+
(keras_shape[0], keras_shape[2], keras_shape[1]),
|
210
|
+
),
|
211
|
+
axes=(0, 2, 1),
|
212
|
+
),
|
213
|
+
)
|
214
|
+
loader.port_weight(
|
215
|
+
keras_variable=decoder_layer.attention.key_dense.kernel,
|
216
|
+
hf_weight_key=f"language_model.model.layers.{index}.self_attn.k_proj.weight",
|
217
|
+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
|
218
|
+
np.reshape(
|
219
|
+
hf_tensor,
|
220
|
+
(keras_shape[0], keras_shape[2], keras_shape[1]),
|
221
|
+
),
|
222
|
+
axes=(0, 2, 1),
|
223
|
+
),
|
224
|
+
)
|
225
|
+
loader.port_weight(
|
226
|
+
keras_variable=decoder_layer.attention.value_dense.kernel,
|
227
|
+
hf_weight_key=f"language_model.model.layers.{index}.self_attn.v_proj.weight",
|
228
|
+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
|
229
|
+
np.reshape(
|
230
|
+
hf_tensor,
|
231
|
+
(keras_shape[0], keras_shape[2], keras_shape[1]),
|
232
|
+
),
|
233
|
+
axes=(0, 2, 1),
|
234
|
+
),
|
235
|
+
)
|
236
|
+
loader.port_weight(
|
237
|
+
keras_variable=decoder_layer.attention.output_dense.kernel,
|
238
|
+
hf_weight_key=f"language_model.model.layers.{index}.self_attn.o_proj.weight",
|
239
|
+
hook_fn=lambda hf_tensor, keras_shape: np.transpose(
|
240
|
+
np.reshape(
|
241
|
+
hf_tensor,
|
242
|
+
(keras_shape[2], keras_shape[0], keras_shape[1]),
|
243
|
+
),
|
244
|
+
axes=(1, 2, 0),
|
245
|
+
),
|
246
|
+
)
|
247
|
+
|
248
|
+
# MLP layers
|
249
|
+
loader.port_weight(
|
250
|
+
keras_variable=decoder_layer.gating_ffw.kernel,
|
251
|
+
hf_weight_key=f"language_model.model.layers.{index}.mlp.gate_proj.weight",
|
252
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
253
|
+
)
|
254
|
+
loader.port_weight(
|
255
|
+
keras_variable=decoder_layer.gating_ffw_2.kernel,
|
256
|
+
hf_weight_key=f"language_model.model.layers.{index}.mlp.up_proj.weight",
|
257
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
258
|
+
)
|
259
|
+
loader.port_weight(
|
260
|
+
keras_variable=decoder_layer.ffw_linear.kernel,
|
261
|
+
hf_weight_key=f"language_model.model.layers.{index}.mlp.down_proj.weight",
|
262
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
263
|
+
)
|
264
|
+
|
265
|
+
# Normalization
|
266
|
+
loader.port_weight(
|
267
|
+
keras_variable=backbone.layer_norm.scale,
|
268
|
+
hf_weight_key="language_model.model.norm.weight",
|
269
|
+
)
|
270
|
+
|
271
|
+
# Embedding
|
272
|
+
loader.port_weight(
|
273
|
+
keras_variable=backbone.token_embedding.embeddings,
|
274
|
+
hf_weight_key="language_model.model.embed_tokens.weight",
|
275
|
+
hook_fn=lambda hf_tensor, keras_shape: hf_tensor[: keras_shape[0]],
|
276
|
+
)
|
277
|
+
|
278
|
+
return backbone
|
279
|
+
|
280
|
+
|
281
|
+
def load_pali_gemma_backbone(cls, preset, load_weights):
|
282
|
+
transformers_config = load_config(preset, HF_CONFIG_FILE)
|
283
|
+
keras_config = convert_backbone_config(transformers_config)
|
284
|
+
backbone = cls(**keras_config)
|
285
|
+
if load_weights:
|
286
|
+
jax_memory_cleanup(backbone)
|
287
|
+
with SafetensorLoader(preset) as loader:
|
288
|
+
convert_weights(backbone, loader, transformers_config)
|
289
|
+
return backbone
|
290
|
+
|
291
|
+
|
292
|
+
def load_pali_gemma_tokenizer(cls, preset):
|
293
|
+
"""
|
294
|
+
Load the Gemma tokenizer.
|
295
|
+
|
296
|
+
Args:
|
297
|
+
cls (class): Tokenizer class.
|
298
|
+
preset (str): Preset configuration name.
|
299
|
+
|
300
|
+
Returns:
|
301
|
+
tokenizer: Initialized tokenizer.
|
302
|
+
"""
|
303
|
+
return cls(get_file(preset, "tokenizer.model"))
|