keras-hub 0.22.1__py3-none-any.whl → 0.23.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. keras_hub/layers/__init__.py +12 -0
  2. keras_hub/models/__init__.py +90 -0
  3. keras_hub/src/layers/modeling/position_embedding.py +21 -6
  4. keras_hub/src/layers/modeling/reversible_embedding.py +8 -1
  5. keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
  6. keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
  7. keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
  8. keras_hub/src/models/backbone.py +28 -16
  9. keras_hub/src/models/causal_lm.py +37 -0
  10. keras_hub/src/models/causal_lm_preprocessor.py +14 -0
  11. keras_hub/src/models/clip/clip_presets.py +8 -8
  12. keras_hub/src/models/d_fine/__init__.py +5 -0
  13. keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
  14. keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
  15. keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
  16. keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
  17. keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
  18. keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
  19. keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
  20. keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
  21. keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
  22. keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
  23. keras_hub/src/models/d_fine/d_fine_presets.py +155 -0
  24. keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
  25. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +7 -2
  26. keras_hub/src/models/depth_anything/__init__.py +9 -0
  27. keras_hub/src/models/depth_anything/depth_anything_backbone.py +232 -0
  28. keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py +70 -0
  29. keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py +16 -0
  30. keras_hub/src/models/depth_anything/depth_anything_image_converter.py +10 -0
  31. keras_hub/src/models/depth_anything/depth_anything_layers.py +725 -0
  32. keras_hub/src/models/depth_anything/depth_anything_loss.py +89 -0
  33. keras_hub/src/models/depth_anything/depth_anything_presets.py +41 -0
  34. keras_hub/src/models/depth_anything/interpolate.py +62 -0
  35. keras_hub/src/models/depth_estimator.py +239 -0
  36. keras_hub/src/models/depth_estimator_preprocessor.py +78 -0
  37. keras_hub/src/models/dinov2/dinov2_backbone.py +29 -3
  38. keras_hub/src/models/dinov2/dinov2_layers.py +13 -3
  39. keras_hub/src/models/gemma/gemma_backbone.py +0 -1
  40. keras_hub/src/models/gemma/gemma_presets.py +30 -0
  41. keras_hub/src/models/gemma3/gemma3_attention.py +48 -0
  42. keras_hub/src/models/gemma3/gemma3_backbone.py +4 -1
  43. keras_hub/src/models/gemma3/gemma3_decoder_block.py +12 -0
  44. keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
  45. keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
  46. keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
  47. keras_hub/src/models/image_to_image.py +5 -0
  48. keras_hub/src/models/inpaint.py +5 -0
  49. keras_hub/src/models/mobilenetv5/__init__.py +9 -0
  50. keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
  51. keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
  52. keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
  53. keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
  54. keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
  55. keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
  56. keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
  57. keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
  58. keras_hub/src/models/mobilenetv5/mobilenetv5_presets.py +15 -0
  59. keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
  60. keras_hub/src/models/parseq/__init__.py +5 -0
  61. keras_hub/src/models/parseq/parseq_backbone.py +134 -0
  62. keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
  63. keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
  64. keras_hub/src/models/parseq/parseq_decoder.py +418 -0
  65. keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
  66. keras_hub/src/models/parseq/parseq_presets.py +15 -0
  67. keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
  68. keras_hub/src/models/qwen3_moe/__init__.py +5 -0
  69. keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +371 -0
  70. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +365 -0
  71. keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py +357 -0
  72. keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py +12 -0
  73. keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +672 -0
  74. keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +45 -0
  75. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
  76. keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py +48 -0
  77. keras_hub/src/models/sam/sam_prompt_encoder.py +3 -1
  78. keras_hub/src/models/smollm3/smollm3_backbone.py +211 -0
  79. keras_hub/src/models/smollm3/smollm3_causal_lm.py +310 -0
  80. keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py +84 -0
  81. keras_hub/src/models/smollm3/smollm3_layers.py +757 -0
  82. keras_hub/src/models/smollm3/smollm3_tokenizer.py +60 -0
  83. keras_hub/src/models/smollm3/smollm3_utils.py +56 -0
  84. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
  85. keras_hub/src/models/t5gemma/__init__.py +5 -0
  86. keras_hub/src/models/t5gemma/t5gemma_attention.py +370 -0
  87. keras_hub/src/models/t5gemma/t5gemma_backbone.py +366 -0
  88. keras_hub/src/models/t5gemma/t5gemma_decoder.py +355 -0
  89. keras_hub/src/models/t5gemma/t5gemma_encoder.py +214 -0
  90. keras_hub/src/models/t5gemma/t5gemma_layers.py +118 -0
  91. keras_hub/src/models/t5gemma/t5gemma_presets.py +374 -0
  92. keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py +442 -0
  93. keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +216 -0
  94. keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +84 -0
  95. keras_hub/src/models/text_to_image.py +5 -0
  96. keras_hub/src/samplers/beam_sampler.py +6 -6
  97. keras_hub/src/samplers/sampler.py +8 -6
  98. keras_hub/src/tests/test_case.py +40 -3
  99. keras_hub/src/tokenizers/tokenizer.py +15 -0
  100. keras_hub/src/utils/openvino_utils.py +141 -0
  101. keras_hub/src/utils/preset_utils.py +58 -2
  102. keras_hub/src/utils/tensor_utils.py +23 -1
  103. keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
  104. keras_hub/src/utils/timm/preset_loader.py +8 -4
  105. keras_hub/src/utils/transformers/convert_dinov2.py +1 -0
  106. keras_hub/src/utils/transformers/convert_qwen3_moe.py +216 -0
  107. keras_hub/src/utils/transformers/convert_smollm3.py +139 -0
  108. keras_hub/src/utils/transformers/convert_t5gemma.py +229 -0
  109. keras_hub/src/utils/transformers/convert_vit.py +4 -1
  110. keras_hub/src/utils/transformers/export/gemma.py +49 -4
  111. keras_hub/src/utils/transformers/export/hf_exporter.py +71 -25
  112. keras_hub/src/utils/transformers/preset_loader.py +9 -0
  113. keras_hub/src/version.py +1 -1
  114. keras_hub/tokenizers/__init__.py +15 -0
  115. {keras_hub-0.22.1.dist-info → keras_hub-0.23.0.dev0.dist-info}/METADATA +1 -1
  116. {keras_hub-0.22.1.dist-info → keras_hub-0.23.0.dev0.dist-info}/RECORD +118 -45
  117. {keras_hub-0.22.1.dist-info → keras_hub-0.23.0.dev0.dist-info}/WHEEL +0 -0
  118. {keras_hub-0.22.1.dist-info → keras_hub-0.23.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,321 @@
1
+ import types
2
+
3
+ import keras
4
+ import numpy as np
5
+
6
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_attention import (
7
+ MobileAttention,
8
+ )
9
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import (
10
+ MobileNetV5Backbone,
11
+ )
12
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_blocks import EdgeResidual
13
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_blocks import (
14
+ UniversalInvertedResidual,
15
+ )
16
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_builder import (
17
+ convert_arch_def_to_stackwise,
18
+ )
19
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import ConvNormAct
20
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import RmsNorm2d
21
+
22
+ backbone_cls = MobileNetV5Backbone
23
+
24
+ MODEL_CONFIGS = {
25
+ "mobilenetv5_300m": {
26
+ "backbone": convert_arch_def_to_stackwise(
27
+ [
28
+ # Stage 0: 128x128 in
29
+ [
30
+ "er_r1_k3_s2_e4_c128",
31
+ "er_r1_k3_s1_e4_c128",
32
+ "er_r1_k3_s1_e4_c128",
33
+ ],
34
+ # Stage 1: 256x256 in
35
+ [
36
+ "uir_r1_a3_k5_s2_e6_c256",
37
+ "uir_r1_a5_k0_s1_e4_c256",
38
+ "uir_r1_a3_k0_s1_e4_c256",
39
+ "uir_r1_a5_k0_s1_e4_c256",
40
+ "uir_r1_a3_k0_s1_e4_c256",
41
+ ],
42
+ # Stage 2: 640x640 in
43
+ [
44
+ "uir_r1_a5_k5_s2_e6_c640",
45
+ "uir_r1_a5_k0_s1_e4_c640",
46
+ "uir_r1_a5_k0_s1_e4_c640",
47
+ "uir_r1_a5_k0_s1_e4_c640",
48
+ "uir_r1_a5_k0_s1_e4_c640",
49
+ "uir_r1_a5_k0_s1_e4_c640",
50
+ "uir_r1_a5_k0_s1_e4_c640",
51
+ "uir_r1_a5_k0_s1_e4_c640",
52
+ "uir_r1_a0_k0_s1_e1_c640",
53
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
54
+ "uir_r1_a0_k0_s1_e2_c640",
55
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
56
+ "uir_r1_a0_k0_s1_e2_c640",
57
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
58
+ "uir_r1_a0_k0_s1_e2_c640",
59
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
60
+ "uir_r1_a0_k0_s1_e2_c640",
61
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
62
+ "uir_r1_a0_k0_s1_e2_c640",
63
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
64
+ "uir_r1_a0_k0_s1_e2_c640",
65
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
66
+ "uir_r1_a0_k0_s1_e2_c640",
67
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
68
+ "uir_r1_a0_k0_s1_e2_c640",
69
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
70
+ "uir_r1_a0_k0_s1_e2_c640",
71
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
72
+ "uir_r1_a0_k0_s1_e2_c640",
73
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
74
+ "uir_r1_a0_k0_s1_e2_c640",
75
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
76
+ "uir_r1_a0_k0_s1_e2_c640",
77
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
78
+ "uir_r1_a0_k0_s1_e2_c640",
79
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
80
+ "uir_r1_a0_k0_s1_e2_c640",
81
+ ],
82
+ # Stage 3: 1280x1280 in
83
+ [
84
+ "uir_r1_a5_k5_s2_e6_c1280",
85
+ "mqa_r1_k3_h16_s1_d96_c1280",
86
+ "uir_r1_a0_k0_s1_e2_c1280",
87
+ "mqa_r1_k3_h16_s1_d96_c1280",
88
+ "uir_r1_a0_k0_s1_e2_c1280",
89
+ "mqa_r1_k3_h16_s1_d96_c1280",
90
+ "uir_r1_a0_k0_s1_e2_c1280",
91
+ "mqa_r1_k3_h16_s1_d96_c1280",
92
+ "uir_r1_a0_k0_s1_e2_c1280",
93
+ "mqa_r1_k3_h16_s1_d96_c1280",
94
+ "uir_r1_a0_k0_s1_e2_c1280",
95
+ "mqa_r1_k3_h16_s1_d96_c1280",
96
+ "uir_r1_a0_k0_s1_e2_c1280",
97
+ "mqa_r1_k3_h16_s1_d96_c1280",
98
+ "uir_r1_a0_k0_s1_e2_c1280",
99
+ "mqa_r1_k3_h16_s1_d96_c1280",
100
+ "uir_r1_a0_k0_s1_e2_c1280",
101
+ "mqa_r1_k3_h16_s1_d96_c1280",
102
+ "uir_r1_a0_k0_s1_e2_c1280",
103
+ "mqa_r1_k3_h16_s1_d96_c1280",
104
+ "uir_r1_a0_k0_s1_e2_c1280",
105
+ "mqa_r1_k3_h16_s1_d96_c1280",
106
+ "uir_r1_a0_k0_s1_e2_c1280",
107
+ "mqa_r1_k3_h16_s1_d96_c1280",
108
+ "uir_r1_a0_k0_s1_e2_c1280",
109
+ "mqa_r1_k3_h16_s1_d96_c1280",
110
+ "uir_r1_a0_k0_s1_e2_c1280",
111
+ "mqa_r1_k3_h16_s1_d96_c1280",
112
+ "uir_r1_a0_k0_s1_e2_c1280",
113
+ "mqa_r1_k3_h16_s1_d96_c1280",
114
+ "uir_r1_a0_k0_s1_e2_c1280",
115
+ "mqa_r1_k3_h16_s1_d96_c1280",
116
+ "uir_r1_a0_k0_s1_e2_c1280",
117
+ "mqa_r1_k3_h16_s1_d96_c1280",
118
+ "uir_r1_a0_k0_s1_e2_c1280",
119
+ "mqa_r1_k3_h16_s1_d96_c1280",
120
+ "uir_r1_a0_k0_s1_e2_c1280",
121
+ "mqa_r1_k3_h16_s1_d96_c1280",
122
+ "uir_r1_a0_k0_s1_e2_c1280",
123
+ ],
124
+ ]
125
+ ),
126
+ "stem_size": 64,
127
+ "num_features": 2048,
128
+ "norm_layer": "rms_norm",
129
+ "act_layer": "gelu",
130
+ "use_msfa": True,
131
+ "layer_scale_init_value": 1e-5,
132
+ },
133
+ }
134
+
135
+
136
+ def convert_head(task, loader, timm_config):
137
+ pass
138
+
139
+
140
+ def convert_backbone_config(timm_config):
141
+ timm_architecture = timm_config["architecture"]
142
+ if timm_architecture not in MODEL_CONFIGS:
143
+ raise ValueError(f"Unsupported architecture: {timm_architecture}")
144
+ config = MODEL_CONFIGS[timm_architecture].copy()
145
+ backbone_config = config.pop("backbone")
146
+ backbone_config.update(config)
147
+ return backbone_config
148
+
149
+
150
+ def convert_weights(backbone, loader, timm_config):
151
+ def key_exists(key):
152
+ try:
153
+ loader.get_tensor(key)
154
+ return True
155
+ except Exception:
156
+ return False
157
+
158
+ def _port_weights(layer, timm_key, transpose_dims=None):
159
+ hf_weight_key = f"{timm_key}.weight"
160
+ if not key_exists(hf_weight_key):
161
+ return
162
+ hook_fn = None
163
+ if transpose_dims:
164
+
165
+ def transpose_hook(x, _):
166
+ return np.transpose(x, transpose_dims)
167
+
168
+ hook_fn = transpose_hook
169
+ loader.port_weight(
170
+ layer.kernel, hf_weight_key=hf_weight_key, hook_fn=hook_fn
171
+ )
172
+ if layer.bias is not None:
173
+ hf_bias_key = f"{timm_key}.bias"
174
+ if key_exists(hf_bias_key):
175
+ loader.port_weight(
176
+ layer.bias,
177
+ hf_weight_key=hf_bias_key,
178
+ )
179
+
180
+ def _port_bn(layer, timm_prefix):
181
+ loader.port_weight(layer.gamma, f"{timm_prefix}.weight")
182
+ loader.port_weight(layer.beta, f"{timm_prefix}.bias")
183
+ loader.port_weight(layer.moving_mean, f"{timm_prefix}.running_mean")
184
+ loader.port_weight(layer.moving_variance, f"{timm_prefix}.running_var")
185
+
186
+ def _port_rms_norm(layer, timm_prefix):
187
+ loader.port_weight(layer.gamma, f"{timm_prefix}.weight")
188
+
189
+ def _port_cna(cna_layer: ConvNormAct, timm_conv_prefix, timm_norm_prefix):
190
+ if isinstance(cna_layer.conv, keras.layers.DepthwiseConv2D):
191
+ _port_weights(
192
+ cna_layer.conv,
193
+ timm_conv_prefix,
194
+ transpose_dims=(2, 3, 0, 1),
195
+ )
196
+ else:
197
+ _port_weights(
198
+ cna_layer.conv,
199
+ timm_conv_prefix,
200
+ transpose_dims=(2, 3, 1, 0),
201
+ )
202
+ if key_exists(f"{timm_norm_prefix}.running_mean"):
203
+ _port_bn(cna_layer.norm, timm_norm_prefix)
204
+ else:
205
+ _port_rms_norm(cna_layer.norm, timm_norm_prefix)
206
+
207
+ def _port_attn(attn_layer, attn_prefix):
208
+ _port_weights(
209
+ attn_layer.query_layers[-1],
210
+ f"{attn_prefix}.query.proj",
211
+ (2, 3, 1, 0),
212
+ )
213
+ if len(attn_layer.key_layers) > 1:
214
+ _port_weights(
215
+ attn_layer.key_layers[0],
216
+ f"{attn_prefix}.key.down_conv",
217
+ (2, 3, 0, 1),
218
+ )
219
+ key_norm_layer = attn_layer.key_layers[1]
220
+ if isinstance(key_norm_layer, RmsNorm2d):
221
+ _port_rms_norm(key_norm_layer, f"{attn_prefix}.key.norm")
222
+ else:
223
+ _port_bn(key_norm_layer, f"{attn_prefix}.key.norm")
224
+ _port_weights(
225
+ attn_layer.key_layers[-1], f"{attn_prefix}.key.proj", (2, 3, 1, 0)
226
+ )
227
+ if len(attn_layer.value_layers) > 1:
228
+ _port_weights(
229
+ attn_layer.value_layers[0],
230
+ f"{attn_prefix}.value.down_conv",
231
+ (2, 3, 0, 1),
232
+ )
233
+ value_norm_layer = attn_layer.value_layers[1]
234
+ if isinstance(value_norm_layer, RmsNorm2d):
235
+ _port_rms_norm(value_norm_layer, f"{attn_prefix}.value.norm")
236
+ else:
237
+ _port_bn(value_norm_layer, f"{attn_prefix}.value.norm")
238
+ _port_weights(
239
+ attn_layer.value_layers[-1],
240
+ f"{attn_prefix}.value.proj",
241
+ (2, 3, 1, 0),
242
+ )
243
+ _port_weights(
244
+ attn_layer.output_proj_layers[-2],
245
+ f"{attn_prefix}.output.proj",
246
+ (2, 3, 1, 0),
247
+ )
248
+
249
+ stem_layer = backbone.get_layer("conv_stem")
250
+ _port_cna(stem_layer, "conv_stem.conv", "conv_stem.bn")
251
+ block_layers = [
252
+ layer
253
+ for layer in backbone.layers
254
+ if isinstance(
255
+ layer, (EdgeResidual, UniversalInvertedResidual, MobileAttention)
256
+ )
257
+ ]
258
+ block_counter = 0
259
+ for stack_idx in range(len(backbone.stackwise_num_blocks)):
260
+ for block_idx_in_stage in range(
261
+ backbone.stackwise_num_blocks[stack_idx]
262
+ ):
263
+ block = block_layers[block_counter]
264
+ timm_prefix = f"blocks.{stack_idx}.{block_idx_in_stage}"
265
+ if isinstance(block, EdgeResidual):
266
+ _port_cna(
267
+ block.conv_exp,
268
+ f"{timm_prefix}.conv_exp",
269
+ f"{timm_prefix}.bn1",
270
+ )
271
+ _port_cna(
272
+ block.conv_pwl,
273
+ f"{timm_prefix}.conv_pwl",
274
+ f"{timm_prefix}.bn2",
275
+ )
276
+ elif isinstance(block, UniversalInvertedResidual):
277
+ if hasattr(block, "dw_start") and not isinstance(
278
+ block.dw_start, types.FunctionType
279
+ ):
280
+ _port_cna(
281
+ block.dw_start,
282
+ f"{timm_prefix}.dw_start.conv",
283
+ f"{timm_prefix}.dw_start.bn",
284
+ )
285
+ _port_cna(
286
+ block.pw_exp,
287
+ f"{timm_prefix}.pw_exp.conv",
288
+ f"{timm_prefix}.pw_exp.bn",
289
+ )
290
+ if hasattr(block, "dw_mid") and not isinstance(
291
+ block.dw_mid, types.FunctionType
292
+ ):
293
+ _port_cna(
294
+ block.dw_mid,
295
+ f"{timm_prefix}.dw_mid.conv",
296
+ f"{timm_prefix}.dw_mid.bn",
297
+ )
298
+ _port_cna(
299
+ block.pw_proj,
300
+ f"{timm_prefix}.pw_proj.conv",
301
+ f"{timm_prefix}.pw_proj.bn",
302
+ )
303
+ gamma_key = f"{timm_prefix}.layer_scale.gamma"
304
+ if key_exists(gamma_key):
305
+ loader.port_weight(block.layer_scale.gamma, gamma_key)
306
+ elif isinstance(block, MobileAttention):
307
+ _port_rms_norm(block.norm, f"{timm_prefix}.norm")
308
+ gamma_key = f"{timm_prefix}.layer_scale.gamma"
309
+ if key_exists(gamma_key):
310
+ loader.port_weight(block.layer_scale.gamma, gamma_key)
311
+ attn_prefix = f"{timm_prefix}.attn"
312
+ _port_attn(block.attn, attn_prefix)
313
+ block_counter += 1
314
+ try:
315
+ msfa_layer = backbone.get_layer("msfa")
316
+ ffn = msfa_layer.ffn
317
+ _port_cna(ffn.pw_exp, "msfa.ffn.pw_exp.conv", "msfa.ffn.pw_exp.bn")
318
+ _port_cna(ffn.pw_proj, "msfa.ffn.pw_proj.conv", "msfa.ffn.pw_proj.bn")
319
+ _port_rms_norm(msfa_layer.norm, "msfa.norm")
320
+ except ValueError:
321
+ pass
@@ -7,6 +7,7 @@ from keras_hub.src.utils.timm import convert_cspnet
7
7
  from keras_hub.src.utils.timm import convert_densenet
8
8
  from keras_hub.src.utils.timm import convert_efficientnet
9
9
  from keras_hub.src.utils.timm import convert_mobilenet
10
+ from keras_hub.src.utils.timm import convert_mobilenetv5
10
11
  from keras_hub.src.utils.timm import convert_resnet
11
12
  from keras_hub.src.utils.timm import convert_vgg
12
13
  from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
@@ -22,6 +23,8 @@ class TimmPresetLoader(PresetLoader):
22
23
  self.converter = convert_cspnet
23
24
  elif architecture.startswith("densenet"):
24
25
  self.converter = convert_densenet
26
+ elif architecture.startswith("mobilenetv5"):
27
+ self.converter = convert_mobilenetv5
25
28
  elif architecture.startswith("mobilenet"):
26
29
  self.converter = convert_mobilenet
27
30
  elif architecture.startswith("vgg"):
@@ -41,7 +44,8 @@ class TimmPresetLoader(PresetLoader):
41
44
  keras_config = self.converter.convert_backbone_config(self.config)
42
45
  backbone = cls(**{**keras_config, **kwargs})
43
46
  if load_weights:
44
- jax_memory_cleanup(backbone)
47
+ if not self.config["architecture"].startswith("mobilenetv5"):
48
+ jax_memory_cleanup(backbone)
45
49
  # Use prefix="" to avoid using `get_prefixed_key`.
46
50
  with SafetensorLoader(self.preset, prefix="") as loader:
47
51
  self.converter.convert_weights(backbone, loader, self.config)
@@ -54,9 +58,9 @@ class TimmPresetLoader(PresetLoader):
54
58
  )
