keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,472 @@
1
+ import warnings
2
+
3
+ import numpy as np
4
+ from keras import layers
5
+
6
+ from keras_hub.src.models.sam3.sam3_detr_decoder import SAM3DetrDecoder
7
+ from keras_hub.src.models.sam3.sam3_detr_encoder import SAM3DetrEncoder
8
+ from keras_hub.src.models.sam3.sam3_geometry_encoder import SAM3GeometryEncoder
9
+ from keras_hub.src.models.sam3.sam3_mask_decoder import SAM3MaskDecoder
10
+ from keras_hub.src.models.sam3.sam3_pc_backbone import (
11
+ SAM3PromptableConceptBackbone,
12
+ )
13
+ from keras_hub.src.models.sam3.sam3_text_encoder import SAM3TextEncoder
14
+ from keras_hub.src.models.sam3.sam3_vision_encoder import SAM3VisionEncoder
15
+ from keras_hub.src.utils.preset_utils import load_json
16
+
17
+ backbone_cls = SAM3PromptableConceptBackbone
18
+
19
+
20
+ def convert_backbone_config(transformers_config, cls, **kwargs):
21
+ # detector_config: Promptable Concept Segmentation (PCS)
22
+ # tracker_config: Promptable Visual Segmentation (PVS)
23
+ if issubclass(cls, SAM3PromptableConceptBackbone):
24
+ # Extract sub-configurations.
25
+ transformers_config = transformers_config["detector_config"]
26
+
27
+ vision_config = transformers_config["vision_config"]
28
+ backbone_config = vision_config["backbone_config"]
29
+ text_config = transformers_config["text_config"]
30
+ geom_config = transformers_config["geometry_encoder_config"]
31
+ detr_enc_config = transformers_config["detr_encoder_config"]
32
+ detr_dec_config = transformers_config["detr_decoder_config"]
33
+ mask_dec_config = transformers_config["mask_decoder_config"]
34
+ dtype = kwargs.pop("dtype", None)
35
+ image_shape = kwargs.pop("image_shape", None)
36
+ if image_shape is None:
37
+ image_shape = (
38
+ backbone_config["image_size"],
39
+ backbone_config["image_size"],
40
+ 3,
41
+ )
42
+
43
+ # Vision Encoder.
44
+ vision_encoder_config = {
45
+ "image_shape": image_shape,
46
+ "patch_size": backbone_config["patch_size"],
47
+ "num_layers": backbone_config["num_hidden_layers"],
48
+ "hidden_dim": backbone_config["hidden_size"],
49
+ "intermediate_dim": backbone_config["intermediate_size"],
50
+ "num_heads": backbone_config["num_attention_heads"],
51
+ "fpn_hidden_dim": vision_config["fpn_hidden_size"],
52
+ "fpn_scale_factors": vision_config["scale_factors"],
53
+ "pretrain_image_shape": (
54
+ backbone_config["pretrain_image_size"],
55
+ backbone_config["pretrain_image_size"],
56
+ 3,
57
+ ),
58
+ "hidden_activation": backbone_config["hidden_act"],
59
+ "rope_theta": backbone_config["rope_theta"],
60
+ "window_size": backbone_config["window_size"],
61
+ "global_attn_indexes": backbone_config["global_attn_indexes"],
62
+ "attention_dropout_rate": backbone_config["attention_dropout"],
63
+ "hidden_dropout_rate": backbone_config["hidden_dropout"],
64
+ "layer_norm_epsilon": backbone_config["layer_norm_eps"],
65
+ "dtype": dtype,
66
+ }
67
+ vision_encoder = SAM3VisionEncoder(**vision_encoder_config)
68
+
69
+ # Text Encoder.
70
+ text_encoder_config = {
71
+ "vocabulary_size": text_config["vocab_size"],
72
+ "embedding_dim": text_config["hidden_size"],
73
+ "hidden_dim": text_config["hidden_size"],
74
+ "num_layers": text_config["num_hidden_layers"],
75
+ "num_heads": text_config["num_attention_heads"],
76
+ "intermediate_dim": text_config["intermediate_size"],
77
+ "intermediate_activation": text_config["hidden_act"],
78
+ "max_sequence_length": text_config["max_position_embeddings"],
79
+ "layer_norm_epsilon": text_config["layer_norm_eps"],
80
+ "dtype": dtype,
81
+ }
82
+ text_encoder = SAM3TextEncoder(**text_encoder_config)
83
+
84
+ # Geometry Encoder.
85
+ geometry_encoder_config = {
86
+ "num_layers": geom_config["num_layers"],
87
+ "hidden_dim": geom_config["hidden_size"],
88
+ "intermediate_dim": geom_config["intermediate_size"],
89
+ "num_heads": geom_config["num_attention_heads"],
90
+ "roi_size": geom_config["roi_size"],
91
+ "hidden_activation": geom_config["hidden_act"],
92
+ "dropout_rate": geom_config["hidden_dropout"],
93
+ "layer_norm_epsilon": geom_config["layer_norm_eps"],
94
+ "dtype": dtype,
95
+ }
96
+ geometry_encoder = SAM3GeometryEncoder(**geometry_encoder_config)
97
+
98
+ # DETR Encoder.
99
+ detr_encoder_config = {
100
+ "num_layers": detr_enc_config["num_layers"],
101
+ "hidden_dim": detr_enc_config["hidden_size"],
102
+ "intermediate_dim": detr_enc_config["intermediate_size"],
103
+ "num_heads": detr_enc_config["num_attention_heads"],
104
+ "hidden_activation": detr_enc_config["hidden_act"],
105
+ "dropout_rate": detr_enc_config["dropout"],
106
+ "layer_norm_epsilon": detr_enc_config["layer_norm_eps"],
107
+ "dtype": dtype,
108
+ }
109
+ detr_encoder = SAM3DetrEncoder(**detr_encoder_config)
110
+
111
+ # DETR Decoder.
112
+ detr_decoder_config = {
113
+ "image_shape": image_shape,
114
+ "patch_size": backbone_config["patch_size"],
115
+ "num_layers": detr_dec_config["num_layers"],
116
+ "hidden_dim": detr_dec_config["hidden_size"],
117
+ "intermediate_dim": detr_dec_config["intermediate_size"],
118
+ "num_heads": detr_dec_config["num_attention_heads"],
119
+ "num_queries": detr_dec_config["num_queries"],
120
+ "hidden_activation": detr_dec_config["hidden_act"],
121
+ "dropout_rate": detr_dec_config["dropout"],
122
+ "layer_norm_epsilon": detr_dec_config["layer_norm_eps"],
123
+ "dtype": dtype,
124
+ }
125
+ detr_decoder = SAM3DetrDecoder(**detr_decoder_config)
126
+
127
+ # Mask Decoder.
128
+ mask_decoder_config = {
129
+ "num_upsampling_stages": mask_dec_config["num_upsampling_stages"],
130
+ "hidden_dim": mask_dec_config["hidden_size"],
131
+ "num_heads": mask_dec_config["num_attention_heads"],
132
+ "dropout_rate": 0.0,
133
+ "layer_norm_epsilon": mask_dec_config["layer_norm_eps"],
134
+ "dtype": dtype,
135
+ }
136
+ mask_decoder = SAM3MaskDecoder(**mask_decoder_config)
137
+
138
+ return {
139
+ "vision_encoder": vision_encoder,
140
+ "text_encoder": text_encoder,
141
+ "geometry_encoder": geometry_encoder,
142
+ "detr_encoder": detr_encoder,
143
+ "detr_decoder": detr_decoder,
144
+ "mask_decoder": mask_decoder,
145
+ }
146
+ else:
147
+ # TODO: Add SAM3Tracker support.
148
+ raise ValueError(
149
+ "The provided class is not a subclass of "
150
+ f"SAM3PromptableConceptBackbone. Received: {cls}"
151
+ )
152
+
153
+
154
+ def convert_weights(backbone, loader, transformers_config):
155
+ if not isinstance(backbone, SAM3PromptableConceptBackbone):
156
+ raise ValueError(
157
+ "The provided backbone must be an instance of "
158
+ f"SAM3PromptableConceptBackbone. Received: {type(backbone)}"
159
+ )
160
+
161
+ def port_dense(keras_dense, hf_name):
162
+ loader.port_weight(
163
+ keras_dense.kernel, f"{hf_name}.weight", hook_fn=lambda x, _: x.T
164
+ )
165
+ if keras_dense.bias is not None:
166
+ loader.port_weight(keras_dense.bias, f"{hf_name}.bias")
167
+
168
+ def port_ln(keras_ln, hf_name):
169
+ loader.port_weight(keras_ln.gamma, f"{hf_name}.weight")
170
+ loader.port_weight(keras_ln.beta, f"{hf_name}.bias")
171
+
172
+ def port_conv(keras_conv, hf_name):
173
+ if not keras_conv.built:
174
+ # https://github.com/huggingface/transformers/issues/43065
175
+ warnings.warn(f"Skipping {hf_name}")
176
+ return
177
+ loader.port_weight(
178
+ keras_conv.kernel,
179
+ f"{hf_name}.weight",
180
+ hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
181
+ )
182
+ if keras_conv.bias is not None:
183
+ loader.port_weight(keras_conv.bias, f"{hf_name}.bias")
184
+
185
+ def port_gn(keras_gn, hf_name):
186
+ if not keras_gn.built:
187
+ # https://github.com/huggingface/transformers/issues/43065
188
+ warnings.warn(f"Skipping {hf_name}")
189
+ return
190
+ loader.port_weight(keras_gn.gamma, f"{hf_name}.weight")
191
+ loader.port_weight(keras_gn.beta, f"{hf_name}.bias")
192
+
193
+ def port_attention(keras_attn, hf_name):
194
+ port_dense(keras_attn.q_proj, f"{hf_name}.q_proj")
195
+ port_dense(keras_attn.k_proj, f"{hf_name}.k_proj")
196
+ port_dense(keras_attn.v_proj, f"{hf_name}.v_proj")
197
+ port_dense(keras_attn.o_proj, f"{hf_name}.o_proj")
198
+
199
+ def port_mlp(keras_mlp, hf_name):
200
+ port_dense(keras_mlp.fc1, f"{hf_name}.fc1")
201
+ port_dense(keras_mlp.fc2, f"{hf_name}.fc2")
202
+
203
+ def port_decoder_mlp(keras_mlp, hf_name):
204
+ port_dense(keras_mlp.layer1, f"{hf_name}.layer1")
205
+ port_dense(keras_mlp.layer2, f"{hf_name}.layer2")
206
+ if hasattr(keras_mlp, "layer3") and keras_mlp.layer3 is not None:
207
+ port_dense(keras_mlp.layer3, f"{hf_name}.layer3")
208
+
209
+ # Vision Encoder.
210
+ vision_prefix = "vision_encoder"
211
+ backbone_prefix = f"{vision_prefix}.backbone"
212
+ emb = backbone.vision_encoder.backbone.embeddings
213
+ port_conv(
214
+ emb.patch_embeddings.projection,
215
+ f"{backbone_prefix}.embeddings.patch_embeddings.projection",
216
+ )
217
+ loader.port_weight(
218
+ emb.position_embeddings,
219
+ f"{backbone_prefix}.embeddings.position_embeddings",
220
+ )
221
+ emb.tiled_position_embeddings.assign(
222
+ emb._tile_position_embeddings(
223
+ emb.position_embeddings,
224
+ patch_size=emb.patch_size,
225
+ source_shape=emb.pretrain_image_shape,
226
+ target_shape=emb.image_shape,
227
+ )
228
+ )
229
+ port_ln(
230
+ backbone.vision_encoder.backbone.layer_norm,
231
+ f"{backbone_prefix}.layer_norm",
232
+ )
233
+ for i, layer in enumerate(backbone.vision_encoder.backbone.layers):
234
+ p = f"{backbone_prefix}.layers.{i}"
235
+ port_ln(layer.layer_norm1, f"{p}.layer_norm1")
236
+ port_attention(layer.attention, f"{p}.attention")
237
+ port_ln(layer.layer_norm2, f"{p}.layer_norm2")
238
+ port_mlp(layer.mlp, f"{p}.mlp")
239
+
240
+ neck_prefix = f"{vision_prefix}.neck"
241
+ for i, layer in enumerate(backbone.vision_encoder.vision_neck.fpn_layers):
242
+ p = f"{neck_prefix}.fpn_layers.{i}"
243
+ # FPN scale layers
244
+ for j, scale_layer in enumerate(layer.scale_layers):
245
+ if isinstance(scale_layer, (layers.Conv2DTranspose, layers.Conv2D)):
246
+ port_conv(scale_layer, f"{p}.scale_layers.{j}")
247
+
248
+ port_conv(layer.proj1, f"{p}.proj1")
249
+ port_conv(layer.proj2, f"{p}.proj2")
250
+
251
+ # Text Encoder.
252
+ text_prefix = "text_encoder.text_model"
253
+ loader.port_weight(
254
+ backbone.text_encoder.embedding.token_embedding.embeddings,
255
+ f"{text_prefix}.embeddings.token_embedding.weight",
256
+ )
257
+ loader.port_weight(
258
+ backbone.text_encoder.embedding.position_embedding.position_embeddings,
259
+ f"{text_prefix}.embeddings.position_embedding.weight",
260
+ )
261
+ for i, layer in enumerate(backbone.text_encoder.encoder_layers):
262
+ p = f"{text_prefix}.encoder.layers.{i}"
263
+ port_ln(layer.layer_norm_1, f"{p}.layer_norm1")
264
+ num_heads = backbone.text_encoder.num_heads
265
+ hidden_dim = backbone.text_encoder.hidden_dim
266
+ head_dim = hidden_dim // num_heads
267
+
268
+ def port_mha_weight(keras_dense, hf_name, is_output=False):
269
+ def hook(x, _):
270
+ w = x.T
271
+ if is_output:
272
+ return w.reshape(num_heads, head_dim, hidden_dim)
273
+ else:
274
+ return w.reshape(hidden_dim, num_heads, head_dim)
275
+
276
+ loader.port_weight(
277
+ keras_dense.kernel,
278
+ f"{hf_name}.weight",
279
+ hook_fn=hook,
280
+ )
281
+ if keras_dense.bias is not None:
282
+
283
+ def bias_hook(x, _):
284
+ if is_output:
285
+ return x # (hidden,)
286
+ else:
287
+ return x.reshape(num_heads, head_dim)
288
+
289
+ loader.port_weight(
290
+ keras_dense.bias, f"{hf_name}.bias", hook_fn=bias_hook
291
+ )
292
+
293
+ port_mha_weight(layer.attention._query_dense, f"{p}.self_attn.q_proj")
294
+ port_mha_weight(layer.attention._key_dense, f"{p}.self_attn.k_proj")
295
+ port_mha_weight(layer.attention._value_dense, f"{p}.self_attn.v_proj")
296
+ port_mha_weight(
297
+ layer.attention._output_dense,
298
+ f"{p}.self_attn.out_proj",
299
+ is_output=True,
300
+ )
301
+ port_ln(layer.layer_norm_2, f"{p}.layer_norm2")
302
+ port_dense(layer.dense_1, f"{p}.mlp.fc1")
303
+ port_dense(layer.dense_2, f"{p}.mlp.fc2")
304
+
305
+ port_ln(backbone.text_encoder.layer_norm, f"{text_prefix}.final_layer_norm")
306
+ port_dense(backbone.text_projection, "text_projection")
307
+
308
+ # Geometry Encoder.
309
+ geo_prefix = "geometry_encoder"
310
+ loader.port_weight(
311
+ backbone.geometry_encoder.label_embed.embeddings,
312
+ f"{geo_prefix}.label_embed.weight",
313
+ )
314
+ loader.port_weight(
315
+ backbone.geometry_encoder.cls_embed.embeddings,
316
+ f"{geo_prefix}.cls_embed.weight",
317
+ )
318
+ port_dense(
319
+ backbone.geometry_encoder.boxes_direct_project,
320
+ f"{geo_prefix}.boxes_direct_project",
321
+ )
322
+ port_conv(
323
+ backbone.geometry_encoder.boxes_pool_project,
324
+ f"{geo_prefix}.boxes_pool_project",
325
+ )
326
+ port_dense(
327
+ backbone.geometry_encoder.boxes_pos_enc_project,
328
+ f"{geo_prefix}.boxes_pos_enc_project",
329
+ )
330
+ port_ln(
331
+ backbone.geometry_encoder.vision_layer_norm,
332
+ f"{geo_prefix}.vision_layer_norm",
333
+ )
334
+ port_dense(backbone.geometry_encoder.final_proj, f"{geo_prefix}.final_proj")
335
+ port_ln(
336
+ backbone.geometry_encoder.prompt_layer_norm,
337
+ f"{geo_prefix}.prompt_layer_norm",
338
+ )
339
+ for i, layer in enumerate(backbone.geometry_encoder.layers):
340
+ p = f"{geo_prefix}.layers.{i}"
341
+ port_ln(layer.layer_norm1, f"{p}.layer_norm1")
342
+ port_attention(layer.self_attn, f"{p}.self_attn")
343
+ port_ln(layer.layer_norm2, f"{p}.layer_norm2")
344
+ port_attention(layer.cross_attn, f"{p}.cross_attn")
345
+ port_ln(layer.layer_norm3, f"{p}.layer_norm3")
346
+ port_mlp(layer.mlp, f"{p}.mlp")
347
+ port_ln(
348
+ backbone.geometry_encoder.output_layer_norm,
349
+ f"{geo_prefix}.output_layer_norm",
350
+ )
351
+
352
+ # DETR Encoder.
353
+ detr_enc_prefix = "detr_encoder"
354
+ for i, layer in enumerate(backbone.detr_encoder.layers):
355
+ p = f"{detr_enc_prefix}.layers.{i}"
356
+ port_ln(layer.layer_norm1, f"{p}.layer_norm1")
357
+ port_attention(layer.self_attn, f"{p}.self_attn")
358
+ port_attention(layer.cross_attn, f"{p}.cross_attn")
359
+ port_ln(layer.layer_norm2, f"{p}.layer_norm2")
360
+ port_mlp(layer.mlp, f"{p}.mlp")
361
+ port_ln(layer.layer_norm3, f"{p}.layer_norm3")
362
+
363
+ # DETR Decoder.
364
+ detr_dec_prefix = "detr_decoder"
365
+ port_ln(
366
+ backbone.detr_decoder.output_layer_norm,
367
+ f"{detr_dec_prefix}.output_layer_norm",
368
+ )
369
+ port_decoder_mlp(
370
+ backbone.detr_decoder.box_head, f"{detr_dec_prefix}.box_head"
371
+ )
372
+ loader.port_weight(
373
+ backbone.detr_decoder.query_embed.embeddings,
374
+ f"{detr_dec_prefix}.query_embed.weight",
375
+ )
376
+ loader.port_weight(
377
+ backbone.detr_decoder.reference_points.embeddings,
378
+ f"{detr_dec_prefix}.reference_points.weight",
379
+ )
380
+ loader.port_weight(
381
+ backbone.detr_decoder.presence_token.embeddings,
382
+ f"{detr_dec_prefix}.presence_token.weight",
383
+ )
384
+ port_decoder_mlp(
385
+ backbone.detr_decoder.presence_head, f"{detr_dec_prefix}.presence_head"
386
+ )
387
+ port_ln(
388
+ backbone.detr_decoder.presence_layer_norm,
389
+ f"{detr_dec_prefix}.presence_layer_norm",
390
+ )
391
+ port_decoder_mlp(
392
+ backbone.detr_decoder.ref_point_head,
393
+ f"{detr_dec_prefix}.ref_point_head",
394
+ )
395
+ port_decoder_mlp(
396
+ backbone.detr_decoder.box_rpb_embed_x,
397
+ f"{detr_dec_prefix}.box_rpb_embed_x",
398
+ )
399
+ port_decoder_mlp(
400
+ backbone.detr_decoder.box_rpb_embed_y,
401
+ f"{detr_dec_prefix}.box_rpb_embed_y",
402
+ )
403
+ for i, layer in enumerate(backbone.detr_decoder.layers):
404
+ p = f"{detr_dec_prefix}.layers.{i}"
405
+ port_attention(layer.self_attn, f"{p}.self_attn")
406
+ port_ln(layer.self_attn_layer_norm, f"{p}.self_attn_layer_norm")
407
+ port_attention(layer.text_cross_attn, f"{p}.text_cross_attn")
408
+ port_ln(
409
+ layer.text_cross_attn_layer_norm, f"{p}.text_cross_attn_layer_norm"
410
+ )
411
+ port_attention(layer.vision_cross_attn, f"{p}.vision_cross_attn")
412
+ port_ln(
413
+ layer.vision_cross_attn_layer_norm,
414
+ f"{p}.vision_cross_attn_layer_norm",
415
+ )
416
+ port_mlp(layer.mlp, f"{p}.mlp")
417
+ port_ln(layer.mlp_layer_norm, f"{p}.mlp_layer_norm")
418
+
419
+ # Mask Decoder.
420
+ mask_prefix = "mask_decoder"
421
+ for i in range(len(backbone.mask_decoder.pixel_decoder.conv_layers)):
422
+ p = f"{mask_prefix}.pixel_decoder"
423
+ port_conv(
424
+ backbone.mask_decoder.pixel_decoder.conv_layers[i],
425
+ f"{p}.conv_layers.{i}",
426
+ )
427
+ port_gn(backbone.mask_decoder.pixel_decoder.norms[i], f"{p}.norms.{i}")
428
+ for i in range(len(backbone.mask_decoder.mask_embedder.layers)):
429
+ port_dense(
430
+ backbone.mask_decoder.mask_embedder.layers[i],
431
+ f"{mask_prefix}.mask_embedder.layers.{i}",
432
+ )
433
+ port_conv(
434
+ backbone.mask_decoder.instance_projection,
435
+ f"{mask_prefix}.instance_projection",
436
+ )
437
+ port_conv(
438
+ backbone.mask_decoder.semantic_projection,
439
+ f"{mask_prefix}.semantic_projection",
440
+ )
441
+ port_attention(
442
+ backbone.mask_decoder.prompt_cross_attn,
443
+ f"{mask_prefix}.prompt_cross_attn",
444
+ )
445
+ port_ln(
446
+ backbone.mask_decoder.prompt_cross_attn_norm,
447
+ f"{mask_prefix}.prompt_cross_attn_norm",
448
+ )
449
+
450
+ # Top Level Backbone Layers.
451
+ scoring_prefix = "dot_product_scoring"
452
+ port_decoder_mlp(
453
+ backbone.dot_product_scoring.text_mlp, f"{scoring_prefix}.text_mlp"
454
+ )
455
+ port_ln(
456
+ backbone.dot_product_scoring.text_mlp_out_norm,
457
+ f"{scoring_prefix}.text_mlp_out_norm",
458
+ )
459
+ port_dense(
460
+ backbone.dot_product_scoring.text_proj, f"{scoring_prefix}.text_proj"
461
+ )
462
+ port_dense(
463
+ backbone.dot_product_scoring.query_proj, f"{scoring_prefix}.query_proj"
464
+ )
465
+
466
+
467
+ def convert_tokenizer(cls, preset, **kwargs):
468
+ tokenizer_config = load_json(preset, "tokenizer.json")
469
+ vocab = tokenizer_config["model"]["vocab"]
470
+ merges = tokenizer_config["model"]["merges"]
471
+ merges = [" ".join(item) for item in merges]
472
+ return cls(vocabulary=vocab, merges=merges, **kwargs)
@@ -0,0 +1,196 @@
1
+ import keras.ops as ops
2
+
3
+
4
+ def get_gemma3_config(backbone):
5
+ """Convert Keras Gemma3 config to Hugging Face config dictionary."""
6
+
7
+ layer_types = []
8
+ for i in range(backbone.num_layers):
9
+ if backbone.use_sliding_window_attention and (i % 6 < 5):
10
+ layer_types.append("sliding_attention")
11
+ else:
12
+ layer_types.append("full_attention")
13
+
14
+ hf_config = {
15
+ "architectures": ["Gemma3ForCausalLM"],
16
+ "model_type": "gemma3_text",
17
+ "vocab_size": backbone.vocabulary_size,
18
+ "num_hidden_layers": backbone.num_layers,
19
+ "num_attention_heads": backbone.num_query_heads,
20
+ "num_key_value_heads": backbone.num_key_value_heads,
21
+ "hidden_size": backbone.hidden_dim,
22
+ "intermediate_size": backbone.intermediate_dim,
23
+ "head_dim": backbone.head_dim,
24
+ "rms_norm_eps": backbone.layer_norm_epsilon,
25
+ "rope_theta": 1000000.0,
26
+ "attention_bias": False,
27
+ "attention_dropout": backbone.dropout,
28
+ "hidden_activation": "gelu_pytorch_tanh",
29
+ # Added missing keys to match official config
30
+ "sliding_window": backbone.sliding_window_size,
31
+ "_sliding_window_pattern": 6,
32
+ "use_cache": True,
33
+ "torch_dtype": backbone.dtype_policy.name,
34
+ "layer_types": layer_types,
35
+ "query_pre_attn_scalar": backbone.head_dim
36
+ if backbone.query_head_dim_normalize
37
+ else backbone.hidden_dim // backbone.num_query_heads,
38
+ }
39
+
40
+ return hf_config
41
+
42
+
43
+ def get_gemma3_weights_map(backbone, include_lm_head=False):
44
+ """Convert a Keras Gemma3 model to Hugging Face format.
45
+
46
+ include_lm_head: If True, exports for CausalLM (with "model." prefix).
47
+ If False, exports for backbone only (without prefix).
48
+ """
49
+
50
+ def _convert_qkv_kernel(kernel, hidden_dim):
51
+ """Helper to convert Q/K/V projection kernels to HF format.
52
+
53
+ Args:
54
+ kernel: The kernel weight tensor to convert.
55
+ hidden_dim: The hidden dimension size for reshaping.
56
+
57
+ Returns:
58
+ Converted kernel in HF format.
59
+ """
60
+ kernel = ops.transpose(kernel, axes=(1, 0, 2)) # permute(1, 0, 2)
61
+ kernel = ops.reshape(kernel, (hidden_dim, -1))
62
+ kernel = ops.transpose(kernel) # .T
63
+ return kernel
64
+
65
+ weights_dict = {}
66
+
67
+ # For CausalLM export, use "model." prefix
68
+ # For backbone export, use no prefix
69
+ prefix = "model." if include_lm_head else ""
70
+
71
+ # Token embeddings - use .weights[0] to get backend tensor
72
+ token_embedding_layer = backbone.get_layer("token_embedding")
73
+ token_embedding = token_embedding_layer.weights[0]
74
+ weights_dict[f"{prefix}embed_tokens.weight"] = token_embedding
75
+
76
+ for i in range(backbone.num_layers):
77
+ block = backbone.get_layer(f"decoder_block_{i}")
78
+
79
+ # Attention query projection
80
+ q_kernel = _convert_qkv_kernel(
81
+ block.attention.query_dense.weights[0], backbone.hidden_dim
82
+ )
83
+ weights_dict[f"{prefix}layers.{i}.self_attn.q_proj.weight"] = q_kernel
84
+
85
+ # Attention key projection
86
+ k_kernel = _convert_qkv_kernel(
87
+ block.attention.key_dense.weights[0], backbone.hidden_dim
88
+ )
89
+ weights_dict[f"{prefix}layers.{i}.self_attn.k_proj.weight"] = k_kernel
90
+
91
+ # Attention value projection
92
+ v_kernel = _convert_qkv_kernel(
93
+ block.attention.value_dense.weights[0], backbone.hidden_dim
94
+ )
95
+ weights_dict[f"{prefix}layers.{i}.self_attn.v_proj.weight"] = v_kernel
96
+
97
+ # Attention output projection
98
+ o_kernel = block.attention.output_dense.weights[0]
99
+ o_kernel = ops.transpose(o_kernel, axes=(2, 0, 1)) # permute(2, 0, 1)
100
+ o_kernel = ops.reshape(o_kernel, (backbone.hidden_dim, -1))
101
+ weights_dict[f"{prefix}layers.{i}.self_attn.o_proj.weight"] = o_kernel
102
+
103
+ # Query and key normalization
104
+ q_norm = block.attention.query_norm.weights[0]
105
+ weights_dict[f"{prefix}layers.{i}.self_attn.q_norm.weight"] = q_norm
106
+
107
+ k_norm = block.attention.key_norm.weights[0]
108
+ weights_dict[f"{prefix}layers.{i}.self_attn.k_norm.weight"] = k_norm
109
+
110
+ # MLP gate projection
111
+ gate_kernel = block.gating_ffw.weights[0]
112
+ gate_kernel = ops.transpose(gate_kernel) # .T
113
+ weights_dict[f"{prefix}layers.{i}.mlp.gate_proj.weight"] = gate_kernel
114
+
115
+ # MLP up projection
116
+ up_kernel = block.gating_ffw_2.weights[0]
117
+ up_kernel = ops.transpose(up_kernel) # .T
118
+ weights_dict[f"{prefix}layers.{i}.mlp.up_proj.weight"] = up_kernel
119
+
120
+ # MLP down projection
121
+ down_kernel = block.ffw_linear.weights[0]
122
+ down_kernel = ops.transpose(down_kernel) # .T
123
+ weights_dict[f"{prefix}layers.{i}.mlp.down_proj.weight"] = down_kernel
124
+
125
+ # Pre-attention normalization
126
+ input_layer_norm = block.pre_attention_norm.weights[0]
127
+ weights_dict[f"{prefix}layers.{i}.input_layernorm.weight"] = (
128
+ input_layer_norm
129
+ )
130
+
131
+ # Post-attention normalization
132
+ if hasattr(block, "post_attention_norm"):
133
+ post_attn_norm = block.post_attention_norm.weights[0]
134
+ weights_dict[
135
+ f"{prefix}layers.{i}.post_attention_layernorm.weight"
136
+ ] = post_attn_norm
137
+ # Pre-feedforward normalization
138
+ pre_feedforward_layernorm = block.pre_ffw_norm.weights[0]
139
+ weights_dict[f"{prefix}layers.{i}.pre_feedforward_layernorm.weight"] = (
140
+ pre_feedforward_layernorm
141
+ )
142
+ # Post-feedforward normalization
143
+ if hasattr(block, "post_ffw_norm"):
144
+ post_feedforward_layernorm = block.post_ffw_norm.weights[0]
145
+ weights_dict[
146
+ f"{prefix}layers.{i}.post_feedforward_layernorm.weight"
147
+ ] = post_feedforward_layernorm
148
+
149
+ # Final normalization
150
+ final_norm = backbone.get_layer("final_normalization").weights[0]
151
+ weights_dict[f"{prefix}norm.weight"] = final_norm
152
+
153
+ if include_lm_head and not token_embedding_layer.tie_weights:
154
+ weights_dict["lm_head.weight"] = ops.transpose(
155
+ token_embedding_layer.reverse_embeddings
156
+ )
157
+
158
+ return weights_dict
159
+
160
+
161
+ def get_gemma3_tokenizer_config(tokenizer):
162
+ tokenizer_config = {
163
+ "tokenizer_class": "GemmaTokenizer",
164
+ "clean_up_tokenization_spaces": False,
165
+ "bos_token": "<bos>",
166
+ "eos_token": "<eos>",
167
+ "pad_token": "<pad>",
168
+ "unk_token": "<unk>",
169
+ "add_bos_token": True,
170
+ "add_eos_token": False,
171
+ "model_max_length": 1000000000000000019884624838656,
172
+ }
173
+ # Add added_tokens_decoder
174
+ added_tokens_decoder = {}
175
+ special_tokens = [
176
+ "<pad>",
177
+ "<bos>",
178
+ "<eos>",
179
+ "<unk>",
180
+ "<mask>",
181
+ "[multimodal]",
182
+ "<img>",
183
+ ]
184
+ for token in special_tokens:
185
+ token_id = tokenizer.token_to_id(token)
186
+ if token_id is not None:
187
+ added_tokens_decoder[str(token_id)] = {
188
+ "content": token,
189
+ "special": True,
190
+ "single_word": False,
191
+ "lstrip": False,
192
+ "rstrip": False,
193
+ "normalized": False,
194
+ }
195
+ tokenizer_config["added_tokens_decoder"] = added_tokens_decoder
196
+ return tokenizer_config