keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +21 -0
- keras_hub/models/__init__.py +27 -0
- keras_hub/src/layers/modeling/non_max_supression.py +5 -2
- keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
- keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
- keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
- keras_hub/src/models/albert/albert_backbone.py +1 -3
- keras_hub/src/models/backbone.py +3 -0
- keras_hub/src/models/bart/bart_backbone.py +1 -3
- keras_hub/src/models/bert/bert_backbone.py +2 -4
- keras_hub/src/models/bloom/bloom_backbone.py +1 -3
- keras_hub/src/models/causal_lm.py +2 -2
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
- keras_hub/src/models/edrec/edrec_backbone.py +147 -0
- keras_hub/src/models/edrec/edrec_layers.py +434 -0
- keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
- keras_hub/src/models/electra/electra_backbone.py +1 -3
- keras_hub/src/models/f_net/f_net_backbone.py +1 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -3
- keras_hub/src/models/flux/flux_layers.py +3 -3
- keras_hub/src/models/flux/flux_maths.py +29 -15
- keras_hub/src/models/gemma/gemma_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
- keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
- keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
- keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
- keras_hub/src/models/llama/llama_backbone.py +1 -3
- keras_hub/src/models/masked_lm.py +1 -1
- keras_hub/src/models/mistral/mistral_backbone.py +1 -3
- keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
- keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
- keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
- keras_hub/src/models/phi3/phi3_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_presets.py +209 -0
- keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
- keras_hub/src/models/rqvae/__init__.py +5 -0
- keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
- keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
- keras_hub/src/models/rwkv7/__init__.py +5 -0
- keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
- keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
- keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
- keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
- keras_hub/src/models/sam/sam_backbone.py +5 -1
- keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
- keras_hub/src/models/sam3/__init__.py +7 -0
- keras_hub/src/models/sam3/roi_align.py +222 -0
- keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
- keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
- keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
- keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
- keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
- keras_hub/src/models/sam3/sam3_layers.py +814 -0
- keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
- keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
- keras_hub/src/models/sam3/sam3_presets.py +16 -0
- keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
- keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
- keras_hub/src/models/sam3/sam3_utils.py +134 -0
- keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
- keras_hub/src/models/segformer/segformer_backbone.py +6 -6
- keras_hub/src/models/siglip/siglip_layers.py +1 -3
- keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
- keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
- keras_hub/src/models/t5/t5_backbone.py +1 -3
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
- keras_hub/src/models/task.py +1 -1
- keras_hub/src/tests/test_case.py +394 -3
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
- keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
- keras_hub/src/utils/preset_utils.py +1 -1
- keras_hub/src/utils/tensor_utils.py +12 -0
- keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
- keras_hub/src/utils/transformers/convert_sam3.py +472 -0
- keras_hub/src/utils/transformers/export/gemma3.py +196 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
- keras_hub/src/utils/transformers/export/qwen.py +136 -0
- keras_hub/src/utils/transformers/preset_loader.py +15 -1
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +6 -0
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
- keras_hub/src/models/gemma3/rms_normalization.py +0 -26
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,472 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from keras import layers
|
|
5
|
+
|
|
6
|
+
from keras_hub.src.models.sam3.sam3_detr_decoder import SAM3DetrDecoder
|
|
7
|
+
from keras_hub.src.models.sam3.sam3_detr_encoder import SAM3DetrEncoder
|
|
8
|
+
from keras_hub.src.models.sam3.sam3_geometry_encoder import SAM3GeometryEncoder
|
|
9
|
+
from keras_hub.src.models.sam3.sam3_mask_decoder import SAM3MaskDecoder
|
|
10
|
+
from keras_hub.src.models.sam3.sam3_pc_backbone import (
|
|
11
|
+
SAM3PromptableConceptBackbone,
|
|
12
|
+
)
|
|
13
|
+
from keras_hub.src.models.sam3.sam3_text_encoder import SAM3TextEncoder
|
|
14
|
+
from keras_hub.src.models.sam3.sam3_vision_encoder import SAM3VisionEncoder
|
|
15
|
+
from keras_hub.src.utils.preset_utils import load_json
|
|
16
|
+
|
|
17
|
+
backbone_cls = SAM3PromptableConceptBackbone
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def convert_backbone_config(transformers_config, cls, **kwargs):
|
|
21
|
+
# detector_config: Promptable Concept Segmentation (PCS)
|
|
22
|
+
# tracker_config: Promptable Visual Segmentation (PVS)
|
|
23
|
+
if issubclass(cls, SAM3PromptableConceptBackbone):
|
|
24
|
+
# Extract sub-configurations.
|
|
25
|
+
transformers_config = transformers_config["detector_config"]
|
|
26
|
+
|
|
27
|
+
vision_config = transformers_config["vision_config"]
|
|
28
|
+
backbone_config = vision_config["backbone_config"]
|
|
29
|
+
text_config = transformers_config["text_config"]
|
|
30
|
+
geom_config = transformers_config["geometry_encoder_config"]
|
|
31
|
+
detr_enc_config = transformers_config["detr_encoder_config"]
|
|
32
|
+
detr_dec_config = transformers_config["detr_decoder_config"]
|
|
33
|
+
mask_dec_config = transformers_config["mask_decoder_config"]
|
|
34
|
+
dtype = kwargs.pop("dtype", None)
|
|
35
|
+
image_shape = kwargs.pop("image_shape", None)
|
|
36
|
+
if image_shape is None:
|
|
37
|
+
image_shape = (
|
|
38
|
+
backbone_config["image_size"],
|
|
39
|
+
backbone_config["image_size"],
|
|
40
|
+
3,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Vision Encoder.
|
|
44
|
+
vision_encoder_config = {
|
|
45
|
+
"image_shape": image_shape,
|
|
46
|
+
"patch_size": backbone_config["patch_size"],
|
|
47
|
+
"num_layers": backbone_config["num_hidden_layers"],
|
|
48
|
+
"hidden_dim": backbone_config["hidden_size"],
|
|
49
|
+
"intermediate_dim": backbone_config["intermediate_size"],
|
|
50
|
+
"num_heads": backbone_config["num_attention_heads"],
|
|
51
|
+
"fpn_hidden_dim": vision_config["fpn_hidden_size"],
|
|
52
|
+
"fpn_scale_factors": vision_config["scale_factors"],
|
|
53
|
+
"pretrain_image_shape": (
|
|
54
|
+
backbone_config["pretrain_image_size"],
|
|
55
|
+
backbone_config["pretrain_image_size"],
|
|
56
|
+
3,
|
|
57
|
+
),
|
|
58
|
+
"hidden_activation": backbone_config["hidden_act"],
|
|
59
|
+
"rope_theta": backbone_config["rope_theta"],
|
|
60
|
+
"window_size": backbone_config["window_size"],
|
|
61
|
+
"global_attn_indexes": backbone_config["global_attn_indexes"],
|
|
62
|
+
"attention_dropout_rate": backbone_config["attention_dropout"],
|
|
63
|
+
"hidden_dropout_rate": backbone_config["hidden_dropout"],
|
|
64
|
+
"layer_norm_epsilon": backbone_config["layer_norm_eps"],
|
|
65
|
+
"dtype": dtype,
|
|
66
|
+
}
|
|
67
|
+
vision_encoder = SAM3VisionEncoder(**vision_encoder_config)
|
|
68
|
+
|
|
69
|
+
# Text Encoder.
|
|
70
|
+
text_encoder_config = {
|
|
71
|
+
"vocabulary_size": text_config["vocab_size"],
|
|
72
|
+
"embedding_dim": text_config["hidden_size"],
|
|
73
|
+
"hidden_dim": text_config["hidden_size"],
|
|
74
|
+
"num_layers": text_config["num_hidden_layers"],
|
|
75
|
+
"num_heads": text_config["num_attention_heads"],
|
|
76
|
+
"intermediate_dim": text_config["intermediate_size"],
|
|
77
|
+
"intermediate_activation": text_config["hidden_act"],
|
|
78
|
+
"max_sequence_length": text_config["max_position_embeddings"],
|
|
79
|
+
"layer_norm_epsilon": text_config["layer_norm_eps"],
|
|
80
|
+
"dtype": dtype,
|
|
81
|
+
}
|
|
82
|
+
text_encoder = SAM3TextEncoder(**text_encoder_config)
|
|
83
|
+
|
|
84
|
+
# Geometry Encoder.
|
|
85
|
+
geometry_encoder_config = {
|
|
86
|
+
"num_layers": geom_config["num_layers"],
|
|
87
|
+
"hidden_dim": geom_config["hidden_size"],
|
|
88
|
+
"intermediate_dim": geom_config["intermediate_size"],
|
|
89
|
+
"num_heads": geom_config["num_attention_heads"],
|
|
90
|
+
"roi_size": geom_config["roi_size"],
|
|
91
|
+
"hidden_activation": geom_config["hidden_act"],
|
|
92
|
+
"dropout_rate": geom_config["hidden_dropout"],
|
|
93
|
+
"layer_norm_epsilon": geom_config["layer_norm_eps"],
|
|
94
|
+
"dtype": dtype,
|
|
95
|
+
}
|
|
96
|
+
geometry_encoder = SAM3GeometryEncoder(**geometry_encoder_config)
|
|
97
|
+
|
|
98
|
+
# DETR Encoder.
|
|
99
|
+
detr_encoder_config = {
|
|
100
|
+
"num_layers": detr_enc_config["num_layers"],
|
|
101
|
+
"hidden_dim": detr_enc_config["hidden_size"],
|
|
102
|
+
"intermediate_dim": detr_enc_config["intermediate_size"],
|
|
103
|
+
"num_heads": detr_enc_config["num_attention_heads"],
|
|
104
|
+
"hidden_activation": detr_enc_config["hidden_act"],
|
|
105
|
+
"dropout_rate": detr_enc_config["dropout"],
|
|
106
|
+
"layer_norm_epsilon": detr_enc_config["layer_norm_eps"],
|
|
107
|
+
"dtype": dtype,
|
|
108
|
+
}
|
|
109
|
+
detr_encoder = SAM3DetrEncoder(**detr_encoder_config)
|
|
110
|
+
|
|
111
|
+
# DETR Decoder.
|
|
112
|
+
detr_decoder_config = {
|
|
113
|
+
"image_shape": image_shape,
|
|
114
|
+
"patch_size": backbone_config["patch_size"],
|
|
115
|
+
"num_layers": detr_dec_config["num_layers"],
|
|
116
|
+
"hidden_dim": detr_dec_config["hidden_size"],
|
|
117
|
+
"intermediate_dim": detr_dec_config["intermediate_size"],
|
|
118
|
+
"num_heads": detr_dec_config["num_attention_heads"],
|
|
119
|
+
"num_queries": detr_dec_config["num_queries"],
|
|
120
|
+
"hidden_activation": detr_dec_config["hidden_act"],
|
|
121
|
+
"dropout_rate": detr_dec_config["dropout"],
|
|
122
|
+
"layer_norm_epsilon": detr_dec_config["layer_norm_eps"],
|
|
123
|
+
"dtype": dtype,
|
|
124
|
+
}
|
|
125
|
+
detr_decoder = SAM3DetrDecoder(**detr_decoder_config)
|
|
126
|
+
|
|
127
|
+
# Mask Decoder.
|
|
128
|
+
mask_decoder_config = {
|
|
129
|
+
"num_upsampling_stages": mask_dec_config["num_upsampling_stages"],
|
|
130
|
+
"hidden_dim": mask_dec_config["hidden_size"],
|
|
131
|
+
"num_heads": mask_dec_config["num_attention_heads"],
|
|
132
|
+
"dropout_rate": 0.0,
|
|
133
|
+
"layer_norm_epsilon": mask_dec_config["layer_norm_eps"],
|
|
134
|
+
"dtype": dtype,
|
|
135
|
+
}
|
|
136
|
+
mask_decoder = SAM3MaskDecoder(**mask_decoder_config)
|
|
137
|
+
|
|
138
|
+
return {
|
|
139
|
+
"vision_encoder": vision_encoder,
|
|
140
|
+
"text_encoder": text_encoder,
|
|
141
|
+
"geometry_encoder": geometry_encoder,
|
|
142
|
+
"detr_encoder": detr_encoder,
|
|
143
|
+
"detr_decoder": detr_decoder,
|
|
144
|
+
"mask_decoder": mask_decoder,
|
|
145
|
+
}
|
|
146
|
+
else:
|
|
147
|
+
# TODO: Add SAM3Tracker support.
|
|
148
|
+
raise ValueError(
|
|
149
|
+
"The provided class is not a subclass of "
|
|
150
|
+
f"SAM3PromptableConceptBackbone. Received: {cls}"
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def convert_weights(backbone, loader, transformers_config):
|
|
155
|
+
if not isinstance(backbone, SAM3PromptableConceptBackbone):
|
|
156
|
+
raise ValueError(
|
|
157
|
+
"The provided backbone must be an instance of "
|
|
158
|
+
f"SAM3PromptableConceptBackbone. Received: {type(backbone)}"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def port_dense(keras_dense, hf_name):
|
|
162
|
+
loader.port_weight(
|
|
163
|
+
keras_dense.kernel, f"{hf_name}.weight", hook_fn=lambda x, _: x.T
|
|
164
|
+
)
|
|
165
|
+
if keras_dense.bias is not None:
|
|
166
|
+
loader.port_weight(keras_dense.bias, f"{hf_name}.bias")
|
|
167
|
+
|
|
168
|
+
def port_ln(keras_ln, hf_name):
|
|
169
|
+
loader.port_weight(keras_ln.gamma, f"{hf_name}.weight")
|
|
170
|
+
loader.port_weight(keras_ln.beta, f"{hf_name}.bias")
|
|
171
|
+
|
|
172
|
+
def port_conv(keras_conv, hf_name):
|
|
173
|
+
if not keras_conv.built:
|
|
174
|
+
# https://github.com/huggingface/transformers/issues/43065
|
|
175
|
+
warnings.warn(f"Skipping {hf_name}")
|
|
176
|
+
return
|
|
177
|
+
loader.port_weight(
|
|
178
|
+
keras_conv.kernel,
|
|
179
|
+
f"{hf_name}.weight",
|
|
180
|
+
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
|
|
181
|
+
)
|
|
182
|
+
if keras_conv.bias is not None:
|
|
183
|
+
loader.port_weight(keras_conv.bias, f"{hf_name}.bias")
|
|
184
|
+
|
|
185
|
+
def port_gn(keras_gn, hf_name):
|
|
186
|
+
if not keras_gn.built:
|
|
187
|
+
# https://github.com/huggingface/transformers/issues/43065
|
|
188
|
+
warnings.warn(f"Skipping {hf_name}")
|
|
189
|
+
return
|
|
190
|
+
loader.port_weight(keras_gn.gamma, f"{hf_name}.weight")
|
|
191
|
+
loader.port_weight(keras_gn.beta, f"{hf_name}.bias")
|
|
192
|
+
|
|
193
|
+
def port_attention(keras_attn, hf_name):
|
|
194
|
+
port_dense(keras_attn.q_proj, f"{hf_name}.q_proj")
|
|
195
|
+
port_dense(keras_attn.k_proj, f"{hf_name}.k_proj")
|
|
196
|
+
port_dense(keras_attn.v_proj, f"{hf_name}.v_proj")
|
|
197
|
+
port_dense(keras_attn.o_proj, f"{hf_name}.o_proj")
|
|
198
|
+
|
|
199
|
+
def port_mlp(keras_mlp, hf_name):
|
|
200
|
+
port_dense(keras_mlp.fc1, f"{hf_name}.fc1")
|
|
201
|
+
port_dense(keras_mlp.fc2, f"{hf_name}.fc2")
|
|
202
|
+
|
|
203
|
+
def port_decoder_mlp(keras_mlp, hf_name):
|
|
204
|
+
port_dense(keras_mlp.layer1, f"{hf_name}.layer1")
|
|
205
|
+
port_dense(keras_mlp.layer2, f"{hf_name}.layer2")
|
|
206
|
+
if hasattr(keras_mlp, "layer3") and keras_mlp.layer3 is not None:
|
|
207
|
+
port_dense(keras_mlp.layer3, f"{hf_name}.layer3")
|
|
208
|
+
|
|
209
|
+
# Vision Encoder.
|
|
210
|
+
vision_prefix = "vision_encoder"
|
|
211
|
+
backbone_prefix = f"{vision_prefix}.backbone"
|
|
212
|
+
emb = backbone.vision_encoder.backbone.embeddings
|
|
213
|
+
port_conv(
|
|
214
|
+
emb.patch_embeddings.projection,
|
|
215
|
+
f"{backbone_prefix}.embeddings.patch_embeddings.projection",
|
|
216
|
+
)
|
|
217
|
+
loader.port_weight(
|
|
218
|
+
emb.position_embeddings,
|
|
219
|
+
f"{backbone_prefix}.embeddings.position_embeddings",
|
|
220
|
+
)
|
|
221
|
+
emb.tiled_position_embeddings.assign(
|
|
222
|
+
emb._tile_position_embeddings(
|
|
223
|
+
emb.position_embeddings,
|
|
224
|
+
patch_size=emb.patch_size,
|
|
225
|
+
source_shape=emb.pretrain_image_shape,
|
|
226
|
+
target_shape=emb.image_shape,
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
port_ln(
|
|
230
|
+
backbone.vision_encoder.backbone.layer_norm,
|
|
231
|
+
f"{backbone_prefix}.layer_norm",
|
|
232
|
+
)
|
|
233
|
+
for i, layer in enumerate(backbone.vision_encoder.backbone.layers):
|
|
234
|
+
p = f"{backbone_prefix}.layers.{i}"
|
|
235
|
+
port_ln(layer.layer_norm1, f"{p}.layer_norm1")
|
|
236
|
+
port_attention(layer.attention, f"{p}.attention")
|
|
237
|
+
port_ln(layer.layer_norm2, f"{p}.layer_norm2")
|
|
238
|
+
port_mlp(layer.mlp, f"{p}.mlp")
|
|
239
|
+
|
|
240
|
+
neck_prefix = f"{vision_prefix}.neck"
|
|
241
|
+
for i, layer in enumerate(backbone.vision_encoder.vision_neck.fpn_layers):
|
|
242
|
+
p = f"{neck_prefix}.fpn_layers.{i}"
|
|
243
|
+
# FPN scale layers
|
|
244
|
+
for j, scale_layer in enumerate(layer.scale_layers):
|
|
245
|
+
if isinstance(scale_layer, (layers.Conv2DTranspose, layers.Conv2D)):
|
|
246
|
+
port_conv(scale_layer, f"{p}.scale_layers.{j}")
|
|
247
|
+
|
|
248
|
+
port_conv(layer.proj1, f"{p}.proj1")
|
|
249
|
+
port_conv(layer.proj2, f"{p}.proj2")
|
|
250
|
+
|
|
251
|
+
# Text Encoder.
|
|
252
|
+
text_prefix = "text_encoder.text_model"
|
|
253
|
+
loader.port_weight(
|
|
254
|
+
backbone.text_encoder.embedding.token_embedding.embeddings,
|
|
255
|
+
f"{text_prefix}.embeddings.token_embedding.weight",
|
|
256
|
+
)
|
|
257
|
+
loader.port_weight(
|
|
258
|
+
backbone.text_encoder.embedding.position_embedding.position_embeddings,
|
|
259
|
+
f"{text_prefix}.embeddings.position_embedding.weight",
|
|
260
|
+
)
|
|
261
|
+
for i, layer in enumerate(backbone.text_encoder.encoder_layers):
|
|
262
|
+
p = f"{text_prefix}.encoder.layers.{i}"
|
|
263
|
+
port_ln(layer.layer_norm_1, f"{p}.layer_norm1")
|
|
264
|
+
num_heads = backbone.text_encoder.num_heads
|
|
265
|
+
hidden_dim = backbone.text_encoder.hidden_dim
|
|
266
|
+
head_dim = hidden_dim // num_heads
|
|
267
|
+
|
|
268
|
+
def port_mha_weight(keras_dense, hf_name, is_output=False):
|
|
269
|
+
def hook(x, _):
|
|
270
|
+
w = x.T
|
|
271
|
+
if is_output:
|
|
272
|
+
return w.reshape(num_heads, head_dim, hidden_dim)
|
|
273
|
+
else:
|
|
274
|
+
return w.reshape(hidden_dim, num_heads, head_dim)
|
|
275
|
+
|
|
276
|
+
loader.port_weight(
|
|
277
|
+
keras_dense.kernel,
|
|
278
|
+
f"{hf_name}.weight",
|
|
279
|
+
hook_fn=hook,
|
|
280
|
+
)
|
|
281
|
+
if keras_dense.bias is not None:
|
|
282
|
+
|
|
283
|
+
def bias_hook(x, _):
|
|
284
|
+
if is_output:
|
|
285
|
+
return x # (hidden,)
|
|
286
|
+
else:
|
|
287
|
+
return x.reshape(num_heads, head_dim)
|
|
288
|
+
|
|
289
|
+
loader.port_weight(
|
|
290
|
+
keras_dense.bias, f"{hf_name}.bias", hook_fn=bias_hook
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
port_mha_weight(layer.attention._query_dense, f"{p}.self_attn.q_proj")
|
|
294
|
+
port_mha_weight(layer.attention._key_dense, f"{p}.self_attn.k_proj")
|
|
295
|
+
port_mha_weight(layer.attention._value_dense, f"{p}.self_attn.v_proj")
|
|
296
|
+
port_mha_weight(
|
|
297
|
+
layer.attention._output_dense,
|
|
298
|
+
f"{p}.self_attn.out_proj",
|
|
299
|
+
is_output=True,
|
|
300
|
+
)
|
|
301
|
+
port_ln(layer.layer_norm_2, f"{p}.layer_norm2")
|
|
302
|
+
port_dense(layer.dense_1, f"{p}.mlp.fc1")
|
|
303
|
+
port_dense(layer.dense_2, f"{p}.mlp.fc2")
|
|
304
|
+
|
|
305
|
+
port_ln(backbone.text_encoder.layer_norm, f"{text_prefix}.final_layer_norm")
|
|
306
|
+
port_dense(backbone.text_projection, "text_projection")
|
|
307
|
+
|
|
308
|
+
# Geometry Encoder.
|
|
309
|
+
geo_prefix = "geometry_encoder"
|
|
310
|
+
loader.port_weight(
|
|
311
|
+
backbone.geometry_encoder.label_embed.embeddings,
|
|
312
|
+
f"{geo_prefix}.label_embed.weight",
|
|
313
|
+
)
|
|
314
|
+
loader.port_weight(
|
|
315
|
+
backbone.geometry_encoder.cls_embed.embeddings,
|
|
316
|
+
f"{geo_prefix}.cls_embed.weight",
|
|
317
|
+
)
|
|
318
|
+
port_dense(
|
|
319
|
+
backbone.geometry_encoder.boxes_direct_project,
|
|
320
|
+
f"{geo_prefix}.boxes_direct_project",
|
|
321
|
+
)
|
|
322
|
+
port_conv(
|
|
323
|
+
backbone.geometry_encoder.boxes_pool_project,
|
|
324
|
+
f"{geo_prefix}.boxes_pool_project",
|
|
325
|
+
)
|
|
326
|
+
port_dense(
|
|
327
|
+
backbone.geometry_encoder.boxes_pos_enc_project,
|
|
328
|
+
f"{geo_prefix}.boxes_pos_enc_project",
|
|
329
|
+
)
|
|
330
|
+
port_ln(
|
|
331
|
+
backbone.geometry_encoder.vision_layer_norm,
|
|
332
|
+
f"{geo_prefix}.vision_layer_norm",
|
|
333
|
+
)
|
|
334
|
+
port_dense(backbone.geometry_encoder.final_proj, f"{geo_prefix}.final_proj")
|
|
335
|
+
port_ln(
|
|
336
|
+
backbone.geometry_encoder.prompt_layer_norm,
|
|
337
|
+
f"{geo_prefix}.prompt_layer_norm",
|
|
338
|
+
)
|
|
339
|
+
for i, layer in enumerate(backbone.geometry_encoder.layers):
|
|
340
|
+
p = f"{geo_prefix}.layers.{i}"
|
|
341
|
+
port_ln(layer.layer_norm1, f"{p}.layer_norm1")
|
|
342
|
+
port_attention(layer.self_attn, f"{p}.self_attn")
|
|
343
|
+
port_ln(layer.layer_norm2, f"{p}.layer_norm2")
|
|
344
|
+
port_attention(layer.cross_attn, f"{p}.cross_attn")
|
|
345
|
+
port_ln(layer.layer_norm3, f"{p}.layer_norm3")
|
|
346
|
+
port_mlp(layer.mlp, f"{p}.mlp")
|
|
347
|
+
port_ln(
|
|
348
|
+
backbone.geometry_encoder.output_layer_norm,
|
|
349
|
+
f"{geo_prefix}.output_layer_norm",
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# DETR Encoder.
|
|
353
|
+
detr_enc_prefix = "detr_encoder"
|
|
354
|
+
for i, layer in enumerate(backbone.detr_encoder.layers):
|
|
355
|
+
p = f"{detr_enc_prefix}.layers.{i}"
|
|
356
|
+
port_ln(layer.layer_norm1, f"{p}.layer_norm1")
|
|
357
|
+
port_attention(layer.self_attn, f"{p}.self_attn")
|
|
358
|
+
port_attention(layer.cross_attn, f"{p}.cross_attn")
|
|
359
|
+
port_ln(layer.layer_norm2, f"{p}.layer_norm2")
|
|
360
|
+
port_mlp(layer.mlp, f"{p}.mlp")
|
|
361
|
+
port_ln(layer.layer_norm3, f"{p}.layer_norm3")
|
|
362
|
+
|
|
363
|
+
# DETR Decoder.
|
|
364
|
+
detr_dec_prefix = "detr_decoder"
|
|
365
|
+
port_ln(
|
|
366
|
+
backbone.detr_decoder.output_layer_norm,
|
|
367
|
+
f"{detr_dec_prefix}.output_layer_norm",
|
|
368
|
+
)
|
|
369
|
+
port_decoder_mlp(
|
|
370
|
+
backbone.detr_decoder.box_head, f"{detr_dec_prefix}.box_head"
|
|
371
|
+
)
|
|
372
|
+
loader.port_weight(
|
|
373
|
+
backbone.detr_decoder.query_embed.embeddings,
|
|
374
|
+
f"{detr_dec_prefix}.query_embed.weight",
|
|
375
|
+
)
|
|
376
|
+
loader.port_weight(
|
|
377
|
+
backbone.detr_decoder.reference_points.embeddings,
|
|
378
|
+
f"{detr_dec_prefix}.reference_points.weight",
|
|
379
|
+
)
|
|
380
|
+
loader.port_weight(
|
|
381
|
+
backbone.detr_decoder.presence_token.embeddings,
|
|
382
|
+
f"{detr_dec_prefix}.presence_token.weight",
|
|
383
|
+
)
|
|
384
|
+
port_decoder_mlp(
|
|
385
|
+
backbone.detr_decoder.presence_head, f"{detr_dec_prefix}.presence_head"
|
|
386
|
+
)
|
|
387
|
+
port_ln(
|
|
388
|
+
backbone.detr_decoder.presence_layer_norm,
|
|
389
|
+
f"{detr_dec_prefix}.presence_layer_norm",
|
|
390
|
+
)
|
|
391
|
+
port_decoder_mlp(
|
|
392
|
+
backbone.detr_decoder.ref_point_head,
|
|
393
|
+
f"{detr_dec_prefix}.ref_point_head",
|
|
394
|
+
)
|
|
395
|
+
port_decoder_mlp(
|
|
396
|
+
backbone.detr_decoder.box_rpb_embed_x,
|
|
397
|
+
f"{detr_dec_prefix}.box_rpb_embed_x",
|
|
398
|
+
)
|
|
399
|
+
port_decoder_mlp(
|
|
400
|
+
backbone.detr_decoder.box_rpb_embed_y,
|
|
401
|
+
f"{detr_dec_prefix}.box_rpb_embed_y",
|
|
402
|
+
)
|
|
403
|
+
for i, layer in enumerate(backbone.detr_decoder.layers):
|
|
404
|
+
p = f"{detr_dec_prefix}.layers.{i}"
|
|
405
|
+
port_attention(layer.self_attn, f"{p}.self_attn")
|
|
406
|
+
port_ln(layer.self_attn_layer_norm, f"{p}.self_attn_layer_norm")
|
|
407
|
+
port_attention(layer.text_cross_attn, f"{p}.text_cross_attn")
|
|
408
|
+
port_ln(
|
|
409
|
+
layer.text_cross_attn_layer_norm, f"{p}.text_cross_attn_layer_norm"
|
|
410
|
+
)
|
|
411
|
+
port_attention(layer.vision_cross_attn, f"{p}.vision_cross_attn")
|
|
412
|
+
port_ln(
|
|
413
|
+
layer.vision_cross_attn_layer_norm,
|
|
414
|
+
f"{p}.vision_cross_attn_layer_norm",
|
|
415
|
+
)
|
|
416
|
+
port_mlp(layer.mlp, f"{p}.mlp")
|
|
417
|
+
port_ln(layer.mlp_layer_norm, f"{p}.mlp_layer_norm")
|
|
418
|
+
|
|
419
|
+
# Mask Decoder.
|
|
420
|
+
mask_prefix = "mask_decoder"
|
|
421
|
+
for i in range(len(backbone.mask_decoder.pixel_decoder.conv_layers)):
|
|
422
|
+
p = f"{mask_prefix}.pixel_decoder"
|
|
423
|
+
port_conv(
|
|
424
|
+
backbone.mask_decoder.pixel_decoder.conv_layers[i],
|
|
425
|
+
f"{p}.conv_layers.{i}",
|
|
426
|
+
)
|
|
427
|
+
port_gn(backbone.mask_decoder.pixel_decoder.norms[i], f"{p}.norms.{i}")
|
|
428
|
+
for i in range(len(backbone.mask_decoder.mask_embedder.layers)):
|
|
429
|
+
port_dense(
|
|
430
|
+
backbone.mask_decoder.mask_embedder.layers[i],
|
|
431
|
+
f"{mask_prefix}.mask_embedder.layers.{i}",
|
|
432
|
+
)
|
|
433
|
+
port_conv(
|
|
434
|
+
backbone.mask_decoder.instance_projection,
|
|
435
|
+
f"{mask_prefix}.instance_projection",
|
|
436
|
+
)
|
|
437
|
+
port_conv(
|
|
438
|
+
backbone.mask_decoder.semantic_projection,
|
|
439
|
+
f"{mask_prefix}.semantic_projection",
|
|
440
|
+
)
|
|
441
|
+
port_attention(
|
|
442
|
+
backbone.mask_decoder.prompt_cross_attn,
|
|
443
|
+
f"{mask_prefix}.prompt_cross_attn",
|
|
444
|
+
)
|
|
445
|
+
port_ln(
|
|
446
|
+
backbone.mask_decoder.prompt_cross_attn_norm,
|
|
447
|
+
f"{mask_prefix}.prompt_cross_attn_norm",
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
# Top Level Backbone Layers.
|
|
451
|
+
scoring_prefix = "dot_product_scoring"
|
|
452
|
+
port_decoder_mlp(
|
|
453
|
+
backbone.dot_product_scoring.text_mlp, f"{scoring_prefix}.text_mlp"
|
|
454
|
+
)
|
|
455
|
+
port_ln(
|
|
456
|
+
backbone.dot_product_scoring.text_mlp_out_norm,
|
|
457
|
+
f"{scoring_prefix}.text_mlp_out_norm",
|
|
458
|
+
)
|
|
459
|
+
port_dense(
|
|
460
|
+
backbone.dot_product_scoring.text_proj, f"{scoring_prefix}.text_proj"
|
|
461
|
+
)
|
|
462
|
+
port_dense(
|
|
463
|
+
backbone.dot_product_scoring.query_proj, f"{scoring_prefix}.query_proj"
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def convert_tokenizer(cls, preset, **kwargs):
|
|
468
|
+
tokenizer_config = load_json(preset, "tokenizer.json")
|
|
469
|
+
vocab = tokenizer_config["model"]["vocab"]
|
|
470
|
+
merges = tokenizer_config["model"]["merges"]
|
|
471
|
+
merges = [" ".join(item) for item in merges]
|
|
472
|
+
return cls(vocabulary=vocab, merges=merges, **kwargs)
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
import keras.ops as ops
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_gemma3_config(backbone):
|
|
5
|
+
"""Convert Keras Gemma3 config to Hugging Face config dictionary."""
|
|
6
|
+
|
|
7
|
+
layer_types = []
|
|
8
|
+
for i in range(backbone.num_layers):
|
|
9
|
+
if backbone.use_sliding_window_attention and (i % 6 < 5):
|
|
10
|
+
layer_types.append("sliding_attention")
|
|
11
|
+
else:
|
|
12
|
+
layer_types.append("full_attention")
|
|
13
|
+
|
|
14
|
+
hf_config = {
|
|
15
|
+
"architectures": ["Gemma3ForCausalLM"],
|
|
16
|
+
"model_type": "gemma3_text",
|
|
17
|
+
"vocab_size": backbone.vocabulary_size,
|
|
18
|
+
"num_hidden_layers": backbone.num_layers,
|
|
19
|
+
"num_attention_heads": backbone.num_query_heads,
|
|
20
|
+
"num_key_value_heads": backbone.num_key_value_heads,
|
|
21
|
+
"hidden_size": backbone.hidden_dim,
|
|
22
|
+
"intermediate_size": backbone.intermediate_dim,
|
|
23
|
+
"head_dim": backbone.head_dim,
|
|
24
|
+
"rms_norm_eps": backbone.layer_norm_epsilon,
|
|
25
|
+
"rope_theta": 1000000.0,
|
|
26
|
+
"attention_bias": False,
|
|
27
|
+
"attention_dropout": backbone.dropout,
|
|
28
|
+
"hidden_activation": "gelu_pytorch_tanh",
|
|
29
|
+
# Added missing keys to match official config
|
|
30
|
+
"sliding_window": backbone.sliding_window_size,
|
|
31
|
+
"_sliding_window_pattern": 6,
|
|
32
|
+
"use_cache": True,
|
|
33
|
+
"torch_dtype": backbone.dtype_policy.name,
|
|
34
|
+
"layer_types": layer_types,
|
|
35
|
+
"query_pre_attn_scalar": backbone.head_dim
|
|
36
|
+
if backbone.query_head_dim_normalize
|
|
37
|
+
else backbone.hidden_dim // backbone.num_query_heads,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
return hf_config
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_gemma3_weights_map(backbone, include_lm_head=False):
|
|
44
|
+
"""Convert a Keras Gemma3 model to Hugging Face format.
|
|
45
|
+
|
|
46
|
+
include_lm_head: If True, exports for CausalLM (with "model." prefix).
|
|
47
|
+
If False, exports for backbone only (without prefix).
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def _convert_qkv_kernel(kernel, hidden_dim):
|
|
51
|
+
"""Helper to convert Q/K/V projection kernels to HF format.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
kernel: The kernel weight tensor to convert.
|
|
55
|
+
hidden_dim: The hidden dimension size for reshaping.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Converted kernel in HF format.
|
|
59
|
+
"""
|
|
60
|
+
kernel = ops.transpose(kernel, axes=(1, 0, 2)) # permute(1, 0, 2)
|
|
61
|
+
kernel = ops.reshape(kernel, (hidden_dim, -1))
|
|
62
|
+
kernel = ops.transpose(kernel) # .T
|
|
63
|
+
return kernel
|
|
64
|
+
|
|
65
|
+
weights_dict = {}
|
|
66
|
+
|
|
67
|
+
# For CausalLM export, use "model." prefix
|
|
68
|
+
# For backbone export, use no prefix
|
|
69
|
+
prefix = "model." if include_lm_head else ""
|
|
70
|
+
|
|
71
|
+
# Token embeddings - use .weights[0] to get backend tensor
|
|
72
|
+
token_embedding_layer = backbone.get_layer("token_embedding")
|
|
73
|
+
token_embedding = token_embedding_layer.weights[0]
|
|
74
|
+
weights_dict[f"{prefix}embed_tokens.weight"] = token_embedding
|
|
75
|
+
|
|
76
|
+
for i in range(backbone.num_layers):
|
|
77
|
+
block = backbone.get_layer(f"decoder_block_{i}")
|
|
78
|
+
|
|
79
|
+
# Attention query projection
|
|
80
|
+
q_kernel = _convert_qkv_kernel(
|
|
81
|
+
block.attention.query_dense.weights[0], backbone.hidden_dim
|
|
82
|
+
)
|
|
83
|
+
weights_dict[f"{prefix}layers.{i}.self_attn.q_proj.weight"] = q_kernel
|
|
84
|
+
|
|
85
|
+
# Attention key projection
|
|
86
|
+
k_kernel = _convert_qkv_kernel(
|
|
87
|
+
block.attention.key_dense.weights[0], backbone.hidden_dim
|
|
88
|
+
)
|
|
89
|
+
weights_dict[f"{prefix}layers.{i}.self_attn.k_proj.weight"] = k_kernel
|
|
90
|
+
|
|
91
|
+
# Attention value projection
|
|
92
|
+
v_kernel = _convert_qkv_kernel(
|
|
93
|
+
block.attention.value_dense.weights[0], backbone.hidden_dim
|
|
94
|
+
)
|
|
95
|
+
weights_dict[f"{prefix}layers.{i}.self_attn.v_proj.weight"] = v_kernel
|
|
96
|
+
|
|
97
|
+
# Attention output projection
|
|
98
|
+
o_kernel = block.attention.output_dense.weights[0]
|
|
99
|
+
o_kernel = ops.transpose(o_kernel, axes=(2, 0, 1)) # permute(2, 0, 1)
|
|
100
|
+
o_kernel = ops.reshape(o_kernel, (backbone.hidden_dim, -1))
|
|
101
|
+
weights_dict[f"{prefix}layers.{i}.self_attn.o_proj.weight"] = o_kernel
|
|
102
|
+
|
|
103
|
+
# Query and key normalization
|
|
104
|
+
q_norm = block.attention.query_norm.weights[0]
|
|
105
|
+
weights_dict[f"{prefix}layers.{i}.self_attn.q_norm.weight"] = q_norm
|
|
106
|
+
|
|
107
|
+
k_norm = block.attention.key_norm.weights[0]
|
|
108
|
+
weights_dict[f"{prefix}layers.{i}.self_attn.k_norm.weight"] = k_norm
|
|
109
|
+
|
|
110
|
+
# MLP gate projection
|
|
111
|
+
gate_kernel = block.gating_ffw.weights[0]
|
|
112
|
+
gate_kernel = ops.transpose(gate_kernel) # .T
|
|
113
|
+
weights_dict[f"{prefix}layers.{i}.mlp.gate_proj.weight"] = gate_kernel
|
|
114
|
+
|
|
115
|
+
# MLP up projection
|
|
116
|
+
up_kernel = block.gating_ffw_2.weights[0]
|
|
117
|
+
up_kernel = ops.transpose(up_kernel) # .T
|
|
118
|
+
weights_dict[f"{prefix}layers.{i}.mlp.up_proj.weight"] = up_kernel
|
|
119
|
+
|
|
120
|
+
# MLP down projection
|
|
121
|
+
down_kernel = block.ffw_linear.weights[0]
|
|
122
|
+
down_kernel = ops.transpose(down_kernel) # .T
|
|
123
|
+
weights_dict[f"{prefix}layers.{i}.mlp.down_proj.weight"] = down_kernel
|
|
124
|
+
|
|
125
|
+
# Pre-attention normalization
|
|
126
|
+
input_layer_norm = block.pre_attention_norm.weights[0]
|
|
127
|
+
weights_dict[f"{prefix}layers.{i}.input_layernorm.weight"] = (
|
|
128
|
+
input_layer_norm
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Post-attention normalization
|
|
132
|
+
if hasattr(block, "post_attention_norm"):
|
|
133
|
+
post_attn_norm = block.post_attention_norm.weights[0]
|
|
134
|
+
weights_dict[
|
|
135
|
+
f"{prefix}layers.{i}.post_attention_layernorm.weight"
|
|
136
|
+
] = post_attn_norm
|
|
137
|
+
# Pre-feedforward normalization
|
|
138
|
+
pre_feedforward_layernorm = block.pre_ffw_norm.weights[0]
|
|
139
|
+
weights_dict[f"{prefix}layers.{i}.pre_feedforward_layernorm.weight"] = (
|
|
140
|
+
pre_feedforward_layernorm
|
|
141
|
+
)
|
|
142
|
+
# Post-feedforward normalization
|
|
143
|
+
if hasattr(block, "post_ffw_norm"):
|
|
144
|
+
post_feedforward_layernorm = block.post_ffw_norm.weights[0]
|
|
145
|
+
weights_dict[
|
|
146
|
+
f"{prefix}layers.{i}.post_feedforward_layernorm.weight"
|
|
147
|
+
] = post_feedforward_layernorm
|
|
148
|
+
|
|
149
|
+
# Final normalization
|
|
150
|
+
final_norm = backbone.get_layer("final_normalization").weights[0]
|
|
151
|
+
weights_dict[f"{prefix}norm.weight"] = final_norm
|
|
152
|
+
|
|
153
|
+
if include_lm_head and not token_embedding_layer.tie_weights:
|
|
154
|
+
weights_dict["lm_head.weight"] = ops.transpose(
|
|
155
|
+
token_embedding_layer.reverse_embeddings
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
return weights_dict
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def get_gemma3_tokenizer_config(tokenizer):
|
|
162
|
+
tokenizer_config = {
|
|
163
|
+
"tokenizer_class": "GemmaTokenizer",
|
|
164
|
+
"clean_up_tokenization_spaces": False,
|
|
165
|
+
"bos_token": "<bos>",
|
|
166
|
+
"eos_token": "<eos>",
|
|
167
|
+
"pad_token": "<pad>",
|
|
168
|
+
"unk_token": "<unk>",
|
|
169
|
+
"add_bos_token": True,
|
|
170
|
+
"add_eos_token": False,
|
|
171
|
+
"model_max_length": 1000000000000000019884624838656,
|
|
172
|
+
}
|
|
173
|
+
# Add added_tokens_decoder
|
|
174
|
+
added_tokens_decoder = {}
|
|
175
|
+
special_tokens = [
|
|
176
|
+
"<pad>",
|
|
177
|
+
"<bos>",
|
|
178
|
+
"<eos>",
|
|
179
|
+
"<unk>",
|
|
180
|
+
"<mask>",
|
|
181
|
+
"[multimodal]",
|
|
182
|
+
"<img>",
|
|
183
|
+
]
|
|
184
|
+
for token in special_tokens:
|
|
185
|
+
token_id = tokenizer.token_to_id(token)
|
|
186
|
+
if token_id is not None:
|
|
187
|
+
added_tokens_decoder[str(token_id)] = {
|
|
188
|
+
"content": token,
|
|
189
|
+
"special": True,
|
|
190
|
+
"single_word": False,
|
|
191
|
+
"lstrip": False,
|
|
192
|
+
"rstrip": False,
|
|
193
|
+
"normalized": False,
|
|
194
|
+
}
|
|
195
|
+
tokenizer_config["added_tokens_decoder"] = added_tokens_decoder
|
|
196
|
+
return tokenizer_config
|