keras-hub-nightly 0.22.0.dev202508170419__py3-none-any.whl → 0.24.0.dev202511090424__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-hub-nightly might be problematic. Click here for more details.
- keras_hub/layers/__init__.py +15 -0
- keras_hub/models/__init__.py +93 -0
- keras_hub/src/layers/modeling/position_embedding.py +21 -6
- keras_hub/src/layers/modeling/reversible_embedding.py +8 -1
- keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
- keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
- keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
- keras_hub/src/models/backbone.py +28 -16
- keras_hub/src/models/causal_lm.py +37 -0
- keras_hub/src/models/causal_lm_preprocessor.py +14 -0
- keras_hub/src/models/clip/clip_presets.py +8 -8
- keras_hub/src/models/d_fine/__init__.py +5 -0
- keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
- keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
- keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
- keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
- keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
- keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
- keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
- keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
- keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
- keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/d_fine/d_fine_presets.py +155 -0
- keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +7 -2
- keras_hub/src/models/depth_anything/__init__.py +9 -0
- keras_hub/src/models/depth_anything/depth_anything_backbone.py +232 -0
- keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py +70 -0
- keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py +16 -0
- keras_hub/src/models/depth_anything/depth_anything_image_converter.py +10 -0
- keras_hub/src/models/depth_anything/depth_anything_layers.py +725 -0
- keras_hub/src/models/depth_anything/depth_anything_loss.py +89 -0
- keras_hub/src/models/depth_anything/depth_anything_presets.py +41 -0
- keras_hub/src/models/depth_anything/interpolate.py +62 -0
- keras_hub/src/models/depth_estimator.py +239 -0
- keras_hub/src/models/depth_estimator_preprocessor.py +78 -0
- keras_hub/src/models/dinov2/dinov2_backbone.py +29 -3
- keras_hub/src/models/dinov2/dinov2_layers.py +16 -4
- keras_hub/src/models/dinov3/__init__.py +5 -0
- keras_hub/src/models/dinov3/dinov3_backbone.py +263 -0
- keras_hub/src/models/dinov3/dinov3_image_converter.py +8 -0
- keras_hub/src/models/dinov3/dinov3_layers.py +1013 -0
- keras_hub/src/models/dinov3/dinov3_presets.py +4 -0
- keras_hub/src/models/gemma/gemma_backbone.py +0 -1
- keras_hub/src/models/gemma/gemma_presets.py +30 -0
- keras_hub/src/models/gemma3/gemma3_attention.py +48 -0
- keras_hub/src/models/gemma3/gemma3_backbone.py +4 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +12 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +39 -0
- keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
- keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
- keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
- keras_hub/src/models/image_to_image.py +5 -0
- keras_hub/src/models/inpaint.py +5 -0
- keras_hub/src/models/mobilenetv5/__init__.py +9 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_presets.py +15 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
- keras_hub/src/models/parseq/__init__.py +5 -0
- keras_hub/src/models/parseq/parseq_backbone.py +134 -0
- keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
- keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
- keras_hub/src/models/parseq/parseq_decoder.py +418 -0
- keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
- keras_hub/src/models/parseq/parseq_presets.py +15 -0
- keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
- keras_hub/src/models/qwen3_moe/__init__.py +5 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +371 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +365 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py +357 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py +12 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +672 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +45 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py +48 -0
- keras_hub/src/models/sam/sam_prompt_encoder.py +3 -1
- keras_hub/src/models/siglip/siglip_presets.py +15 -0
- keras_hub/src/models/smollm3/smollm3_backbone.py +211 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm.py +310 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py +84 -0
- keras_hub/src/models/smollm3/smollm3_layers.py +757 -0
- keras_hub/src/models/smollm3/smollm3_tokenizer.py +60 -0
- keras_hub/src/models/smollm3/smollm3_utils.py +56 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
- keras_hub/src/models/t5gemma/__init__.py +5 -0
- keras_hub/src/models/t5gemma/t5gemma_attention.py +370 -0
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +366 -0
- keras_hub/src/models/t5gemma/t5gemma_decoder.py +355 -0
- keras_hub/src/models/t5gemma/t5gemma_encoder.py +214 -0
- keras_hub/src/models/t5gemma/t5gemma_layers.py +118 -0
- keras_hub/src/models/t5gemma/t5gemma_presets.py +374 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py +442 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +216 -0
- keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +84 -0
- keras_hub/src/models/text_to_image.py +5 -0
- keras_hub/src/samplers/beam_sampler.py +6 -6
- keras_hub/src/samplers/sampler.py +8 -6
- keras_hub/src/tests/test_case.py +40 -3
- keras_hub/src/tokenizers/tokenizer.py +15 -0
- keras_hub/src/utils/openvino_utils.py +141 -0
- keras_hub/src/utils/preset_utils.py +58 -2
- keras_hub/src/utils/tensor_utils.py +26 -2
- keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
- keras_hub/src/utils/timm/preset_loader.py +8 -4
- keras_hub/src/utils/transformers/convert_dinov2.py +1 -0
- keras_hub/src/utils/transformers/convert_dinov3.py +106 -0
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +216 -0
- keras_hub/src/utils/transformers/convert_smollm3.py +139 -0
- keras_hub/src/utils/transformers/convert_t5gemma.py +229 -0
- keras_hub/src/utils/transformers/convert_vit.py +4 -1
- keras_hub/src/utils/transformers/export/gemma.py +49 -4
- keras_hub/src/utils/transformers/export/hf_exporter.py +71 -25
- keras_hub/src/utils/transformers/preset_loader.py +12 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +15 -0
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/RECORD +126 -47
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
import types
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_attention import (
|
|
7
|
+
MobileAttention,
|
|
8
|
+
)
|
|
9
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import (
|
|
10
|
+
MobileNetV5Backbone,
|
|
11
|
+
)
|
|
12
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_blocks import EdgeResidual
|
|
13
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_blocks import (
|
|
14
|
+
UniversalInvertedResidual,
|
|
15
|
+
)
|
|
16
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_builder import (
|
|
17
|
+
convert_arch_def_to_stackwise,
|
|
18
|
+
)
|
|
19
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import ConvNormAct
|
|
20
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import RmsNorm2d
|
|
21
|
+
|
|
22
|
+
backbone_cls = MobileNetV5Backbone
|
|
23
|
+
|
|
24
|
+
MODEL_CONFIGS = {
|
|
25
|
+
"mobilenetv5_300m": {
|
|
26
|
+
"backbone": convert_arch_def_to_stackwise(
|
|
27
|
+
[
|
|
28
|
+
# Stage 0: 128x128 in
|
|
29
|
+
[
|
|
30
|
+
"er_r1_k3_s2_e4_c128",
|
|
31
|
+
"er_r1_k3_s1_e4_c128",
|
|
32
|
+
"er_r1_k3_s1_e4_c128",
|
|
33
|
+
],
|
|
34
|
+
# Stage 1: 256x256 in
|
|
35
|
+
[
|
|
36
|
+
"uir_r1_a3_k5_s2_e6_c256",
|
|
37
|
+
"uir_r1_a5_k0_s1_e4_c256",
|
|
38
|
+
"uir_r1_a3_k0_s1_e4_c256",
|
|
39
|
+
"uir_r1_a5_k0_s1_e4_c256",
|
|
40
|
+
"uir_r1_a3_k0_s1_e4_c256",
|
|
41
|
+
],
|
|
42
|
+
# Stage 2: 640x640 in
|
|
43
|
+
[
|
|
44
|
+
"uir_r1_a5_k5_s2_e6_c640",
|
|
45
|
+
"uir_r1_a5_k0_s1_e4_c640",
|
|
46
|
+
"uir_r1_a5_k0_s1_e4_c640",
|
|
47
|
+
"uir_r1_a5_k0_s1_e4_c640",
|
|
48
|
+
"uir_r1_a5_k0_s1_e4_c640",
|
|
49
|
+
"uir_r1_a5_k0_s1_e4_c640",
|
|
50
|
+
"uir_r1_a5_k0_s1_e4_c640",
|
|
51
|
+
"uir_r1_a5_k0_s1_e4_c640",
|
|
52
|
+
"uir_r1_a0_k0_s1_e1_c640",
|
|
53
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
54
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
55
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
56
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
57
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
58
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
59
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
60
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
61
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
62
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
63
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
64
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
65
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
66
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
67
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
68
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
69
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
70
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
71
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
72
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
73
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
74
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
75
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
76
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
77
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
78
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
79
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
80
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
81
|
+
],
|
|
82
|
+
# Stage 3: 1280x1280 in
|
|
83
|
+
[
|
|
84
|
+
"uir_r1_a5_k5_s2_e6_c1280",
|
|
85
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
86
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
87
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
88
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
89
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
90
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
91
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
92
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
93
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
94
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
95
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
96
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
97
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
98
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
99
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
100
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
101
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
102
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
103
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
104
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
105
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
106
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
107
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
108
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
109
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
110
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
111
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
112
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
113
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
114
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
115
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
116
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
117
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
118
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
119
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
120
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
121
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
122
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
123
|
+
],
|
|
124
|
+
]
|
|
125
|
+
),
|
|
126
|
+
"stem_size": 64,
|
|
127
|
+
"num_features": 2048,
|
|
128
|
+
"norm_layer": "rms_norm",
|
|
129
|
+
"act_layer": "gelu",
|
|
130
|
+
"use_msfa": True,
|
|
131
|
+
"layer_scale_init_value": 1e-5,
|
|
132
|
+
},
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def convert_head(task, loader, timm_config):
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def convert_backbone_config(timm_config):
|
|
141
|
+
timm_architecture = timm_config["architecture"]
|
|
142
|
+
if timm_architecture not in MODEL_CONFIGS:
|
|
143
|
+
raise ValueError(f"Unsupported architecture: {timm_architecture}")
|
|
144
|
+
config = MODEL_CONFIGS[timm_architecture].copy()
|
|
145
|
+
backbone_config = config.pop("backbone")
|
|
146
|
+
backbone_config.update(config)
|
|
147
|
+
return backbone_config
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def convert_weights(backbone, loader, timm_config):
|
|
151
|
+
def key_exists(key):
|
|
152
|
+
try:
|
|
153
|
+
loader.get_tensor(key)
|
|
154
|
+
return True
|
|
155
|
+
except Exception:
|
|
156
|
+
return False
|
|
157
|
+
|
|
158
|
+
def _port_weights(layer, timm_key, transpose_dims=None):
|
|
159
|
+
hf_weight_key = f"{timm_key}.weight"
|
|
160
|
+
if not key_exists(hf_weight_key):
|
|
161
|
+
return
|
|
162
|
+
hook_fn = None
|
|
163
|
+
if transpose_dims:
|
|
164
|
+
|
|
165
|
+
def transpose_hook(x, _):
|
|
166
|
+
return np.transpose(x, transpose_dims)
|
|
167
|
+
|
|
168
|
+
hook_fn = transpose_hook
|
|
169
|
+
loader.port_weight(
|
|
170
|
+
layer.kernel, hf_weight_key=hf_weight_key, hook_fn=hook_fn
|
|
171
|
+
)
|
|
172
|
+
if layer.bias is not None:
|
|
173
|
+
hf_bias_key = f"{timm_key}.bias"
|
|
174
|
+
if key_exists(hf_bias_key):
|
|
175
|
+
loader.port_weight(
|
|
176
|
+
layer.bias,
|
|
177
|
+
hf_weight_key=hf_bias_key,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def _port_bn(layer, timm_prefix):
|
|
181
|
+
loader.port_weight(layer.gamma, f"{timm_prefix}.weight")
|
|
182
|
+
loader.port_weight(layer.beta, f"{timm_prefix}.bias")
|
|
183
|
+
loader.port_weight(layer.moving_mean, f"{timm_prefix}.running_mean")
|
|
184
|
+
loader.port_weight(layer.moving_variance, f"{timm_prefix}.running_var")
|
|
185
|
+
|
|
186
|
+
def _port_rms_norm(layer, timm_prefix):
|
|
187
|
+
loader.port_weight(layer.gamma, f"{timm_prefix}.weight")
|
|
188
|
+
|
|
189
|
+
def _port_cna(cna_layer: ConvNormAct, timm_conv_prefix, timm_norm_prefix):
|
|
190
|
+
if isinstance(cna_layer.conv, keras.layers.DepthwiseConv2D):
|
|
191
|
+
_port_weights(
|
|
192
|
+
cna_layer.conv,
|
|
193
|
+
timm_conv_prefix,
|
|
194
|
+
transpose_dims=(2, 3, 0, 1),
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
_port_weights(
|
|
198
|
+
cna_layer.conv,
|
|
199
|
+
timm_conv_prefix,
|
|
200
|
+
transpose_dims=(2, 3, 1, 0),
|
|
201
|
+
)
|
|
202
|
+
if key_exists(f"{timm_norm_prefix}.running_mean"):
|
|
203
|
+
_port_bn(cna_layer.norm, timm_norm_prefix)
|
|
204
|
+
else:
|
|
205
|
+
_port_rms_norm(cna_layer.norm, timm_norm_prefix)
|
|
206
|
+
|
|
207
|
+
def _port_attn(attn_layer, attn_prefix):
|
|
208
|
+
_port_weights(
|
|
209
|
+
attn_layer.query_layers[-1],
|
|
210
|
+
f"{attn_prefix}.query.proj",
|
|
211
|
+
(2, 3, 1, 0),
|
|
212
|
+
)
|
|
213
|
+
if len(attn_layer.key_layers) > 1:
|
|
214
|
+
_port_weights(
|
|
215
|
+
attn_layer.key_layers[0],
|
|
216
|
+
f"{attn_prefix}.key.down_conv",
|
|
217
|
+
(2, 3, 0, 1),
|
|
218
|
+
)
|
|
219
|
+
key_norm_layer = attn_layer.key_layers[1]
|
|
220
|
+
if isinstance(key_norm_layer, RmsNorm2d):
|
|
221
|
+
_port_rms_norm(key_norm_layer, f"{attn_prefix}.key.norm")
|
|
222
|
+
else:
|
|
223
|
+
_port_bn(key_norm_layer, f"{attn_prefix}.key.norm")
|
|
224
|
+
_port_weights(
|
|
225
|
+
attn_layer.key_layers[-1], f"{attn_prefix}.key.proj", (2, 3, 1, 0)
|
|
226
|
+
)
|
|
227
|
+
if len(attn_layer.value_layers) > 1:
|
|
228
|
+
_port_weights(
|
|
229
|
+
attn_layer.value_layers[0],
|
|
230
|
+
f"{attn_prefix}.value.down_conv",
|
|
231
|
+
(2, 3, 0, 1),
|
|
232
|
+
)
|
|
233
|
+
value_norm_layer = attn_layer.value_layers[1]
|
|
234
|
+
if isinstance(value_norm_layer, RmsNorm2d):
|
|
235
|
+
_port_rms_norm(value_norm_layer, f"{attn_prefix}.value.norm")
|
|
236
|
+
else:
|
|
237
|
+
_port_bn(value_norm_layer, f"{attn_prefix}.value.norm")
|
|
238
|
+
_port_weights(
|
|
239
|
+
attn_layer.value_layers[-1],
|
|
240
|
+
f"{attn_prefix}.value.proj",
|
|
241
|
+
(2, 3, 1, 0),
|
|
242
|
+
)
|
|
243
|
+
_port_weights(
|
|
244
|
+
attn_layer.output_proj_layers[-2],
|
|
245
|
+
f"{attn_prefix}.output.proj",
|
|
246
|
+
(2, 3, 1, 0),
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
stem_layer = backbone.get_layer("conv_stem")
|
|
250
|
+
_port_cna(stem_layer, "conv_stem.conv", "conv_stem.bn")
|
|
251
|
+
block_layers = [
|
|
252
|
+
layer
|
|
253
|
+
for layer in backbone.layers
|
|
254
|
+
if isinstance(
|
|
255
|
+
layer, (EdgeResidual, UniversalInvertedResidual, MobileAttention)
|
|
256
|
+
)
|
|
257
|
+
]
|
|
258
|
+
block_counter = 0
|
|
259
|
+
for stack_idx in range(len(backbone.stackwise_num_blocks)):
|
|
260
|
+
for block_idx_in_stage in range(
|
|
261
|
+
backbone.stackwise_num_blocks[stack_idx]
|
|
262
|
+
):
|
|
263
|
+
block = block_layers[block_counter]
|
|
264
|
+
timm_prefix = f"blocks.{stack_idx}.{block_idx_in_stage}"
|
|
265
|
+
if isinstance(block, EdgeResidual):
|
|
266
|
+
_port_cna(
|
|
267
|
+
block.conv_exp,
|
|
268
|
+
f"{timm_prefix}.conv_exp",
|
|
269
|
+
f"{timm_prefix}.bn1",
|
|
270
|
+
)
|
|
271
|
+
_port_cna(
|
|
272
|
+
block.conv_pwl,
|
|
273
|
+
f"{timm_prefix}.conv_pwl",
|
|
274
|
+
f"{timm_prefix}.bn2",
|
|
275
|
+
)
|
|
276
|
+
elif isinstance(block, UniversalInvertedResidual):
|
|
277
|
+
if hasattr(block, "dw_start") and not isinstance(
|
|
278
|
+
block.dw_start, types.FunctionType
|
|
279
|
+
):
|
|
280
|
+
_port_cna(
|
|
281
|
+
block.dw_start,
|
|
282
|
+
f"{timm_prefix}.dw_start.conv",
|
|
283
|
+
f"{timm_prefix}.dw_start.bn",
|
|
284
|
+
)
|
|
285
|
+
_port_cna(
|
|
286
|
+
block.pw_exp,
|
|
287
|
+
f"{timm_prefix}.pw_exp.conv",
|
|
288
|
+
f"{timm_prefix}.pw_exp.bn",
|
|
289
|
+
)
|
|
290
|
+
if hasattr(block, "dw_mid") and not isinstance(
|
|
291
|
+
block.dw_mid, types.FunctionType
|
|
292
|
+
):
|
|
293
|
+
_port_cna(
|
|
294
|
+
block.dw_mid,
|
|
295
|
+
f"{timm_prefix}.dw_mid.conv",
|
|
296
|
+
f"{timm_prefix}.dw_mid.bn",
|
|
297
|
+
)
|
|
298
|
+
_port_cna(
|
|
299
|
+
block.pw_proj,
|
|
300
|
+
f"{timm_prefix}.pw_proj.conv",
|
|
301
|
+
f"{timm_prefix}.pw_proj.bn",
|
|
302
|
+
)
|
|
303
|
+
gamma_key = f"{timm_prefix}.layer_scale.gamma"
|
|
304
|
+
if key_exists(gamma_key):
|
|
305
|
+
loader.port_weight(block.layer_scale.gamma, gamma_key)
|
|
306
|
+
elif isinstance(block, MobileAttention):
|
|
307
|
+
_port_rms_norm(block.norm, f"{timm_prefix}.norm")
|
|
308
|
+
gamma_key = f"{timm_prefix}.layer_scale.gamma"
|
|
309
|
+
if key_exists(gamma_key):
|
|
310
|
+
loader.port_weight(block.layer_scale.gamma, gamma_key)
|
|
311
|
+
attn_prefix = f"{timm_prefix}.attn"
|
|
312
|
+
_port_attn(block.attn, attn_prefix)
|
|
313
|
+
block_counter += 1
|
|
314
|
+
try:
|
|
315
|
+
msfa_layer = backbone.get_layer("msfa")
|
|
316
|
+
ffn = msfa_layer.ffn
|
|
317
|
+
_port_cna(ffn.pw_exp, "msfa.ffn.pw_exp.conv", "msfa.ffn.pw_exp.bn")
|
|
318
|
+
_port_cna(ffn.pw_proj, "msfa.ffn.pw_proj.conv", "msfa.ffn.pw_proj.bn")
|
|
319
|
+
_port_rms_norm(msfa_layer.norm, "msfa.norm")
|
|
320
|
+
except ValueError:
|
|
321
|
+
pass
|
|
@@ -7,6 +7,7 @@ from keras_hub.src.utils.timm import convert_cspnet
|
|
|
7
7
|
from keras_hub.src.utils.timm import convert_densenet
|
|
8
8
|
from keras_hub.src.utils.timm import convert_efficientnet
|
|
9
9
|
from keras_hub.src.utils.timm import convert_mobilenet
|
|
10
|
+
from keras_hub.src.utils.timm import convert_mobilenetv5
|
|
10
11
|
from keras_hub.src.utils.timm import convert_resnet
|
|
11
12
|
from keras_hub.src.utils.timm import convert_vgg
|
|
12
13
|
from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
|
|
@@ -22,6 +23,8 @@ class TimmPresetLoader(PresetLoader):
|
|
|
22
23
|
self.converter = convert_cspnet
|
|
23
24
|
elif architecture.startswith("densenet"):
|
|
24
25
|
self.converter = convert_densenet
|
|
26
|
+
elif architecture.startswith("mobilenetv5"):
|
|
27
|
+
self.converter = convert_mobilenetv5
|
|
25
28
|
elif architecture.startswith("mobilenet"):
|
|
26
29
|
self.converter = convert_mobilenet
|
|
27
30
|
elif architecture.startswith("vgg"):
|
|
@@ -41,7 +44,8 @@ class TimmPresetLoader(PresetLoader):
|
|
|
41
44
|
keras_config = self.converter.convert_backbone_config(self.config)
|
|
42
45
|
backbone = cls(**{**keras_config, **kwargs})
|
|
43
46
|
if load_weights:
|
|
44
|
-
|
|
47
|
+
if not self.config["architecture"].startswith("mobilenetv5"):
|
|
48
|
+
jax_memory_cleanup(backbone)
|
|
45
49
|
# Use prefix="" to avoid using `get_prefixed_key`.
|
|
46
50
|
with SafetensorLoader(self.preset, prefix="") as loader:
|
|
47
51
|
self.converter.convert_weights(backbone, loader, self.config)
|
|
@@ -54,9 +58,9 @@ class TimmPresetLoader(PresetLoader):
|
|
|
54
58
|
)
|
|
55
59
|
# Support loading the classification head for classifier models.
|
|
56
60
|
kwargs["num_classes"] = self.config["num_classes"]
|
|
57
|
-
if (
|
|
58
|
-
"
|
|
59
|
-
|
|
61
|
+
if "num_features" in self.config and (
|
|
62
|
+
"mobilenet" in self.config["architecture"]
|
|
63
|
+
or "mobilenetv5" in self.config["architecture"]
|
|
60
64
|
):
|
|
61
65
|
kwargs["num_features"] = self.config["num_features"]
|
|
62
66
|
|
|
@@ -29,6 +29,7 @@ def convert_backbone_config(transformers_config):
|
|
|
29
29
|
"image_shape": (image_size, image_size, 3),
|
|
30
30
|
"position_embedding_shape": (image_size, image_size),
|
|
31
31
|
"antialias_in_interpolation": antialias_in_interpolation,
|
|
32
|
+
"apply_layernorm": transformers_config.get("apply_layernorm", False),
|
|
32
33
|
}
|
|
33
34
|
|
|
34
35
|
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from keras_hub.src.models.dinov3.dinov3_backbone import DINOV3Backbone
|
|
4
|
+
|
|
5
|
+
backbone_cls = DINOV3Backbone
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def convert_backbone_config(transformers_config):
|
|
9
|
+
image_size = transformers_config["image_size"]
|
|
10
|
+
return {
|
|
11
|
+
"patch_size": transformers_config["patch_size"],
|
|
12
|
+
"num_layers": transformers_config["num_hidden_layers"],
|
|
13
|
+
"hidden_dim": transformers_config["hidden_size"],
|
|
14
|
+
"num_heads": transformers_config["num_attention_heads"],
|
|
15
|
+
"intermediate_dim": transformers_config["intermediate_size"],
|
|
16
|
+
"layer_scale_init_value": transformers_config["layerscale_value"],
|
|
17
|
+
"num_register_tokens": transformers_config["num_register_tokens"],
|
|
18
|
+
"use_mask_token": True,
|
|
19
|
+
"hidden_activation": transformers_config["hidden_act"],
|
|
20
|
+
"use_gated_mlp": transformers_config["use_gated_mlp"],
|
|
21
|
+
"use_query_bias": transformers_config["query_bias"],
|
|
22
|
+
"use_key_bias": transformers_config["key_bias"],
|
|
23
|
+
"use_value_bias": transformers_config["value_bias"],
|
|
24
|
+
"use_proj_bias": transformers_config["proj_bias"],
|
|
25
|
+
"use_mlp_bias": transformers_config["mlp_bias"],
|
|
26
|
+
"attention_dropout": transformers_config["attention_dropout"],
|
|
27
|
+
"drop_path_rate": transformers_config["drop_path_rate"],
|
|
28
|
+
"layer_norm_eps": transformers_config["layer_norm_eps"],
|
|
29
|
+
"image_shape": (image_size, image_size, 3),
|
|
30
|
+
"rope_theta": transformers_config["rope_theta"],
|
|
31
|
+
"apply_layernorm": False,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def convert_weights(backbone, loader, transformers_config):
|
|
36
|
+
if not isinstance(backbone, DINOV3Backbone):
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"The provided backbone must be an instance of DINOV3Backbone. "
|
|
39
|
+
f"Received: {type(backbone)}"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def port_ln(keras_variable, weight_key):
|
|
43
|
+
loader.port_weight(keras_variable.gamma, f"{weight_key}.weight")
|
|
44
|
+
loader.port_weight(keras_variable.beta, f"{weight_key}.bias")
|
|
45
|
+
|
|
46
|
+
def port_dense(keras_variable, weight_key):
|
|
47
|
+
loader.port_weight(
|
|
48
|
+
keras_variable.kernel,
|
|
49
|
+
f"{weight_key}.weight",
|
|
50
|
+
hook_fn=lambda x, _: x.T,
|
|
51
|
+
)
|
|
52
|
+
if keras_variable.bias is not None:
|
|
53
|
+
loader.port_weight(keras_variable.bias, f"{weight_key}.bias")
|
|
54
|
+
|
|
55
|
+
# Embedding.
|
|
56
|
+
loader.port_weight(
|
|
57
|
+
keras_variable=backbone.embeddings.cls_token,
|
|
58
|
+
hf_weight_key="embeddings.cls_token",
|
|
59
|
+
)
|
|
60
|
+
if backbone.use_mask_token:
|
|
61
|
+
loader.port_weight(
|
|
62
|
+
keras_variable=backbone.embeddings.mask_token,
|
|
63
|
+
hf_weight_key="embeddings.mask_token",
|
|
64
|
+
)
|
|
65
|
+
if backbone.num_register_tokens > 0:
|
|
66
|
+
loader.port_weight(
|
|
67
|
+
keras_variable=backbone.embeddings.register_tokens,
|
|
68
|
+
hf_weight_key="embeddings.register_tokens",
|
|
69
|
+
)
|
|
70
|
+
loader.port_weight(
|
|
71
|
+
keras_variable=backbone.embeddings.patch_embeddings.projection.kernel,
|
|
72
|
+
hf_weight_key="embeddings.patch_embeddings.weight",
|
|
73
|
+
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
|
|
74
|
+
)
|
|
75
|
+
loader.port_weight(
|
|
76
|
+
keras_variable=backbone.embeddings.patch_embeddings.projection.bias,
|
|
77
|
+
hf_weight_key="embeddings.patch_embeddings.bias",
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Encoder.
|
|
81
|
+
for i, layer in enumerate(backbone.encoder.layers):
|
|
82
|
+
prefix = f"layer.{i}"
|
|
83
|
+
port_ln(layer.norm1, f"{prefix}.norm1")
|
|
84
|
+
port_dense(layer.attention.query_dense, f"{prefix}.attention.q_proj")
|
|
85
|
+
port_dense(layer.attention.key_dense, f"{prefix}.attention.k_proj")
|
|
86
|
+
port_dense(layer.attention.value_dense, f"{prefix}.attention.v_proj")
|
|
87
|
+
port_dense(layer.attention.output_dense, f"{prefix}.attention.o_proj")
|
|
88
|
+
|
|
89
|
+
loader.port_weight(
|
|
90
|
+
keras_variable=layer.layer_scale1.lambda1,
|
|
91
|
+
hf_weight_key=f"{prefix}.layer_scale1.lambda1",
|
|
92
|
+
)
|
|
93
|
+
port_ln(layer.norm2, f"{prefix}.norm2")
|
|
94
|
+
if backbone.use_gated_mlp:
|
|
95
|
+
port_dense(layer.mlp.gate_proj, f"{prefix}.mlp.gate_proj")
|
|
96
|
+
port_dense(layer.mlp.up_proj, f"{prefix}.mlp.up_proj")
|
|
97
|
+
port_dense(layer.mlp.down_proj, f"{prefix}.mlp.down_proj")
|
|
98
|
+
else:
|
|
99
|
+
port_dense(layer.mlp.up_proj, f"{prefix}.mlp.up_proj")
|
|
100
|
+
port_dense(layer.mlp.down_proj, f"{prefix}.mlp.down_proj")
|
|
101
|
+
loader.port_weight(
|
|
102
|
+
keras_variable=layer.layer_scale2.lambda1,
|
|
103
|
+
hf_weight_key=f"{prefix}.layer_scale2.lambda1",
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
port_ln(backbone.layernorm, "norm")
|