keras-hub-nightly 0.24.0.dev202511220420__py3-none-any.whl → 0.26.0.dev202601010440__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-hub-nightly might be problematic. Click here for more details.

Files changed (66) hide show
  1. keras_hub/models/__init__.py +12 -0
  2. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  3. keras_hub/src/layers/modeling/rotary_embedding.py +188 -14
  4. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -3
  5. keras_hub/src/models/albert/albert_backbone.py +1 -3
  6. keras_hub/src/models/bart/bart_backbone.py +1 -3
  7. keras_hub/src/models/bert/bert_backbone.py +1 -3
  8. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  9. keras_hub/src/models/causal_lm.py +23 -1
  10. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  11. keras_hub/src/models/dinov3/dinov3_presets.py +90 -1
  12. keras_hub/src/models/electra/electra_backbone.py +1 -3
  13. keras_hub/src/models/esm/esm_attention.py +11 -4
  14. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  15. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  16. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  17. keras_hub/src/models/gemma/gemma_causal_lm.py +16 -0
  18. keras_hub/src/models/gemma3/gemma3_backbone.py +1 -3
  19. keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +8 -3
  20. keras_hub/src/models/gemma3/gemma3_presets.py +12 -0
  21. keras_hub/src/models/gemma3/gemma3_tokenizer.py +20 -8
  22. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  23. keras_hub/src/models/gpt2/gpt2_causal_lm.py +17 -0
  24. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  25. keras_hub/src/models/gpt_oss/__init__.py +5 -0
  26. keras_hub/src/models/gpt_oss/gpt_oss_attention.py +330 -0
  27. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +219 -0
  28. keras_hub/src/models/gpt_oss/gpt_oss_causal_lm.py +284 -0
  29. keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_preprocessor.py +79 -0
  30. keras_hub/src/models/gpt_oss/gpt_oss_decoder.py +444 -0
  31. keras_hub/src/models/gpt_oss/gpt_oss_layer_norm.py +34 -0
  32. keras_hub/src/models/gpt_oss/gpt_oss_presets.py +51 -0
  33. keras_hub/src/models/gpt_oss/gpt_oss_tokenizer.py +39 -0
  34. keras_hub/src/models/llama/llama_backbone.py +1 -3
  35. keras_hub/src/models/llama3/llama3_presets.py +1 -1
  36. keras_hub/src/models/masked_lm.py +22 -0
  37. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  38. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  39. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  40. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  41. keras_hub/src/models/parseq/parseq_decoder.py +21 -9
  42. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  43. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  44. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  45. keras_hub/src/models/qwen3/qwen3_presets.py +36 -0
  46. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  47. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  48. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  49. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  50. keras_hub/src/models/smollm3/__init__.py +5 -0
  51. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  52. keras_hub/src/models/smollm3/smollm3_presets.py +16 -0
  53. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +1 -1
  54. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  55. keras_hub/src/models/t5/t5_backbone.py +1 -3
  56. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  57. keras_hub/src/tests/test_case.py +1 -3
  58. keras_hub/src/utils/transformers/convert_gemma3.py +353 -0
  59. keras_hub/src/utils/transformers/convert_gpt_oss.py +302 -0
  60. keras_hub/src/utils/transformers/preset_loader.py +12 -0
  61. keras_hub/src/version.py +1 -1
  62. keras_hub/tokenizers/__init__.py +3 -0
  63. {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/METADATA +4 -5
  64. {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/RECORD +66 -53
  65. {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/WHEEL +0 -0
  66. {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/top_level.txt +0 -0
@@ -364,20 +364,32 @@ class PARSeqDecoder(keras.layers.Layer):
364
364
  null_context = self.hidden_dim**0.5 * self.token_embedding(
365
365
  token_ids[:, :1]
366
366
  )
367
- if tokens_length > 1:
368
- content = self.pos_query_embeddings[:, : tokens_length - 1, :]
369
- content = content + self.hidden_dim**0.5 * self.token_embedding(
370
- token_ids[:, 1:]
371
- )
372
- content = ops.concatenate([null_context, content], axis=1)
373
- else:
374
- content = null_context
367
+
368
+ # Build content embeddings. When tokens_length == 1, this produces an
369
+ # empty tensor (shape: bs, 0, hidden), avoiding the need for a Python
370
+ # conditional.
371
+ content_embeddings = self.hidden_dim**0.5 * self.token_embedding(
372
+ token_ids[:, 1:]
373
+ )
374
+ # Use ops.take instead of dynamic slicing for JAX/TF graph compatibility
375
+ pos_embeds = ops.take(
376
+ self.pos_query_embeddings,
377
+ ops.arange(ops.shape(content_embeddings)[1], dtype="int32"),
378
+ axis=1,
379
+ )
380
+ content = ops.concatenate(
381
+ [null_context, pos_embeds + content_embeddings], axis=1
382
+ )
375
383
 
376
384
  content = self.dropout(content)
377
385
 
378
386
  query = ops.multiply(
379
387
  ops.ones((bs, 1, 1), dtype=self.dtype),
380
- self.pos_query_embeddings[:, :tokens_length, :],
388
+ ops.take(
389
+ self.pos_query_embeddings,
390
+ ops.arange(tokens_length, dtype="int32"),
391
+ axis=1,
392
+ ),
381
393
  )
382
394
  query = self.dropout(query)
383
395
 
@@ -1,9 +1,7 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
- from keras_hub.src.layers.modeling.reversible_embedding import (
5
- ReversibleEmbedding,
6
- )
7
5
  from keras_hub.src.models.backbone import Backbone
8
6
  from keras_hub.src.models.phi3.phi3_decoder import Phi3Decoder
9
7
  from keras_hub.src.models.phi3.phi3_layernorm import Phi3LayerNorm
@@ -1,10 +1,8 @@
1
1
  import keras
2
2
  from keras import ops
3
+ from keras.layers import ReversibleEmbedding
3
4
 
4
5
  from keras_hub.src.api_export import keras_hub_export
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.models.backbone import Backbone
9
7
  from keras_hub.src.models.qwen.qwen_decoder import QwenTransformerDecoder
10
8
  from keras_hub.src.models.qwen.qwen_layernorm import QwenLayerNorm
@@ -1,10 +1,8 @@
1
1
  import keras
2
2
  from keras import ops
3
+ from keras.layers import ReversibleEmbedding
3
4
 
4
5
  from keras_hub.src.api_export import keras_hub_export
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.models.backbone import Backbone
9
7
  from keras_hub.src.models.qwen3.qwen3_decoder import Qwen3TransformerDecoder
10
8
  from keras_hub.src.models.qwen3.qwen3_layernorm import Qwen3LayerNorm
@@ -70,4 +70,40 @@ backbone_presets = {
70
70
  },
71
71
  "kaggle_handle": "kaggle://keras/qwen-3/keras/qwen3_32b_en/1",
72
72
  },
73
+ "qwen3_embedding_0.6b_en": {
74
+ "metadata": {
75
+ "description": (
76
+ "This text embedding model features a 32k context length and "
77
+ "offers flexible, user-defined embedding dimensions that can "
78
+ "range from 32 to 1024."
79
+ ),
80
+ "params": 595776512,
81
+ "path": "qwen3",
82
+ },
83
+ "kaggle_handle": "kaggle://keras/qwen-3-embedding/keras/qwen3_embedding_0.6b_en/1",
84
+ },
85
+ "qwen3_embedding_4b_en": {
86
+ "metadata": {
87
+ "description": (
88
+ "This text embedding model features a 32k context length and "
89
+ "offers flexible, user-defined embedding dimensions that can "
90
+ "range from 32 to 2560."
91
+ ),
92
+ "params": 4021774336,
93
+ "path": "qwen3",
94
+ },
95
+ "kaggle_handle": "kaggle://keras/qwen-3-embedding/keras/qwen3_embedding_4b_en/1",
96
+ },
97
+ "qwen3_embedding_8b_en": {
98
+ "metadata": {
99
+ "description": (
100
+ "This text embedding model features a 32k context length and "
101
+ "offers flexible, user-defined embedding dimensions that can "
102
+ "range from 32 to 4096."
103
+ ),
104
+ "params": 8188515328,
105
+ "path": "qwen3",
106
+ },
107
+ "kaggle_handle": "kaggle://keras/qwen-3-embedding/keras/qwen3_embedding_8b_en/1",
108
+ },
73
109
  }
@@ -1,10 +1,8 @@
1
1
  import keras
2
2
  from keras import ops
3
+ from keras.layers import ReversibleEmbedding
3
4
 
4
5
  from keras_hub.src.api_export import keras_hub_export
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.models.backbone import Backbone
9
7
  from keras_hub.src.models.qwen3_moe.qwen3_moe_decoder import (
10
8
  Qwen3MoeTransformerDecoder,
@@ -1,10 +1,8 @@
1
1
  import keras
2
2
  from keras import ops
3
+ from keras.layers import ReversibleEmbedding
3
4
 
4
5
  from keras_hub.src.api_export import keras_hub_export
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.models.backbone import Backbone
9
7
  from keras_hub.src.models.qwen.qwen_layernorm import QwenLayerNorm
10
8
  from keras_hub.src.models.qwen_moe.qwen_moe_decoder import (
@@ -1,10 +1,8 @@
1
1
  import keras
2
2
  from keras import activations
3
+ from keras.layers import ReversibleEmbedding
3
4
 
4
5
  from keras_hub.src.api_export import keras_hub_export
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.models.backbone import Backbone
9
7
  from keras_hub.src.models.roformer_v2.roformer_v2_attention import RoformerNorm
10
8
  from keras_hub.src.models.roformer_v2.roformer_v2_encoder import (
@@ -3,10 +3,8 @@ import math
3
3
  from keras import initializers
4
4
  from keras import layers
5
5
  from keras import ops
6
+ from keras.layers import ReversibleEmbedding
6
7
 
7
- from keras_hub.src.layers.modeling.reversible_embedding import (
8
- ReversibleEmbedding,
9
- )
10
8
  from keras_hub.src.utils.keras_utils import clone_initializer
11
9
  from keras_hub.src.utils.keras_utils import gelu_approximate
12
10
  from keras_hub.src.utils.keras_utils import standardize_data_format
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.smollm3.smollm3_backbone import SmolLM3Backbone
2
+ from keras_hub.src.models.smollm3.smollm3_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, SmolLM3Backbone)
@@ -1,9 +1,7 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
- from keras_hub.src.layers.modeling.reversible_embedding import (
5
- ReversibleEmbedding,
6
- )
7
5
  from keras_hub.src.models.backbone import Backbone
8
6
  from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer
9
7
 
@@ -0,0 +1,16 @@
1
+ """SmolLM3 model preset configurations."""
2
+
3
+ backbone_presets = {
4
+ "smollm3_3b_en": {
5
+ "metadata": {
6
+ "description": (
7
+ "Dense decoder-only model has 3 billion total parameters, "
8
+ "built on 36 layers and utilizes 16 query and "
9
+ "4 key/value attention heads."
10
+ ),
11
+ "params": 3075100928,
12
+ "path": "smollm3",
13
+ },
14
+ "kaggle_handle": "kaggle://keras/smollm3/keras/smollm3_3b_en/1",
15
+ },
16
+ }
@@ -23,7 +23,7 @@ backbone_presets = {
23
23
  "params": 3371793763,
24
24
  "path": "stable_diffusion_3",
25
25
  },
26
- "kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3.5_medium/1",
26
+ "kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_medium/1",
27
27
  },
28
28
  "stable_diffusion_3.5_large": {
29
29
  "metadata": {
@@ -1,8 +1,6 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
- from keras_hub.src.layers.modeling.reversible_embedding import (
4
- ReversibleEmbedding,
5
- )
6
4
  from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm
7
5
  from keras_hub.src.models.t5.t5_transformer_layer import T5TransformerLayer
8
6
 
@@ -1,9 +1,7 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
- from keras_hub.src.layers.modeling.reversible_embedding import (
5
- ReversibleEmbedding,
6
- )
7
5
  from keras_hub.src.models.backbone import Backbone
8
6
  from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm
9
7
  from keras_hub.src.models.t5.t5_transformer_layer import T5TransformerLayer
@@ -1,9 +1,7 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
- from keras_hub.src.layers.modeling.reversible_embedding import (
5
- ReversibleEmbedding,
6
- )
7
5
  from keras_hub.src.models.backbone import Backbone
8
6
  from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
9
7
  from keras_hub.src.models.t5gemma.t5gemma_decoder import T5GemmaDecoderLayer
@@ -9,10 +9,8 @@ import tensorflow as tf
9
9
  from absl.testing import parameterized
10
10
  from keras import ops
11
11
  from keras import tree
12
+ from keras.layers import ReversibleEmbedding
12
13
 
13
- from keras_hub.src.layers.modeling.reversible_embedding import (
14
- ReversibleEmbedding,
15
- )
16
14
  from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid
17
15
  from keras_hub.src.tokenizers.tokenizer import Tokenizer
18
16
  from keras_hub.src.utils.tensor_utils import is_float_dtype
@@ -0,0 +1,353 @@
1
+ import numpy as np
2
+ from sentencepiece import SentencePieceProcessor
3
+
4
+ from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
5
+ from keras_hub.src.models.gemma3.gemma3_vision_encoder import (
6
+ Gemma3VisionEncoder,
7
+ )
8
+ from keras_hub.src.utils.preset_utils import get_file
9
+ from keras_hub.src.utils.preset_utils import load_json
10
+
11
+ backbone_cls = Gemma3Backbone
12
+
13
+
14
+ def load_image_converter_config(preset, transformers_config):
15
+ if "vision_config" in transformers_config:
16
+ preprocessor_config = load_json(preset, "preprocessor_config.json")
17
+ mean = preprocessor_config["image_mean"]
18
+ std = preprocessor_config["image_std"]
19
+ rescale_factor = preprocessor_config["rescale_factor"]
20
+ offset = [(-m / s) for m, s in zip(mean, std)]
21
+ scale = [(s * rescale_factor) for s in std]
22
+ image_size = transformers_config["vision_config"].get("image_size", 224)
23
+ return {
24
+ "image_size": (image_size, image_size),
25
+ "scale": scale,
26
+ "offset": offset,
27
+ }
28
+ else:
29
+ return None
30
+
31
+
32
+ def convert_backbone_config(transformers_config):
33
+ if transformers_config["model_type"] == "gemma3_text":
34
+ image_size = None
35
+ vision_encoder = None
36
+ transformer_config = transformers_config
37
+ else:
38
+ vision_config = transformers_config["vision_config"]
39
+ image_size = vision_config["image_size"]
40
+ vision_encoder_config = {
41
+ "image_size": image_size,
42
+ "patch_size": vision_config["patch_size"],
43
+ "num_heads": vision_config["num_attention_heads"],
44
+ "hidden_dim": vision_config["hidden_size"],
45
+ "num_layers": vision_config["num_hidden_layers"],
46
+ "intermediate_dim": vision_config["intermediate_size"],
47
+ "output_dim": 2560,
48
+ "pool_size": 4,
49
+ "layer_norm_epsilon": vision_config.get("layer_norm_eps", 1e-6),
50
+ }
51
+ vision_encoder = Gemma3VisionEncoder(**vision_encoder_config)
52
+ transformer_config = transformers_config["text_config"]
53
+
54
+ if "rope_parameters" in transformer_config:
55
+ rope_global_config = transformer_config.get("rope_parameters", {}).get(
56
+ "full_attention"
57
+ )
58
+ elif "rope_scaling" in transformer_config:
59
+ rope_global_config = transformer_config["rope_scaling"]
60
+ else:
61
+ rope_global_config = {}
62
+ return {
63
+ "vocabulary_size": transformer_config.get(
64
+ "vocab_size", 262144 if vision_encoder is None else 262208
65
+ ),
66
+ "image_size": image_size,
67
+ "num_layers": transformer_config["num_hidden_layers"],
68
+ "num_query_heads": transformer_config.get("num_attention_heads", 8),
69
+ "num_key_value_heads": transformer_config.get("num_key_value_heads", 4),
70
+ "hidden_dim": transformer_config["hidden_size"],
71
+ "intermediate_dim": transformer_config["intermediate_size"],
72
+ "head_dim": transformer_config["head_dim"],
73
+ "use_post_ffw_norm": True,
74
+ "use_post_attention_norm": True,
75
+ "attention_logit_softcap": transformer_config.get(
76
+ "attn_logit_softcap", None
77
+ ),
78
+ "final_logit_softcap": transformer_config.get(
79
+ "final_logit_softcap", None
80
+ ),
81
+ "use_sliding_window_attention": True,
82
+ "query_head_dim_normalize": True,
83
+ "sliding_window_size": transformer_config["sliding_window"],
84
+ "local_rope_scaling_factor": 1.0,
85
+ "global_rope_scaling_factor": (
86
+ rope_global_config.get("factor", 1.0) if rope_global_config else 1.0
87
+ ),
88
+ "layer_norm_epsilon": transformer_config.get("rms_norm_eps", 1e-6),
89
+ "use_bidirectional_attention": transformer_config.get(
90
+ "use_bidirectional_attention", False
91
+ ),
92
+ "vision_encoder": vision_encoder,
93
+ }
94
+
95
+
96
+ def convert_weights(backbone, loader, transformers_config):
97
+ if transformers_config["model_type"] == "gemma3_text":
98
+ prefix = "model"
99
+ else:
100
+ prefix = "language_model.model"
101
+
102
+ loader.port_weight(
103
+ keras_variable=backbone.get_layer("token_embedding").embeddings,
104
+ hf_weight_key=f"{prefix}.embed_tokens.weight",
105
+ )
106
+
107
+ def transpose(x, shape):
108
+ return np.transpose(x)
109
+
110
+ vision_encoder = backbone.vision_encoder
111
+ if vision_encoder is not None:
112
+ image_encoder = vision_encoder.get_layer("image_encoder")
113
+
114
+ loader.port_weight(
115
+ keras_variable=image_encoder.vision_embeddings.patch_embedding.kernel,
116
+ hf_weight_key="vision_tower.vision_model.embeddings.patch_embedding.weight",
117
+ hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
118
+ )
119
+ loader.port_weight(
120
+ keras_variable=image_encoder.vision_embeddings.patch_embedding.bias,
121
+ hf_weight_key="vision_tower.vision_model.embeddings.patch_embedding.bias",
122
+ )
123
+
124
+ loader.port_weight(
125
+ keras_variable=image_encoder.vision_embeddings.position_embedding.embeddings,
126
+ hf_weight_key="vision_tower.vision_model.embeddings.position_embedding.weight",
127
+ )
128
+
129
+ for i in range(image_encoder.num_layers):
130
+ loader.port_weight(
131
+ keras_variable=image_encoder.resblocks[i].layer_norm_1.gamma,
132
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight",
133
+ )
134
+ loader.port_weight(
135
+ keras_variable=image_encoder.resblocks[i].layer_norm_1.beta,
136
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias",
137
+ )
138
+ loader.port_weight(
139
+ keras_variable=image_encoder.resblocks[
140
+ i
141
+ ].attn.query_proj.kernel,
142
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight",
143
+ hook_fn=transpose,
144
+ )
145
+ loader.port_weight(
146
+ keras_variable=image_encoder.resblocks[i].attn.query_proj.bias,
147
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias",
148
+ )
149
+ loader.port_weight(
150
+ keras_variable=image_encoder.resblocks[i].attn.key_proj.kernel,
151
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight",
152
+ hook_fn=transpose,
153
+ )
154
+ loader.port_weight(
155
+ keras_variable=image_encoder.resblocks[i].attn.key_proj.bias,
156
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias",
157
+ )
158
+ loader.port_weight(
159
+ keras_variable=image_encoder.resblocks[
160
+ i
161
+ ].attn.value_proj.kernel,
162
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight",
163
+ hook_fn=transpose,
164
+ )
165
+ loader.port_weight(
166
+ keras_variable=image_encoder.resblocks[i].attn.value_proj.bias,
167
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias",
168
+ )
169
+ loader.port_weight(
170
+ keras_variable=image_encoder.resblocks[i].attn.out_proj.kernel,
171
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight",
172
+ hook_fn=transpose,
173
+ )
174
+ loader.port_weight(
175
+ keras_variable=image_encoder.resblocks[i].attn.out_proj.bias,
176
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias",
177
+ )
178
+
179
+ loader.port_weight(
180
+ keras_variable=image_encoder.resblocks[i].layer_norm_2.gamma,
181
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight",
182
+ )
183
+ loader.port_weight(
184
+ keras_variable=image_encoder.resblocks[i].layer_norm_2.beta,
185
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias",
186
+ )
187
+ loader.port_weight(
188
+ keras_variable=image_encoder.resblocks[i].mlp_dense_1.kernel,
189
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight",
190
+ hook_fn=transpose,
191
+ )
192
+ loader.port_weight(
193
+ keras_variable=image_encoder.resblocks[i].mlp_dense_1.bias,
194
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias",
195
+ )
196
+ loader.port_weight(
197
+ keras_variable=image_encoder.resblocks[i].mlp_dense_2.kernel,
198
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight",
199
+ hook_fn=transpose,
200
+ )
201
+ loader.port_weight(
202
+ keras_variable=image_encoder.resblocks[i].mlp_dense_2.bias,
203
+ hf_weight_key=f"vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias",
204
+ )
205
+
206
+ loader.port_weight(
207
+ keras_variable=image_encoder.encoder_layer_norm.gamma,
208
+ hf_weight_key="vision_tower.vision_model.post_layernorm.weight",
209
+ )
210
+ loader.port_weight(
211
+ keras_variable=image_encoder.encoder_layer_norm.beta,
212
+ hf_weight_key="vision_tower.vision_model.post_layernorm.bias",
213
+ )
214
+
215
+ loader.port_weight(
216
+ keras_variable=vision_encoder.get_layer(
217
+ "vision_output_encoder"
218
+ ).vision_soft_embedding_norm.scale,
219
+ hf_weight_key="multi_modal_projector.mm_soft_emb_norm.weight",
220
+ )
221
+
222
+ loader.port_weight(
223
+ keras_variable=vision_encoder.get_layer(
224
+ "vision_output_encoder"
225
+ ).vision_input_projection.kernel,
226
+ hf_weight_key="multi_modal_projector.mm_input_projection_weight",
227
+ )
228
+
229
+ for i in range(backbone.num_layers):
230
+ decoder_layer = backbone.get_layer(f"decoder_block_{i}")
231
+
232
+ loader.port_weight(
233
+ keras_variable=decoder_layer.pre_attention_norm.scale,
234
+ hf_weight_key=f"{prefix}.layers.{i}.input_layernorm.weight",
235
+ )
236
+ loader.port_weight(
237
+ keras_variable=decoder_layer.post_attention_norm.scale,
238
+ hf_weight_key=f"{prefix}.layers.{i}.post_attention_layernorm.weight",
239
+ )
240
+ loader.port_weight(
241
+ keras_variable=decoder_layer.pre_ffw_norm.scale,
242
+ hf_weight_key=f"{prefix}.layers.{i}.pre_feedforward_layernorm.weight",
243
+ )
244
+ loader.port_weight(
245
+ keras_variable=decoder_layer.post_ffw_norm.scale,
246
+ hf_weight_key=f"{prefix}.layers.{i}.post_feedforward_layernorm.weight",
247
+ )
248
+
249
+ # Attention layers
250
+
251
+ ## Query
252
+ loader.port_weight(
253
+ keras_variable=decoder_layer.attention.query_dense.kernel,
254
+ hf_weight_key=f"{prefix}.layers.{i}.self_attn.q_proj.weight",
255
+ hook_fn=lambda hf_tensor, keras_shape: np.transpose(
256
+ np.reshape(
257
+ hf_tensor,
258
+ (keras_shape[0], keras_shape[2], keras_shape[1]),
259
+ ),
260
+ axes=(0, 2, 1),
261
+ ),
262
+ )
263
+ loader.port_weight(
264
+ keras_variable=decoder_layer.attention.query_norm.scale,
265
+ hf_weight_key=f"{prefix}.layers.{i}.self_attn.q_norm.weight",
266
+ )
267
+ ## Key
268
+ loader.port_weight(
269
+ keras_variable=decoder_layer.attention.key_dense.kernel,
270
+ hf_weight_key=f"{prefix}.layers.{i}.self_attn.k_proj.weight",
271
+ hook_fn=lambda hf_tensor, keras_shape: np.transpose(
272
+ np.reshape(
273
+ hf_tensor,
274
+ (keras_shape[0], keras_shape[2], keras_shape[1]),
275
+ ),
276
+ axes=(0, 2, 1),
277
+ ),
278
+ )
279
+ loader.port_weight(
280
+ keras_variable=decoder_layer.attention.key_norm.scale,
281
+ hf_weight_key=f"{prefix}.layers.{i}.self_attn.k_norm.weight",
282
+ )
283
+ ## Value
284
+ loader.port_weight(
285
+ keras_variable=decoder_layer.attention.value_dense.kernel,
286
+ hf_weight_key=f"{prefix}.layers.{i}.self_attn.v_proj.weight",
287
+ hook_fn=lambda hf_tensor, keras_shape: np.transpose(
288
+ np.reshape(
289
+ hf_tensor,
290
+ (keras_shape[0], keras_shape[2], keras_shape[1]),
291
+ ),
292
+ axes=(0, 2, 1),
293
+ ),
294
+ )
295
+ ## Output
296
+ loader.port_weight(
297
+ keras_variable=decoder_layer.attention.output_dense.kernel,
298
+ hf_weight_key=f"{prefix}.layers.{i}.self_attn.o_proj.weight",
299
+ # rearrange_patterns="c (a b) -> a b c",
300
+ # rearrange_dims={"a": backbone.num_query_heads},
301
+ hook_fn=lambda hf_tensor, keras_shape: np.transpose(
302
+ np.reshape(
303
+ hf_tensor,
304
+ (keras_shape[2], keras_shape[0], keras_shape[1]),
305
+ ),
306
+ axes=(1, 2, 0),
307
+ ),
308
+ )
309
+
310
+ # MLP layers
311
+ loader.port_weight(
312
+ keras_variable=decoder_layer.gating_ffw.kernel,
313
+ hf_weight_key=f"{prefix}.layers.{i}.mlp.gate_proj.weight",
314
+ # rearrange_patterns="b a -> a b",
315
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
316
+ )
317
+ loader.port_weight(
318
+ keras_variable=decoder_layer.gating_ffw_2.kernel,
319
+ hf_weight_key=f"{prefix}.layers.{i}.mlp.up_proj.weight",
320
+ # rearrange_patterns="b a -> a b",
321
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
322
+ )
323
+ loader.port_weight(
324
+ keras_variable=decoder_layer.ffw_linear.kernel,
325
+ hf_weight_key=f"{prefix}.layers.{i}.mlp.down_proj.weight",
326
+ # rearrange_patterns="b a -> a b",
327
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
328
+ )
329
+
330
+ # Final normalization layer
331
+ loader.port_weight(
332
+ keras_variable=backbone.get_layer("final_normalization").scale,
333
+ hf_weight_key=f"{prefix}.norm.weight",
334
+ )
335
+
336
+ return backbone
337
+
338
+
339
+ def convert_tokenizer(cls, preset, **kwargs):
340
+ proto = get_file(preset, "tokenizer.model")
341
+ sp = SentencePieceProcessor()
342
+ if isinstance(proto, bytes):
343
+ sp.LoadFromSerializedProto(proto)
344
+ else:
345
+ sp.load(proto)
346
+
347
+ has_vision_tokens = (
348
+ sp.PieceToId("<start_of_image>") != sp.unk_id()
349
+ and sp.PieceToId("<img>") != sp.unk_id()
350
+ and sp.PieceToId("<end_of_image>") != sp.unk_id()
351
+ )
352
+
353
+ return cls(proto, has_vision_tokens=has_vision_tokens, **kwargs)