keras-hub-nightly 0.21.0.dev202505050407__py3-none-any.whl → 0.21.0.dev202505060405__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. keras_hub/models/__init__.py +21 -0
  2. keras_hub/src/models/backbone.py +5 -2
  3. keras_hub/src/models/mixtral/mixtral_attention.py +263 -0
  4. keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
  5. keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
  6. keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
  7. keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
  8. keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
  9. keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
  10. keras_hub/src/models/qwen/qwen_attention.py +3 -1
  11. keras_hub/src/models/qwen/qwen_presets.py +61 -0
  12. keras_hub/src/models/qwen_moe/__init__.py +0 -0
  13. keras_hub/src/models/qwen_moe/qwen_moe_attention.py +377 -0
  14. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
  15. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
  16. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
  17. keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
  18. keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
  19. keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
  20. keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
  21. keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
  22. keras_hub/src/models/task.py +5 -2
  23. keras_hub/src/utils/keras_utils.py +11 -0
  24. keras_hub/src/utils/preset_utils.py +69 -9
  25. keras_hub/src/utils/tensor_utils.py +27 -1
  26. keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
  27. keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
  28. keras_hub/src/utils/transformers/preset_loader.py +6 -0
  29. keras_hub/src/version.py +1 -1
  30. keras_hub/tokenizers/__init__.py +6 -0
  31. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/METADATA +1 -1
  32. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/RECORD +34 -16
  33. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/WHEEL +0 -0
  34. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/top_level.txt +0 -0
@@ -6,16 +6,3 @@ from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
6
6
  @keras_hub_export("keras_hub.layers.RetinaNetImageConverter")
7
7
  class RetinaNetImageConverter(ImageConverter):
8
8
  backbone_cls = RetinaNetBackbone
9
-
10
- def __init__(
11
- self,
12
- *args,
13
- **kwargs,
14
- ):
15
- # TODO: update presets and remove these old config options. They were
16
- # never needed.
17
- if "norm_mean" in kwargs:
18
- kwargs["offset"] = [-x for x in kwargs.pop("norm_mean")]
19
- if "norm_std" in kwargs:
20
- kwargs["scale"] = [1.0 / x for x in kwargs.pop("norm_std")]
21
- super().__init__(*args, **kwargs)
@@ -11,7 +11,7 @@ backbone_presets = {
11
11
  "params": 34121239,
12
12
  "path": "retinanet",
13
13
  },
14
- "kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_coco/3",
14
+ "kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_coco/4",
15
15
  },
16
16
  "retinanet_resnet50_fpn_v2_coco": {
17
17
  "metadata": {
@@ -22,6 +22,6 @@ backbone_presets = {
22
22
  "params": 31558592,
23
23
  "path": "retinanet",
24
24
  },
25
- "kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_v2_coco/2",
25
+ "kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_v2_coco/3",
26
26
  },
27
27
  }
@@ -236,14 +236,17 @@ class Task(PipelineModel):
236
236
  objects_to_skip=backbone_layer_ids,
237
237
  )
238
238
 
239
- def save_to_preset(self, preset_dir):
239
+ def save_to_preset(self, preset_dir, max_shard_size=10):
240
240
  """Save task to a preset directory.
241
241
 
242
242
  Args:
243
243
  preset_dir: The path to the local model preset directory.
244
+ max_shard_size: `int` or `float`. Maximum size in GB for each
245
+ sharded file. If `None`, no sharding will be done. Defaults to
246
+ `10`.
244
247
  """
245
248
  saver = get_preset_saver(preset_dir)
246
- saver.save_task(self)
249
+ saver.save_task(self, max_shard_size=max_shard_size)
247
250
 
248
251
  @property
249
252
  def layers(self):
@@ -1,3 +1,4 @@
1
+ import inspect
1
2
  import sys
2
3
 
3
4
  import keras
@@ -147,3 +148,13 @@ def get_gpu_names():
147
148
  ]
148
149
  else:
149
150
  return [""]
