keras-hub-nightly 0.23.0.dev202510080414__py3-none-any.whl → 0.24.0.dev202511080419__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. keras_hub/layers/__init__.py +6 -0
  2. keras_hub/models/__init__.py +36 -0
  3. keras_hub/src/layers/modeling/reversible_embedding.py +6 -0
  4. keras_hub/src/models/causal_lm.py +5 -0
  5. keras_hub/src/models/depth_anything/depth_anything_presets.py +38 -1
  6. keras_hub/src/models/dinov2/dinov2_layers.py +3 -1
  7. keras_hub/src/models/dinov3/__init__.py +5 -0
  8. keras_hub/src/models/dinov3/dinov3_backbone.py +263 -0
  9. keras_hub/src/models/dinov3/dinov3_image_converter.py +8 -0
  10. keras_hub/src/models/dinov3/dinov3_layers.py +1013 -0
  11. keras_hub/src/models/dinov3/dinov3_presets.py +4 -0
  12. keras_hub/src/models/gemma/gemma_presets.py +22 -0
  13. keras_hub/src/models/gemma3/gemma3_presets.py +39 -0
  14. keras_hub/src/models/image_to_image.py +5 -0
  15. keras_hub/src/models/inpaint.py +5 -0
  16. keras_hub/src/models/mobilenetv5/__init__.py +9 -0
  17. keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
  18. keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
  19. keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
  20. keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
  21. keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
  22. keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
  23. keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
  24. keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
  25. keras_hub/src/models/mobilenetv5/mobilenetv5_presets.py +15 -0
  26. keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
  27. keras_hub/src/models/parseq/__init__.py +5 -0
  28. keras_hub/src/models/parseq/parseq_presets.py +15 -0
  29. keras_hub/src/models/qwen3_moe/__init__.py +5 -0
  30. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
  31. keras_hub/src/models/siglip/siglip_presets.py +15 -0
  32. keras_hub/src/models/smollm3/smollm3_backbone.py +211 -0
  33. keras_hub/src/models/smollm3/smollm3_causal_lm.py +310 -0
  34. keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py +84 -0
  35. keras_hub/src/models/smollm3/smollm3_layers.py +757 -0
  36. keras_hub/src/models/smollm3/smollm3_tokenizer.py +60 -0
  37. keras_hub/src/models/smollm3/smollm3_utils.py +56 -0
  38. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
  39. keras_hub/src/models/text_to_image.py +5 -0
  40. keras_hub/src/utils/preset_utils.py +9 -2
  41. keras_hub/src/utils/tensor_utils.py +3 -1
  42. keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
  43. keras_hub/src/utils/timm/preset_loader.py +8 -4
  44. keras_hub/src/utils/transformers/convert_dinov3.py +106 -0
  45. keras_hub/src/utils/transformers/convert_smollm3.py +139 -0
  46. keras_hub/src/utils/transformers/preset_loader.py +6 -0
  47. keras_hub/src/version.py +1 -1
  48. keras_hub/tokenizers/__init__.py +6 -0
  49. {keras_hub_nightly-0.23.0.dev202510080414.dist-info → keras_hub_nightly-0.24.0.dev202511080419.dist-info}/METADATA +1 -1
  50. {keras_hub_nightly-0.23.0.dev202510080414.dist-info → keras_hub_nightly-0.24.0.dev202511080419.dist-info}/RECORD +52 -24
  51. {keras_hub_nightly-0.23.0.dev202510080414.dist-info → keras_hub_nightly-0.24.0.dev202511080419.dist-info}/WHEEL +0 -0
  52. {keras_hub_nightly-0.23.0.dev202510080414.dist-info → keras_hub_nightly-0.24.0.dev202511080419.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,60 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.smollm3.smollm3_backbone import SmolLM3Backbone
3
+ from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
4
+
5
+
6
+ @keras_hub_export(
7
+ [
8
+ "keras_hub.tokenizers.SmolLM3Tokenizer",
9
+ "keras_hub.tokenizers.SmolLMTokenizer",
10
+ "keras_hub.models.SmolLM3Tokenizer",
11
+ "keras_hub.models.SmolLMTokenizer",
12
+ ]
13
+ )
14
+ class SmolLM3Tokenizer(BytePairTokenizer):
15
+ """Tokenizer for SmolLM3 models.
16
+
17
+ This tokenizer implements byte-pair encoding (BPE) for SmolLM3 models,
18
+ handling special tokens like BOS (beginning of sequence) and EOS (end of
19
+ sequence).
20
+
21
+ Args:
22
+ vocabulary: Dictionary mapping tokens to token IDs, or path to
23
+ vocabulary file.
24
+ merges: List of BPE merges, or path to merges file.
25
+ bos_token: Beginning of sequence token. Defaults to None.
26
+ eos_token: End of sequence token. Defaults to "<|endoftext|>".
27
+ misc_special_tokens: Set of additional special tokens. Defaults to
28
+ empty set.
29
+ """
30
+
31
+ backbone_cls = SmolLM3Backbone
32
+
33
+ def __init__(
34
+ self,
35
+ vocabulary=None,
36
+ merges=None,
37
+ **kwargs,
38
+ ):
39
+ # Add EOS token
40
+ eos_token = "<|end_of_text|>"
41
+ self._add_special_token(eos_token, "end_token")
42
+
43
+ bos_token = "<|begin_of_text|>"
44
+ self._add_special_token(bos_token, "bos_token")
45
+
46
+ start_think_token = "<think>"
47
+ self._add_special_token(start_think_token, "start_think_token")
48
+
49
+ end_think_token = "</think>"
50
+ self._add_special_token(end_think_token, "end_think_token")
51
+
52
+ self.start_token_id = None
53
+ self.start_token = None
54
+ self.pad_token_id = 0
55
+
56
+ super().__init__(
57
+ vocabulary=vocabulary,
58
+ merges=merges,
59
+ **kwargs,
60
+ )
@@ -0,0 +1,56 @@
1
+ from keras import ops
2
+
3
+
4
+ def rotate_half(x):
5
+ x1 = x[..., : ops.shape(x)[-1] // 2]
6
+ x2 = x[..., ops.shape(x)[-1] // 2 :]
7
+ return ops.concatenate((-x2, x1), axis=-1)
8
+
9
+
10
+ def apply_rotary_pos_emb(q, k, cos, sin, expansion_axis=1):
11
+ cos = ops.expand_dims(cos, expansion_axis)
12
+ sin = ops.expand_dims(sin, expansion_axis)
13
+ q_embed = (q * cos) + (rotate_half(q) * sin)
14
+ k_embed = (k * cos) + (rotate_half(k) * sin)
15
+ return q_embed, k_embed
16
+
17
+
18
+ def apply_rotary_pos_single(tensor, cos, sin, expansion_axis=1):
19
+ cos = ops.expand_dims(cos, expansion_axis)
20
+ sin = ops.expand_dims(sin, expansion_axis)
21
+ tensor_embed = (tensor * cos) + (rotate_half(tensor) * sin)
22
+ return tensor_embed
23
+
24
+
25
+ def repeat_kv(hidden_states, n_rep):
26
+ batch, num_key_value_heads, slen, head_dim = ops.shape(hidden_states)
27
+ if n_rep == 1:
28
+ return hidden_states
29
+ hidden_states = ops.expand_dims(hidden_states, axis=2)
30
+ target_shape = (batch, num_key_value_heads, n_rep, slen, head_dim)
31
+ hidden_states = ops.broadcast_to(hidden_states, target_shape)
32
+ return ops.reshape(
33
+ hidden_states, [batch, num_key_value_heads * n_rep, slen, head_dim]
34
+ )
35
+
36
+
37
+ def rope_init(rope_theta, partial_rotary_factor, head_dim):
38
+ """Initialize RoPE (Rotary Position Embedding) parameters.
39
+
40
+ Args:
41
+ rope_theta: float. The theta value for RoPE.
42
+ partial_rotary_factor: float. The factor for partial rotary embedding.
43
+ head_dim: int. The dimension of each attention head.
44
+
45
+ Returns:
46
+ A tuple of (inv_freq, attention_scaling) where inv_freq is the inverse
47
+ frequency tensor and attention_scaling is the scaling factor.
48
+ """
49
+ base = rope_theta
50
+ dim = int(head_dim * partial_rotary_factor)
51
+
52
+ inv_freq = 1.0 / (
53
+ ops.power(base, ops.arange(0, dim, 2, dtype="float32") / dim)
54
+ )
55
+ attention_scaling = 1.0
56
+ return inv_freq, attention_scaling
@@ -11,7 +11,7 @@ backbone_presets = {
11
11
  "params": 2987080931,
12
12
  "path": "stable_diffusion_3",
13
13
  },
14
- "kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/4",
14
+ "kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/5",
15
15
  },
16
16
  "stable_diffusion_3.5_medium": {
17
17
  "metadata": {
@@ -35,7 +35,7 @@ backbone_presets = {
35
35
  "params": 9048410595,
36
36
  "path": "stable_diffusion_3",
37
37
  },
38
- "kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large/2",
38
+ "kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large/3",
39
39
  },
40
40
  "stable_diffusion_3.5_large_turbo": {
41
41
  "metadata": {
@@ -49,6 +49,6 @@ backbone_presets = {
49
49
  "params": 9048410595,
50
50
  "path": "stable_diffusion_3",
51
51
  },
52
- "kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large_turbo/2",
52
+ "kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large_turbo/3",
53
53
  },
54
54
  }
@@ -345,3 +345,8 @@ class TextToImage(Task):
345
345
  # Text-to-image.
346
346
  outputs = [generate(x) for x in inputs]
347
347
  return self._normalize_generate_outputs(outputs, input_is_scalar)
348
+
349
+ def _post_quantize(self, mode, **kwargs):
350
+ super()._post_quantize(mode, **kwargs)
351
+ # Reset the compiled generate function.
352
+ self.generate_function = None
@@ -502,10 +502,17 @@ def jax_memory_cleanup(layer):
502
502
  # For jax, delete all previous allocated memory to avoid temporarily
503
503
  # duplicating variable allocations. torch and tensorflow have stateful
504
504
  # variable types and do not need this fix.
505
+ # Skip deletion for sharded arrays to avoid breaking references in
506
+ # distributed setups.
505
507
  if keras.config.backend() == "jax":
506
508
  for weight in layer.weights:
507
- if getattr(weight, "_value", None) is not None:
508
- weight._value.delete()
509
+ if weight._value is not None:
510
+ # Do not delete sharded arrays, as they may be referenced in
511
+ # JAX's distributed computation graph and deletion can cause
512
+ # errors.
513
+ sharding = getattr(weight._value, "sharding", None)
514
+ if sharding is None:
515
+ weight._value.delete()
509
516
 
510
517
 
511
518
  def set_dtype_in_config(config, dtype=None):
@@ -12,9 +12,11 @@ from packaging import version
12
12
 
13
13
  try:
14
14
  import tensorflow as tf
15
- import tensorflow_text as tf_text
16
15
  except ImportError:
17
16
  tf = None
17
+ try:
18
+ import tensorflow_text as tf_text
19
+ except ImportError:
18
20
  tf_text = None
19
21
 
20
22
 
@@ -0,0 +1,321 @@
1
+ import types
2
+
3
+ import keras
4
+ import numpy as np
5
+
6
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_attention import (
7
+ MobileAttention,
8
+ )
9
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import (
10
+ MobileNetV5Backbone,
11
+ )
12
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_blocks import EdgeResidual
13
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_blocks import (
14
+ UniversalInvertedResidual,
15
+ )
16
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_builder import (
17
+ convert_arch_def_to_stackwise,
18
+ )
19
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import ConvNormAct
20
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import RmsNorm2d
21
+
22
+ backbone_cls = MobileNetV5Backbone
23
+
24
+ MODEL_CONFIGS = {
25
+ "mobilenetv5_300m": {
26
+ "backbone": convert_arch_def_to_stackwise(
27
+ [
28
+ # Stage 0: 128x128 in
29
+ [
30
+ "er_r1_k3_s2_e4_c128",
31
+ "er_r1_k3_s1_e4_c128",
32
+ "er_r1_k3_s1_e4_c128",
33
+ ],
34
+ # Stage 1: 256x256 in
35
+ [
36
+ "uir_r1_a3_k5_s2_e6_c256",
37
+ "uir_r1_a5_k0_s1_e4_c256",
38
+ "uir_r1_a3_k0_s1_e4_c256",
39
+ "uir_r1_a5_k0_s1_e4_c256",
40
+ "uir_r1_a3_k0_s1_e4_c256",
41
+ ],
42
+ # Stage 2: 640x640 in
43
+ [
44
+ "uir_r1_a5_k5_s2_e6_c640",
45
+ "uir_r1_a5_k0_s1_e4_c640",
46
+ "uir_r1_a5_k0_s1_e4_c640",
47
+ "uir_r1_a5_k0_s1_e4_c640",
48
+ "uir_r1_a5_k0_s1_e4_c640",
49
+ "uir_r1_a5_k0_s1_e4_c640",
50
+ "uir_r1_a5_k0_s1_e4_c640",
51
+ "uir_r1_a5_k0_s1_e4_c640",
52
+ "uir_r1_a0_k0_s1_e1_c640",
53
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
54
+ "uir_r1_a0_k0_s1_e2_c640",
55
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
56
+ "uir_r1_a0_k0_s1_e2_c640",
57
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
58
+ "uir_r1_a0_k0_s1_e2_c640",
59
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
60
+ "uir_r1_a0_k0_s1_e2_c640",
61
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
62
+ "uir_r1_a0_k0_s1_e2_c640",
63
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
64
+ "uir_r1_a0_k0_s1_e2_c640",
65
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
66
+ "uir_r1_a0_k0_s1_e2_c640",
67
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
68
+ "uir_r1_a0_k0_s1_e2_c640",
69
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
70
+ "uir_r1_a0_k0_s1_e2_c640",
71
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
72
+ "uir_r1_a0_k0_s1_e2_c640",
73
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
74
+ "uir_r1_a0_k0_s1_e2_c640",
75
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
76
+ "uir_r1_a0_k0_s1_e2_c640",
77
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
78
+ "uir_r1_a0_k0_s1_e2_c640",
79
+ "mqa_r1_k3_h12_v2_s1_d64_c640",
80
+ "uir_r1_a0_k0_s1_e2_c640",
81
+ ],
82
+ # Stage 3: 1280x1280 in
83
+ [
84
+ "uir_r1_a5_k5_s2_e6_c1280",
85
+ "mqa_r1_k3_h16_s1_d96_c1280",
86
+ "uir_r1_a0_k0_s1_e2_c1280",
87
+ "mqa_r1_k3_h16_s1_d96_c1280",
88
+ "uir_r1_a0_k0_s1_e2_c1280",
89
+ "mqa_r1_k3_h16_s1_d96_c1280",
90
+ "uir_r1_a0_k0_s1_e2_c1280",
91
+ "mqa_r1_k3_h16_s1_d96_c1280",
92
+ "uir_r1_a0_k0_s1_e2_c1280",
93
+ "mqa_r1_k3_h16_s1_d96_c1280",
94
+ "uir_r1_a0_k0_s1_e2_c1280",
95
+ "mqa_r1_k3_h16_s1_d96_c1280",
96
+ "uir_r1_a0_k0_s1_e2_c1280",
97
+ "mqa_r1_k3_h16_s1_d96_c1280",
98
+ "uir_r1_a0_k0_s1_e2_c1280",
99
+ "mqa_r1_k3_h16_s1_d96_c1280",
100
+ "uir_r1_a0_k0_s1_e2_c1280",
101
+ "mqa_r1_k3_h16_s1_d96_c1280",
102
+ "uir_r1_a0_k0_s1_e2_c1280",
103
+ "mqa_r1_k3_h16_s1_d96_c1280",
104
+ "uir_r1_a0_k0_s1_e2_c1280",
105
+ "mqa_r1_k3_h16_s1_d96_c1280",
106
+ "uir_r1_a0_k0_s1_e2_c1280",
107
+ "mqa_r1_k3_h16_s1_d96_c1280",
108
+ "uir_r1_a0_k0_s1_e2_c1280",
109
+ "mqa_r1_k3_h16_s1_d96_c1280",
110
+ "uir_r1_a0_k0_s1_e2_c1280",
111
+ "mqa_r1_k3_h16_s1_d96_c1280",
112
+ "uir_r1_a0_k0_s1_e2_c1280",
113
+ "mqa_r1_k3_h16_s1_d96_c1280",
114
+ "uir_r1_a0_k0_s1_e2_c1280",
115
+ "mqa_r1_k3_h16_s1_d96_c1280",
116
+ "uir_r1_a0_k0_s1_e2_c1280",
117
+ "mqa_r1_k3_h16_s1_d96_c1280",
118
+ "uir_r1_a0_k0_s1_e2_c1280",
119
+ "mqa_r1_k3_h16_s1_d96_c1280",
120
+ "uir_r1_a0_k0_s1_e2_c1280",
121
+ "mqa_r1_k3_h16_s1_d96_c1280",
122
+ "uir_r1_a0_k0_s1_e2_c1280",
123
+ ],
124
+ ]
125
+ ),
126
+ "stem_size": 64,
127
+ "num_features": 2048,
128
+ "norm_layer": "rms_norm",
129
+ "act_layer": "gelu",
130
+ "use_msfa": True,
131
+ "layer_scale_init_value": 1e-5,
132
+ },
133
+ }
134
+
135
+
136
+ def convert_head(task, loader, timm_config):
137
+ pass
138
+
139
+
140
+ def convert_backbone_config(timm_config):
141
+ timm_architecture = timm_config["architecture"]
142
+ if timm_architecture not in MODEL_CONFIGS:
143
+ raise ValueError(f"Unsupported architecture: {timm_architecture}")
144
+ config = MODEL_CONFIGS[timm_architecture].copy()
145
+ backbone_config = config.pop("backbone")
146
+ backbone_config.update(config)
147
+ return backbone_config
148
+
149
+
150
+ def convert_weights(backbone, loader, timm_config):
151
+ def key_exists(key):
152
+ try:
153
+ loader.get_tensor(key)
154
+ return True
155
+ except Exception:
156
+ return False
157
+
158
+ def _port_weights(layer, timm_key, transpose_dims=None):
159
+ hf_weight_key = f"{timm_key}.weight"
160
+ if not key_exists(hf_weight_key):
161
+ return
162
+ hook_fn = None
163
+ if transpose_dims:
164
+
165
+ def transpose_hook(x, _):
166
+ return np.transpose(x, transpose_dims)
167
+
168
+ hook_fn = transpose_hook
169
+ loader.port_weight(
170
+ layer.kernel, hf_weight_key=hf_weight_key, hook_fn=hook_fn
171
+ )
172
+ if layer.bias is not None:
173
+ hf_bias_key = f"{timm_key}.bias"
174
+ if key_exists(hf_bias_key):
175
+ loader.port_weight(
176
+ layer.bias,
177
+ hf_weight_key=hf_bias_key,
178
+ )
179
+
180
+ def _port_bn(layer, timm_prefix):
181
+ loader.port_weight(layer.gamma, f"{timm_prefix}.weight")
182
+ loader.port_weight(layer.beta, f"{timm_prefix}.bias")
183
+ loader.port_weight(layer.moving_mean, f"{timm_prefix}.running_mean")
184
+ loader.port_weight(layer.moving_variance, f"{timm_prefix}.running_var")
185
+
186
+ def _port_rms_norm(layer, timm_prefix):
187
+ loader.port_weight(layer.gamma, f"{timm_prefix}.weight")
188
+
189
+ def _port_cna(cna_layer: ConvNormAct, timm_conv_prefix, timm_norm_prefix):
190
+ if isinstance(cna_layer.conv, keras.layers.DepthwiseConv2D):
191
+ _port_weights(
192
+ cna_layer.conv,
193
+ timm_conv_prefix,
194
+ transpose_dims=(2, 3, 0, 1),
195
+ )
196
+ else:
197
+ _port_weights(
198
+ cna_layer.conv,
199
+ timm_conv_prefix,
200
+ transpose_dims=(2, 3, 1, 0),
201
+ )
202
+ if key_exists(f"{timm_norm_prefix}.running_mean"):
203
+ _port_bn(cna_layer.norm, timm_norm_prefix)
204
+ else:
205
+ _port_rms_norm(cna_layer.norm, timm_norm_prefix)
206
+
207
+ def _port_attn(attn_layer, attn_prefix):
208
+ _port_weights(
209
+ attn_layer.query_layers[-1],
210
+ f"{attn_prefix}.query.proj",
211
+ (2, 3, 1, 0),
212
+ )
213
+ if len(attn_layer.key_layers) > 1:
214
+ _port_weights(
215
+ attn_layer.key_layers[0],
216
+ f"{attn_prefix}.key.down_conv",
217
+ (2, 3, 0, 1),
218
+ )
219
+ key_norm_layer = attn_layer.key_layers[1]
220
+ if isinstance(key_norm_layer, RmsNorm2d):
221
+ _port_rms_norm(key_norm_layer, f"{attn_prefix}.key.norm")
222
+ else:
223
+ _port_bn(key_norm_layer, f"{attn_prefix}.key.norm")
224
+ _port_weights(
225
+ attn_layer.key_layers[-1], f"{attn_prefix}.key.proj", (2, 3, 1, 0)
226
+ )
227
+ if len(attn_layer.value_layers) > 1:
228
+ _port_weights(
229
+ attn_layer.value_layers[0],
230
+ f"{attn_prefix}.value.down_conv",
231
+ (2, 3, 0, 1),
232
+ )
233
+ value_norm_layer = attn_layer.value_layers[1]
234
+ if isinstance(value_norm_layer, RmsNorm2d):
235
+ _port_rms_norm(value_norm_layer, f"{attn_prefix}.value.norm")
236
+ else:
237
+ _port_bn(value_norm_layer, f"{attn_prefix}.value.norm")
238
+ _port_weights(
239
+ attn_layer.value_layers[-1],
240
+ f"{attn_prefix}.value.proj",
241
+ (2, 3, 1, 0),
242
+ )
243
+ _port_weights(
244
+ attn_layer.output_proj_layers[-2],
245
+ f"{attn_prefix}.output.proj",
246
+ (2, 3, 1, 0),
247
+ )
248
+
249
+ stem_layer = backbone.get_layer("conv_stem")
250
+ _port_cna(stem_layer, "conv_stem.conv", "conv_stem.bn")
251
+ block_layers = [
252
+ layer
253
+ for layer in backbone.layers
254
+ if isinstance(
255
+ layer, (EdgeResidual, UniversalInvertedResidual, MobileAttention)
256
+ )
257
+ ]
258
+ block_counter = 0
259
+ for stack_idx in range(len(backbone.stackwise_num_blocks)):
260
+ for block_idx_in_stage in range(
261
+ backbone.stackwise_num_blocks[stack_idx]
262
+ ):
263
+ block = block_layers[block_counter]
264
+ timm_prefix = f"blocks.{stack_idx}.{block_idx_in_stage}"
265
+ if isinstance(block, EdgeResidual):
266
+ _port_cna(
267
+ block.conv_exp,
268
+ f"{timm_prefix}.conv_exp",
269
+ f"{timm_prefix}.bn1",
270
+ )
271
+ _port_cna(
272
+ block.conv_pwl,
273
+ f"{timm_prefix}.conv_pwl",
274
+ f"{timm_prefix}.bn2",
275
+ )
276
+ elif isinstance(block, UniversalInvertedResidual):
277
+ if hasattr(block, "dw_start") and not isinstance(
278
+ block.dw_start, types.FunctionType
279
+ ):
280
+ _port_cna(
281
+ block.dw_start,
282
+ f"{timm_prefix}.dw_start.conv",
283
+ f"{timm_prefix}.dw_start.bn",
284
+ )
285
+ _port_cna(
286
+ block.pw_exp,
287
+ f"{timm_prefix}.pw_exp.conv",
288
+ f"{timm_prefix}.pw_exp.bn",
289
+ )
290
+ if hasattr(block, "dw_mid") and not isinstance(
291
+ block.dw_mid, types.FunctionType
292
+ ):
293
+ _port_cna(
294
+ block.dw_mid,
295
+ f"{timm_prefix}.dw_mid.conv",
296
+ f"{timm_prefix}.dw_mid.bn",
297
+ )
298
+ _port_cna(
299
+ block.pw_proj,
300
+ f"{timm_prefix}.pw_proj.conv",
301
+ f"{timm_prefix}.pw_proj.bn",
302
+ )
303
+ gamma_key = f"{timm_prefix}.layer_scale.gamma"
304
+ if key_exists(gamma_key):
305
+ loader.port_weight(block.layer_scale.gamma, gamma_key)
306
+ elif isinstance(block, MobileAttention):
307
+ _port_rms_norm(block.norm, f"{timm_prefix}.norm")
308
+ gamma_key = f"{timm_prefix}.layer_scale.gamma"
309
+ if key_exists(gamma_key):
310
+ loader.port_weight(block.layer_scale.gamma, gamma_key)
311
+ attn_prefix = f"{timm_prefix}.attn"
312
+ _port_attn(block.attn, attn_prefix)
313
+ block_counter += 1
314
+ try:
315
+ msfa_layer = backbone.get_layer("msfa")
316
+ ffn = msfa_layer.ffn
317
+ _port_cna(ffn.pw_exp, "msfa.ffn.pw_exp.conv", "msfa.ffn.pw_exp.bn")
318
+ _port_cna(ffn.pw_proj, "msfa.ffn.pw_proj.conv", "msfa.ffn.pw_proj.bn")
319
+ _port_rms_norm(msfa_layer.norm, "msfa.norm")
320
+ except ValueError:
321
+ pass
@@ -7,6 +7,7 @@ from keras_hub.src.utils.timm import convert_cspnet
7
7
  from keras_hub.src.utils.timm import convert_densenet
8
8
  from keras_hub.src.utils.timm import convert_efficientnet
9
9
  from keras_hub.src.utils.timm import convert_mobilenet
10
+ from keras_hub.src.utils.timm import convert_mobilenetv5
10
11
  from keras_hub.src.utils.timm import convert_resnet
11
12
  from keras_hub.src.utils.timm import convert_vgg
12
13
  from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
@@ -22,6 +23,8 @@ class TimmPresetLoader(PresetLoader):
22
23
  self.converter = convert_cspnet
23
24
  elif architecture.startswith("densenet"):
24
25
  self.converter = convert_densenet
26
+ elif architecture.startswith("mobilenetv5"):
27
+ self.converter = convert_mobilenetv5
25
28
  elif architecture.startswith("mobilenet"):
26
29
  self.converter = convert_mobilenet
27
30
  elif architecture.startswith("vgg"):
@@ -41,7 +44,8 @@ class TimmPresetLoader(PresetLoader):
41
44
  keras_config = self.converter.convert_backbone_config(self.config)
42
45
  backbone = cls(**{**keras_config, **kwargs})
43
46
  if load_weights:
44
- jax_memory_cleanup(backbone)
47
+ if not self.config["architecture"].startswith("mobilenetv5"):
48
+ jax_memory_cleanup(backbone)
45
49
  # Use prefix="" to avoid using `get_prefixed_key`.
46
50
  with SafetensorLoader(self.preset, prefix="") as loader:
47
51
  self.converter.convert_weights(backbone, loader, self.config)
@@ -54,9 +58,9 @@ class TimmPresetLoader(PresetLoader):
54
58
  )
55
59
  # Support loading the classification head for classifier models.
56
60
  kwargs["num_classes"] = self.config["num_classes"]
57
- if (
58
- "num_features" in self.config
59
- and "mobilenet" in self.config["architecture"]
61
+ if "num_features" in self.config and (
62
+ "mobilenet" in self.config["architecture"]
63
+ or "mobilenetv5" in self.config["architecture"]
60
64
  ):
61
65
  kwargs["num_features"] = self.config["num_features"]
62
66
 
@@ -0,0 +1,106 @@
1
+ import numpy as np
2
+
3
+ from keras_hub.src.models.dinov3.dinov3_backbone import DINOV3Backbone
4
+
5
+ backbone_cls = DINOV3Backbone
6
+
7
+
8
+ def convert_backbone_config(transformers_config):
9
+ image_size = transformers_config["image_size"]
10
+ return {
11
+ "patch_size": transformers_config["patch_size"],
12
+ "num_layers": transformers_config["num_hidden_layers"],
13
+ "hidden_dim": transformers_config["hidden_size"],
14
+ "num_heads": transformers_config["num_attention_heads"],
15
+ "intermediate_dim": transformers_config["intermediate_size"],
16
+ "layer_scale_init_value": transformers_config["layerscale_value"],
17
+ "num_register_tokens": transformers_config["num_register_tokens"],
18
+ "use_mask_token": True,
19
+ "hidden_activation": transformers_config["hidden_act"],
20
+ "use_gated_mlp": transformers_config["use_gated_mlp"],
21
+ "use_query_bias": transformers_config["query_bias"],
22
+ "use_key_bias": transformers_config["key_bias"],
23
+ "use_value_bias": transformers_config["value_bias"],
24
+ "use_proj_bias": transformers_config["proj_bias"],
25
+ "use_mlp_bias": transformers_config["mlp_bias"],
26
+ "attention_dropout": transformers_config["attention_dropout"],
27
+ "drop_path_rate": transformers_config["drop_path_rate"],
28
+ "layer_norm_eps": transformers_config["layer_norm_eps"],
29
+ "image_shape": (image_size, image_size, 3),
30
+ "rope_theta": transformers_config["rope_theta"],
31
+ "apply_layernorm": False,
32
+ }
33
+
34
+
35
+ def convert_weights(backbone, loader, transformers_config):
36
+ if not isinstance(backbone, DINOV3Backbone):
37
+ raise ValueError(
38
+ "The provided backbone must be an instance of DINOV3Backbone. "
39
+ f"Received: {type(backbone)}"
40
+ )
41
+
42
+ def port_ln(keras_variable, weight_key):
43
+ loader.port_weight(keras_variable.gamma, f"{weight_key}.weight")
44
+ loader.port_weight(keras_variable.beta, f"{weight_key}.bias")
45
+
46
+ def port_dense(keras_variable, weight_key):
47
+ loader.port_weight(
48
+ keras_variable.kernel,
49
+ f"{weight_key}.weight",
50
+ hook_fn=lambda x, _: x.T,
51
+ )
52
+ if keras_variable.bias is not None:
53
+ loader.port_weight(keras_variable.bias, f"{weight_key}.bias")
54
+
55
+ # Embedding.
56
+ loader.port_weight(
57
+ keras_variable=backbone.embeddings.cls_token,
58
+ hf_weight_key="embeddings.cls_token",
59
+ )
60
+ if backbone.use_mask_token:
61
+ loader.port_weight(
62
+ keras_variable=backbone.embeddings.mask_token,
63
+ hf_weight_key="embeddings.mask_token",
64
+ )
65
+ if backbone.num_register_tokens > 0:
66
+ loader.port_weight(
67
+ keras_variable=backbone.embeddings.register_tokens,
68
+ hf_weight_key="embeddings.register_tokens",
69
+ )
70
+ loader.port_weight(
71
+ keras_variable=backbone.embeddings.patch_embeddings.projection.kernel,
72
+ hf_weight_key="embeddings.patch_embeddings.weight",
73
+ hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
74
+ )
75
+ loader.port_weight(
76
+ keras_variable=backbone.embeddings.patch_embeddings.projection.bias,
77
+ hf_weight_key="embeddings.patch_embeddings.bias",
78
+ )
79
+
80
+ # Encoder.
81
+ for i, layer in enumerate(backbone.encoder.layers):
82
+ prefix = f"layer.{i}"
83
+ port_ln(layer.norm1, f"{prefix}.norm1")
84
+ port_dense(layer.attention.query_dense, f"{prefix}.attention.q_proj")
85
+ port_dense(layer.attention.key_dense, f"{prefix}.attention.k_proj")
86
+ port_dense(layer.attention.value_dense, f"{prefix}.attention.v_proj")
87
+ port_dense(layer.attention.output_dense, f"{prefix}.attention.o_proj")
88
+
89
+ loader.port_weight(
90
+ keras_variable=layer.layer_scale1.lambda1,
91
+ hf_weight_key=f"{prefix}.layer_scale1.lambda1",
92
+ )
93
+ port_ln(layer.norm2, f"{prefix}.norm2")
94
+ if backbone.use_gated_mlp:
95
+ port_dense(layer.mlp.gate_proj, f"{prefix}.mlp.gate_proj")
96
+ port_dense(layer.mlp.up_proj, f"{prefix}.mlp.up_proj")
97
+ port_dense(layer.mlp.down_proj, f"{prefix}.mlp.down_proj")
98
+ else:
99
+ port_dense(layer.mlp.up_proj, f"{prefix}.mlp.up_proj")
100
+ port_dense(layer.mlp.down_proj, f"{prefix}.mlp.down_proj")
101
+ loader.port_weight(
102
+ keras_variable=layer.layer_scale2.lambda1,
103
+ hf_weight_key=f"{prefix}.layer_scale2.lambda1",
104
+ )
105
+
106
+ port_ln(backbone.layernorm, "norm")