keras-hub-nightly 0.22.0.dev202508170419__py3-none-any.whl → 0.24.0.dev202511090424__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-hub-nightly might be problematic. Click here for more details.
- keras_hub/layers/__init__.py +15 -0
- keras_hub/models/__init__.py +93 -0
- keras_hub/src/layers/modeling/position_embedding.py +21 -6
- keras_hub/src/layers/modeling/reversible_embedding.py +8 -1
- keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
- keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
- keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
- keras_hub/src/models/backbone.py +28 -16
- keras_hub/src/models/causal_lm.py +37 -0
- keras_hub/src/models/causal_lm_preprocessor.py +14 -0
- keras_hub/src/models/clip/clip_presets.py +8 -8
- keras_hub/src/models/d_fine/__init__.py +5 -0
- keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
- keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
- keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
- keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
- keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
- keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
- keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
- keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
- keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
- keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/d_fine/d_fine_presets.py +155 -0
- keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +7 -2
- keras_hub/src/models/depth_anything/__init__.py +9 -0
- keras_hub/src/models/depth_anything/depth_anything_backbone.py +232 -0
- keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py +70 -0
- keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py +16 -0
- keras_hub/src/models/depth_anything/depth_anything_image_converter.py +10 -0
- keras_hub/src/models/depth_anything/depth_anything_layers.py +725 -0
- keras_hub/src/models/depth_anything/depth_anything_loss.py +89 -0
- keras_hub/src/models/depth_anything/depth_anything_presets.py +41 -0
- keras_hub/src/models/depth_anything/interpolate.py +62 -0
- keras_hub/src/models/depth_estimator.py +239 -0
- keras_hub/src/models/depth_estimator_preprocessor.py +78 -0
- keras_hub/src/models/dinov2/dinov2_backbone.py +29 -3
- keras_hub/src/models/dinov2/dinov2_layers.py +16 -4
- keras_hub/src/models/dinov3/__init__.py +5 -0
- keras_hub/src/models/dinov3/dinov3_backbone.py +263 -0
- keras_hub/src/models/dinov3/dinov3_image_converter.py +8 -0
- keras_hub/src/models/dinov3/dinov3_layers.py +1013 -0
- keras_hub/src/models/dinov3/dinov3_presets.py +4 -0
- keras_hub/src/models/gemma/gemma_backbone.py +0 -1
- keras_hub/src/models/gemma/gemma_presets.py +30 -0
- keras_hub/src/models/gemma3/gemma3_attention.py +48 -0
- keras_hub/src/models/gemma3/gemma3_backbone.py +4 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +12 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +39 -0
- keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
- keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
- keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
- keras_hub/src/models/image_to_image.py +5 -0
- keras_hub/src/models/inpaint.py +5 -0
- keras_hub/src/models/mobilenetv5/__init__.py +9 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_presets.py +15 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
- keras_hub/src/models/parseq/__init__.py +5 -0
- keras_hub/src/models/parseq/parseq_backbone.py +134 -0
- keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
- keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
- keras_hub/src/models/parseq/parseq_decoder.py +418 -0
- keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
- keras_hub/src/models/parseq/parseq_presets.py +15 -0
- keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
- keras_hub/src/models/qwen3_moe/__init__.py +5 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +371 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +365 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py +357 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py +12 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +672 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +45 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py +48 -0
- keras_hub/src/models/sam/sam_prompt_encoder.py +3 -1
- keras_hub/src/models/siglip/siglip_presets.py +15 -0
- keras_hub/src/models/smollm3/smollm3_backbone.py +211 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm.py +310 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py +84 -0
- keras_hub/src/models/smollm3/smollm3_layers.py +757 -0
- keras_hub/src/models/smollm3/smollm3_tokenizer.py +60 -0
- keras_hub/src/models/smollm3/smollm3_utils.py +56 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
- keras_hub/src/models/t5gemma/__init__.py +5 -0
- keras_hub/src/models/t5gemma/t5gemma_attention.py +370 -0
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +366 -0
- keras_hub/src/models/t5gemma/t5gemma_decoder.py +355 -0
- keras_hub/src/models/t5gemma/t5gemma_encoder.py +214 -0
- keras_hub/src/models/t5gemma/t5gemma_layers.py +118 -0
- keras_hub/src/models/t5gemma/t5gemma_presets.py +374 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py +442 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +216 -0
- keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +84 -0
- keras_hub/src/models/text_to_image.py +5 -0
- keras_hub/src/samplers/beam_sampler.py +6 -6
- keras_hub/src/samplers/sampler.py +8 -6
- keras_hub/src/tests/test_case.py +40 -3
- keras_hub/src/tokenizers/tokenizer.py +15 -0
- keras_hub/src/utils/openvino_utils.py +141 -0
- keras_hub/src/utils/preset_utils.py +58 -2
- keras_hub/src/utils/tensor_utils.py +26 -2
- keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
- keras_hub/src/utils/timm/preset_loader.py +8 -4
- keras_hub/src/utils/transformers/convert_dinov2.py +1 -0
- keras_hub/src/utils/transformers/convert_dinov3.py +106 -0
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +216 -0
- keras_hub/src/utils/transformers/convert_smollm3.py +139 -0
- keras_hub/src/utils/transformers/convert_t5gemma.py +229 -0
- keras_hub/src/utils/transformers/convert_vit.py +4 -1
- keras_hub/src/utils/transformers/export/gemma.py +49 -4
- keras_hub/src/utils/transformers/export/hf_exporter.py +71 -25
- keras_hub/src/utils/transformers/preset_loader.py +12 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +15 -0
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/RECORD +126 -47
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone
|
|
4
|
+
from keras_hub.src.utils.preset_utils import load_json
|
|
5
|
+
|
|
6
|
+
backbone_cls = Qwen3MoeBackbone
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def convert_backbone_config(transformers_config):
|
|
10
|
+
return {
|
|
11
|
+
"vocabulary_size": transformers_config["vocab_size"],
|
|
12
|
+
"hidden_dim": transformers_config["hidden_size"],
|
|
13
|
+
"head_dim": transformers_config["head_dim"],
|
|
14
|
+
"num_layers": transformers_config["num_hidden_layers"],
|
|
15
|
+
"num_query_heads": transformers_config["num_attention_heads"],
|
|
16
|
+
"num_key_value_heads": transformers_config["num_key_value_heads"],
|
|
17
|
+
"intermediate_dim": transformers_config["intermediate_size"],
|
|
18
|
+
"moe_intermediate_dim": transformers_config["moe_intermediate_size"],
|
|
19
|
+
"num_experts": transformers_config["num_experts"],
|
|
20
|
+
"top_k": transformers_config["num_experts_per_tok"],
|
|
21
|
+
"norm_top_k_prob": transformers_config["norm_topk_prob"],
|
|
22
|
+
"decoder_sparse_step": transformers_config["decoder_sparse_step"],
|
|
23
|
+
"layer_norm_epsilon": transformers_config["rms_norm_eps"],
|
|
24
|
+
"rope_max_wavelength": transformers_config["rope_theta"],
|
|
25
|
+
"sliding_window_size": transformers_config["sliding_window"],
|
|
26
|
+
"router_aux_loss_coefficient": transformers_config[
|
|
27
|
+
"router_aux_loss_coef"
|
|
28
|
+
],
|
|
29
|
+
"tie_word_embeddings": transformers_config.get(
|
|
30
|
+
"tie_word_embeddings", False
|
|
31
|
+
),
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def convert_weights(backbone, loader, transformers_config):
|
|
36
|
+
loader.port_weight(
|
|
37
|
+
keras_variable=backbone.get_layer("token_embedding").embeddings,
|
|
38
|
+
hf_weight_key="model.embed_tokens.weight",
|
|
39
|
+
)
|
|
40
|
+
if not backbone.tie_word_embeddings:
|
|
41
|
+
loader.port_weight(
|
|
42
|
+
keras_variable=backbone.get_layer(
|
|
43
|
+
"token_embedding"
|
|
44
|
+
).reverse_embeddings,
|
|
45
|
+
hf_weight_key="lm_head.weight",
|
|
46
|
+
# rearrange_pattern="b a -> a b",
|
|
47
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
def transpose_and_reshape(x, shape):
|
|
51
|
+
return np.reshape(np.transpose(x), shape)
|
|
52
|
+
|
|
53
|
+
for i in range(backbone.num_layers):
|
|
54
|
+
decoder_layer = backbone.get_layer(f"transformer_layer_{i}")
|
|
55
|
+
|
|
56
|
+
# Input layernorm
|
|
57
|
+
loader.port_weight(
|
|
58
|
+
keras_variable=decoder_layer._self_attention_layernorm.scale,
|
|
59
|
+
hf_weight_key=f"model.layers.{i}.input_layernorm.weight",
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Attention layers
|
|
63
|
+
|
|
64
|
+
## Query
|
|
65
|
+
loader.port_weight(
|
|
66
|
+
keras_variable=decoder_layer._self_attention_layer._query_dense.kernel,
|
|
67
|
+
hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight",
|
|
68
|
+
hook_fn=transpose_and_reshape,
|
|
69
|
+
)
|
|
70
|
+
loader.port_weight(
|
|
71
|
+
keras_variable=decoder_layer._self_attention_layer._query_dense_layer_norm.scale,
|
|
72
|
+
hf_weight_key=f"model.layers.{i}.self_attn.q_norm.weight",
|
|
73
|
+
)
|
|
74
|
+
## Key
|
|
75
|
+
loader.port_weight(
|
|
76
|
+
keras_variable=decoder_layer._self_attention_layer._key_dense.kernel,
|
|
77
|
+
hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight",
|
|
78
|
+
hook_fn=transpose_and_reshape,
|
|
79
|
+
)
|
|
80
|
+
loader.port_weight(
|
|
81
|
+
keras_variable=decoder_layer._self_attention_layer._key_dense_layer_norm.scale,
|
|
82
|
+
hf_weight_key=f"model.layers.{i}.self_attn.k_norm.weight",
|
|
83
|
+
)
|
|
84
|
+
## Value
|
|
85
|
+
loader.port_weight(
|
|
86
|
+
keras_variable=decoder_layer._self_attention_layer._value_dense.kernel,
|
|
87
|
+
hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight",
|
|
88
|
+
hook_fn=transpose_and_reshape,
|
|
89
|
+
)
|
|
90
|
+
## Output
|
|
91
|
+
loader.port_weight(
|
|
92
|
+
keras_variable=decoder_layer._self_attention_layer._output_dense.kernel,
|
|
93
|
+
hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight",
|
|
94
|
+
# rearrange_patterns="c (a b) -> a b c",
|
|
95
|
+
# rearrange_dims={"a": backbone.num_query_heads},
|
|
96
|
+
hook_fn=transpose_and_reshape,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# MLP layers
|
|
100
|
+
if (
|
|
101
|
+
(i not in backbone.mlp_only_layers)
|
|
102
|
+
and backbone.num_experts > 0
|
|
103
|
+
and ((i + 1) % backbone.decoder_sparse_step == 0)
|
|
104
|
+
):
|
|
105
|
+
# MoE layers
|
|
106
|
+
loader.port_weight(
|
|
107
|
+
keras_variable=decoder_layer.mlp._sparse_feedforward_gate_dense.kernel,
|
|
108
|
+
hf_weight_key=f"model.layers.{i}.mlp.gate.weight",
|
|
109
|
+
# rearrange_patterns="b a -> a b",
|
|
110
|
+
hook_fn=lambda hf_tensor, _: np.transpose(
|
|
111
|
+
hf_tensor, axes=(1, 0)
|
|
112
|
+
),
|
|
113
|
+
)
|
|
114
|
+
# Batched experts: gate_up_proj and down_proj
|
|
115
|
+
gate_up_proj_list = []
|
|
116
|
+
down_proj_list = []
|
|
117
|
+
for expert_idx in range(backbone.num_experts):
|
|
118
|
+
# Load gate_proj and up_proj for each expert
|
|
119
|
+
gate_proj = loader.get_tensor(
|
|
120
|
+
f"model.layers.{i}.mlp.experts.{expert_idx}.gate_proj.weight"
|
|
121
|
+
)
|
|
122
|
+
up_proj = loader.get_tensor(
|
|
123
|
+
f"model.layers.{i}.mlp.experts.{expert_idx}.up_proj.weight"
|
|
124
|
+
)
|
|
125
|
+
# Transpose to (hidden_dim, intermediate_dim)
|
|
126
|
+
gate_proj = np.transpose(gate_proj, axes=(1, 0))
|
|
127
|
+
up_proj = np.transpose(up_proj, axes=(1, 0))
|
|
128
|
+
# Concatenate gate_proj and up_proj along the last dimension
|
|
129
|
+
gate_up_proj = np.concatenate([gate_proj, up_proj], axis=-1)
|
|
130
|
+
gate_up_proj_list.append(gate_up_proj)
|
|
131
|
+
|
|
132
|
+
# Load down_proj for each expert
|
|
133
|
+
down_proj = loader.get_tensor(
|
|
134
|
+
f"model.layers.{i}.mlp.experts.{expert_idx}.down_proj.weight"
|
|
135
|
+
)
|
|
136
|
+
down_proj = np.transpose(
|
|
137
|
+
down_proj, axes=(1, 0)
|
|
138
|
+
) # (intermediate_dim, hidden_dim)
|
|
139
|
+
down_proj_list.append(down_proj)
|
|
140
|
+
|
|
141
|
+
# Stack the lists to create batched weights
|
|
142
|
+
gate_up_proj_batched = np.stack(
|
|
143
|
+
gate_up_proj_list, axis=0
|
|
144
|
+
) # (num_experts, hidden_dim, 2 * intermediate_dim)
|
|
145
|
+
down_proj_batched = np.stack(
|
|
146
|
+
down_proj_list, axis=0
|
|
147
|
+
) # (num_experts, intermediate_dim, hidden_dim)
|
|
148
|
+
|
|
149
|
+
# Assign batched weights to expert_bank
|
|
150
|
+
decoder_layer.mlp.expert_bank._expert_feedforward_gate_dense.assign(
|
|
151
|
+
gate_up_proj_batched
|
|
152
|
+
)
|
|
153
|
+
decoder_layer.mlp.expert_bank._expert_feedforward_output_dense.assign(
|
|
154
|
+
down_proj_batched
|
|
155
|
+
)
|
|
156
|
+
else:
|
|
157
|
+
loader.port_weight(
|
|
158
|
+
keras_variable=decoder_layer._feedforward_intermediate_dense.kernel,
|
|
159
|
+
hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight",
|
|
160
|
+
# rearrange_patterns="b a -> a b",
|
|
161
|
+
hook_fn=lambda hf_tensor, _: np.transpose(
|
|
162
|
+
hf_tensor, axes=(1, 0)
|
|
163
|
+
),
|
|
164
|
+
)
|
|
165
|
+
loader.port_weight(
|
|
166
|
+
keras_variable=decoder_layer._feedforward_output_dense.kernel,
|
|
167
|
+
hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight",
|
|
168
|
+
# rearrange_patterns="b a -> a b",
|
|
169
|
+
hook_fn=lambda hf_tensor, _: np.transpose(
|
|
170
|
+
hf_tensor, axes=(1, 0)
|
|
171
|
+
),
|
|
172
|
+
)
|
|
173
|
+
loader.port_weight(
|
|
174
|
+
keras_variable=decoder_layer._feedforward_gate_dense.kernel,
|
|
175
|
+
hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight",
|
|
176
|
+
# rearrange_patterns="b a -> a b",
|
|
177
|
+
hook_fn=lambda hf_tensor, _: np.transpose(
|
|
178
|
+
hf_tensor, axes=(1, 0)
|
|
179
|
+
),
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Feedforward layernorm
|
|
183
|
+
loader.port_weight(
|
|
184
|
+
keras_variable=decoder_layer._feedforward_layernorm.scale,
|
|
185
|
+
hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight",
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# Final normalization layer
|
|
189
|
+
loader.port_weight(
|
|
190
|
+
keras_variable=backbone.get_layer("sequence_output_layernorm").scale,
|
|
191
|
+
hf_weight_key="model.norm.weight",
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
return backbone
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def convert_tokenizer(cls, preset, **kwargs):
|
|
198
|
+
tokenizer_config = load_json(preset, "tokenizer.json")
|
|
199
|
+
vocab = tokenizer_config["model"]["vocab"]
|
|
200
|
+
merges = tokenizer_config["model"]["merges"]
|
|
201
|
+
merges = [" ".join(item) for item in merges]
|
|
202
|
+
|
|
203
|
+
# Load all special tokens with the exception of "reserved" ones.
|
|
204
|
+
special_tokens = set()
|
|
205
|
+
for token in tokenizer_config["added_tokens"]:
|
|
206
|
+
if not token["content"].startswith("<|reserved_special_token_"):
|
|
207
|
+
vocab[token["content"]] = token["id"]
|
|
208
|
+
special_tokens.add(token["content"])
|
|
209
|
+
|
|
210
|
+
kwargs.update(
|
|
211
|
+
{
|
|
212
|
+
"unsplittable_tokens": list(special_tokens),
|
|
213
|
+
}
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
return cls(vocabulary=vocab, merges=merges, **kwargs)
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from keras_hub.src.models.smollm3.smollm3_backbone import SmolLM3Backbone
|
|
4
|
+
from keras_hub.src.utils.preset_utils import load_json
|
|
5
|
+
|
|
6
|
+
backbone_cls = SmolLM3Backbone
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def convert_backbone_config(transformers_config):
|
|
10
|
+
return {
|
|
11
|
+
"vocabulary_size": transformers_config["vocab_size"],
|
|
12
|
+
"hidden_dim": transformers_config["hidden_size"],
|
|
13
|
+
"num_layers": transformers_config["num_hidden_layers"],
|
|
14
|
+
"num_attention_heads": transformers_config["num_attention_heads"],
|
|
15
|
+
"num_key_value_heads": transformers_config["num_key_value_heads"],
|
|
16
|
+
"intermediate_dim": transformers_config["intermediate_size"],
|
|
17
|
+
"layer_norm_epsilon": transformers_config[
|
|
18
|
+
"rms_norm_eps"
|
|
19
|
+
], # Using rms_norm_eps as layer_norm_epsilon
|
|
20
|
+
"max_position_embeddings": transformers_config[
|
|
21
|
+
"max_position_embeddings"
|
|
22
|
+
],
|
|
23
|
+
"rope_theta": transformers_config["rope_theta"],
|
|
24
|
+
# partial_rotary_factor is not explicitly in config.json
|
|
25
|
+
# but is inherited from the default value in the
|
|
26
|
+
# `_compute_default_rope_parameters()` function
|
|
27
|
+
"partial_rotary_factor": 1.0,
|
|
28
|
+
"attention_bias": transformers_config["attention_bias"],
|
|
29
|
+
"attention_dropout": transformers_config["attention_dropout"],
|
|
30
|
+
# Despite the name, no_rope_layers: 1 = HAS RoPE, 0 = NO RoPE
|
|
31
|
+
"rope_layer_enabled_list": [
|
|
32
|
+
bool(x) for x in transformers_config["no_rope_layers"]
|
|
33
|
+
],
|
|
34
|
+
"layer_types": transformers_config["layer_types"],
|
|
35
|
+
"mlp_bias": transformers_config["mlp_bias"],
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def convert_weights(backbone, loader, transformers_config):
|
|
40
|
+
loader.port_weight(
|
|
41
|
+
keras_variable=backbone.get_layer("token_embedding").embeddings,
|
|
42
|
+
hf_weight_key="model.embed_tokens.weight",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
def transpose_and_reshape(x, shape):
|
|
46
|
+
return np.reshape(np.transpose(x), shape)
|
|
47
|
+
|
|
48
|
+
for i in range(backbone.num_layers):
|
|
49
|
+
decoder_layer = backbone.get_layer(f"transformer_layer_{i}")
|
|
50
|
+
|
|
51
|
+
# Input layernorm
|
|
52
|
+
loader.port_weight(
|
|
53
|
+
keras_variable=decoder_layer.input_layernorm.scale,
|
|
54
|
+
hf_weight_key=f"model.layers.{i}.input_layernorm.weight",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Attention layers
|
|
58
|
+
## Query
|
|
59
|
+
loader.port_weight(
|
|
60
|
+
keras_variable=decoder_layer.self_attn.q_proj.kernel,
|
|
61
|
+
hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight",
|
|
62
|
+
hook_fn=transpose_and_reshape,
|
|
63
|
+
)
|
|
64
|
+
## Key
|
|
65
|
+
loader.port_weight(
|
|
66
|
+
keras_variable=decoder_layer.self_attn.k_proj.kernel,
|
|
67
|
+
hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight",
|
|
68
|
+
hook_fn=transpose_and_reshape,
|
|
69
|
+
)
|
|
70
|
+
## Value
|
|
71
|
+
loader.port_weight(
|
|
72
|
+
keras_variable=decoder_layer.self_attn.v_proj.kernel,
|
|
73
|
+
hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight",
|
|
74
|
+
hook_fn=transpose_and_reshape,
|
|
75
|
+
)
|
|
76
|
+
## Output
|
|
77
|
+
loader.port_weight(
|
|
78
|
+
keras_variable=decoder_layer.self_attn.o_proj.kernel,
|
|
79
|
+
hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight",
|
|
80
|
+
hook_fn=transpose_and_reshape,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# MLP layers
|
|
84
|
+
loader.port_weight(
|
|
85
|
+
keras_variable=decoder_layer.mlp.up_proj.kernel,
|
|
86
|
+
hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight",
|
|
87
|
+
# rearrange_patterns="b a -> a b",
|
|
88
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
|
89
|
+
)
|
|
90
|
+
loader.port_weight(
|
|
91
|
+
keras_variable=decoder_layer.mlp.down_proj.kernel,
|
|
92
|
+
hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight",
|
|
93
|
+
# rearrange_patterns="b a -> a b",
|
|
94
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
|
95
|
+
)
|
|
96
|
+
loader.port_weight(
|
|
97
|
+
keras_variable=decoder_layer.mlp.gate_proj.kernel,
|
|
98
|
+
hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight",
|
|
99
|
+
# rearrange_patterns="b a -> a b",
|
|
100
|
+
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Feedforward layernorm
|
|
104
|
+
loader.port_weight(
|
|
105
|
+
keras_variable=decoder_layer.post_attention_layernorm.scale,
|
|
106
|
+
hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Final normalization layer
|
|
110
|
+
loader.port_weight(
|
|
111
|
+
keras_variable=backbone.get_layer("sequence_output_layernorm").scale,
|
|
112
|
+
hf_weight_key="model.norm.weight",
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
backbone.training = False
|
|
116
|
+
|
|
117
|
+
return backbone
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def convert_tokenizer(cls, preset, **kwargs):
|
|
121
|
+
tokenizer_config = load_json(preset, "tokenizer.json")
|
|
122
|
+
vocab = tokenizer_config["model"]["vocab"]
|
|
123
|
+
merges = tokenizer_config["model"]["merges"]
|
|
124
|
+
merges = [" ".join(item) for item in merges]
|
|
125
|
+
|
|
126
|
+
# Load all special tokens with the exception of "reserved" ones.
|
|
127
|
+
special_tokens = set()
|
|
128
|
+
for token in tokenizer_config["added_tokens"]:
|
|
129
|
+
if not token["content"].startswith("<|reserved_special_token_"):
|
|
130
|
+
vocab[token["content"]] = token["id"]
|
|
131
|
+
special_tokens.add(token["content"])
|
|
132
|
+
|
|
133
|
+
kwargs.update(
|
|
134
|
+
{
|
|
135
|
+
"unsplittable_tokens": list(special_tokens),
|
|
136
|
+
}
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
return cls(vocabulary=vocab, merges=merges, **kwargs)
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone
|
|
2
|
+
from keras_hub.src.utils.preset_utils import get_file
|
|
3
|
+
|
|
4
|
+
backbone_cls = T5GemmaBackbone
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def convert_backbone_config(transformers_config):
|
|
8
|
+
"""Convert a Hugging Face T5Gemma config to a KerasHub backbone config."""
|
|
9
|
+
encoder_config = transformers_config["encoder"]
|
|
10
|
+
decoder_config = transformers_config["decoder"]
|
|
11
|
+
|
|
12
|
+
if decoder_config.get("hidden_activation") == "gelu_pytorch_tanh":
|
|
13
|
+
decoder_config["hidden_activation"] = "gelu_approximate"
|
|
14
|
+
if encoder_config.get("hidden_activation") == "gelu_pytorch_tanh":
|
|
15
|
+
encoder_config["hidden_activation"] = "gelu_approximate"
|
|
16
|
+
|
|
17
|
+
backbone_config = {
|
|
18
|
+
"vocabulary_size": decoder_config["vocab_size"],
|
|
19
|
+
"encoder_hidden_dim": encoder_config["hidden_size"],
|
|
20
|
+
"encoder_intermediate_dim": encoder_config["intermediate_size"],
|
|
21
|
+
"encoder_num_layers": encoder_config["num_hidden_layers"],
|
|
22
|
+
"encoder_num_attention_heads": encoder_config["num_attention_heads"],
|
|
23
|
+
"encoder_num_key_value_heads": encoder_config["num_key_value_heads"],
|
|
24
|
+
"encoder_head_dim": encoder_config["head_dim"],
|
|
25
|
+
"encoder_layer_types": encoder_config["layer_types"],
|
|
26
|
+
"decoder_hidden_dim": decoder_config["hidden_size"],
|
|
27
|
+
"decoder_intermediate_dim": decoder_config["intermediate_size"],
|
|
28
|
+
"decoder_num_layers": decoder_config["num_hidden_layers"],
|
|
29
|
+
"decoder_num_attention_heads": decoder_config["num_attention_heads"],
|
|
30
|
+
"decoder_num_key_value_heads": decoder_config["num_key_value_heads"],
|
|
31
|
+
"decoder_head_dim": decoder_config["head_dim"],
|
|
32
|
+
"decoder_layer_types": decoder_config["layer_types"],
|
|
33
|
+
"dropout_rate": decoder_config["dropout_rate"],
|
|
34
|
+
"rms_norm_eps": decoder_config["rms_norm_eps"],
|
|
35
|
+
"query_pre_attn_scalar": decoder_config["query_pre_attn_scalar"],
|
|
36
|
+
"tie_word_embeddings": transformers_config.get(
|
|
37
|
+
"tie_word_embeddings", True
|
|
38
|
+
),
|
|
39
|
+
"attention_bias": decoder_config["attention_bias"],
|
|
40
|
+
"hidden_activation": decoder_config["hidden_activation"],
|
|
41
|
+
"initializer_range": decoder_config["initializer_range"],
|
|
42
|
+
"attention_dropout": decoder_config["attention_dropout"],
|
|
43
|
+
"sliding_window": decoder_config["sliding_window"],
|
|
44
|
+
"cross_attention_hidden_size": encoder_config["hidden_size"],
|
|
45
|
+
"attn_logit_softcapping": decoder_config["attn_logit_softcapping"],
|
|
46
|
+
"final_logit_softcapping": decoder_config["final_logit_softcapping"],
|
|
47
|
+
"rope_max_wavelength": decoder_config["rope_theta"],
|
|
48
|
+
}
|
|
49
|
+
return backbone_config
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def convert_weights(backbone, loader, transformers_config):
|
|
53
|
+
"""Convert T5Gemma from Hugging Face to KerasHub."""
|
|
54
|
+
# Token embeddings.
|
|
55
|
+
loader.port_weight(
|
|
56
|
+
keras_variable=backbone.token_embedding.embeddings,
|
|
57
|
+
hf_weight_key="encoder.embed_tokens.weight",
|
|
58
|
+
)
|
|
59
|
+
loader.port_weight(
|
|
60
|
+
keras_variable=backbone.decoder_token_embedding.embeddings,
|
|
61
|
+
hf_weight_key="decoder.embed_tokens.weight",
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Encoder.
|
|
65
|
+
loader.port_weight(
|
|
66
|
+
keras_variable=backbone.encoder_norm.scale,
|
|
67
|
+
hf_weight_key="encoder.norm.weight",
|
|
68
|
+
)
|
|
69
|
+
for i in range(backbone.encoder_num_layers):
|
|
70
|
+
layer = backbone.get_layer(f"encoder_layer_{i}")
|
|
71
|
+
hf_prefix = f"encoder.layers.{i}"
|
|
72
|
+
|
|
73
|
+
# Self-attention.
|
|
74
|
+
loader.port_weight(
|
|
75
|
+
keras_variable=layer.self_attn.query_dense.kernel,
|
|
76
|
+
hf_weight_key=f"{hf_prefix}.self_attn.q_proj.weight",
|
|
77
|
+
hook_fn=lambda w, s: w.T.reshape(s),
|
|
78
|
+
)
|
|
79
|
+
loader.port_weight(
|
|
80
|
+
keras_variable=layer.self_attn.key_dense.kernel,
|
|
81
|
+
hf_weight_key=f"{hf_prefix}.self_attn.k_proj.weight",
|
|
82
|
+
hook_fn=lambda w, s: w.T.reshape(s),
|
|
83
|
+
)
|
|
84
|
+
loader.port_weight(
|
|
85
|
+
keras_variable=layer.self_attn.value_dense.kernel,
|
|
86
|
+
hf_weight_key=f"{hf_prefix}.self_attn.v_proj.weight",
|
|
87
|
+
hook_fn=lambda w, s: w.T.reshape(s),
|
|
88
|
+
)
|
|
89
|
+
loader.port_weight(
|
|
90
|
+
keras_variable=layer.self_attn.output_dense.kernel,
|
|
91
|
+
hf_weight_key=f"{hf_prefix}.self_attn.o_proj.weight",
|
|
92
|
+
hook_fn=lambda w, s: w.T.reshape(s),
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# MLP.
|
|
96
|
+
loader.port_weight(
|
|
97
|
+
keras_variable=layer.mlp.gate_proj.kernel,
|
|
98
|
+
hf_weight_key=f"{hf_prefix}.mlp.gate_proj.weight",
|
|
99
|
+
hook_fn=lambda w, s: w.T,
|
|
100
|
+
)
|
|
101
|
+
loader.port_weight(
|
|
102
|
+
keras_variable=layer.mlp.up_proj.kernel,
|
|
103
|
+
hf_weight_key=f"{hf_prefix}.mlp.up_proj.weight",
|
|
104
|
+
hook_fn=lambda w, s: w.T,
|
|
105
|
+
)
|
|
106
|
+
loader.port_weight(
|
|
107
|
+
keras_variable=layer.mlp.down_proj.kernel,
|
|
108
|
+
hf_weight_key=f"{hf_prefix}.mlp.down_proj.weight",
|
|
109
|
+
hook_fn=lambda w, s: w.T,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Layer norm.
|
|
113
|
+
loader.port_weight(
|
|
114
|
+
keras_variable=layer.pre_self_attn_layernorm.scale,
|
|
115
|
+
hf_weight_key=f"{hf_prefix}.pre_self_attn_layernorm.weight",
|
|
116
|
+
)
|
|
117
|
+
loader.port_weight(
|
|
118
|
+
keras_variable=layer.post_self_attn_layernorm.scale,
|
|
119
|
+
hf_weight_key=f"{hf_prefix}.post_self_attn_layernorm.weight",
|
|
120
|
+
)
|
|
121
|
+
loader.port_weight(
|
|
122
|
+
keras_variable=layer.pre_feedforward_layernorm.scale,
|
|
123
|
+
hf_weight_key=f"{hf_prefix}.pre_feedforward_layernorm.weight",
|
|
124
|
+
)
|
|
125
|
+
loader.port_weight(
|
|
126
|
+
keras_variable=layer.post_feedforward_layernorm.scale,
|
|
127
|
+
hf_weight_key=f"{hf_prefix}.post_feedforward_layernorm.weight",
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Decoder.
|
|
131
|
+
loader.port_weight(
|
|
132
|
+
keras_variable=backbone.decoder_norm.scale,
|
|
133
|
+
hf_weight_key="decoder.norm.weight",
|
|
134
|
+
)
|
|
135
|
+
for i in range(backbone.decoder_num_layers):
|
|
136
|
+
layer = backbone.get_layer(f"decoder_layer_{i}")
|
|
137
|
+
hf_prefix = f"decoder.layers.{i}"
|
|
138
|
+
|
|
139
|
+
# Self-attention.
|
|
140
|
+
loader.port_weight(
|
|
141
|
+
keras_variable=layer.self_attn.query_dense.kernel,
|
|
142
|
+
hf_weight_key=f"{hf_prefix}.self_attn.q_proj.weight",
|
|
143
|
+
hook_fn=lambda w, s: w.T.reshape(s),
|
|
144
|
+
)
|
|
145
|
+
loader.port_weight(
|
|
146
|
+
keras_variable=layer.self_attn.key_dense.kernel,
|
|
147
|
+
hf_weight_key=f"{hf_prefix}.self_attn.k_proj.weight",
|
|
148
|
+
hook_fn=lambda w, s: w.T.reshape(s),
|
|
149
|
+
)
|
|
150
|
+
loader.port_weight(
|
|
151
|
+
keras_variable=layer.self_attn.value_dense.kernel,
|
|
152
|
+
hf_weight_key=f"{hf_prefix}.self_attn.v_proj.weight",
|
|
153
|
+
hook_fn=lambda w, s: w.T.reshape(s),
|
|
154
|
+
)
|
|
155
|
+
loader.port_weight(
|
|
156
|
+
keras_variable=layer.self_attn.output_dense.kernel,
|
|
157
|
+
hf_weight_key=f"{hf_prefix}.self_attn.o_proj.weight",
|
|
158
|
+
hook_fn=lambda w, s: w.T.reshape(s),
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Cross-attention.
|
|
162
|
+
loader.port_weight(
|
|
163
|
+
keras_variable=layer.cross_attn.query_dense.kernel,
|
|
164
|
+
hf_weight_key=f"{hf_prefix}.cross_attn.q_proj.weight",
|
|
165
|
+
hook_fn=lambda w, s: w.T.reshape(s),
|
|
166
|
+
)
|
|
167
|
+
loader.port_weight(
|
|
168
|
+
keras_variable=layer.cross_attn.key_dense.kernel,
|
|
169
|
+
hf_weight_key=f"{hf_prefix}.cross_attn.k_proj.weight",
|
|
170
|
+
hook_fn=lambda w, s: w.T.reshape(s),
|
|
171
|
+
)
|
|
172
|
+
loader.port_weight(
|
|
173
|
+
keras_variable=layer.cross_attn.value_dense.kernel,
|
|
174
|
+
hf_weight_key=f"{hf_prefix}.cross_attn.v_proj.weight",
|
|
175
|
+
hook_fn=lambda w, s: w.T.reshape(s),
|
|
176
|
+
)
|
|
177
|
+
loader.port_weight(
|
|
178
|
+
keras_variable=layer.cross_attn.output_dense.kernel,
|
|
179
|
+
hf_weight_key=f"{hf_prefix}.cross_attn.o_proj.weight",
|
|
180
|
+
hook_fn=lambda w, s: w.T.reshape(s),
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# MLP.
|
|
184
|
+
loader.port_weight(
|
|
185
|
+
keras_variable=layer.mlp.gate_proj.kernel,
|
|
186
|
+
hf_weight_key=f"{hf_prefix}.mlp.gate_proj.weight",
|
|
187
|
+
hook_fn=lambda w, s: w.T,
|
|
188
|
+
)
|
|
189
|
+
loader.port_weight(
|
|
190
|
+
keras_variable=layer.mlp.up_proj.kernel,
|
|
191
|
+
hf_weight_key=f"{hf_prefix}.mlp.up_proj.weight",
|
|
192
|
+
hook_fn=lambda w, s: w.T,
|
|
193
|
+
)
|
|
194
|
+
loader.port_weight(
|
|
195
|
+
keras_variable=layer.mlp.down_proj.kernel,
|
|
196
|
+
hf_weight_key=f"{hf_prefix}.mlp.down_proj.weight",
|
|
197
|
+
hook_fn=lambda w, s: w.T,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Layer norm.
|
|
201
|
+
loader.port_weight(
|
|
202
|
+
keras_variable=layer.pre_self_attn_layernorm.scale,
|
|
203
|
+
hf_weight_key=f"{hf_prefix}.pre_self_attn_layernorm.weight",
|
|
204
|
+
)
|
|
205
|
+
loader.port_weight(
|
|
206
|
+
keras_variable=layer.post_self_attn_layernorm.scale,
|
|
207
|
+
hf_weight_key=f"{hf_prefix}.post_self_attn_layernorm.weight",
|
|
208
|
+
)
|
|
209
|
+
loader.port_weight(
|
|
210
|
+
keras_variable=layer.pre_cross_attn_layernorm.scale,
|
|
211
|
+
hf_weight_key=f"{hf_prefix}.pre_cross_attn_layernorm.weight",
|
|
212
|
+
)
|
|
213
|
+
loader.port_weight(
|
|
214
|
+
keras_variable=layer.post_cross_attn_layernorm.scale,
|
|
215
|
+
hf_weight_key=f"{hf_prefix}.post_cross_attn_layernorm.weight",
|
|
216
|
+
)
|
|
217
|
+
loader.port_weight(
|
|
218
|
+
keras_variable=layer.pre_feedforward_layernorm.scale,
|
|
219
|
+
hf_weight_key=f"{hf_prefix}.pre_feedforward_layernorm.weight",
|
|
220
|
+
)
|
|
221
|
+
loader.port_weight(
|
|
222
|
+
keras_variable=layer.post_feedforward_layernorm.scale,
|
|
223
|
+
hf_weight_key=f"{hf_prefix}.post_feedforward_layernorm.weight",
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def convert_tokenizer(cls, preset, **kwargs):
|
|
228
|
+
"""Convert a T5Gemma tokenizer."""
|
|
229
|
+
return cls(get_file(preset, "tokenizer.model"), **kwargs)
|
|
@@ -9,7 +9,10 @@ def convert_backbone_config(transformers_config):
|
|
|
9
9
|
image_size = transformers_config["image_size"]
|
|
10
10
|
return {
|
|
11
11
|
"image_shape": (image_size, image_size, 3),
|
|
12
|
-
"patch_size":
|
|
12
|
+
"patch_size": (
|
|
13
|
+
transformers_config["patch_size"],
|
|
14
|
+
transformers_config["patch_size"],
|
|
15
|
+
),
|
|
13
16
|
"num_layers": transformers_config["num_hidden_layers"],
|
|
14
17
|
"num_heads": transformers_config["num_attention_heads"],
|
|
15
18
|
"hidden_dim": transformers_config["hidden_size"],
|
|
@@ -2,6 +2,7 @@ import keras.ops as ops
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
def get_gemma_config(backbone):
|
|
5
|
+
token_embedding_layer = backbone.get_layer("token_embedding")
|
|
5
6
|
hf_config = {
|
|
6
7
|
"vocab_size": backbone.vocabulary_size,
|
|
7
8
|
"num_hidden_layers": backbone.num_layers,
|
|
@@ -11,11 +12,16 @@ def get_gemma_config(backbone):
|
|
|
11
12
|
"intermediate_size": backbone.intermediate_dim // 2,
|
|
12
13
|
"head_dim": backbone.head_dim,
|
|
13
14
|
"max_position_embeddings": 8192,
|
|
15
|
+
"tie_word_embeddings": token_embedding_layer.tie_weights,
|
|
16
|
+
"pad_token_id": 0,
|
|
17
|
+
"bos_token_id": 2,
|
|
18
|
+
"eos_token_id": 1,
|
|
19
|
+
"model_type": "gemma",
|
|
14
20
|
}
|
|
15
21
|
return hf_config
|
|
16
22
|
|
|
17
23
|
|
|
18
|
-
def get_gemma_weights_map(backbone):
|
|
24
|
+
def get_gemma_weights_map(backbone, include_lm_head=False):
|
|
19
25
|
weights_dict = {}
|
|
20
26
|
|
|
21
27
|
# Map token embedding
|
|
@@ -83,7 +89,46 @@ def get_gemma_weights_map(backbone):
|
|
|
83
89
|
"final_normalization"
|
|
84
90
|
).weights[0]
|
|
85
91
|
|
|
86
|
-
#
|
|
87
|
-
|
|
88
|
-
|
|
92
|
+
# Map lm_head if embeddings are not tied
|
|
93
|
+
if include_lm_head and not token_embedding_layer.tie_weights:
|
|
94
|
+
weights_dict["lm_head.weight"] = ops.transpose(
|
|
95
|
+
token_embedding_layer.reverse_embeddings
|
|
96
|
+
)
|
|
89
97
|
return weights_dict
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def get_gemma_tokenizer_config(tokenizer):
|
|
101
|
+
tokenizer_config = {
|
|
102
|
+
"tokenizer_class": "GemmaTokenizer",
|
|
103
|
+
"clean_up_tokenization_spaces": False,
|
|
104
|
+
"bos_token": "<bos>",
|
|
105
|
+
"eos_token": "<eos>",
|
|
106
|
+
"pad_token": "<pad>",
|
|
107
|
+
"unk_token": "<unk>",
|
|
108
|
+
"add_bos_token": True,
|
|
109
|
+
"add_eos_token": False,
|
|
110
|
+
"model_max_length": 8192,
|
|
111
|
+
}
|
|
112
|
+
# Add added_tokens_decoder
|
|
113
|
+
added_tokens_decoder = {}
|
|
114
|
+
special_tokens = [
|
|
115
|
+
"<pad>",
|
|
116
|
+
"<bos>",
|
|
117
|
+
"<eos>",
|
|
118
|
+
"<unk>",
|
|
119
|
+
"<start_of_turn>",
|
|
120
|
+
"<end_of_turn>",
|
|
121
|
+
]
|
|
122
|
+
for token in special_tokens:
|
|
123
|
+
token_id = tokenizer.token_to_id(token)
|
|
124
|
+
if token_id is not None:
|
|
125
|
+
added_tokens_decoder[str(token_id)] = {
|
|
126
|
+
"content": token,
|
|
127
|
+
"special": True,
|
|
128
|
+
"single_word": False,
|
|
129
|
+
"lstrip": False,
|
|
130
|
+
"rstrip": False,
|
|
131
|
+
"normalized": False,
|
|
132
|
+
}
|
|
133
|
+
tokenizer_config["added_tokens_decoder"] = added_tokens_decoder
|
|
134
|
+
return tokenizer_config
|