keras-hub-nightly 0.23.0.dev202510240418__py3-none-any.whl → 0.24.0.dev202512090431__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-hub-nightly might be problematic. Click here for more details.
- keras_hub/layers/__init__.py +3 -0
- keras_hub/models/__init__.py +3 -0
- keras_hub/src/models/causal_lm.py +22 -0
- keras_hub/src/models/dinov2/dinov2_layers.py +3 -1
- keras_hub/src/models/dinov3/__init__.py +5 -0
- keras_hub/src/models/dinov3/dinov3_backbone.py +263 -0
- keras_hub/src/models/dinov3/dinov3_image_converter.py +8 -0
- keras_hub/src/models/dinov3/dinov3_layers.py +1013 -0
- keras_hub/src/models/dinov3/dinov3_presets.py +93 -0
- keras_hub/src/models/esm/esm_attention.py +11 -4
- keras_hub/src/models/gemma/gemma_causal_lm.py +16 -0
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +8 -3
- keras_hub/src/models/gemma3/gemma3_presets.py +39 -0
- keras_hub/src/models/gemma3/gemma3_tokenizer.py +20 -8
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +17 -0
- keras_hub/src/models/masked_lm.py +22 -0
- keras_hub/src/models/qwen3/qwen3_presets.py +36 -0
- keras_hub/src/models/siglip/siglip_presets.py +15 -0
- keras_hub/src/models/smollm3/__init__.py +5 -0
- keras_hub/src/models/smollm3/smollm3_presets.py +16 -0
- keras_hub/src/utils/tensor_utils.py +3 -1
- keras_hub/src/utils/transformers/convert_dinov3.py +106 -0
- keras_hub/src/utils/transformers/convert_gemma3.py +353 -0
- keras_hub/src/utils/transformers/preset_loader.py +12 -0
- keras_hub/src/version.py +1 -1
- {keras_hub_nightly-0.23.0.dev202510240418.dist-info → keras_hub_nightly-0.24.0.dev202512090431.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.23.0.dev202510240418.dist-info → keras_hub_nightly-0.24.0.dev202512090431.dist-info}/RECORD +29 -20
- {keras_hub_nightly-0.23.0.dev202510240418.dist-info → keras_hub_nightly-0.24.0.dev202512090431.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.23.0.dev202510240418.dist-info → keras_hub_nightly-0.24.0.dev202512090431.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""DINOV3 model preset configurations."""
|
|
2
|
+
|
|
3
|
+
# Metadata for loading pretrained model weights.
|
|
4
|
+
backbone_presets = {
|
|
5
|
+
"dinov3_vit_small_lvd1689m": {
|
|
6
|
+
"metadata": {
|
|
7
|
+
"description": (
|
|
8
|
+
"Vision Transformer (small-sized model) trained on LVD-1689M "
|
|
9
|
+
"using DINOv3."
|
|
10
|
+
),
|
|
11
|
+
"params": 21_600_000,
|
|
12
|
+
"path": "dinov3",
|
|
13
|
+
},
|
|
14
|
+
"kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_small_lvd1689m/1",
|
|
15
|
+
},
|
|
16
|
+
"dinov3_vit_small_plus_lvd1689m": {
|
|
17
|
+
"metadata": {
|
|
18
|
+
"description": (
|
|
19
|
+
"Vision Transformer (small-plus-sized model) trained on "
|
|
20
|
+
"LVD-1689M using DINOv3."
|
|
21
|
+
),
|
|
22
|
+
"params": 29_000_000,
|
|
23
|
+
"path": "dinov3",
|
|
24
|
+
},
|
|
25
|
+
"kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_small_plus_lvd1689m/1",
|
|
26
|
+
},
|
|
27
|
+
"dinov3_vit_base_lvd1689m": {
|
|
28
|
+
"metadata": {
|
|
29
|
+
"description": (
|
|
30
|
+
"Vision Transformer (base-sized model) trained on LVD-1689M "
|
|
31
|
+
"using DINOv3."
|
|
32
|
+
),
|
|
33
|
+
"params": 86_000_000,
|
|
34
|
+
"path": "dinov3",
|
|
35
|
+
},
|
|
36
|
+
"kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_base_lvd1689m/1",
|
|
37
|
+
},
|
|
38
|
+
"dinov3_vit_large_lvd1689m": {
|
|
39
|
+
"metadata": {
|
|
40
|
+
"description": (
|
|
41
|
+
"Vision Transformer (large-sized model) trained on LVD-1689M "
|
|
42
|
+
"using DINOv3."
|
|
43
|
+
),
|
|
44
|
+
"params": 300_000_000,
|
|
45
|
+
"path": "dinov3",
|
|
46
|
+
},
|
|
47
|
+
"kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_large_lvd1689m/1",
|
|
48
|
+
},
|
|
49
|
+
"dinov3_vit_huge_plus_lvd1689m": {
|
|
50
|
+
"metadata": {
|
|
51
|
+
"description": (
|
|
52
|
+
"Vision Transformer (huge-plus-sized model) trained on "
|
|
53
|
+
"LVD-1689M using DINOv3."
|
|
54
|
+
),
|
|
55
|
+
"params": 840_000_000,
|
|
56
|
+
"path": "dinov3",
|
|
57
|
+
},
|
|
58
|
+
"kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_huge_plus_lvd1689m/1",
|
|
59
|
+
},
|
|
60
|
+
"dinov3_vit_7b_lvd1689m": {
|
|
61
|
+
"metadata": {
|
|
62
|
+
"description": (
|
|
63
|
+
"Vision Transformer (7B-sized model) trained on LVD-1689M "
|
|
64
|
+
"using DINOv3."
|
|
65
|
+
),
|
|
66
|
+
"params": 6_700_000_000,
|
|
67
|
+
"path": "dinov3",
|
|
68
|
+
},
|
|
69
|
+
"kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_7b_lvd1689m/1",
|
|
70
|
+
},
|
|
71
|
+
"dinov3_vit_large_sat493m": {
|
|
72
|
+
"metadata": {
|
|
73
|
+
"description": (
|
|
74
|
+
"Vision Transformer (large-sized model) trained on SAT-493M "
|
|
75
|
+
"using DINOv3."
|
|
76
|
+
),
|
|
77
|
+
"params": 300_000_000,
|
|
78
|
+
"path": "dinov3",
|
|
79
|
+
},
|
|
80
|
+
"kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_large_sat493m/1",
|
|
81
|
+
},
|
|
82
|
+
"dinov3_vit_7b_sat493m": {
|
|
83
|
+
"metadata": {
|
|
84
|
+
"description": (
|
|
85
|
+
"Vision Transformer (7B-sized model) trained on SAT-493M "
|
|
86
|
+
"using DINOv3."
|
|
87
|
+
),
|
|
88
|
+
"params": 6_700_000_000,
|
|
89
|
+
"path": "dinov3",
|
|
90
|
+
},
|
|
91
|
+
"kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_7b_sat493m/1",
|
|
92
|
+
},
|
|
93
|
+
}
|
|
@@ -14,7 +14,8 @@ class ESMRotaryEmbedding(RotaryEmbedding):
|
|
|
14
14
|
inv_freq = self.scaling_factor / (
|
|
15
15
|
self.max_wavelength ** (ops.arange(0, dim, 2, dtype=x.dtype) / dim)
|
|
16
16
|
)
|
|
17
|
-
|
|
17
|
+
# Use ops.shape for dynamic shape compatibility with TFLite
|
|
18
|
+
t = ops.arange(ops.shape(x)[position], dtype=x.dtype)
|
|
18
19
|
freqs = ops.outer(t, inv_freq)
|
|
19
20
|
emb = ops.concatenate((freqs, freqs), axis=-1)
|
|
20
21
|
|
|
@@ -32,11 +33,17 @@ class ESMRotaryEmbedding(RotaryEmbedding):
|
|
|
32
33
|
|
|
33
34
|
def rotate_half(self, x):
|
|
34
35
|
x1, x2 = ops.split(x, 2, -1)
|
|
35
|
-
|
|
36
|
+
# Avoid `ops.concatenate` to prevent XLA compilation issues on JAX
|
|
37
|
+
# backend. Use stack + reshape approach from base RotaryEmbedding.
|
|
38
|
+
half_rot_x = ops.stack((-x2, x1), axis=-2)
|
|
39
|
+
half_rot_x = ops.reshape(half_rot_x, ops.shape(x))
|
|
40
|
+
return half_rot_x
|
|
36
41
|
|
|
37
42
|
def apply_rotary_pos_emb(self, x, cos, sin):
|
|
38
|
-
|
|
39
|
-
|
|
43
|
+
# Use ops.shape for dynamic shape compatibility with TFLite
|
|
44
|
+
seq_len = ops.shape(x)[1]
|
|
45
|
+
cos = cos[:, :seq_len, :, :]
|
|
46
|
+
sin = sin[:, :seq_len, :, :]
|
|
40
47
|
|
|
41
48
|
return (x * cos) + (self.rotate_half(x) * sin)
|
|
42
49
|
|
|
@@ -431,3 +431,19 @@ class GemmaCausalLM(CausalLM):
|
|
|
431
431
|
)
|
|
432
432
|
per_token_loss = per_token_loss_fn(target_ids, logits)
|
|
433
433
|
return per_token_loss
|
|
434
|
+
|
|
435
|
+
def get_quantization_layer_structure(self, mode):
|
|
436
|
+
if mode != "gptq":
|
|
437
|
+
return None
|
|
438
|
+
|
|
439
|
+
# Wrap embedding + scaling
|
|
440
|
+
backbone = self.backbone
|
|
441
|
+
inputs = keras.Input(shape=(None,), dtype="int32")
|
|
442
|
+
x = backbone.token_embedding(inputs)
|
|
443
|
+
x = x * ops.cast(ops.sqrt(backbone.hidden_dim), x.dtype)
|
|
444
|
+
pre_processor = keras.Model(inputs=inputs, outputs=x)
|
|
445
|
+
|
|
446
|
+
return {
|
|
447
|
+
"pre_block_layers": [pre_processor],
|
|
448
|
+
"sequential_blocks": backbone.transformer_layers,
|
|
449
|
+
}
|
|
@@ -283,9 +283,14 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
|
|
|
283
283
|
# is `None`.
|
|
284
284
|
self.text_only_model = self.image_converter is None
|
|
285
285
|
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
286
|
+
if self.text_only_model:
|
|
287
|
+
self.image_placeholder = None
|
|
288
|
+
self.start_of_image_token = None
|
|
289
|
+
self.end_of_image_token = None
|
|
290
|
+
else:
|
|
291
|
+
self.image_placeholder = self.tokenizer.image_placeholder
|
|
292
|
+
self.start_of_image_token = self.tokenizer.start_of_image_token
|
|
293
|
+
self.end_of_image_token = self.tokenizer.end_of_image_token
|
|
289
294
|
|
|
290
295
|
def build(self, input_shape):
|
|
291
296
|
# Defer packer creation to `build()` so that we can be sure tokenizer
|
|
@@ -181,4 +181,43 @@ backbone_presets = {
|
|
|
181
181
|
},
|
|
182
182
|
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_270m/4",
|
|
183
183
|
},
|
|
184
|
+
"medgemma_instruct_4b": {
|
|
185
|
+
"metadata": {
|
|
186
|
+
"description": (
|
|
187
|
+
"A 4 billion parameter model based on Gemma 3. "
|
|
188
|
+
"This model is trained for performance on medical text"
|
|
189
|
+
"and image comprehension and is optimized for medical"
|
|
190
|
+
"applications that involve a text generation component."
|
|
191
|
+
),
|
|
192
|
+
"params": 4300079472,
|
|
193
|
+
"path": "gemma3",
|
|
194
|
+
},
|
|
195
|
+
"kaggle_handle": "kaggle://keras/medgemma/keras/medgemma_instruct_4b/1",
|
|
196
|
+
},
|
|
197
|
+
"medgemma_instruct_27b": {
|
|
198
|
+
"metadata": {
|
|
199
|
+
"description": (
|
|
200
|
+
"A 27 billion parameter model based on Gemma 3. "
|
|
201
|
+
"This model trained for performance on medical text "
|
|
202
|
+
"and image comprehension and is optimized for medical "
|
|
203
|
+
"applications that involve a text generation component."
|
|
204
|
+
),
|
|
205
|
+
"params": 27432406640,
|
|
206
|
+
"path": "gemma3",
|
|
207
|
+
},
|
|
208
|
+
"kaggle_handle": "kaggle://keras/medgemma/keras/medgemma_instruct_27b/1",
|
|
209
|
+
},
|
|
210
|
+
"medgemma_instruct_27b_text": {
|
|
211
|
+
"metadata": {
|
|
212
|
+
"description": (
|
|
213
|
+
"A 27 billion parameter text-only model based on Gemma 3. "
|
|
214
|
+
"This model is trained for performance on medical text "
|
|
215
|
+
"comprehension and is optimized for medical applications "
|
|
216
|
+
"that involve a text generation component."
|
|
217
|
+
),
|
|
218
|
+
"params": 27009002240,
|
|
219
|
+
"path": "gemma3",
|
|
220
|
+
},
|
|
221
|
+
"kaggle_handle": "kaggle://keras/medgemma/keras/medgemma_instruct_27b_text/1",
|
|
222
|
+
},
|
|
184
223
|
}
|
|
@@ -77,20 +77,32 @@ class Gemma3Tokenizer(SentencePieceTokenizer):
|
|
|
77
77
|
|
|
78
78
|
backbone_cls = Gemma3Backbone
|
|
79
79
|
|
|
80
|
-
def __init__(self, proto, **kwargs):
|
|
80
|
+
def __init__(self, proto, has_vision_tokens=True, **kwargs):
|
|
81
81
|
# Add special tokens.
|
|
82
82
|
|
|
83
|
+
self.has_vision_tokens = has_vision_tokens
|
|
83
84
|
# The usual tokens.
|
|
84
85
|
self._add_special_token("<bos>", "start_token")
|
|
85
86
|
self._add_special_token("<eos>", "end_token")
|
|
86
87
|
self._add_special_token("<pad>", "pad_token")
|
|
87
88
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
89
|
+
if has_vision_tokens:
|
|
90
|
+
# Image placeholder token.
|
|
91
|
+
self._add_special_token("<img>", "image_placeholder")
|
|
92
|
+
# Some tokens which are used in the preprocessor.
|
|
93
|
+
# We need to keep them
|
|
94
|
+
# here so that the preprocessor works with tf.data.
|
|
95
|
+
self._add_special_token("<start_of_image>", "start_of_image_token")
|
|
96
|
+
self._add_special_token("<end_of_image>", "end_of_image_token")
|
|
97
|
+
else:
|
|
98
|
+
# For text-only, skip assigning token IDs or set to -1
|
|
99
|
+
self.start_of_image_token_id = -1
|
|
100
|
+
self.image_placeholder_token_id = -1
|
|
101
|
+
self.end_of_image_token_id = -1
|
|
95
102
|
|
|
96
103
|
super().__init__(proto=proto, **kwargs)
|
|
104
|
+
|
|
105
|
+
def get_config(self):
|
|
106
|
+
config = super().get_config()
|
|
107
|
+
config.update({"has_vision_tokens": self.has_vision_tokens})
|
|
108
|
+
return config
|
|
@@ -420,3 +420,20 @@ class GPT2CausalLM(CausalLM):
|
|
|
420
420
|
)
|
|
421
421
|
per_token_loss = per_token_loss_fn(target_ids, logits)
|
|
422
422
|
return per_token_loss
|
|
423
|
+
|
|
424
|
+
def get_quantization_layer_structure(self, mode):
|
|
425
|
+
if mode != "gptq":
|
|
426
|
+
return None
|
|
427
|
+
|
|
428
|
+
backbone = self.backbone
|
|
429
|
+
token_ids = keras.Input(shape=(None,), dtype="int32")
|
|
430
|
+
tokens = backbone.token_embedding(token_ids)
|
|
431
|
+
positions = backbone.position_embedding(tokens)
|
|
432
|
+
x = backbone.embeddings_add((tokens, positions))
|
|
433
|
+
x = backbone.embeddings_dropout(x)
|
|
434
|
+
pre_processor = keras.Model(inputs=token_ids, outputs=x)
|
|
435
|
+
|
|
436
|
+
return {
|
|
437
|
+
"pre_block_layers": [pre_processor],
|
|
438
|
+
"sequential_blocks": backbone.transformer_layers,
|
|
439
|
+
}
|
|
@@ -84,3 +84,25 @@ class MaskedLM(Task):
|
|
|
84
84
|
weighted_metrics=weighted_metrics,
|
|
85
85
|
**kwargs,
|
|
86
86
|
)
|
|
87
|
+
|
|
88
|
+
def get_quantization_layer_structure(self, mode):
|
|
89
|
+
if mode != "gptq":
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
backbone = self.backbone
|
|
93
|
+
# Check for standard backbone structure.
|
|
94
|
+
if not hasattr(backbone, "transformer_layers"):
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
# Check for embedding.
|
|
98
|
+
embedding = getattr(backbone, "token_embedding", None)
|
|
99
|
+
if embedding is None:
|
|
100
|
+
embedding = getattr(backbone, "embedding", None)
|
|
101
|
+
|
|
102
|
+
if embedding is None:
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
return {
|
|
106
|
+
"pre_block_layers": [embedding],
|
|
107
|
+
"sequential_blocks": backbone.transformer_layers,
|
|
108
|
+
}
|
|
@@ -70,4 +70,40 @@ backbone_presets = {
|
|
|
70
70
|
},
|
|
71
71
|
"kaggle_handle": "kaggle://keras/qwen-3/keras/qwen3_32b_en/1",
|
|
72
72
|
},
|
|
73
|
+
"qwen3_embedding_0.6b_en": {
|
|
74
|
+
"metadata": {
|
|
75
|
+
"description": (
|
|
76
|
+
"This text embedding model features a 32k context length and "
|
|
77
|
+
"offers flexible, user-defined embedding dimensions that can "
|
|
78
|
+
"range from 32 to 1024."
|
|
79
|
+
),
|
|
80
|
+
"params": 595776512,
|
|
81
|
+
"path": "qwen3",
|
|
82
|
+
},
|
|
83
|
+
"kaggle_handle": "kaggle://keras/qwen-3-embedding/keras/qwen3_embedding_0.6b_en/1",
|
|
84
|
+
},
|
|
85
|
+
"qwen3_embedding_4b_en": {
|
|
86
|
+
"metadata": {
|
|
87
|
+
"description": (
|
|
88
|
+
"This text embedding model features a 32k context length and "
|
|
89
|
+
"offers flexible, user-defined embedding dimensions that can "
|
|
90
|
+
"range from 32 to 2560."
|
|
91
|
+
),
|
|
92
|
+
"params": 4021774336,
|
|
93
|
+
"path": "qwen3",
|
|
94
|
+
},
|
|
95
|
+
"kaggle_handle": "kaggle://keras/qwen-3-embedding/keras/qwen3_embedding_4b_en/1",
|
|
96
|
+
},
|
|
97
|
+
"qwen3_embedding_8b_en": {
|
|
98
|
+
"metadata": {
|
|
99
|
+
"description": (
|
|
100
|
+
"This text embedding model features a 32k context length and "
|
|
101
|
+
"offers flexible, user-defined embedding dimensions that can "
|
|
102
|
+
"range from 32 to 4096."
|
|
103
|
+
),
|
|
104
|
+
"params": 8188515328,
|
|
105
|
+
"path": "qwen3",
|
|
106
|
+
},
|
|
107
|
+
"kaggle_handle": "kaggle://keras/qwen-3-embedding/keras/qwen3_embedding_8b_en/1",
|
|
108
|
+
},
|
|
73
109
|
}
|
|
@@ -321,4 +321,19 @@ backbone_presets = {
|
|
|
321
321
|
},
|
|
322
322
|
"kaggle_handle": "kaggle://keras/siglip/keras/siglip2_so400m_patch16_512/1",
|
|
323
323
|
},
|
|
324
|
+
"medsiglip_900m_448": {
|
|
325
|
+
"metadata": {
|
|
326
|
+
"description": (
|
|
327
|
+
"A 900 million parameter variant of SigLIP trained to encode "
|
|
328
|
+
"medical images and text into a common embedding space. "
|
|
329
|
+
"MedSigLIP contains a vision encoder and a text encoder, and "
|
|
330
|
+
"supports 448x448 image resolution with up to 64 text tokens."
|
|
331
|
+
),
|
|
332
|
+
"params": 878301426,
|
|
333
|
+
"official_name": "SigLIP2",
|
|
334
|
+
"path": "siglip",
|
|
335
|
+
"model_card": "https://huggingface.co/google/medsiglip-448#medsiglip-model-card",
|
|
336
|
+
},
|
|
337
|
+
"kaggle_handle": "kaggle://keras/medsiglip/keras/medsiglip_900m_448/1",
|
|
338
|
+
},
|
|
324
339
|
}
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""SmolLM3 model preset configurations."""
|
|
2
|
+
|
|
3
|
+
backbone_presets = {
|
|
4
|
+
"smollm3_3b_en": {
|
|
5
|
+
"metadata": {
|
|
6
|
+
"description": (
|
|
7
|
+
"Dense decoder-only model has 3 billion total parameters, "
|
|
8
|
+
"built on 36 layers and utilizes 16 query and "
|
|
9
|
+
"4 key/value attention heads."
|
|
10
|
+
),
|
|
11
|
+
"params": 3075100928,
|
|
12
|
+
"path": "smollm3",
|
|
13
|
+
},
|
|
14
|
+
"kaggle_handle": "kaggle://keras/smollm3/keras/smollm3_3b_en/1",
|
|
15
|
+
},
|
|
16
|
+
}
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from keras_hub.src.models.dinov3.dinov3_backbone import DINOV3Backbone
|
|
4
|
+
|
|
5
|
+
backbone_cls = DINOV3Backbone
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def convert_backbone_config(transformers_config):
|
|
9
|
+
image_size = transformers_config["image_size"]
|
|
10
|
+
return {
|
|
11
|
+
"patch_size": transformers_config["patch_size"],
|
|
12
|
+
"num_layers": transformers_config["num_hidden_layers"],
|
|
13
|
+
"hidden_dim": transformers_config["hidden_size"],
|
|
14
|
+
"num_heads": transformers_config["num_attention_heads"],
|
|
15
|
+
"intermediate_dim": transformers_config["intermediate_size"],
|
|
16
|
+
"layer_scale_init_value": transformers_config["layerscale_value"],
|
|
17
|
+
"num_register_tokens": transformers_config["num_register_tokens"],
|
|
18
|
+
"use_mask_token": True,
|
|
19
|
+
"hidden_activation": transformers_config["hidden_act"],
|
|
20
|
+
"use_gated_mlp": transformers_config["use_gated_mlp"],
|
|
21
|
+
"use_query_bias": transformers_config["query_bias"],
|
|
22
|
+
"use_key_bias": transformers_config["key_bias"],
|
|
23
|
+
"use_value_bias": transformers_config["value_bias"],
|
|
24
|
+
"use_proj_bias": transformers_config["proj_bias"],
|
|
25
|
+
"use_mlp_bias": transformers_config["mlp_bias"],
|
|
26
|
+
"attention_dropout": transformers_config["attention_dropout"],
|
|
27
|
+
"drop_path_rate": transformers_config["drop_path_rate"],
|
|
28
|
+
"layer_norm_eps": transformers_config["layer_norm_eps"],
|
|
29
|
+
"image_shape": (image_size, image_size, 3),
|
|
30
|
+
"rope_theta": transformers_config["rope_theta"],
|
|
31
|
+
"apply_layernorm": False,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def convert_weights(backbone, loader, transformers_config):
|
|
36
|
+
if not isinstance(backbone, DINOV3Backbone):
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"The provided backbone must be an instance of DINOV3Backbone. "
|
|
39
|
+
f"Received: {type(backbone)}"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def port_ln(keras_variable, weight_key):
|
|
43
|
+
loader.port_weight(keras_variable.gamma, f"{weight_key}.weight")
|
|
44
|
+
loader.port_weight(keras_variable.beta, f"{weight_key}.bias")
|
|
45
|
+
|
|
46
|
+
def port_dense(keras_variable, weight_key):
|
|
47
|
+
loader.port_weight(
|
|
48
|
+
keras_variable.kernel,
|
|
49
|
+
f"{weight_key}.weight",
|
|
50
|
+
hook_fn=lambda x, _: x.T,
|
|
51
|
+
)
|
|
52
|
+
if keras_variable.bias is not None:
|
|
53
|
+
loader.port_weight(keras_variable.bias, f"{weight_key}.bias")
|
|
54
|
+
|
|
55
|
+
# Embedding.
|
|
56
|
+
loader.port_weight(
|
|
57
|
+
keras_variable=backbone.embeddings.cls_token,
|
|
58
|
+
hf_weight_key="embeddings.cls_token",
|
|
59
|
+
)
|
|
60
|
+
if backbone.use_mask_token:
|
|
61
|
+
loader.port_weight(
|
|
62
|
+
keras_variable=backbone.embeddings.mask_token,
|
|
63
|
+
hf_weight_key="embeddings.mask_token",
|
|
64
|
+
)
|
|
65
|
+
if backbone.num_register_tokens > 0:
|
|
66
|
+
loader.port_weight(
|
|
67
|
+
keras_variable=backbone.embeddings.register_tokens,
|
|
68
|
+
hf_weight_key="embeddings.register_tokens",
|
|
69
|
+
)
|
|
70
|
+
loader.port_weight(
|
|
71
|
+
keras_variable=backbone.embeddings.patch_embeddings.projection.kernel,
|
|
72
|
+
hf_weight_key="embeddings.patch_embeddings.weight",
|
|
73
|
+
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
|
|
74
|
+
)
|
|
75
|
+
loader.port_weight(
|
|
76
|
+
keras_variable=backbone.embeddings.patch_embeddings.projection.bias,
|
|
77
|
+
hf_weight_key="embeddings.patch_embeddings.bias",
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Encoder.
|
|
81
|
+
for i, layer in enumerate(backbone.encoder.layers):
|
|
82
|
+
prefix = f"layer.{i}"
|
|
83
|
+
port_ln(layer.norm1, f"{prefix}.norm1")
|
|
84
|
+
port_dense(layer.attention.query_dense, f"{prefix}.attention.q_proj")
|
|
85
|
+
port_dense(layer.attention.key_dense, f"{prefix}.attention.k_proj")
|
|
86
|
+
port_dense(layer.attention.value_dense, f"{prefix}.attention.v_proj")
|
|
87
|
+
port_dense(layer.attention.output_dense, f"{prefix}.attention.o_proj")
|
|
88
|
+
|
|
89
|
+
loader.port_weight(
|
|
90
|
+
keras_variable=layer.layer_scale1.lambda1,
|
|
91
|
+
hf_weight_key=f"{prefix}.layer_scale1.lambda1",
|
|
92
|
+
)
|
|
93
|
+
port_ln(layer.norm2, f"{prefix}.norm2")
|
|
94
|
+
if backbone.use_gated_mlp:
|
|
95
|
+
port_dense(layer.mlp.gate_proj, f"{prefix}.mlp.gate_proj")
|
|
96
|
+
port_dense(layer.mlp.up_proj, f"{prefix}.mlp.up_proj")
|
|
97
|
+
port_dense(layer.mlp.down_proj, f"{prefix}.mlp.down_proj")
|
|
98
|
+
else:
|
|
99
|
+
port_dense(layer.mlp.up_proj, f"{prefix}.mlp.up_proj")
|
|
100
|
+
port_dense(layer.mlp.down_proj, f"{prefix}.mlp.down_proj")
|
|
101
|
+
loader.port_weight(
|
|
102
|
+
keras_variable=layer.layer_scale2.lambda1,
|
|
103
|
+
hf_weight_key=f"{prefix}.layer_scale2.lambda1",
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
port_ln(backbone.layernorm, "norm")
|