keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +0 -6
- keras_hub/api/__init__.py +2 -0
- keras_hub/api/bounding_box/__init__.py +36 -0
- keras_hub/api/layers/__init__.py +14 -0
- keras_hub/api/models/__init__.py +97 -48
- keras_hub/api/tokenizers/__init__.py +30 -0
- keras_hub/api/utils/__init__.py +22 -0
- keras_hub/src/api_export.py +15 -9
- keras_hub/src/bounding_box/__init__.py +13 -0
- keras_hub/src/bounding_box/converters.py +529 -0
- keras_hub/src/bounding_box/formats.py +162 -0
- keras_hub/src/bounding_box/iou.py +263 -0
- keras_hub/src/bounding_box/to_dense.py +95 -0
- keras_hub/src/bounding_box/to_ragged.py +99 -0
- keras_hub/src/bounding_box/utils.py +194 -0
- keras_hub/src/bounding_box/validate_format.py +99 -0
- keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
- keras_hub/src/layers/preprocessing/image_converter.py +130 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
- keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
- keras_hub/src/layers/preprocessing/random_swap.py +33 -31
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
- keras_hub/src/models/albert/__init__.py +1 -2
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
- keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
- keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/albert/albert_tokenizer.py +17 -36
- keras_hub/src/models/backbone.py +12 -34
- keras_hub/src/models/bart/__init__.py +1 -2
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
- keras_hub/src/models/bart/bart_tokenizer.py +12 -39
- keras_hub/src/models/bert/__init__.py +1 -5
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
- keras_hub/src/models/bert/bert_presets.py +1 -4
- keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
- keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/bert/bert_tokenizer.py +17 -35
- keras_hub/src/models/bloom/__init__.py +1 -2
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
- keras_hub/src/models/causal_lm.py +10 -29
- keras_hub/src/models/causal_lm_preprocessor.py +195 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
- keras_hub/src/models/deberta_v3/__init__.py +1 -4
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
- keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
- keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
- keras_hub/src/models/densenet/densenet_backbone.py +46 -22
- keras_hub/src/models/distil_bert/__init__.py +1 -4
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
- keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
- keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
- keras_hub/src/models/efficientnet/__init__.py +13 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
- keras_hub/src/models/efficientnet/mbconv.py +238 -0
- keras_hub/src/models/electra/__init__.py +1 -2
- keras_hub/src/models/electra/electra_tokenizer.py +17 -32
- keras_hub/src/models/f_net/__init__.py +1 -2
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
- keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
- keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
- keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
- keras_hub/src/models/falcon/__init__.py +1 -2
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
- keras_hub/src/models/gemma/__init__.py +1 -2
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
- keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
- keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
- keras_hub/src/models/gpt2/__init__.py +1 -2
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
- keras_hub/src/models/image_classifier.py +0 -5
- keras_hub/src/models/image_classifier_preprocessor.py +83 -0
- keras_hub/src/models/llama/__init__.py +1 -2
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
- keras_hub/src/models/llama/llama_tokenizer.py +12 -25
- keras_hub/src/models/llama3/__init__.py +1 -2
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
- keras_hub/src/models/masked_lm.py +0 -2
- keras_hub/src/models/masked_lm_preprocessor.py +156 -0
- keras_hub/src/models/mistral/__init__.py +1 -2
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
- keras_hub/src/models/mobilenet/__init__.py +13 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
- keras_hub/src/models/opt/__init__.py +1 -2
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
- keras_hub/src/models/opt/opt_tokenizer.py +12 -41
- keras_hub/src/models/pali_gemma/__init__.py +1 -4
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
- keras_hub/src/models/phi3/__init__.py +1 -2
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
- keras_hub/src/models/preprocessor.py +72 -83
- keras_hub/src/models/resnet/__init__.py +6 -0
- keras_hub/src/models/resnet/resnet_backbone.py +390 -42
- keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
- keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
- keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
- keras_hub/src/models/resnet/resnet_presets.py +95 -0
- keras_hub/src/models/retinanet/__init__.py +13 -0
- keras_hub/src/models/retinanet/anchor_generator.py +175 -0
- keras_hub/src/models/retinanet/box_matcher.py +259 -0
- keras_hub/src/models/retinanet/non_max_supression.py +578 -0
- keras_hub/src/models/roberta/__init__.py +1 -2
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
- keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
- keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
- keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
- keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
- keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
- keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
- keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
- keras_hub/src/models/t5/__init__.py +1 -2
- keras_hub/src/models/t5/t5_tokenizer.py +13 -23
- keras_hub/src/models/task.py +71 -116
- keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
- keras_hub/src/models/text_classifier_preprocessor.py +138 -0
- keras_hub/src/models/whisper/__init__.py +1 -2
- keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
- keras_hub/src/models/whisper/whisper_backbone.py +0 -3
- keras_hub/src/models/whisper/whisper_presets.py +10 -10
- keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
- keras_hub/src/models/xlm_roberta/__init__.py +1 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
- keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
- keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
- keras_hub/src/tests/test_case.py +46 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
- keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
- keras_hub/src/tokenizers/tokenizer.py +67 -32
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
- keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
- keras_hub/src/utils/imagenet/__init__.py +13 -0
- keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
- keras_hub/src/utils/keras_utils.py +0 -50
- keras_hub/src/utils/preset_utils.py +230 -68
- keras_hub/src/utils/tensor_utils.py +187 -69
- keras_hub/src/utils/timm/convert_resnet.py +19 -16
- keras_hub/src/utils/timm/preset_loader.py +66 -0
- keras_hub/src/utils/transformers/convert_albert.py +193 -0
- keras_hub/src/utils/transformers/convert_bart.py +373 -0
- keras_hub/src/utils/transformers/convert_bert.py +7 -17
- keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
- keras_hub/src/utils/transformers/convert_gemma.py +5 -19
- keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
- keras_hub/src/utils/transformers/convert_llama3.py +7 -18
- keras_hub/src/utils/transformers/convert_mistral.py +129 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
- keras_hub/src/utils/transformers/preset_loader.py +77 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
- keras_hub/src/version_utils.py +1 -1
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
- keras_hub/src/models/bart/bart_preprocessor.py +0 -276
- keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
- keras_hub/src/models/electra/electra_preprocessor.py +0 -154
- keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
- keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
- keras_hub/src/models/llama/llama_preprocessor.py +0 -189
- keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
- keras_hub/src/models/opt/opt_preprocessor.py +0 -188
- keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
- keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
- keras_hub/src/utils/timm/convert.py +0 -37
- keras_hub/src/utils/transformers/convert.py +0 -101
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,373 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import numpy as np
|
15
|
+
|
16
|
+
from keras_hub.src.models.bart.bart_backbone import BartBackbone
|
17
|
+
from keras_hub.src.utils.preset_utils import get_file
|
18
|
+
|
19
|
+
backbone_cls = BartBackbone
|
20
|
+
|
21
|
+
|
22
|
+
def convert_backbone_config(transformers_config):
|
23
|
+
return {
|
24
|
+
"vocabulary_size": transformers_config["vocab_size"],
|
25
|
+
"num_layers": transformers_config["num_hidden_layers"],
|
26
|
+
"num_heads": transformers_config["encoder_attention_heads"],
|
27
|
+
"hidden_dim": transformers_config["d_model"],
|
28
|
+
"intermediate_dim": transformers_config["encoder_ffn_dim"],
|
29
|
+
"dropout": transformers_config["dropout"],
|
30
|
+
"max_sequence_length": transformers_config["max_position_embeddings"],
|
31
|
+
}
|
32
|
+
|
33
|
+
|
34
|
+
def convert_weights(backbone, loader, transformers_config):
|
35
|
+
# Embeddings
|
36
|
+
loader.port_weight(
|
37
|
+
keras_variable=backbone.token_embedding.embeddings,
|
38
|
+
hf_weight_key="shared.weight",
|
39
|
+
)
|
40
|
+
loader.port_weight(
|
41
|
+
keras_variable=backbone.encoder_position_embedding.position_embeddings,
|
42
|
+
hf_weight_key="encoder.embed_positions.weight",
|
43
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
44
|
+
hf_tensor[2:, :], keras_shape
|
45
|
+
),
|
46
|
+
)
|
47
|
+
loader.port_weight(
|
48
|
+
keras_variable=backbone.decoder_position_embedding.position_embeddings,
|
49
|
+
hf_weight_key="decoder.embed_positions.weight",
|
50
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
51
|
+
hf_tensor[2:, :], keras_shape
|
52
|
+
),
|
53
|
+
)
|
54
|
+
|
55
|
+
# Encoder blocks
|
56
|
+
for index in range(backbone.num_layers):
|
57
|
+
encoder_layer = backbone.encoder_transformer_layers[index]
|
58
|
+
encoder_self_attention = encoder_layer._self_attention_layer
|
59
|
+
hf_encoder_prefix = f"encoder.layers.{index}"
|
60
|
+
|
61
|
+
# Norm layers
|
62
|
+
loader.port_weight(
|
63
|
+
keras_variable=encoder_layer._self_attention_layer_norm.gamma,
|
64
|
+
hf_weight_key=f"{hf_encoder_prefix}.self_attn_layer_norm.weight",
|
65
|
+
)
|
66
|
+
loader.port_weight(
|
67
|
+
keras_variable=encoder_layer._self_attention_layer_norm.beta,
|
68
|
+
hf_weight_key=f"{hf_encoder_prefix}.self_attn_layer_norm.bias",
|
69
|
+
)
|
70
|
+
loader.port_weight(
|
71
|
+
keras_variable=encoder_layer._feedforward_layer_norm.gamma,
|
72
|
+
hf_weight_key=f"{hf_encoder_prefix}.final_layer_norm.weight",
|
73
|
+
)
|
74
|
+
loader.port_weight(
|
75
|
+
keras_variable=encoder_layer._feedforward_layer_norm.beta,
|
76
|
+
hf_weight_key=f"{hf_encoder_prefix}.final_layer_norm.bias",
|
77
|
+
)
|
78
|
+
|
79
|
+
# Self Attention layers
|
80
|
+
# Query
|
81
|
+
loader.port_weight(
|
82
|
+
keras_variable=encoder_self_attention.query_dense.kernel,
|
83
|
+
hf_weight_key=f"{hf_encoder_prefix}.self_attn.q_proj.weight",
|
84
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
85
|
+
np.transpose(hf_tensor), keras_shape
|
86
|
+
),
|
87
|
+
)
|
88
|
+
loader.port_weight(
|
89
|
+
keras_variable=encoder_self_attention.query_dense.bias,
|
90
|
+
hf_weight_key=f"{hf_encoder_prefix}.self_attn.q_proj.bias",
|
91
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
92
|
+
np.transpose(hf_tensor), keras_shape
|
93
|
+
),
|
94
|
+
)
|
95
|
+
|
96
|
+
# Key
|
97
|
+
loader.port_weight(
|
98
|
+
keras_variable=encoder_self_attention.key_dense.kernel,
|
99
|
+
hf_weight_key=f"{hf_encoder_prefix}.self_attn.k_proj.weight",
|
100
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
101
|
+
np.transpose(hf_tensor), keras_shape
|
102
|
+
),
|
103
|
+
)
|
104
|
+
loader.port_weight(
|
105
|
+
keras_variable=encoder_self_attention.key_dense.bias,
|
106
|
+
hf_weight_key=f"{hf_encoder_prefix}.self_attn.k_proj.bias",
|
107
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
108
|
+
np.transpose(hf_tensor), keras_shape
|
109
|
+
),
|
110
|
+
)
|
111
|
+
|
112
|
+
# Value
|
113
|
+
loader.port_weight(
|
114
|
+
keras_variable=encoder_self_attention.value_dense.kernel,
|
115
|
+
hf_weight_key=f"{hf_encoder_prefix}.self_attn.v_proj.weight",
|
116
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
117
|
+
np.transpose(hf_tensor), keras_shape
|
118
|
+
),
|
119
|
+
)
|
120
|
+
loader.port_weight(
|
121
|
+
keras_variable=encoder_self_attention.value_dense.bias,
|
122
|
+
hf_weight_key=f"{hf_encoder_prefix}.self_attn.v_proj.bias",
|
123
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
124
|
+
np.transpose(hf_tensor), keras_shape
|
125
|
+
),
|
126
|
+
)
|
127
|
+
|
128
|
+
# Output
|
129
|
+
loader.port_weight(
|
130
|
+
keras_variable=encoder_self_attention.output_dense.kernel,
|
131
|
+
hf_weight_key=f"{hf_encoder_prefix}.self_attn.out_proj.weight",
|
132
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
133
|
+
np.transpose(hf_tensor), keras_shape
|
134
|
+
),
|
135
|
+
)
|
136
|
+
loader.port_weight(
|
137
|
+
keras_variable=encoder_self_attention.output_dense.bias,
|
138
|
+
hf_weight_key=f"{hf_encoder_prefix}.self_attn.out_proj.bias",
|
139
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
140
|
+
np.transpose(hf_tensor), keras_shape
|
141
|
+
),
|
142
|
+
)
|
143
|
+
|
144
|
+
# MLP layers
|
145
|
+
loader.port_weight(
|
146
|
+
keras_variable=encoder_layer._feedforward_intermediate_dense.kernel,
|
147
|
+
hf_weight_key=f"{hf_encoder_prefix}.fc1.weight",
|
148
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
149
|
+
)
|
150
|
+
loader.port_weight(
|
151
|
+
keras_variable=encoder_layer._feedforward_intermediate_dense.bias,
|
152
|
+
hf_weight_key=f"{hf_encoder_prefix}.fc1.bias",
|
153
|
+
)
|
154
|
+
loader.port_weight(
|
155
|
+
keras_variable=encoder_layer._feedforward_output_dense.kernel,
|
156
|
+
hf_weight_key=f"{hf_encoder_prefix}.fc2.weight",
|
157
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
158
|
+
)
|
159
|
+
loader.port_weight(
|
160
|
+
keras_variable=encoder_layer._feedforward_output_dense.bias,
|
161
|
+
hf_weight_key=f"{hf_encoder_prefix}.fc2.bias",
|
162
|
+
)
|
163
|
+
|
164
|
+
# Decoder blocks
|
165
|
+
for index in range(backbone.num_layers):
|
166
|
+
decoder_layer = backbone.decoder_transformer_layers[index]
|
167
|
+
decoder_self_attention = decoder_layer._self_attention_layer
|
168
|
+
decoder_cross_attention = decoder_layer._cross_attention_layer
|
169
|
+
hf_decoder_prefix = f"decoder.layers.{index}"
|
170
|
+
|
171
|
+
# Norm layers
|
172
|
+
loader.port_weight(
|
173
|
+
keras_variable=decoder_layer._self_attention_layer_norm.gamma,
|
174
|
+
hf_weight_key=f"{hf_decoder_prefix}.self_attn_layer_norm.weight",
|
175
|
+
)
|
176
|
+
loader.port_weight(
|
177
|
+
keras_variable=decoder_layer._self_attention_layer_norm.beta,
|
178
|
+
hf_weight_key=f"{hf_decoder_prefix}.self_attn_layer_norm.bias",
|
179
|
+
)
|
180
|
+
loader.port_weight(
|
181
|
+
keras_variable=decoder_layer._feedforward_layer_norm.gamma,
|
182
|
+
hf_weight_key=f"{hf_decoder_prefix}.final_layer_norm.weight",
|
183
|
+
)
|
184
|
+
loader.port_weight(
|
185
|
+
keras_variable=decoder_layer._feedforward_layer_norm.beta,
|
186
|
+
hf_weight_key=f"{hf_decoder_prefix}.final_layer_norm.bias",
|
187
|
+
)
|
188
|
+
loader.port_weight(
|
189
|
+
keras_variable=decoder_layer._cross_attention_layer_norm.gamma,
|
190
|
+
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn_layer_norm.weight",
|
191
|
+
)
|
192
|
+
loader.port_weight(
|
193
|
+
keras_variable=decoder_layer._cross_attention_layer_norm.beta,
|
194
|
+
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn_layer_norm.bias",
|
195
|
+
)
|
196
|
+
|
197
|
+
# Self Attention layers
|
198
|
+
# Query
|
199
|
+
loader.port_weight(
|
200
|
+
keras_variable=decoder_self_attention.query_dense.kernel,
|
201
|
+
hf_weight_key=f"{hf_decoder_prefix}.self_attn.q_proj.weight",
|
202
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
203
|
+
np.transpose(hf_tensor), keras_shape
|
204
|
+
),
|
205
|
+
)
|
206
|
+
loader.port_weight(
|
207
|
+
keras_variable=decoder_self_attention.query_dense.bias,
|
208
|
+
hf_weight_key=f"{hf_decoder_prefix}.self_attn.q_proj.bias",
|
209
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
210
|
+
np.transpose(hf_tensor), keras_shape
|
211
|
+
),
|
212
|
+
)
|
213
|
+
|
214
|
+
# Key
|
215
|
+
loader.port_weight(
|
216
|
+
keras_variable=decoder_self_attention.key_dense.kernel,
|
217
|
+
hf_weight_key=f"{hf_decoder_prefix}.self_attn.k_proj.weight",
|
218
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
219
|
+
np.transpose(hf_tensor), keras_shape
|
220
|
+
),
|
221
|
+
)
|
222
|
+
loader.port_weight(
|
223
|
+
keras_variable=decoder_self_attention.key_dense.bias,
|
224
|
+
hf_weight_key=f"{hf_decoder_prefix}.self_attn.k_proj.bias",
|
225
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
226
|
+
np.transpose(hf_tensor), keras_shape
|
227
|
+
),
|
228
|
+
)
|
229
|
+
|
230
|
+
# Value
|
231
|
+
loader.port_weight(
|
232
|
+
keras_variable=decoder_self_attention.value_dense.kernel,
|
233
|
+
hf_weight_key=f"{hf_decoder_prefix}.self_attn.v_proj.weight",
|
234
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
235
|
+
np.transpose(hf_tensor), keras_shape
|
236
|
+
),
|
237
|
+
)
|
238
|
+
loader.port_weight(
|
239
|
+
keras_variable=decoder_self_attention.value_dense.bias,
|
240
|
+
hf_weight_key=f"{hf_decoder_prefix}.self_attn.v_proj.bias",
|
241
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
242
|
+
np.transpose(hf_tensor), keras_shape
|
243
|
+
),
|
244
|
+
)
|
245
|
+
|
246
|
+
# Output
|
247
|
+
loader.port_weight(
|
248
|
+
keras_variable=decoder_self_attention.output_dense.kernel,
|
249
|
+
hf_weight_key=f"{hf_decoder_prefix}.self_attn.out_proj.weight",
|
250
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
251
|
+
np.transpose(hf_tensor), keras_shape
|
252
|
+
),
|
253
|
+
)
|
254
|
+
loader.port_weight(
|
255
|
+
keras_variable=decoder_self_attention.output_dense.bias,
|
256
|
+
hf_weight_key=f"{hf_decoder_prefix}.self_attn.out_proj.bias",
|
257
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
258
|
+
np.transpose(hf_tensor), keras_shape
|
259
|
+
),
|
260
|
+
)
|
261
|
+
|
262
|
+
# MLP layers
|
263
|
+
loader.port_weight(
|
264
|
+
keras_variable=decoder_layer._feedforward_intermediate_dense.kernel,
|
265
|
+
hf_weight_key=f"{hf_decoder_prefix}.fc1.weight",
|
266
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
267
|
+
)
|
268
|
+
loader.port_weight(
|
269
|
+
keras_variable=decoder_layer._feedforward_intermediate_dense.bias,
|
270
|
+
hf_weight_key=f"{hf_decoder_prefix}.fc1.bias",
|
271
|
+
)
|
272
|
+
loader.port_weight(
|
273
|
+
keras_variable=decoder_layer._feedforward_output_dense.kernel,
|
274
|
+
hf_weight_key=f"{hf_decoder_prefix}.fc2.weight",
|
275
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
276
|
+
)
|
277
|
+
loader.port_weight(
|
278
|
+
keras_variable=decoder_layer._feedforward_output_dense.bias,
|
279
|
+
hf_weight_key=f"{hf_decoder_prefix}.fc2.bias",
|
280
|
+
)
|
281
|
+
|
282
|
+
# Cross Attention Layers
|
283
|
+
# Query
|
284
|
+
loader.port_weight(
|
285
|
+
keras_variable=decoder_cross_attention.query_dense.kernel,
|
286
|
+
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.q_proj.weight",
|
287
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
288
|
+
np.transpose(hf_tensor), keras_shape
|
289
|
+
),
|
290
|
+
)
|
291
|
+
loader.port_weight(
|
292
|
+
keras_variable=decoder_cross_attention.query_dense.bias,
|
293
|
+
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.q_proj.bias",
|
294
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
295
|
+
np.transpose(hf_tensor), keras_shape
|
296
|
+
),
|
297
|
+
)
|
298
|
+
|
299
|
+
# Key
|
300
|
+
loader.port_weight(
|
301
|
+
keras_variable=decoder_cross_attention.key_dense.kernel,
|
302
|
+
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.k_proj.weight",
|
303
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
304
|
+
np.transpose(hf_tensor), keras_shape
|
305
|
+
),
|
306
|
+
)
|
307
|
+
loader.port_weight(
|
308
|
+
keras_variable=decoder_cross_attention.key_dense.bias,
|
309
|
+
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.k_proj.bias",
|
310
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
311
|
+
np.transpose(hf_tensor), keras_shape
|
312
|
+
),
|
313
|
+
)
|
314
|
+
|
315
|
+
# Value
|
316
|
+
loader.port_weight(
|
317
|
+
keras_variable=decoder_cross_attention.value_dense.kernel,
|
318
|
+
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.v_proj.weight",
|
319
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
320
|
+
np.transpose(hf_tensor), keras_shape
|
321
|
+
),
|
322
|
+
)
|
323
|
+
loader.port_weight(
|
324
|
+
keras_variable=decoder_cross_attention.value_dense.bias,
|
325
|
+
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.v_proj.bias",
|
326
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
327
|
+
np.transpose(hf_tensor), keras_shape
|
328
|
+
),
|
329
|
+
)
|
330
|
+
|
331
|
+
# Output
|
332
|
+
loader.port_weight(
|
333
|
+
keras_variable=decoder_cross_attention.output_dense.kernel,
|
334
|
+
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.out_proj.weight",
|
335
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
336
|
+
np.transpose(hf_tensor), keras_shape
|
337
|
+
),
|
338
|
+
)
|
339
|
+
loader.port_weight(
|
340
|
+
keras_variable=decoder_cross_attention.output_dense.bias,
|
341
|
+
hf_weight_key=f"{hf_decoder_prefix}.encoder_attn.out_proj.bias",
|
342
|
+
hook_fn=lambda hf_tensor, keras_shape: np.reshape(
|
343
|
+
np.transpose(hf_tensor), keras_shape
|
344
|
+
),
|
345
|
+
)
|
346
|
+
|
347
|
+
# Normalization
|
348
|
+
loader.port_weight(
|
349
|
+
keras_variable=backbone.encoder_embeddings_layer_norm.gamma,
|
350
|
+
hf_weight_key="encoder.layernorm_embedding.weight",
|
351
|
+
)
|
352
|
+
loader.port_weight(
|
353
|
+
keras_variable=backbone.encoder_embeddings_layer_norm.beta,
|
354
|
+
hf_weight_key="encoder.layernorm_embedding.bias",
|
355
|
+
)
|
356
|
+
loader.port_weight(
|
357
|
+
keras_variable=backbone.decoder_embeddings_layer_norm.gamma,
|
358
|
+
hf_weight_key="decoder.layernorm_embedding.weight",
|
359
|
+
)
|
360
|
+
loader.port_weight(
|
361
|
+
keras_variable=backbone.decoder_embeddings_layer_norm.beta,
|
362
|
+
hf_weight_key="decoder.layernorm_embedding.bias",
|
363
|
+
)
|
364
|
+
|
365
|
+
|
366
|
+
def convert_tokenizer(cls, preset, **kwargs):
|
367
|
+
vocab_file = get_file(preset, "vocab.json")
|
368
|
+
merges_file = get_file(preset, "merges.txt")
|
369
|
+
return cls(
|
370
|
+
vocabulary=vocab_file,
|
371
|
+
merges=merges_file,
|
372
|
+
**kwargs,
|
373
|
+
)
|
@@ -13,12 +13,12 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
import numpy as np
|
15
15
|
|
16
|
-
from keras_hub.src.
|
16
|
+
from keras_hub.src.models.bert.bert_backbone import BertBackbone
|
17
17
|
from keras_hub.src.utils.preset_utils import HF_TOKENIZER_CONFIG_FILE
|
18
18
|
from keras_hub.src.utils.preset_utils import get_file
|
19
|
-
from keras_hub.src.utils.preset_utils import
|
20
|
-
|
21
|
-
|
19
|
+
from keras_hub.src.utils.preset_utils import load_json
|
20
|
+
|
21
|
+
backbone_cls = BertBackbone
|
22
22
|
|
23
23
|
|
24
24
|
def convert_backbone_config(transformers_config):
|
@@ -154,20 +154,10 @@ def convert_weights(backbone, loader, transformers_config):
|
|
154
154
|
)
|
155
155
|
|
156
156
|
|
157
|
-
def
|
158
|
-
transformers_config =
|
159
|
-
keras_config = convert_backbone_config(transformers_config)
|
160
|
-
backbone = cls(**keras_config)
|
161
|
-
if load_weights:
|
162
|
-
jax_memory_cleanup(backbone)
|
163
|
-
with SafetensorLoader(preset) as loader:
|
164
|
-
convert_weights(backbone, loader, transformers_config)
|
165
|
-
return backbone
|
166
|
-
|
167
|
-
|
168
|
-
def load_bert_tokenizer(cls, preset):
|
169
|
-
transformers_config = load_config(preset, HF_TOKENIZER_CONFIG_FILE)
|
157
|
+
def convert_tokenizer(cls, preset, **kwargs):
|
158
|
+
transformers_config = load_json(preset, HF_TOKENIZER_CONFIG_FILE)
|
170
159
|
return cls(
|
171
160
|
get_file(preset, "vocab.txt"),
|
172
161
|
lowercase=transformers_config["do_lower_case"],
|
162
|
+
**kwargs,
|
173
163
|
)
|
@@ -13,12 +13,14 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
import numpy as np
|
15
15
|
|
16
|
-
from keras_hub.src.
|
16
|
+
from keras_hub.src.models.distil_bert.distil_bert_backbone import (
|
17
|
+
DistilBertBackbone,
|
18
|
+
)
|
17
19
|
from keras_hub.src.utils.preset_utils import HF_TOKENIZER_CONFIG_FILE
|
18
20
|
from keras_hub.src.utils.preset_utils import get_file
|
19
|
-
from keras_hub.src.utils.preset_utils import
|
20
|
-
|
21
|
-
|
21
|
+
from keras_hub.src.utils.preset_utils import load_json
|
22
|
+
|
23
|
+
backbone_cls = DistilBertBackbone
|
22
24
|
|
23
25
|
|
24
26
|
def convert_backbone_config(transformers_config):
|
@@ -33,7 +35,7 @@ def convert_backbone_config(transformers_config):
|
|
33
35
|
}
|
34
36
|
|
35
37
|
|
36
|
-
def convert_weights(backbone, loader):
|
38
|
+
def convert_weights(backbone, loader, transformers_config):
|
37
39
|
# Embeddings
|
38
40
|
loader.port_weight(
|
39
41
|
keras_variable=backbone.get_layer(
|
@@ -162,23 +164,11 @@ def convert_weights(backbone, loader):
|
|
162
164
|
hf_weight_key="distilbert.embeddings.LayerNorm.bias",
|
163
165
|
)
|
164
166
|
|
165
|
-
return backbone
|
166
|
-
|
167
|
-
|
168
|
-
def load_distilbert_backbone(cls, preset, load_weights):
|
169
|
-
transformers_config = load_config(preset, HF_CONFIG_FILE)
|
170
|
-
keras_config = convert_backbone_config(transformers_config)
|
171
|
-
backbone = cls(**keras_config)
|
172
|
-
if load_weights:
|
173
|
-
jax_memory_cleanup(backbone)
|
174
|
-
with SafetensorLoader(preset) as loader:
|
175
|
-
convert_weights(backbone, loader)
|
176
|
-
return backbone
|
177
|
-
|
178
167
|
|
179
|
-
def
|
180
|
-
transformers_config =
|
168
|
+
def convert_tokenizer(cls, preset, **kwargs):
|
169
|
+
transformers_config = load_json(preset, HF_TOKENIZER_CONFIG_FILE)
|
181
170
|
return cls(
|
182
171
|
get_file(preset, "vocab.txt"),
|
183
172
|
lowercase=transformers_config["do_lower_case"],
|
173
|
+
**kwargs,
|
184
174
|
)
|
@@ -13,11 +13,10 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
import numpy as np
|
15
15
|
|
16
|
-
from keras_hub.src.
|
16
|
+
from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
|
17
17
|
from keras_hub.src.utils.preset_utils import get_file
|
18
|
-
|
19
|
-
|
20
|
-
from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
|
18
|
+
|
19
|
+
backbone_cls = GemmaBackbone
|
21
20
|
|
22
21
|
|
23
22
|
def convert_backbone_config(transformers_config):
|
@@ -169,19 +168,6 @@ def convert_weights(backbone, loader, transformers_config):
|
|
169
168
|
hf_weight_key="model.norm.weight",
|
170
169
|
)
|
171
170
|
|
172
|
-
return backbone
|
173
|
-
|
174
|
-
|
175
|
-
def load_gemma_backbone(cls, preset, load_weights):
|
176
|
-
transformers_config = load_config(preset, HF_CONFIG_FILE)
|
177
|
-
keras_config = convert_backbone_config(transformers_config)
|
178
|
-
backbone = cls(**keras_config)
|
179
|
-
if load_weights:
|
180
|
-
jax_memory_cleanup(backbone)
|
181
|
-
with SafetensorLoader(preset) as loader:
|
182
|
-
convert_weights(backbone, loader, transformers_config)
|
183
|
-
return backbone
|
184
|
-
|
185
171
|
|
186
|
-
def
|
187
|
-
return cls(get_file(preset, "tokenizer.model"))
|
172
|
+
def convert_tokenizer(cls, preset, **kwargs):
|
173
|
+
return cls(get_file(preset, "tokenizer.model"), **kwargs)
|
@@ -13,11 +13,10 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
import numpy as np
|
15
15
|
|
16
|
-
from keras_hub.src.
|
16
|
+
from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone
|
17
17
|
from keras_hub.src.utils.preset_utils import get_file
|
18
|
-
|
19
|
-
|
20
|
-
from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
|
18
|
+
|
19
|
+
backbone_cls = GPT2Backbone
|
21
20
|
|
22
21
|
|
23
22
|
def convert_backbone_config(transformers_config):
|
@@ -163,24 +162,12 @@ def convert_weights(backbone, loader, transformers_config):
|
|
163
162
|
hf_weight_key="ln_f.bias",
|
164
163
|
)
|
165
164
|
|
166
|
-
return backbone
|
167
|
-
|
168
|
-
|
169
|
-
def load_gpt2_backbone(cls, preset, load_weights):
|
170
|
-
transformers_config = load_config(preset, HF_CONFIG_FILE)
|
171
|
-
keras_config = convert_backbone_config(transformers_config)
|
172
|
-
backbone = cls(**keras_config)
|
173
|
-
if load_weights:
|
174
|
-
jax_memory_cleanup(backbone)
|
175
|
-
with SafetensorLoader(preset) as loader:
|
176
|
-
convert_weights(backbone, loader, transformers_config)
|
177
|
-
return backbone
|
178
|
-
|
179
165
|
|
180
|
-
def
|
166
|
+
def convert_tokenizer(cls, preset, **kwargs):
|
181
167
|
vocab_file = get_file(preset, "vocab.json")
|
182
168
|
merges_file = get_file(preset, "merges.txt")
|
183
169
|
return cls(
|
184
170
|
vocabulary=vocab_file,
|
185
171
|
merges=merges_file,
|
172
|
+
**kwargs,
|
186
173
|
)
|
@@ -13,10 +13,10 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
import numpy as np
|
15
15
|
|
16
|
-
from keras_hub.src.
|
17
|
-
from keras_hub.src.utils.preset_utils import
|
18
|
-
|
19
|
-
|
16
|
+
from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone
|
17
|
+
from keras_hub.src.utils.preset_utils import load_json
|
18
|
+
|
19
|
+
backbone_cls = Llama3Backbone
|
20
20
|
|
21
21
|
|
22
22
|
def convert_backbone_config(transformers_config):
|
@@ -111,19 +111,8 @@ def convert_weights(backbone, loader, transformers_config):
|
|
111
111
|
return backbone
|
112
112
|
|
113
113
|
|
114
|
-
def
|
115
|
-
|
116
|
-
keras_config = convert_backbone_config(transformers_config)
|
117
|
-
backbone = cls(**keras_config)
|
118
|
-
if load_weights:
|
119
|
-
jax_memory_cleanup(backbone)
|
120
|
-
with SafetensorLoader(preset) as loader:
|
121
|
-
convert_weights(backbone, loader, transformers_config)
|
122
|
-
return backbone
|
123
|
-
|
124
|
-
|
125
|
-
def load_llama3_tokenizer(cls, preset):
|
126
|
-
tokenizer_config = load_config(preset, "tokenizer.json")
|
114
|
+
def convert_tokenizer(cls, preset, **kwargs):
|
115
|
+
tokenizer_config = load_json(preset, "tokenizer.json")
|
127
116
|
vocab = tokenizer_config["model"]["vocab"]
|
128
117
|
merges = tokenizer_config["model"]["merges"]
|
129
118
|
|
@@ -133,4 +122,4 @@ def load_llama3_tokenizer(cls, preset):
|
|
133
122
|
vocab[bot["content"]] = bot["id"]
|
134
123
|
vocab[eot["content"]] = eot["id"]
|
135
124
|
|
136
|
-
return cls(vocabulary=vocab, merges=merges)
|
125
|
+
return cls(vocabulary=vocab, merges=merges, **kwargs)
|