55
59
  # Support loading the classification head for classifier models.
56
60
  kwargs["num_classes"] = self.config["num_classes"]
57
- if (
58
- "num_features" in self.config
59
- and "mobilenet" in self.config["architecture"]
61
+ if "num_features" in self.config and (
62
+ "mobilenet" in self.config["architecture"]
63
+ or "mobilenetv5" in self.config["architecture"]
60
64
  ):
61
65
  kwargs["num_features"] = self.config["num_features"]
62
66
 
@@ -29,6 +29,7 @@ def convert_backbone_config(transformers_config):
29
29
  "image_shape": (image_size, image_size, 3),
30
30
  "position_embedding_shape": (image_size, image_size),
31
31
  "antialias_in_interpolation": antialias_in_interpolation,
32
+ "apply_layernorm": transformers_config.get("apply_layernorm", False),
32
33
  }
33
34
 
34
35
 
@@ -0,0 +1,216 @@
1
+ import numpy as np
2
+
3
+ from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone
4
+ from keras_hub.src.utils.preset_utils import load_json
5
+
6
+ backbone_cls = Qwen3MoeBackbone
7
+
8
+
9
+ def convert_backbone_config(transformers_config):
10
+ return {
11
+ "vocabulary_size": transformers_config["vocab_size"],
12
+ "hidden_dim": transformers_config["hidden_size"],
13
+ "head_dim": transformers_config["head_dim"],
14
+ "num_layers": transformers_config["num_hidden_layers"],
15
+ "num_query_heads": transformers_config["num_attention_heads"],
16
+ "num_key_value_heads": transformers_config["num_key_value_heads"],
17
+ "intermediate_dim": transformers_config["intermediate_size"],
18
+ "moe_intermediate_dim": transformers_config["moe_intermediate_size"],
19
+ "num_experts": transformers_config["num_experts"],
20
+ "top_k": transformers_config["num_experts_per_tok"],
21
+ "norm_top_k_prob": transformers_config["norm_topk_prob"],
22
+ "decoder_sparse_step": transformers_config["decoder_sparse_step"],
23
+ "layer_norm_epsilon": transformers_config["rms_norm_eps"],
24
+ "rope_max_wavelength": transformers_config["rope_theta"],
25
+ "sliding_window_size": transformers_config["sliding_window"],
26
+ "router_aux_loss_coefficient": transformers_config[
27
+ "router_aux_loss_coef"
28
+ ],
29
+ "tie_word_embeddings": transformers_config.get(
30
+ "tie_word_embeddings", False
31
+ ),
32
+ }
33
+
34
+
35
+ def convert_weights(backbone, loader, transformers_config):
36
+ loader.port_weight(
37
+ keras_variable=backbone.get_layer("token_embedding").embeddings,
38
+ hf_weight_key="model.embed_tokens.weight",
39
+ )
40
+ if not backbone.tie_word_embeddings:
41
+ loader.port_weight(
42
+ keras_variable=backbone.get_layer(
43
+ "token_embedding"
44
+ ).reverse_embeddings,
45
+ hf_weight_key="lm_head.weight",
46
+ # rearrange_pattern="b a -> a b",
47
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
48
+ )
49
+
50
+ def transpose_and_reshape(x, shape):
51
+ return np.reshape(np.transpose(x), shape)
52
+
53
+ for i in range(backbone.num_layers):
54
+ decoder_layer = backbone.get_layer(f"transformer_layer_{i}")
55
+
56
+ # Input layernorm
57
+ loader.port_weight(
58
+ keras_variable=decoder_layer._self_attention_layernorm.scale,
59
+ hf_weight_key=f"model.layers.{i}.input_layernorm.weight",
60
+ )
61
+
62
+ # Attention layers
63
+
64
+ ## Query
65
+ loader.port_weight(
66
+ keras_variable=decoder_layer._self_attention_layer._query_dense.kernel,
67
+ hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight",
68
+ hook_fn=transpose_and_reshape,
69
+ )
70
+ loader.port_weight(
71
+ keras_variable=decoder_layer._self_attention_layer._query_dense_layer_norm.scale,
72
+ hf_weight_key=f"model.layers.{i}.self_attn.q_norm.weight",
73
+ )
74
+ ## Key
75
+ loader.port_weight(
76
+ keras_variable=decoder_layer._self_attention_layer._key_dense.kernel,
77
+ hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight",
78
+ hook_fn=transpose_and_reshape,
79
+ )
80
+ loader.port_weight(
81
+ keras_variable=decoder_layer._self_attention_layer._key_dense_layer_norm.scale,
82
+ hf_weight_key=f"model.layers.{i}.self_attn.k_norm.weight",
83
+ )
84
+ ## Value
85
+ loader.port_weight(
86
+ keras_variable=decoder_layer._self_attention_layer._value_dense.kernel,
87
+ hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight",
88
+ hook_fn=transpose_and_reshape,
89
+ )
90
+ ## Output
91
+ loader.port_weight(
92
+ keras_variable=decoder_layer._self_attention_layer._output_dense.kernel,
93
+ hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight",
94
+ # rearrange_patterns="c (a b) -> a b c",
95
+ # rearrange_dims={"a": backbone.num_query_heads},
96
+ hook_fn=transpose_and_reshape,
97
+ )
98
+
99
+ # MLP layers
100
+ if (
101
+ (i not in backbone.mlp_only_layers)
102
+ and backbone.num_experts > 0
103
+ and ((i + 1) % backbone.decoder_sparse_step == 0)
104
+ ):
105
+ # MoE layers
106
+ loader.port_weight(
107
+ keras_variable=decoder_layer.mlp._sparse_feedforward_gate_dense.kernel,
108
+ hf_weight_key=f"model.layers.{i}.mlp.gate.weight",
109
+ # rearrange_patterns="b a -> a b",
110
+ hook_fn=lambda hf_tensor, _: np.transpose(
111
+ hf_tensor, axes=(1, 0)
112
+ ),
113
+ )
114
+ # Batched experts: gate_up_proj and down_proj
115
+ gate_up_proj_list = []
116
+ down_proj_list = []
117
+ for expert_idx in range(backbone.num_experts):
118
+ # Load gate_proj and up_proj for each expert
119
+ gate_proj = loader.get_tensor(
120
+ f"model.layers.{i}.mlp.experts.{expert_idx}.gate_proj.weight"
121
+ )
122
+ up_proj = loader.get_tensor(
123
+ f"model.layers.{i}.mlp.experts.{expert_idx}.up_proj.weight"
124
+ )
125
+ # Transpose to (hidden_dim, intermediate_dim)
126
+ gate_proj = np.transpose(gate_proj, axes=(1, 0))
127
+ up_proj = np.transpose(up_proj, axes=(1, 0))
128
+ # Concatenate gate_proj and up_proj along the last dimension
129
+ gate_up_proj = np.concatenate([gate_proj, up_proj], axis=-1)
130
+ gate_up_proj_list.append(gate_up_proj)
131
+
132
+ # Load down_proj for each expert
133
+ down_proj = loader.get_tensor(
134
+ f"model.layers.{i}.mlp.experts.{expert_idx}.down_proj.weight"
135
+ )
136
+ down_proj = np.transpose(
137
+ down_proj, axes=(1, 0)
138
+ ) # (intermediate_dim, hidden_dim)
139
+ down_proj_list.append(down_proj)
140
+
141
+ # Stack the lists to create batched weights
142
+ gate_up_proj_batched = np.stack(
143
+ gate_up_proj_list, axis=0
144
+ ) # (num_experts, hidden_dim, 2 * intermediate_dim)
145
+ down_proj_batched = np.stack(
146
+ down_proj_list, axis=0
147
+ ) # (num_experts, intermediate_dim, hidden_dim)
148
+
149
+ # Assign batched weights to expert_bank
150
+ decoder_layer.mlp.expert_bank._expert_feedforward_gate_dense.assign(
151
+ gate_up_proj_batched
152
+ )
153
+ decoder_layer.mlp.expert_bank._expert_feedforward_output_dense.assign(
154
+ down_proj_batched
155
+ )
156
+ else:
157
+ loader.port_weight(
158
+ keras_variable=decoder_layer._feedforward_intermediate_dense.kernel,
159
+ hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight",
160
+ # rearrange_patterns="b a -> a b",
161
+ hook_fn=lambda hf_tensor, _: np.transpose(
162
+ hf_tensor, axes=(1, 0)
163
+ ),
164
+ )
165
+ loader.port_weight(
166
+ keras_variable=decoder_layer._feedforward_output_dense.kernel,
167
+ hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight",
168
+ # rearrange_patterns="b a -> a b",
169
+ hook_fn=lambda hf_tensor, _: np.transpose(
170
+ hf_tensor, axes=(1, 0)
171
+ ),
172
+ )
173
+ loader.port_weight(
174
+ keras_variable=decoder_layer._feedforward_gate_dense.kernel,
175
+ hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight",
176
+ # rearrange_patterns="b a -> a b",
177
+ hook_fn=lambda hf_tensor, _: np.transpose(
178
+ hf_tensor, axes=(1, 0)
179
+ ),
180
+ )
181
+
182
+ # Feedforward layernorm
183
+ loader.port_weight(
184
+ keras_variable=decoder_layer._feedforward_layernorm.scale,
185
+ hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight",
186
+ )
187
+
188
+ # Final normalization layer
189
+ loader.port_weight(
190
+ keras_variable=backbone.get_layer("sequence_output_layernorm").scale,
191
+ hf_weight_key="model.norm.weight",
192
+ )
193
+
194
+ return backbone
195
+
196
+
197
+ def convert_tokenizer(cls, preset, **kwargs):
198
+ tokenizer_config = load_json(preset, "tokenizer.json")
199
+ vocab = tokenizer_config["model"]["vocab"]
200
+ merges = tokenizer_config["model"]["merges"]
201
+ merges = [" ".join(item) for item in merges]
202
+
203
+ # Load all special tokens with the exception of "reserved" ones.
204
+ special_tokens = set()
205
+ for token in tokenizer_config["added_tokens"]:
206
+ if not token["content"].startswith("<|reserved_special_token_"):
207
+ vocab[token["content"]] = token["id"]
208
+ special_tokens.add(token["content"])
209
+
210
+ kwargs.update(
211
+ {
212
+ "unsplittable_tokens": list(special_tokens),
213
+ }
214
+ )
215
+
216
+ return cls(vocabulary=vocab, merges=merges, **kwargs)