151
+
152
+
153
+ def sharded_weights_available():
154
+ """Whether sharded weights serialization is available.
155
+
156
+ Returns:
157
+ `True` if sharded weights are available, `False` otherwise.
158
+ """
159
+ save_weights_signature = inspect.signature(keras.saving.save_weights)
160
+ return "max_shard_size" in save_weights_signature.parameters
@@ -10,6 +10,8 @@ from absl import logging
10
10
 
11
11
  from keras_hub.src.api_export import keras_hub_export
12
12
  from keras_hub.src.utils.keras_utils import print_msg
13
+ from keras_hub.src.utils.keras_utils import sharded_weights_available
14
+ from keras_hub.src.utils.tensor_utils import get_tensor_size_in_bits
13
15
 
14
16
  try:
15
17
  import kagglehub
@@ -48,6 +50,7 @@ METADATA_FILE = "metadata.json"
48
50
  # Weight file names.
49
51
  MODEL_WEIGHTS_FILE = "model.weights.h5"
50
52
  TASK_WEIGHTS_FILE = "task.weights.h5"
53
+ SHARDED_MODEL_WEIGHTS_CONFIG_FILE = "model.weights.json"
51
54
 
52
55
  # HuggingFace filenames.
53
56
  README_FILE = "README.md"
@@ -647,7 +650,7 @@ class KerasPresetLoader(PresetLoader):
647
650
  backbone = self._load_serialized_object(self.config, **kwargs)
648
651
  if load_weights:
649
652
  jax_memory_cleanup(backbone)
650
- backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
653
+ self._load_backbone_weights(backbone)
651
654
  return backbone
652
655
 
653
656
  def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs):
@@ -697,8 +700,7 @@ class KerasPresetLoader(PresetLoader):
697
700
  task.load_task_weights(task_weights)
698
701
  else:
699
702
  jax_memory_cleanup(task.backbone)
700
- backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE)
701
- task.backbone.load_weights(backbone_weights)
703
+ self._load_backbone_weights(task.backbone)
702
704
  return task
703
705
 
