keras-hub-nightly 0.23.0.dev202510080414__py3-none-any.whl → 0.24.0.dev202511080419__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +6 -0
- keras_hub/models/__init__.py +36 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +6 -0
- keras_hub/src/models/causal_lm.py +5 -0
- keras_hub/src/models/depth_anything/depth_anything_presets.py +38 -1
- keras_hub/src/models/dinov2/dinov2_layers.py +3 -1
- keras_hub/src/models/dinov3/__init__.py +5 -0
- keras_hub/src/models/dinov3/dinov3_backbone.py +263 -0
- keras_hub/src/models/dinov3/dinov3_image_converter.py +8 -0
- keras_hub/src/models/dinov3/dinov3_layers.py +1013 -0
- keras_hub/src/models/dinov3/dinov3_presets.py +4 -0
- keras_hub/src/models/gemma/gemma_presets.py +22 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +39 -0
- keras_hub/src/models/image_to_image.py +5 -0
- keras_hub/src/models/inpaint.py +5 -0
- keras_hub/src/models/mobilenetv5/__init__.py +9 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_presets.py +15 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
- keras_hub/src/models/parseq/__init__.py +5 -0
- keras_hub/src/models/parseq/parseq_presets.py +15 -0
- keras_hub/src/models/qwen3_moe/__init__.py +5 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
- keras_hub/src/models/siglip/siglip_presets.py +15 -0
- keras_hub/src/models/smollm3/smollm3_backbone.py +211 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm.py +310 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py +84 -0
- keras_hub/src/models/smollm3/smollm3_layers.py +757 -0
- keras_hub/src/models/smollm3/smollm3_tokenizer.py +60 -0
- keras_hub/src/models/smollm3/smollm3_utils.py +56 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
- keras_hub/src/models/text_to_image.py +5 -0
- keras_hub/src/utils/preset_utils.py +9 -2
- keras_hub/src/utils/tensor_utils.py +3 -1
- keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
- keras_hub/src/utils/timm/preset_loader.py +8 -4
- keras_hub/src/utils/transformers/convert_dinov3.py +106 -0
- keras_hub/src/utils/transformers/convert_smollm3.py +139 -0
- keras_hub/src/utils/transformers/preset_loader.py +6 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +6 -0
- {keras_hub_nightly-0.23.0.dev202510080414.dist-info → keras_hub_nightly-0.24.0.dev202511080419.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.23.0.dev202510080414.dist-info → keras_hub_nightly-0.24.0.dev202511080419.dist-info}/RECORD +52 -24
- {keras_hub_nightly-0.23.0.dev202510080414.dist-info → keras_hub_nightly-0.24.0.dev202511080419.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.23.0.dev202510080414.dist-info → keras_hub_nightly-0.24.0.dev202511080419.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
2
|
+
from keras_hub.src.models.smollm3.smollm3_backbone import SmolLM3Backbone
|
|
3
|
+
from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@keras_hub_export(
|
|
7
|
+
[
|
|
8
|
+
"keras_hub.tokenizers.SmolLM3Tokenizer",
|
|
9
|
+
"keras_hub.tokenizers.SmolLMTokenizer",
|
|
10
|
+
"keras_hub.models.SmolLM3Tokenizer",
|
|
11
|
+
"keras_hub.models.SmolLMTokenizer",
|
|
12
|
+
]
|
|
13
|
+
)
|
|
14
|
+
class SmolLM3Tokenizer(BytePairTokenizer):
|
|
15
|
+
"""Tokenizer for SmolLM3 models.
|
|
16
|
+
|
|
17
|
+
This tokenizer implements byte-pair encoding (BPE) for SmolLM3 models,
|
|
18
|
+
handling special tokens like BOS (beginning of sequence) and EOS (end of
|
|
19
|
+
sequence).
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
vocabulary: Dictionary mapping tokens to token IDs, or path to
|
|
23
|
+
vocabulary file.
|
|
24
|
+
merges: List of BPE merges, or path to merges file.
|
|
25
|
+
bos_token: Beginning of sequence token. Defaults to None.
|
|
26
|
+
eos_token: End of sequence token. Defaults to "<|endoftext|>".
|
|
27
|
+
misc_special_tokens: Set of additional special tokens. Defaults to
|
|
28
|
+
empty set.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
backbone_cls = SmolLM3Backbone
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
vocabulary=None,
|
|
36
|
+
merges=None,
|
|
37
|
+
**kwargs,
|
|
38
|
+
):
|
|
39
|
+
# Add EOS token
|
|
40
|
+
eos_token = "<|end_of_text|>"
|
|
41
|
+
self._add_special_token(eos_token, "end_token")
|
|
42
|
+
|
|
43
|
+
bos_token = "<|begin_of_text|>"
|
|
44
|
+
self._add_special_token(bos_token, "bos_token")
|
|
45
|
+
|
|
46
|
+
start_think_token = "<think>"
|
|
47
|
+
self._add_special_token(start_think_token, "start_think_token")
|
|
48
|
+
|
|
49
|
+
end_think_token = "</think>"
|
|
50
|
+
self._add_special_token(end_think_token, "end_think_token")
|
|
51
|
+
|
|
52
|
+
self.start_token_id = None
|
|
53
|
+
self.start_token = None
|
|
54
|
+
self.pad_token_id = 0
|
|
55
|
+
|
|
56
|
+
super().__init__(
|
|
57
|
+
vocabulary=vocabulary,
|
|
58
|
+
merges=merges,
|
|
59
|
+
**kwargs,
|
|
60
|
+
)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def rotate_half(x):
|
|
5
|
+
x1 = x[..., : ops.shape(x)[-1] // 2]
|
|
6
|
+
x2 = x[..., ops.shape(x)[-1] // 2 :]
|
|
7
|
+
return ops.concatenate((-x2, x1), axis=-1)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def apply_rotary_pos_emb(q, k, cos, sin, expansion_axis=1):
|
|
11
|
+
cos = ops.expand_dims(cos, expansion_axis)
|
|
12
|
+
sin = ops.expand_dims(sin, expansion_axis)
|
|
13
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
14
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
15
|
+
return q_embed, k_embed
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def apply_rotary_pos_single(tensor, cos, sin, expansion_axis=1):
|
|
19
|
+
cos = ops.expand_dims(cos, expansion_axis)
|
|
20
|
+
sin = ops.expand_dims(sin, expansion_axis)
|
|
21
|
+
tensor_embed = (tensor * cos) + (rotate_half(tensor) * sin)
|
|
22
|
+
return tensor_embed
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def repeat_kv(hidden_states, n_rep):
|
|
26
|
+
batch, num_key_value_heads, slen, head_dim = ops.shape(hidden_states)
|
|
27
|
+
if n_rep == 1:
|
|
28
|
+
return hidden_states
|
|
29
|
+
hidden_states = ops.expand_dims(hidden_states, axis=2)
|
|
30
|
+
target_shape = (batch, num_key_value_heads, n_rep, slen, head_dim)
|
|
31
|
+
hidden_states = ops.broadcast_to(hidden_states, target_shape)
|
|
32
|
+
return ops.reshape(
|
|
33
|
+
hidden_states, [batch, num_key_value_heads * n_rep, slen, head_dim]
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def rope_init(rope_theta, partial_rotary_factor, head_dim):
|
|
38
|
+
"""Initialize RoPE (Rotary Position Embedding) parameters.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
rope_theta: float. The theta value for RoPE.
|
|
42
|
+
partial_rotary_factor: float. The factor for partial rotary embedding.
|
|
43
|
+
head_dim: int. The dimension of each attention head.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
A tuple of (inv_freq, attention_scaling) where inv_freq is the inverse
|
|
47
|
+
frequency tensor and attention_scaling is the scaling factor.
|
|
48
|
+
"""
|
|
49
|
+
base = rope_theta
|
|
50
|
+
dim = int(head_dim * partial_rotary_factor)
|
|
51
|
+
|
|
52
|
+
inv_freq = 1.0 / (
|
|
53
|
+
ops.power(base, ops.arange(0, dim, 2, dtype="float32") / dim)
|
|
54
|
+
)
|
|
55
|
+
attention_scaling = 1.0
|
|
56
|
+
return inv_freq, attention_scaling
|
|
@@ -11,7 +11,7 @@ backbone_presets = {
|
|
|
11
11
|
"params": 2987080931,
|
|
12
12
|
"path": "stable_diffusion_3",
|
|
13
13
|
},
|
|
14
|
-
"kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/
|
|
14
|
+
"kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/5",
|
|
15
15
|
},
|
|
16
16
|
"stable_diffusion_3.5_medium": {
|
|
17
17
|
"metadata": {
|
|
@@ -35,7 +35,7 @@ backbone_presets = {
|
|
|
35
35
|
"params": 9048410595,
|
|
36
36
|
"path": "stable_diffusion_3",
|
|
37
37
|
},
|
|
38
|
-
"kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large/
|
|
38
|
+
"kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large/3",
|
|
39
39
|
},
|
|
40
40
|
"stable_diffusion_3.5_large_turbo": {
|
|
41
41
|
"metadata": {
|
|
@@ -49,6 +49,6 @@ backbone_presets = {
|
|
|
49
49
|
"params": 9048410595,
|
|
50
50
|
"path": "stable_diffusion_3",
|
|
51
51
|
},
|
|
52
|
-
"kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large_turbo/
|
|
52
|
+
"kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large_turbo/3",
|
|
53
53
|
},
|
|
54
54
|
}
|
|
@@ -345,3 +345,8 @@ class TextToImage(Task):
|
|
|
345
345
|
# Text-to-image.
|
|
346
346
|
outputs = [generate(x) for x in inputs]
|
|
347
347
|
return self._normalize_generate_outputs(outputs, input_is_scalar)
|
|
348
|
+
|
|
349
|
+
def _post_quantize(self, mode, **kwargs):
|
|
350
|
+
super()._post_quantize(mode, **kwargs)
|
|
351
|
+
# Reset the compiled generate function.
|
|
352
|
+
self.generate_function = None
|
|
@@ -502,10 +502,17 @@ def jax_memory_cleanup(layer):
|
|
|
502
502
|
# For jax, delete all previous allocated memory to avoid temporarily
|
|
503
503
|
# duplicating variable allocations. torch and tensorflow have stateful
|
|
504
504
|
# variable types and do not need this fix.
|
|
505
|
+
# Skip deletion for sharded arrays to avoid breaking references in
|
|
506
|
+
# distributed setups.
|
|
505
507
|
if keras.config.backend() == "jax":
|
|
506
508
|
for weight in layer.weights:
|
|
507
|
-
if
|
|
508
|
-
|
|
509
|
+
if weight._value is not None:
|
|
510
|
+
# Do not delete sharded arrays, as they may be referenced in
|
|
511
|
+
# JAX's distributed computation graph and deletion can cause
|
|
512
|
+
# errors.
|
|
513
|
+
sharding = getattr(weight._value, "sharding", None)
|
|
514
|
+
if sharding is None:
|
|
515
|
+
weight._value.delete()
|
|
509
516
|
|
|
510
517
|
|
|
511
518
|
def set_dtype_in_config(config, dtype=None):
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
import types
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_attention import (
|
|
7
|
+
MobileAttention,
|
|
8
|
+
)
|
|
9
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import (
|
|
10
|
+
MobileNetV5Backbone,
|
|
11
|
+
)
|
|
12
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_blocks import EdgeResidual
|
|
13
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_blocks import (
|
|
14
|
+
UniversalInvertedResidual,
|
|
15
|
+
)
|
|
16
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_builder import (
|
|
17
|
+
convert_arch_def_to_stackwise,
|
|
18
|
+
)
|
|
19
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import ConvNormAct
|
|
20
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import RmsNorm2d
|
|
21
|
+
|
|
22
|
+
backbone_cls = MobileNetV5Backbone
|
|
23
|
+
|
|
24
|
+
MODEL_CONFIGS = {
|
|
25
|
+
"mobilenetv5_300m": {
|
|
26
|
+
"backbone": convert_arch_def_to_stackwise(
|
|
27
|
+
[
|
|
28
|
+
# Stage 0: 128x128 in
|
|
29
|
+
[
|
|
30
|
+
"er_r1_k3_s2_e4_c128",
|
|
31
|
+
"er_r1_k3_s1_e4_c128",
|
|
32
|
+
"er_r1_k3_s1_e4_c128",
|
|
33
|
+
],
|
|
34
|
+
# Stage 1: 256x256 in
|
|
35
|
+
[
|
|
36
|
+
"uir_r1_a3_k5_s2_e6_c256",
|
|
37
|
+
"uir_r1_a5_k0_s1_e4_c256",
|
|
38
|
+
"uir_r1_a3_k0_s1_e4_c256",
|
|
39
|
+
"uir_r1_a5_k0_s1_e4_c256",
|
|
40
|
+
"uir_r1_a3_k0_s1_e4_c256",
|
|
41
|
+
],
|
|
42
|
+
# Stage 2: 640x640 in
|
|
43
|
+
[
|
|
44
|
+
"uir_r1_a5_k5_s2_e6_c640",
|
|
45
|
+
"uir_r1_a5_k0_s1_e4_c640",
|
|
46
|
+
"uir_r1_a5_k0_s1_e4_c640",
|
|
47
|
+
"uir_r1_a5_k0_s1_e4_c640",
|
|
48
|
+
"uir_r1_a5_k0_s1_e4_c640",
|
|
49
|
+
"uir_r1_a5_k0_s1_e4_c640",
|
|
50
|
+
"uir_r1_a5_k0_s1_e4_c640",
|
|
51
|
+
"uir_r1_a5_k0_s1_e4_c640",
|
|
52
|
+
"uir_r1_a0_k0_s1_e1_c640",
|
|
53
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
54
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
55
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
56
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
57
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
58
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
59
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
60
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
61
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
62
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
63
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
64
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
65
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
66
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
67
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
68
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
69
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
70
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
71
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
72
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
73
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
74
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
75
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
76
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
77
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
78
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
79
|
+
"mqa_r1_k3_h12_v2_s1_d64_c640",
|
|
80
|
+
"uir_r1_a0_k0_s1_e2_c640",
|
|
81
|
+
],
|
|
82
|
+
# Stage 3: 1280x1280 in
|
|
83
|
+
[
|
|
84
|
+
"uir_r1_a5_k5_s2_e6_c1280",
|
|
85
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
86
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
87
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
88
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
89
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
90
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
91
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
92
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
93
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
94
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
95
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
96
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
97
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
98
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
99
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
100
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
101
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
102
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
103
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
104
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
105
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
106
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
107
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
108
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
109
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
110
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
111
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
112
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
113
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
114
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
115
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
116
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
117
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
118
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
119
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
120
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
121
|
+
"mqa_r1_k3_h16_s1_d96_c1280",
|
|
122
|
+
"uir_r1_a0_k0_s1_e2_c1280",
|
|
123
|
+
],
|
|
124
|
+
]
|
|
125
|
+
),
|
|
126
|
+
"stem_size": 64,
|
|
127
|
+
"num_features": 2048,
|
|
128
|
+
"norm_layer": "rms_norm",
|
|
129
|
+
"act_layer": "gelu",
|
|
130
|
+
"use_msfa": True,
|
|
131
|
+
"layer_scale_init_value": 1e-5,
|
|
132
|
+
},
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def convert_head(task, loader, timm_config):
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def convert_backbone_config(timm_config):
|
|
141
|
+
timm_architecture = timm_config["architecture"]
|
|
142
|
+
if timm_architecture not in MODEL_CONFIGS:
|
|
143
|
+
raise ValueError(f"Unsupported architecture: {timm_architecture}")
|
|
144
|
+
config = MODEL_CONFIGS[timm_architecture].copy()
|
|
145
|
+
backbone_config = config.pop("backbone")
|
|
146
|
+
backbone_config.update(config)
|
|
147
|
+
return backbone_config
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def convert_weights(backbone, loader, timm_config):
|
|
151
|
+
def key_exists(key):
|
|
152
|
+
try:
|
|
153
|
+
loader.get_tensor(key)
|
|
154
|
+
return True
|
|
155
|
+
except Exception:
|
|
156
|
+
return False
|
|
157
|
+
|
|
158
|
+
def _port_weights(layer, timm_key, transpose_dims=None):
|
|
159
|
+
hf_weight_key = f"{timm_key}.weight"
|
|
160
|
+
if not key_exists(hf_weight_key):
|
|
161
|
+
return
|
|
162
|
+
hook_fn = None
|
|
163
|
+
if transpose_dims:
|
|
164
|
+
|
|
165
|
+
def transpose_hook(x, _):
|
|
166
|
+
return np.transpose(x, transpose_dims)
|
|
167
|
+
|
|
168
|
+
hook_fn = transpose_hook
|
|
169
|
+
loader.port_weight(
|
|
170
|
+
layer.kernel, hf_weight_key=hf_weight_key, hook_fn=hook_fn
|
|
171
|
+
)
|
|
172
|
+
if layer.bias is not None:
|
|
173
|
+
hf_bias_key = f"{timm_key}.bias"
|
|
174
|
+
if key_exists(hf_bias_key):
|
|
175
|
+
loader.port_weight(
|
|
176
|
+
layer.bias,
|
|
177
|
+
hf_weight_key=hf_bias_key,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def _port_bn(layer, timm_prefix):
|
|
181
|
+
loader.port_weight(layer.gamma, f"{timm_prefix}.weight")
|
|
182
|
+
loader.port_weight(layer.beta, f"{timm_prefix}.bias")
|
|
183
|
+
loader.port_weight(layer.moving_mean, f"{timm_prefix}.running_mean")
|
|
184
|
+
loader.port_weight(layer.moving_variance, f"{timm_prefix}.running_var")
|
|
185
|
+
|
|
186
|
+
def _port_rms_norm(layer, timm_prefix):
|
|
187
|
+
loader.port_weight(layer.gamma, f"{timm_prefix}.weight")
|
|
188
|
+
|
|
189
|
+
def _port_cna(cna_layer: ConvNormAct, timm_conv_prefix, timm_norm_prefix):
|
|
190
|
+
if isinstance(cna_layer.conv, keras.layers.DepthwiseConv2D):
|
|
191
|
+
_port_weights(
|
|
192
|
+
cna_layer.conv,
|
|
193
|
+
timm_conv_prefix,
|
|
194
|
+
transpose_dims=(2, 3, 0, 1),
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
_port_weights(
|
|
198
|
+
cna_layer.conv,
|
|
199
|
+
timm_conv_prefix,
|
|
200
|
+
transpose_dims=(2, 3, 1, 0),
|
|
201
|
+
)
|
|
202
|
+
if key_exists(f"{timm_norm_prefix}.running_mean"):
|
|
203
|
+
_port_bn(cna_layer.norm, timm_norm_prefix)
|
|
204
|
+
else:
|
|
205
|
+
_port_rms_norm(cna_layer.norm, timm_norm_prefix)
|
|
206
|
+
|
|
207
|
+
def _port_attn(attn_layer, attn_prefix):
|
|
208
|
+
_port_weights(
|
|
209
|
+
attn_layer.query_layers[-1],
|
|
210
|
+
f"{attn_prefix}.query.proj",
|
|
211
|
+
(2, 3, 1, 0),
|
|
212
|
+
)
|
|
213
|
+
if len(attn_layer.key_layers) > 1:
|
|
214
|
+
_port_weights(
|
|
215
|
+
attn_layer.key_layers[0],
|
|
216
|
+
f"{attn_prefix}.key.down_conv",
|
|
217
|
+
(2, 3, 0, 1),
|
|
218
|
+
)
|
|
219
|
+
key_norm_layer = attn_layer.key_layers[1]
|
|
220
|
+
if isinstance(key_norm_layer, RmsNorm2d):
|
|
221
|
+
_port_rms_norm(key_norm_layer, f"{attn_prefix}.key.norm")
|
|
222
|
+
else:
|
|
223
|
+
_port_bn(key_norm_layer, f"{attn_prefix}.key.norm")
|
|
224
|
+
_port_weights(
|
|
225
|
+
attn_layer.key_layers[-1], f"{attn_prefix}.key.proj", (2, 3, 1, 0)
|
|
226
|
+
)
|
|
227
|
+
if len(attn_layer.value_layers) > 1:
|
|
228
|
+
_port_weights(
|
|
229
|
+
attn_layer.value_layers[0],
|
|
230
|
+
f"{attn_prefix}.value.down_conv",
|
|
231
|
+
(2, 3, 0, 1),
|
|
232
|
+
)
|
|
233
|
+
value_norm_layer = attn_layer.value_layers[1]
|
|
234
|
+
if isinstance(value_norm_layer, RmsNorm2d):
|
|
235
|
+
_port_rms_norm(value_norm_layer, f"{attn_prefix}.value.norm")
|
|
236
|
+
else:
|
|
237
|
+
_port_bn(value_norm_layer, f"{attn_prefix}.value.norm")
|
|
238
|
+
_port_weights(
|
|
239
|
+
attn_layer.value_layers[-1],
|
|
240
|
+
f"{attn_prefix}.value.proj",
|
|
241
|
+
(2, 3, 1, 0),
|
|
242
|
+
)
|
|
243
|
+
_port_weights(
|
|
244
|
+
attn_layer.output_proj_layers[-2],
|
|
245
|
+
f"{attn_prefix}.output.proj",
|
|
246
|
+
(2, 3, 1, 0),
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
stem_layer = backbone.get_layer("conv_stem")
|
|
250
|
+
_port_cna(stem_layer, "conv_stem.conv", "conv_stem.bn")
|
|
251
|
+
block_layers = [
|
|
252
|
+
layer
|
|
253
|
+
for layer in backbone.layers
|
|
254
|
+
if isinstance(
|
|
255
|
+
layer, (EdgeResidual, UniversalInvertedResidual, MobileAttention)
|
|
256
|
+
)
|
|
257
|
+
]
|
|
258
|
+
block_counter = 0
|
|
259
|
+
for stack_idx in range(len(backbone.stackwise_num_blocks)):
|
|
260
|
+
for block_idx_in_stage in range(
|
|
261
|
+
backbone.stackwise_num_blocks[stack_idx]
|
|
262
|
+
):
|
|
263
|
+
block = block_layers[block_counter]
|
|
264
|
+
timm_prefix = f"blocks.{stack_idx}.{block_idx_in_stage}"
|
|
265
|
+
if isinstance(block, EdgeResidual):
|
|
266
|
+
_port_cna(
|
|
267
|
+
block.conv_exp,
|
|
268
|
+
f"{timm_prefix}.conv_exp",
|
|
269
|
+
f"{timm_prefix}.bn1",
|
|
270
|
+
)
|
|
271
|
+
_port_cna(
|
|
272
|
+
block.conv_pwl,
|
|
273
|
+
f"{timm_prefix}.conv_pwl",
|
|
274
|
+
f"{timm_prefix}.bn2",
|
|
275
|
+
)
|
|
276
|
+
elif isinstance(block, UniversalInvertedResidual):
|
|
277
|
+
if hasattr(block, "dw_start") and not isinstance(
|
|
278
|
+
block.dw_start, types.FunctionType
|
|
279
|
+
):
|
|
280
|
+
_port_cna(
|
|
281
|
+
block.dw_start,
|
|
282
|
+
f"{timm_prefix}.dw_start.conv",
|
|
283
|
+
f"{timm_prefix}.dw_start.bn",
|
|
284
|
+
)
|
|
285
|
+
_port_cna(
|
|
286
|
+
block.pw_exp,
|
|
287
|
+
f"{timm_prefix}.pw_exp.conv",
|
|
288
|
+
f"{timm_prefix}.pw_exp.bn",
|
|
289
|
+
)
|
|
290
|
+
if hasattr(block, "dw_mid") and not isinstance(
|
|
291
|
+
block.dw_mid, types.FunctionType
|
|
292
|
+
):
|
|
293
|
+
_port_cna(
|
|
294
|
+
block.dw_mid,
|
|
295
|
+
f"{timm_prefix}.dw_mid.conv",
|
|
296
|
+
f"{timm_prefix}.dw_mid.bn",
|
|
297
|
+
)
|
|
298
|
+
_port_cna(
|
|
299
|
+
block.pw_proj,
|
|
300
|
+
f"{timm_prefix}.pw_proj.conv",
|
|
301
|
+
f"{timm_prefix}.pw_proj.bn",
|
|
302
|
+
)
|
|
303
|
+
gamma_key = f"{timm_prefix}.layer_scale.gamma"
|
|
304
|
+
if key_exists(gamma_key):
|
|
305
|
+
loader.port_weight(block.layer_scale.gamma, gamma_key)
|
|
306
|
+
elif isinstance(block, MobileAttention):
|
|
307
|
+
_port_rms_norm(block.norm, f"{timm_prefix}.norm")
|
|
308
|
+
gamma_key = f"{timm_prefix}.layer_scale.gamma"
|
|
309
|
+
if key_exists(gamma_key):
|
|
310
|
+
loader.port_weight(block.layer_scale.gamma, gamma_key)
|
|
311
|
+
attn_prefix = f"{timm_prefix}.attn"
|
|
312
|
+
_port_attn(block.attn, attn_prefix)
|
|
313
|
+
block_counter += 1
|
|
314
|
+
try:
|
|
315
|
+
msfa_layer = backbone.get_layer("msfa")
|
|
316
|
+
ffn = msfa_layer.ffn
|
|
317
|
+
_port_cna(ffn.pw_exp, "msfa.ffn.pw_exp.conv", "msfa.ffn.pw_exp.bn")
|
|
318
|
+
_port_cna(ffn.pw_proj, "msfa.ffn.pw_proj.conv", "msfa.ffn.pw_proj.bn")
|
|
319
|
+
_port_rms_norm(msfa_layer.norm, "msfa.norm")
|
|
320
|
+
except ValueError:
|
|
321
|
+
pass
|
|
@@ -7,6 +7,7 @@ from keras_hub.src.utils.timm import convert_cspnet
|
|
|
7
7
|
from keras_hub.src.utils.timm import convert_densenet
|
|
8
8
|
from keras_hub.src.utils.timm import convert_efficientnet
|
|
9
9
|
from keras_hub.src.utils.timm import convert_mobilenet
|
|
10
|
+
from keras_hub.src.utils.timm import convert_mobilenetv5
|
|
10
11
|
from keras_hub.src.utils.timm import convert_resnet
|
|
11
12
|
from keras_hub.src.utils.timm import convert_vgg
|
|
12
13
|
from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
|
|
@@ -22,6 +23,8 @@ class TimmPresetLoader(PresetLoader):
|
|
|
22
23
|
self.converter = convert_cspnet
|
|
23
24
|
elif architecture.startswith("densenet"):
|
|
24
25
|
self.converter = convert_densenet
|
|
26
|
+
elif architecture.startswith("mobilenetv5"):
|
|
27
|
+
self.converter = convert_mobilenetv5
|
|
25
28
|
elif architecture.startswith("mobilenet"):
|
|
26
29
|
self.converter = convert_mobilenet
|
|
27
30
|
elif architecture.startswith("vgg"):
|
|
@@ -41,7 +44,8 @@ class TimmPresetLoader(PresetLoader):
|
|
|
41
44
|
keras_config = self.converter.convert_backbone_config(self.config)
|
|
42
45
|
backbone = cls(**{**keras_config, **kwargs})
|
|
43
46
|
if load_weights:
|
|
44
|
-
|
|
47
|
+
if not self.config["architecture"].startswith("mobilenetv5"):
|
|
48
|
+
jax_memory_cleanup(backbone)
|
|
45
49
|
# Use prefix="" to avoid using `get_prefixed_key`.
|
|
46
50
|
with SafetensorLoader(self.preset, prefix="") as loader:
|
|
47
51
|
self.converter.convert_weights(backbone, loader, self.config)
|
|
@@ -54,9 +58,9 @@ class TimmPresetLoader(PresetLoader):
|
|
|
54
58
|
)
|
|
55
59
|
# Support loading the classification head for classifier models.
|
|
56
60
|
kwargs["num_classes"] = self.config["num_classes"]
|
|
57
|
-
if (
|
|
58
|
-
"
|
|
59
|
-
|
|
61
|
+
if "num_features" in self.config and (
|
|
62
|
+
"mobilenet" in self.config["architecture"]
|
|
63
|
+
or "mobilenetv5" in self.config["architecture"]
|
|
60
64
|
):
|
|
61
65
|
kwargs["num_features"] = self.config["num_features"]
|
|
62
66
|
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from keras_hub.src.models.dinov3.dinov3_backbone import DINOV3Backbone
|
|
4
|
+
|
|
5
|
+
backbone_cls = DINOV3Backbone
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def convert_backbone_config(transformers_config):
|
|
9
|
+
image_size = transformers_config["image_size"]
|
|
10
|
+
return {
|
|
11
|
+
"patch_size": transformers_config["patch_size"],
|
|
12
|
+
"num_layers": transformers_config["num_hidden_layers"],
|
|
13
|
+
"hidden_dim": transformers_config["hidden_size"],
|
|
14
|
+
"num_heads": transformers_config["num_attention_heads"],
|
|
15
|
+
"intermediate_dim": transformers_config["intermediate_size"],
|
|
16
|
+
"layer_scale_init_value": transformers_config["layerscale_value"],
|
|
17
|
+
"num_register_tokens": transformers_config["num_register_tokens"],
|
|
18
|
+
"use_mask_token": True,
|
|
19
|
+
"hidden_activation": transformers_config["hidden_act"],
|
|
20
|
+
"use_gated_mlp": transformers_config["use_gated_mlp"],
|
|
21
|
+
"use_query_bias": transformers_config["query_bias"],
|
|
22
|
+
"use_key_bias": transformers_config["key_bias"],
|
|
23
|
+
"use_value_bias": transformers_config["value_bias"],
|
|
24
|
+
"use_proj_bias": transformers_config["proj_bias"],
|
|
25
|
+
"use_mlp_bias": transformers_config["mlp_bias"],
|
|
26
|
+
"attention_dropout": transformers_config["attention_dropout"],
|
|
27
|
+
"drop_path_rate": transformers_config["drop_path_rate"],
|
|
28
|
+
"layer_norm_eps": transformers_config["layer_norm_eps"],
|
|
29
|
+
"image_shape": (image_size, image_size, 3),
|
|
30
|
+
"rope_theta": transformers_config["rope_theta"],
|
|
31
|
+
"apply_layernorm": False,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def convert_weights(backbone, loader, transformers_config):
|
|
36
|
+
if not isinstance(backbone, DINOV3Backbone):
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"The provided backbone must be an instance of DINOV3Backbone. "
|
|
39
|
+
f"Received: {type(backbone)}"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def port_ln(keras_variable, weight_key):
|
|
43
|
+
loader.port_weight(keras_variable.gamma, f"{weight_key}.weight")
|
|
44
|
+
loader.port_weight(keras_variable.beta, f"{weight_key}.bias")
|
|
45
|
+
|
|
46
|
+
def port_dense(keras_variable, weight_key):
|
|
47
|
+
loader.port_weight(
|
|
48
|
+
keras_variable.kernel,
|
|
49
|
+
f"{weight_key}.weight",
|
|
50
|
+
hook_fn=lambda x, _: x.T,
|
|
51
|
+
)
|
|
52
|
+
if keras_variable.bias is not None:
|
|
53
|
+
loader.port_weight(keras_variable.bias, f"{weight_key}.bias")
|
|
54
|
+
|
|
55
|
+
# Embedding.
|
|
56
|
+
loader.port_weight(
|
|
57
|
+
keras_variable=backbone.embeddings.cls_token,
|
|
58
|
+
hf_weight_key="embeddings.cls_token",
|
|
59
|
+
)
|
|
60
|
+
if backbone.use_mask_token:
|
|
61
|
+
loader.port_weight(
|
|
62
|
+
keras_variable=backbone.embeddings.mask_token,
|
|
63
|
+
hf_weight_key="embeddings.mask_token",
|
|
64
|
+
)
|
|
65
|
+
if backbone.num_register_tokens > 0:
|
|
66
|
+
loader.port_weight(
|
|
67
|
+
keras_variable=backbone.embeddings.register_tokens,
|
|
68
|
+
hf_weight_key="embeddings.register_tokens",
|
|
69
|
+
)
|
|
70
|
+
loader.port_weight(
|
|
71
|
+
keras_variable=backbone.embeddings.patch_embeddings.projection.kernel,
|
|
72
|
+
hf_weight_key="embeddings.patch_embeddings.weight",
|
|
73
|
+
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
|
|
74
|
+
)
|
|
75
|
+
loader.port_weight(
|
|
76
|
+
keras_variable=backbone.embeddings.patch_embeddings.projection.bias,
|
|
77
|
+
hf_weight_key="embeddings.patch_embeddings.bias",
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Encoder.
|
|
81
|
+
for i, layer in enumerate(backbone.encoder.layers):
|
|
82
|
+
prefix = f"layer.{i}"
|
|
83
|
+
port_ln(layer.norm1, f"{prefix}.norm1")
|
|
84
|
+
port_dense(layer.attention.query_dense, f"{prefix}.attention.q_proj")
|
|
85
|
+
port_dense(layer.attention.key_dense, f"{prefix}.attention.k_proj")
|
|
86
|
+
port_dense(layer.attention.value_dense, f"{prefix}.attention.v_proj")
|
|
87
|
+
port_dense(layer.attention.output_dense, f"{prefix}.attention.o_proj")
|
|
88
|
+
|
|
89
|
+
loader.port_weight(
|
|
90
|
+
keras_variable=layer.layer_scale1.lambda1,
|
|
91
|
+
hf_weight_key=f"{prefix}.layer_scale1.lambda1",
|
|
92
|
+
)
|
|
93
|
+
port_ln(layer.norm2, f"{prefix}.norm2")
|
|
94
|
+
if backbone.use_gated_mlp:
|
|
95
|
+
port_dense(layer.mlp.gate_proj, f"{prefix}.mlp.gate_proj")
|
|
96
|
+
port_dense(layer.mlp.up_proj, f"{prefix}.mlp.up_proj")
|
|
97
|
+
port_dense(layer.mlp.down_proj, f"{prefix}.mlp.down_proj")
|
|
98
|
+
else:
|
|
99
|
+
port_dense(layer.mlp.up_proj, f"{prefix}.mlp.up_proj")
|
|
100
|
+
port_dense(layer.mlp.down_proj, f"{prefix}.mlp.down_proj")
|
|
101
|
+
loader.port_weight(
|
|
102
|
+
keras_variable=layer.layer_scale2.lambda1,
|
|
103
|
+
hf_weight_key=f"{prefix}.layer_scale2.lambda1",
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
port_ln(backbone.layernorm, "norm")
|