704
706
  def load_preprocessor(
@@ -726,18 +728,64 @@ class KerasPresetLoader(PresetLoader):
726
728
  config["config"] = {**config["config"], **kwargs}
727
729
  return keras.saving.deserialize_keras_object(config)
728
730
 
731
+ def _get_sharded_filenames(self, config_path):
732
+ with open(config_path, encoding="utf-8") as config_file:
733
+ config = json.load(config_file)
734
+ weight_map = config["weight_map"]
735
+ return sorted(set(weight_map.values()))
736
+
737
+ def _load_backbone_weights(self, backbone):
738
+ # Detect if the backbone is sharded or not.
739
+ has_single_file_weights = check_file_exists(
740
+ self.preset, MODEL_WEIGHTS_FILE
741
+ )
742
+ if has_single_file_weights:
743
+ filepath = get_file(self.preset, MODEL_WEIGHTS_FILE)
744
+ else:
745
+ if not sharded_weights_available():
746
+ raise RuntimeError(
747
+ "Sharded weights loading is not supported in the current "
748
+ f"Keras version {keras.__version__}. "
749
+ "Please update to a newer version."
750
+ )
751
+ filepath = get_file(self.preset, SHARDED_MODEL_WEIGHTS_CONFIG_FILE)
752
+ sharded_filenames = self._get_sharded_filenames(filepath)
753
+ for sharded_filename in sharded_filenames:
754
+ # Download the sharded weights.
755
+ _ = get_file(self.preset, sharded_filename)
756
+ backbone.load_weights(filepath)
757
+
729
758
 
730
759
  class KerasPresetSaver:
731
760
  def __init__(self, preset_dir):
732
761
  os.makedirs(preset_dir, exist_ok=True)
733
762
  self.preset_dir = preset_dir
734
763
 
735
- def save_backbone(self, backbone):
764
+ def save_backbone(self, backbone, max_shard_size=10):
736
765
  self._save_serialized_object(backbone, config_file=CONFIG_FILE)
737
- backbone_weight_path = os.path.join(self.preset_dir, MODEL_WEIGHTS_FILE)
738
- backbone.save_weights(backbone_weight_path)
739
766
  self._save_metadata(backbone)
740
767
 
768
+ # Save the weights.
769
+ backbone_size_in_bytes = self._get_variables_size_in_bytes(
770
+ backbone.variables
771
+ )
772
+ backbone_size_in_gb = backbone_size_in_bytes / (1024**3)
773
+ # If the size of the backbone is larger than `max_shard_size`, save
774
+ # sharded weights.
775
+ if sharded_weights_available() and backbone_size_in_gb > max_shard_size:
776
+ backbone_sharded_weights_config_path = os.path.join(
777
+ self.preset_dir, SHARDED_MODEL_WEIGHTS_CONFIG_FILE
778
+ )
779
+ backbone.save_weights(
780
+ backbone_sharded_weights_config_path,
781
+ max_shard_size=max_shard_size,
782
+ )
783
+ else:
784
+ backbone_weight_path = os.path.join(
785
+ self.preset_dir, MODEL_WEIGHTS_FILE
786
+ )
787
+ backbone.save_weights(backbone_weight_path)
788
+
741
789
  def save_tokenizer(self, tokenizer):
742
790
  config_file = TOKENIZER_CONFIG_FILE
743
791
  if hasattr(tokenizer, "config_file"):
@@ -755,7 +803,7 @@ class KerasPresetSaver:
755
803
  def save_image_converter(self, converter):
756
804
  self._save_serialized_object(converter, IMAGE_CONVERTER_CONFIG_FILE)
757
805
 
758
- def save_task(self, task):
806
+ def save_task(self, task, max_shard_size=10):
759
807
  # Save task specific config and weights.
760
808
  self._save_serialized_object(task, TASK_CONFIG_FILE)
761
809
  if task.has_task_weights():
@@ -763,10 +811,12 @@ class KerasPresetSaver:
763
811
  task.save_task_weights(task_weight_path)
764
812
  # Save backbone.
765
813
  if hasattr(task.backbone, "save_to_preset"):
766
- task.backbone.save_to_preset(self.preset_dir)
814
+ task.backbone.save_to_preset(
815
+ self.preset_dir, max_shard_size=max_shard_size
816
+ )
767
817
  else:
768
818
  # Allow saving a `keras.Model` that is not a backbone subclass.
769
- self.save_backbone(task.backbone)
819
+ self.save_backbone(task.backbone, max_shard_size=max_shard_size)
770
820
  # Save preprocessor.
771
821
  if task.preprocessor and hasattr(task.preprocessor, "save_to_preset"):
772
822
  task.preprocessor.save_to_preset(self.preset_dir)
@@ -823,3 +873,13 @@ class KerasPresetSaver:
823
873
  metadata_path = os.path.join(self.preset_dir, METADATA_FILE)
824
874
  with open(metadata_path, "w") as metadata_file:
825
875
  metadata_file.write(json.dumps(metadata, indent=4))
876
+
877
+ def _get_variables_size_in_bytes(self, variables):
878
+ unique_variables = {}
879
+ for v in variables:
880
+ if id(v) not in unique_variables:
881
+ unique_variables[id(v)] = (v.shape, v.dtype)
882
+ total_memory_size = 0
883
+ for shape, dtype in unique_variables.values():
884
+ total_memory_size += get_tensor_size_in_bits(shape, dtype)
885
+ return total_memory_size / 8
@@ -1,6 +1,8 @@
1
1
  import contextlib
2
2
  import functools
3
3
  import inspect
4
+ import math
5
+ import re
4
6
  import threading
5
7
 
6
8
  import keras
@@ -305,6 +307,29 @@ def is_string_dtype(dtype):
305
307
  return "string" in keras.backend.standardize_dtype(dtype)
306
308
 
307
309
 
310
+ def get_dtype_size_in_bits(dtype):
311
+ """Get the size of a given dtype in bits."""
312
+ dtype = keras.backend.standardize_dtype(dtype)
313
+ # If dtype is bool, return 1 immediately.
314
+ if dtype == "bool":
315
+ return 1
316
+ # Else, we extract the bit size from the string.
317
+ return int(re.sub(r"bfloat|float|uint|int", "", dtype))
318
+
319
+
320
+ def get_tensor_size_in_bits(shape, dtype):
321
+ """Calculate the size given dtype and shape in bits.
322
+
323
+ Args:
324
+ dtype: The dtype of the tensor.
325
+ shape: List of iterables representing the shape of the tensor.
326
+
327
+ Returns:
328
+ The size of the tensor in bytes.
329
+ """
330
+ return math.prod(shape) * get_dtype_size_in_bits(dtype)
331
+
332
+
308
333
  def any_equal(inputs, values, padding_mask):
309
334
  """Return a mask that is True anywhere `inputs` has a value in `values`.
310
335
 
@@ -320,7 +345,8 @@ def any_equal(inputs, values, padding_mask):
320
345
  Returns:
321
346
  A tensor with `inputs` shape where each position is True if it contains
322
347
  a value from any `values`. Padding mask will be applied before
323
- returning."""
348
+ returning.
349
+ """
324
350
  output = ops.equal(inputs, values[0])
325
351
  for value in values[1:]:
326
352
  value_equality = ops.equal(inputs, value)
@@ -0,0 +1,139 @@
1
+ import numpy as np
2
+
3
+ from keras_hub.src.models.mixtral.mixtral_backbone import MixtralBackbone
4
+ from keras_hub.src.utils.preset_utils import get_file
5
+
6
+ backbone_cls = MixtralBackbone
7
+
8
+
9
+ def convert_backbone_config(transformers_config):
10
+ return {
11
+ "vocabulary_size": transformers_config["vocab_size"],
12
+ "num_layers": transformers_config["num_hidden_layers"],
13
+ "num_query_heads": transformers_config["num_attention_heads"],
14
+ "hidden_dim": transformers_config["hidden_size"],
15
+ "intermediate_dim": transformers_config["intermediate_size"],
16
+ "num_key_value_heads": transformers_config["num_key_value_heads"],
17
+ "num_experts": transformers_config["num_local_experts"],
18
+ "top_k": transformers_config["num_experts_per_tok"],
19
+ "rope_max_wavelength": transformers_config["rope_theta"],
20
+ "layer_norm_epsilon": transformers_config["rms_norm_eps"],
21
+ "sliding_window": transformers_config["sliding_window"],
22
+ "output_router_logits": transformers_config["output_router_logits"],
23
+ }
24
+
25
+
26
+ def convert_weights(backbone, loader, transformers_config):
27
+ # Embeddings
28
+ loader.port_weight(
29
+ keras_variable=backbone.get_layer("token_embedding").embeddings,
30
+ hf_weight_key="model.embed_tokens.weight",
31
+ )
32
+ loader.port_weight(
33
+ keras_variable=backbone.get_layer("token_embedding").reverse_embeddings,
34
+ hf_weight_key="lm_head.weight",
35
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
36
+ )
37
+
38
+ def transpose_and_reshape(x, shape):
39
+ return np.reshape(np.transpose(x), shape)
40
+
41
+ for i in range(backbone.num_layers):
42
+ decoder_layer = backbone.get_layer(f"transformer_layer_{i}")
43
+
44
+ # Input layernorm
45
+ loader.port_weight(
46
+ keras_variable=decoder_layer._self_attention_layernorm.scale,
47
+ hf_weight_key=f"model.layers.{i}.input_layernorm.weight",
48
+ )
49
+
50
+ # Attention layers
51
+ ## Query
52
+ loader.port_weight(
53
+ keras_variable=decoder_layer._self_attention_layer._query_dense.kernel,
54
+ hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight",
55
+ hook_fn=transpose_and_reshape,
56
+ )
57
+ ## Key
58
+ loader.port_weight(
59
+ keras_variable=decoder_layer._self_attention_layer._key_dense.kernel,
60
+ hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight",
61
+ hook_fn=transpose_and_reshape,
62
+ )
63
+ ## Value
64
+ loader.port_weight(
65
+ keras_variable=decoder_layer._self_attention_layer._value_dense.kernel,
66
+ hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight",
67
+ hook_fn=transpose_and_reshape,
68
+ )
69
+ ## Output
70
+ loader.port_weight(
71
+ keras_variable=decoder_layer._self_attention_layer._output_dense.kernel,
72
+ hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight",
73
+ hook_fn=transpose_and_reshape,
74
+ )
75
+
76
+ # MoE layers
77
+ # Router gate
78
+ loader.port_weight(
79
+ keras_variable=decoder_layer._sparse_moe_block._sparse_feedforward_gate_dense.kernel,
80
+ hf_weight_key=f"model.layers.{i}.block_sparse_moe.gate.weight",
81
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
82
+ )
83
+
84
+ # Batched experts: w1 (gate), w3 (intermediate), and w2 (output) weights
85
+ gate_weights_list = []
86
+ intermediate_weights_list = []
87
+ output_weights_list = []
88
+ for expert_idx in range(backbone.num_experts):
89
+ # Load w1 (gate dense) for each expert
90
+ w1 = loader.get_tensor(
91
+ f"model.layers.{i}.block_sparse_moe.experts.{expert_idx}.w1.weight"
92
+ )
93
+ w1_transposed = np.transpose(w1, axes=(1, 0))
94
+ gate_weights_list.append(w1_transposed)
95
+
96
+ w3 = loader.get_tensor(
97
+ f"model.layers.{i}.block_sparse_moe.experts.{expert_idx}.w3.weight"
98
+ )
99
+ w3_transposed = np.transpose(w3, axes=(1, 0))
100
+ intermediate_weights_list.append(w3_transposed)
101
+
102
+ w2 = loader.get_tensor(
103
+ f"model.layers.{i}.block_sparse_moe.experts.{expert_idx}.w2.weight"
104
+ )
105
+ w2_transposed = np.transpose(w2, axes=(1, 0))
106
+ output_weights_list.append(w2_transposed)
107
+
108
+ gate_batched = np.stack(gate_weights_list, axis=0)
109
+ intermediate_batched = np.stack(intermediate_weights_list, axis=0)
110
+ output_batched = np.stack(output_weights_list, axis=0)
111
+
112
+ # Assign batched weights to expert_bank
113
+ decoder_layer._sparse_moe_block.expert_bank._expert_feedforward_gate_dense.assign(
114
+ gate_batched
115
+ )
116
+ decoder_layer._sparse_moe_block.expert_bank._expert_feedforward_intermediate_dense.assign(
117
+ intermediate_batched
118
+ )
119
+ decoder_layer._sparse_moe_block.expert_bank._expert_feedforward_output_dense.assign(
120
+ output_batched
121
+ )
122
+
123
+ # Feedforward layernorm
124
+ loader.port_weight(
125
+ keras_variable=decoder_layer._feedforward_layernorm.scale,
126
+ hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight",
127
+ )
128
+
129
+ # Final normalization layer
130
+ loader.port_weight(
131
+ keras_variable=backbone.get_layer("sequence_output_layernorm").scale,
132
+ hf_weight_key="model.norm.weight",
133
+ )
134
+
135
+ return backbone
136
+
137
+
138
+ def convert_tokenizer(cls, preset, **kwargs):
139
+ return cls(get_file(preset, "tokenizer.model"), **kwargs)
@@ -0,0 +1,253 @@
1
+ import numpy as np
2
+
3
+ from keras_hub.src.models.qwen_moe.qwen_moe_backbone import QwenMoeBackbone
4
+ from keras_hub.src.utils.preset_utils import load_json
5
+
6
+ backbone_cls = QwenMoeBackbone
7
+
8
+
9
+ def convert_backbone_config(transformers_config):
10
+ return {
11
+ "vocabulary_size": transformers_config["vocab_size"],
12
+ "hidden_dim": transformers_config["hidden_size"],
13
+ "num_layers": transformers_config["num_hidden_layers"],
14
+ "num_query_heads": transformers_config["num_attention_heads"],
15
+ "num_key_value_heads": transformers_config["num_key_value_heads"],
16
+ "intermediate_dim": transformers_config["intermediate_size"],
17
+ "moe_intermediate_dim": transformers_config["moe_intermediate_size"],
18
+ "shared_expert_intermediate_dim": transformers_config[
19
+ "shared_expert_intermediate_size"
20
+ ],
21
+ "num_experts": transformers_config["num_experts"],
22
+ "top_k": transformers_config["num_experts_per_tok"],
23
+ "norm_top_k_prob": transformers_config["norm_topk_prob"],
24
+ "decoder_sparse_step": transformers_config["decoder_sparse_step"],
25
+ "layer_norm_epsilon": transformers_config["rms_norm_eps"],
26
+ "rope_max_wavelength": transformers_config["rope_theta"],
27
+ "use_sliding_window": transformers_config["use_sliding_window"],
28
+ "sliding_window_size": transformers_config["sliding_window"],
29
+ "output_router_logits": transformers_config["output_router_logits"],
30
+ "router_aux_loss_coefficient": transformers_config[
31
+ "router_aux_loss_coef"
32
+ ],
33
+ }
34
+
35
+
36
+ def convert_weights(backbone, loader, transformers_config):
37
+ loader.port_weight(
38
+ keras_variable=backbone.get_layer("token_embedding").embeddings,
39
+ hf_weight_key="model.embed_tokens.weight",
40
+ )
41
+ if not backbone.tie_word_embeddings:
42
+ loader.port_weight(
43
+ keras_variable=backbone.get_layer(
44
+ "token_embedding"
45
+ ).reverse_embeddings,
46
+ hf_weight_key="lm_head.weight",
47
+ # rearrange_pattern="b a -> a b",
48
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
49
+ )
50
+
51
+ def transpose_and_reshape(x, shape):
52
+ return np.reshape(np.transpose(x), shape)
53
+
54
+ for i in range(backbone.num_layers):
55
+ decoder_layer = backbone.get_layer(f"transformer_layer_{i}")
56
+
57
+ # Input layernorm
58
+ loader.port_weight(
59
+ keras_variable=decoder_layer._self_attention_layernorm.scale,
60
+ hf_weight_key=f"model.layers.{i}.input_layernorm.weight",
61
+ )
62
+
63
+ # Attention layers
64
+
65
+ ## Query
66
+ loader.port_weight(
67
+ keras_variable=decoder_layer._self_attention_layer._query_dense.kernel,
68
+ hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight",
69
+ hook_fn=transpose_and_reshape,
70
+ )
71
+ loader.port_weight(
72
+ keras_variable=decoder_layer._self_attention_layer._query_dense.bias,
73
+ hf_weight_key=f"model.layers.{i}.self_attn.q_proj.bias",
74
+ hook_fn=transpose_and_reshape,
75
+ )
76
+ ## Key
77
+ loader.port_weight(
78
+ keras_variable=decoder_layer._self_attention_layer._key_dense.kernel,
79
+ hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight",
80
+ hook_fn=transpose_and_reshape,
81
+ )
82
+ loader.port_weight(
83
+ keras_variable=decoder_layer._self_attention_layer._key_dense.bias,
84
+ hf_weight_key=f"model.layers.{i}.self_attn.k_proj.bias",
85
+ hook_fn=transpose_and_reshape,
86
+ )
87
+ ## Value
88
+ loader.port_weight(
89
+ keras_variable=decoder_layer._self_attention_layer._value_dense.kernel,
90
+ hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight",
91
+ hook_fn=transpose_and_reshape,
92
+ )
93
+ loader.port_weight(
94
+ keras_variable=decoder_layer._self_attention_layer._value_dense.bias,
95
+ hf_weight_key=f"model.layers.{i}.self_attn.v_proj.bias",
96
+ hook_fn=transpose_and_reshape,
97
+ )
98
+ ## Output
99
+ loader.port_weight(
100
+ keras_variable=decoder_layer._self_attention_layer._output_dense.kernel,
101
+ hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight",
102
+ # rearrange_patterns="c (a b) -> a b c",
103
+ # rearrange_dims={"a": backbone.num_query_heads},
104
+ hook_fn=transpose_and_reshape,
105
+ )
106
+
107
+ # MLP layers
108
+ if (
109
+ (i not in backbone.mlp_only_layers)
110
+ and backbone.num_experts > 0
111
+ and ((i + 1) % backbone.decoder_sparse_step == 0)
112
+ ):
113
+ # MoE layers
114
+ loader.port_weight(
115
+ keras_variable=decoder_layer.mlp._sparse_feedforward_gate_dense.kernel,
116
+ hf_weight_key=f"model.layers.{i}.mlp.gate.weight",
117
+ # rearrange_patterns="b a -> a b",
118
+ hook_fn=lambda hf_tensor, _: np.transpose(
119
+ hf_tensor, axes=(1, 0)
120
+ ),
121
+ )
122
+ # Batched experts: gate_up_proj and down_proj
123
+ gate_up_proj_list = []
124
+ down_proj_list = []
125
+ for expert_idx in range(backbone.num_experts):
126
+ # Load gate_proj and up_proj for each expert
127
+ gate_proj = loader.get_tensor(
128
+ f"model.layers.{i}.mlp.experts.{expert_idx}.gate_proj.weight"
129
+ )
130
+ up_proj = loader.get_tensor(
131
+ f"model.layers.{i}.mlp.experts.{expert_idx}.up_proj.weight"
132
+ )
133
+ # Transpose to (hidden_dim, intermediate_dim)
134
+ gate_proj = np.transpose(gate_proj, axes=(1, 0))
135
+ up_proj = np.transpose(up_proj, axes=(1, 0))
136
+ # Concatenate gate_proj and up_proj along the last dimension
137
+ gate_up_proj = np.concatenate([gate_proj, up_proj], axis=-1)
138
+ gate_up_proj_list.append(gate_up_proj)
139
+
140
+ # Load down_proj for each expert
141
+ down_proj = loader.get_tensor(
142
+ f"model.layers.{i}.mlp.experts.{expert_idx}.down_proj.weight"
143
+ )
144
+ down_proj = np.transpose(
145
+ down_proj, axes=(1, 0)
146
+ ) # (intermediate_dim, hidden_dim)
147
+ down_proj_list.append(down_proj)
148
+
149
+ # Stack the lists to create batched weights
150
+ gate_up_proj_batched = np.stack(
151
+ gate_up_proj_list, axis=0
152
+ ) # (num_experts, hidden_dim, 2 * intermediate_dim)
153
+ down_proj_batched = np.stack(
154
+ down_proj_list, axis=0
155
+ ) # (num_experts, intermediate_dim, hidden_dim)
156
+
157
+ # Assign batched weights to expert_bank
158
+ decoder_layer.mlp.expert_bank._expert_feedforward_gate_dense.assign(
159
+ gate_up_proj_batched
160
+ )
161
+ decoder_layer.mlp.expert_bank._expert_feedforward_output_dense.assign(
162
+ down_proj_batched
163
+ )
164
+
165
+ loader.port_weight(
166
+ keras_variable=decoder_layer.mlp.shared_expert_dense._feedforward_intermediate_dense.kernel,
167
+ hf_weight_key=f"model.layers.{i}.mlp.shared_expert.up_proj.weight",
168
+ hook_fn=lambda hf_tensor, _: np.transpose(
169
+ hf_tensor, axes=(1, 0)
170
+ ),
171
+ )
172
+ loader.port_weight(
173
+ keras_variable=decoder_layer.mlp.shared_expert_dense._feedforward_output_dense.kernel,
174
+ hf_weight_key=f"model.layers.{i}.mlp.shared_expert.down_proj.weight",
175
+ hook_fn=lambda hf_tensor, _: np.transpose(
176
+ hf_tensor, axes=(1, 0)
177
+ ),
178
+ )
179
+ loader.port_weight(
180
+ keras_variable=decoder_layer.mlp.shared_expert_dense._feedforward_gate_dense.kernel,
181
+ hf_weight_key=f"model.layers.{i}.mlp.shared_expert.gate_proj.weight",
182
+ hook_fn=lambda hf_tensor, _: np.transpose(
183
+ hf_tensor, axes=(1, 0)
184
+ ),
185
+ )
186
+
187
+ loader.port_weight(
188
+ keras_variable=decoder_layer.mlp.shared_expert_gate_dense.kernel,
189
+ hf_weight_key=f"model.layers.{i}.mlp.shared_expert_gate.weight",
190
+ hook_fn=lambda hf_tensor, _: np.transpose(
191
+ hf_tensor, axes=(1, 0)
192
+ ),
193
+ )
194
+ else:
195
+ loader.port_weight(
196
+ keras_variable=decoder_layer._feedforward_intermediate_dense.kernel,
197
+ hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight",
198
+ # rearrange_patterns="b a -> a b",
199
+ hook_fn=lambda hf_tensor, _: np.transpose(
200
+ hf_tensor, axes=(1, 0)
201
+ ),
202
+ )
203
+ loader.port_weight(
204
+ keras_variable=decoder_layer._feedforward_output_dense.kernel,
205
+ hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight",
206
+ # rearrange_patterns="b a -> a b",
207
+ hook_fn=lambda hf_tensor, _: np.transpose(
208
+ hf_tensor, axes=(1, 0)
209
+ ),
210
+ )
211
+ loader.port_weight(
212
+ keras_variable=decoder_layer._feedforward_gate_dense.kernel,
213
+ hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight",
214
+ # rearrange_patterns="b a -> a b",
215
+ hook_fn=lambda hf_tensor, _: np.transpose(
216
+ hf_tensor, axes=(1, 0)
217
+ ),
218
+ )
219
+
220
+ # Feedforward layernorm
221
+ loader.port_weight(
222
+ keras_variable=decoder_layer._feedforward_layernorm.scale,
223
+ hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight",
224
+ )
225
+
226
+ # Final normalization layer
227
+ loader.port_weight(
228
+ keras_variable=backbone.get_layer("sequence_output_layernorm").scale,
229
+ hf_weight_key="model.norm.weight",
230
+ )
231
+
232
+ return backbone
233
+
234
+
235
+ def convert_tokenizer(cls, preset, **kwargs):
236
+ tokenizer_config = load_json(preset, "tokenizer.json")
237
+ vocab = tokenizer_config["model"]["vocab"]
238
+ merges = tokenizer_config["model"]["merges"]
239
+
240
+ # Load all special tokens with the exception of "reserved" ones.
241
+ special_tokens = set()
242
+ for token in tokenizer_config["added_tokens"]:
243
+ if not token["content"].startswith("<|reserved_special_token_"):
244
+ vocab[token["content"]] = token["id"]
245
+ special_tokens.add(token["content"])
246
+
247
+ kwargs.update(
248
+ {
249
+ "unsplittable_tokens": list(special_tokens),
250
+ }
251
+ )
252
+
253
+ return cls(vocabulary=vocab, merges=merges, **kwargs)
@@ -11,8 +11,10 @@ from keras_hub.src.utils.transformers import convert_gemma
11
11
  from keras_hub.src.utils.transformers import convert_gpt2
12
12
  from keras_hub.src.utils.transformers import convert_llama3
13
13
  from keras_hub.src.utils.transformers import convert_mistral
14
+ from keras_hub.src.utils.transformers import convert_mixtral
14
15
  from keras_hub.src.utils.transformers import convert_pali_gemma
15
16
  from keras_hub.src.utils.transformers import convert_qwen
17
+ from keras_hub.src.utils.transformers import convert_qwen_moe
16
18
  from keras_hub.src.utils.transformers import convert_vit
17
19
  from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
18
20
 
@@ -44,6 +46,10 @@ class TransformersPresetLoader(PresetLoader):
44
46
  self.converter = convert_vit
45
47
  elif model_type == "qwen2":
46
48
  self.converter = convert_qwen
49
+ elif model_type == "mixtral":
50
+ self.converter = convert_mixtral
51
+ elif model_type == "qwen2_moe":
52
+ self.converter = convert_qwen_moe
47
53
  else:
48
54
  raise ValueError(
49
55
  "KerasHub has no converter for huggingface/transformers models "