sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,6 @@
1
+ import json as json_lib
2
+ import logging
3
+ import os
1
4
  from collections.abc import Iterable
2
5
  from typing import List, Optional, Set, Tuple
3
6
 
@@ -16,8 +19,17 @@ from sglang.srt.managers.mm_utils import (
16
19
  from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
17
20
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
18
21
  from sglang.srt.model_loader.weight_utils import default_weight_loader
22
+ from sglang.srt.utils import add_prefix, is_cpu
23
+
24
+ _is_cpu = is_cpu()
25
+ from sglang.srt.model_loader.weight_utils import (
26
+ default_weight_loader,
27
+ maybe_remap_kv_scale_name,
28
+ )
19
29
  from sglang.srt.utils import add_prefix
20
30
 
31
+ logger = logging.getLogger(__name__)
32
+
21
33
 
22
34
  class Llama4ForConditionalGeneration(nn.Module):
23
35
  packed_modules_mapping = {
@@ -35,31 +47,98 @@ class Llama4ForConditionalGeneration(nn.Module):
35
47
  self.config = config
36
48
  self.quant_config = quant_config
37
49
 
38
- self.vision_model = Llama4VisionModel(config.vision_config)
39
- self.multi_modal_projector = Llama4MultiModalProjector(config)
50
+ # Check if this is a text-only model (modelopt fp8 llama4 has no vision components)
51
+ self.has_vision = self._has_vision_weights(config)
52
+ if not self.has_vision:
53
+ logger.warning(
54
+ "No vision weights found in checkpoint. Model will run in text-only mode. "
55
+ "Multimodal capabilities (image processing) will be unavailable."
56
+ )
57
+
58
+ if self.has_vision:
59
+ self.vision_model = Llama4VisionModel(config.vision_config)
60
+ self.multi_modal_projector = Llama4MultiModalProjector(config)
61
+ else:
62
+ self.vision_model = None
63
+ self.multi_modal_projector = None
40
64
 
41
65
  # Initialize the language model
42
66
  from sglang.srt.models.llama4 import Llama4ForCausalLM
43
67
 
44
68
  self.language_model = Llama4ForCausalLM(
45
- config.text_config,
69
+ config.text_config if hasattr(config, "text_config") else config,
46
70
  quant_config=quant_config,
47
71
  prefix=add_prefix("language_model", prefix),
48
72
  )
49
73
 
50
- self.logits_processor = LogitsProcessor(config.text_config)
74
+ self.logits_processor = LogitsProcessor(
75
+ config.text_config if hasattr(config, "text_config") else config
76
+ )
51
77
 
52
- def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
53
- # Get all special token IDs
54
- im_token_id: int = mm_inputs.im_token_id
78
+ def _has_vision_weights(self, config) -> bool:
79
+ """Check if the model has vision components by examining the checkpoint."""
80
+ model_path = getattr(config, "_name_or_path", None)
81
+ if not model_path:
82
+ return False
83
+
84
+ # Check if this is a local path first
85
+ if os.path.isdir(model_path):
86
+ index_file = os.path.join(model_path, "model.safetensors.index.json")
87
+ if os.path.exists(index_file):
88
+ return self._check_vision_weights_in_index(index_file)
89
+
90
+ # For HuggingFace models, we need to check the actual checkpoint
91
+ # The config might say it's multimodal, but the checkpoint might be text-only
92
+ try:
93
+ # Try to access the HuggingFace cache directory
94
+ from huggingface_hub import try_to_load_from_cache
95
+
96
+ # Check if index file exists in cache
97
+ index_file_path = try_to_load_from_cache(
98
+ repo_id=model_path,
99
+ filename="model.safetensors.index.json",
100
+ cache_dir=None,
101
+ )
102
+
103
+ if index_file_path and os.path.exists(index_file_path):
104
+ return self._check_vision_weights_in_index(index_file_path)
105
+
106
+ except Exception:
107
+ # If we can't access the cache, fall back to config-based detection
108
+ pass
109
+
110
+ # Fallback, assume text-only
111
+ return False
112
+
113
+ def _check_vision_weights_in_index(self, index_file: str) -> bool:
114
+ """Check if the model.safetensors.index.json contains vision weights."""
115
+ try:
116
+ with open(index_file, "r") as f:
117
+ index_data = json_lib.load(f)
118
+
119
+ vision_patterns = ["vision_model", "vision_tower", "multi_modal_projector"]
120
+ weight_names = index_data.get("weight_map", {}).keys()
121
+
122
+ return any(
123
+ pattern in weight_name
124
+ for weight_name in weight_names
125
+ for pattern in vision_patterns
126
+ )
127
+ except (OSError, json_lib.JSONDecodeError, KeyError):
128
+ return False
55
129
 
56
- pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
130
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
131
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
57
132
  return pattern.pad_input_tokens(input_ids, mm_inputs)
58
133
 
59
134
  def get_image_feature(
60
135
  self,
61
136
  items: List[MultimodalDataItem],
62
137
  ) -> torch.Tensor:
138
+ # For text-only models, return None or raise an error
139
+ if not self.has_vision or self.vision_model is None:
140
+ raise ValueError("Vision model not available for text-only checkpoint")
141
+
63
142
  pixel_values = (
64
143
  torch.concat([item.pixel_values for item in items])
65
144
  .to(next(self.vision_model.parameters()).device)
@@ -80,11 +159,14 @@ class Llama4ForConditionalGeneration(nn.Module):
80
159
  **kwargs: object,
81
160
  ) -> torch.Tensor:
82
161
 
162
+ # For text-only models, pass None for image_data_embedding_func
163
+ image_embedding_func = self.get_image_feature if self.has_vision else None
164
+
83
165
  hs = general_mm_embed_routine(
84
166
  input_ids=input_ids,
85
167
  forward_batch=forward_batch,
86
168
  language_model=self.language_model,
87
- image_data_embedding_func=self.get_image_feature,
169
+ image_data_embedding_func=image_embedding_func,
88
170
  positions=positions,
89
171
  )
90
172
 
@@ -110,18 +192,21 @@ class Llama4ForConditionalGeneration(nn.Module):
110
192
 
111
193
  # rotary embeds should be sliced
112
194
  if ("wk" in modules or "k_proj" in modules) and modules[-1] == "weight":
113
- loaded_weight = permute(
114
- loaded_weight, self.language_model.config.num_key_value_heads
115
- )
195
+ if _is_cpu:
196
+ dim = self.language_model.config.original_total_num_kv_heads
197
+ else:
198
+ dim = self.language_model.config.num_key_value_heads
199
+ loaded_weight = permute(loaded_weight, dim)
116
200
  elif ("wq" in modules or "q_proj" in modules) and modules[-1] == "weight":
117
- loaded_weight = permute(
118
- loaded_weight, self.language_model.config.num_attention_heads
119
- )
201
+ if _is_cpu:
202
+ dim = self.language_model.config.original_num_attention_heads
203
+ else:
204
+ dim = self.language_model.config.num_attention_heads
205
+ loaded_weight = permute(loaded_weight, dim)
120
206
 
121
207
  return name, loaded_weight
122
208
 
123
209
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
124
-
125
210
  stacked_params_mapping = [
126
211
  # (param_name, shard_name, shard_id)
127
212
  (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
@@ -134,11 +219,12 @@ class Llama4ForConditionalGeneration(nn.Module):
134
219
  ]
135
220
 
136
221
  params_dict = dict(self.named_parameters())
222
+ num_experts = (
223
+ self.config.text_config.num_local_experts
224
+ if hasattr(self.config, "text_config")
225
+ else self.config.num_local_experts
226
+ )
137
227
 
138
- num_experts = self.config.text_config.num_local_experts
139
-
140
- # Params for weights, fp8 weight scales, fp8 activation scales
141
- # (param_name, weight_name, expert_id, shard_id)
142
228
  expert_params_mapping = FusedMoE.make_expert_params_mapping(
143
229
  ckpt_gate_proj_name="gate_proj",
144
230
  ckpt_down_proj_name="down_proj",
@@ -147,81 +233,308 @@ class Llama4ForConditionalGeneration(nn.Module):
147
233
  )
148
234
 
149
235
  for name, loaded_weight in weights:
150
- if not "vision" in name:
236
+ if self._should_skip_weight(name):
237
+ continue
238
+
239
+ name = self._transform_weight_name(name)
240
+
241
+ if "vision" not in name:
151
242
  name, loaded_weight = self.permute_qk_weight_for_rotary(
152
243
  name, loaded_weight
153
244
  )
154
245
 
155
- for param_name, weight_name, shard_id in stacked_params_mapping:
156
- if weight_name not in name:
157
- continue
158
-
159
- if "vision" in name:
160
- continue
161
- name = name.replace(weight_name, param_name)
162
- param = params_dict[name]
163
- weight_loader = param.weight_loader
164
- weight_loader(param, loaded_weight, shard_id)
165
- break
246
+ if self._handle_scale_remapping(name, params_dict):
247
+ continue
248
+
249
+ if self._handle_stacked_params(
250
+ name, loaded_weight, stacked_params_mapping, params_dict
251
+ ):
252
+ continue
253
+
254
+ if self._handle_expert_weights(
255
+ name, loaded_weight, expert_params_mapping, params_dict, num_experts
256
+ ):
257
+ continue
258
+
259
+ self._handle_default_weight(name, loaded_weight, params_dict)
260
+
261
+ def _should_skip_weight(self, name: str) -> bool:
262
+ """Check if we should skip loading this weight."""
263
+ return "vision" in name and not self.has_vision
264
+
265
+ def _transform_weight_name(self, name: str) -> str:
266
+ """Transform weight name by adding language_model prefix if needed."""
267
+ if (
268
+ not name.startswith("language_model.")
269
+ and "vision" not in name
270
+ and "multi_modal_projector" not in name
271
+ ):
272
+ return f"language_model.{name}"
273
+ return name
274
+
275
+ def _handle_scale_remapping(self, name: str, params_dict: dict) -> bool:
276
+ """Handle scale parameter remapping. Returns True if handled."""
277
+ if "scale" in name and "expert" not in name:
278
+ remapped_name = maybe_remap_kv_scale_name(name, params_dict)
279
+ return remapped_name is None
280
+ return False
281
+
282
+ def _handle_stacked_params(
283
+ self,
284
+ name: str,
285
+ loaded_weight: torch.Tensor,
286
+ stacked_params_mapping: list,
287
+ params_dict: dict,
288
+ ) -> bool:
289
+ """Handle stacked parameter loading. Returns True if handled."""
290
+ for param_name, weight_name, shard_id in stacked_params_mapping:
291
+ if weight_name in name and "vision" not in name:
292
+ transformed_name = name.replace(weight_name, param_name)
293
+ param = params_dict[transformed_name]
294
+ param.weight_loader(param, loaded_weight, shard_id)
295
+ return True
296
+ return False
297
+
298
+ def _handle_expert_weights(
299
+ self,
300
+ name: str,
301
+ loaded_weight: torch.Tensor,
302
+ expert_params_mapping: list,
303
+ params_dict: dict,
304
+ num_experts: int,
305
+ ) -> bool:
306
+ """Handle expert weight loading for MoE (Mixture of Experts) layers.
307
+
308
+ Args:
309
+ name: Parameter name from the checkpoint
310
+ loaded_weight: The weight tensor to be loaded
311
+ expert_params_mapping: Mapping of parameter names to expert configurations
312
+ params_dict: Dictionary of model parameters
313
+ num_experts: Total number of experts in the MoE layer
314
+
315
+ Returns:
316
+ bool: True if the parameter was handled (is an expert parameter), False otherwise
317
+ """
318
+ if ".experts" not in name:
319
+ return False
320
+
321
+ if "experts.gate_up_proj" not in name and "experts.down_proj" not in name:
322
+ return self._handle_other_expert_params(
323
+ name, loaded_weight, expert_params_mapping, params_dict
324
+ )
325
+
326
+ if "scale" in name:
327
+ return self._handle_expert_scale_params(
328
+ name, loaded_weight, params_dict, num_experts
329
+ )
330
+ else:
331
+ return self._handle_expert_weight_params(
332
+ name, loaded_weight, params_dict, num_experts
333
+ )
334
+
335
+ def _handle_other_expert_params(
336
+ self,
337
+ name: str,
338
+ loaded_weight: torch.Tensor,
339
+ expert_params_mapping: list,
340
+ params_dict: dict,
341
+ ) -> bool:
342
+ """Handle expert parameters that are not gate_up_proj or down_proj weights.
343
+
344
+ Args:
345
+ name: Parameter name from the checkpoint
346
+ loaded_weight: The weight tensor to be loaded
347
+ expert_params_mapping: List of tuples mapping checkpoint names to model parameters
348
+ params_dict: Dictionary of model parameters
349
+
350
+ Returns:
351
+ bool: True if parameter was found and handled, False otherwise
352
+ """
353
+ for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
354
+ if weight_name in name:
355
+ transformed_name = name.replace(weight_name, param_name)
356
+ param = params_dict[transformed_name]
357
+ param.weight_loader(
358
+ param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id
359
+ )
360
+ return True
361
+ return False
362
+
363
+ def _transform_expert_name(
364
+ self, name: str, is_weight: bool = False
365
+ ) -> Tuple[str, str, List[str]]:
366
+ """Transform expert parameter name and get shard information.
367
+
368
+ Args:
369
+ name: The original parameter name
370
+ is_weight: Whether this is a weight parameter (adds _weight suffix)
371
+
372
+ Returns:
373
+ Tuple of (transformed_name, shard_id, shard_id_list)
374
+ """
375
+ suffix = "_weight" if is_weight else ""
376
+
377
+ if ".gate_up_proj" in name:
378
+ transformed_name = name.replace(
379
+ ".experts.gate_up_proj", f".experts.w13{suffix}"
380
+ )
381
+ shard_id = "w13"
382
+ shard_id_list = ["w1", "w3"]
383
+ else: # down_proj
384
+ transformed_name = name.replace(
385
+ ".experts.down_proj", f".experts.w2{suffix}"
386
+ )
387
+ shard_id = "w2"
388
+ shard_id_list = ["w2"]
389
+
390
+ return transformed_name, shard_id, shard_id_list
391
+
392
+ def _handle_expert_scale_params(
393
+ self,
394
+ name: str,
395
+ loaded_weight: torch.Tensor,
396
+ params_dict: dict,
397
+ num_experts: int,
398
+ ) -> bool:
399
+ """Handle quantization scale parameters for expert weights.
400
+
401
+ Args:
402
+ name: Parameter name containing scale information
403
+ loaded_weight: Scale tensor to be loaded
404
+ params_dict: Dictionary of model parameters
405
+ num_experts: Total number of experts for broadcast operations
406
+
407
+ Returns:
408
+ bool: True (always handles scale parameters)
409
+ """
410
+ import re
411
+
412
+ # Check if this matches the expert parameter pattern: experts.{expert_id}.{param_name}
413
+ expert_match = re.search(r"experts\.(\d+)\.", name)
414
+
415
+ # Transform name
416
+ transformed_name, _, _ = self._transform_expert_name(name)
417
+
418
+ if transformed_name not in params_dict:
419
+ return True
420
+
421
+ param = params_dict[transformed_name]
422
+
423
+ # Handle scale parameters
424
+ if expert_match:
425
+ # If we have a specific expert ID, only load for that expert
426
+ expert_id = int(expert_match.group(1))
427
+ # For scale parameters, we can directly set the value
428
+ param.data[expert_id] = loaded_weight
429
+ else:
430
+ # No expert ID found - this is a single scale for all experts
431
+ # Load the same scale for all experts
432
+ for expert_id in range(num_experts):
433
+ param.data[expert_id] = loaded_weight
434
+
435
+ return True
436
+
437
+ def _handle_expert_weight_params(
438
+ self,
439
+ name: str,
440
+ loaded_weight: torch.Tensor,
441
+ params_dict: dict,
442
+ num_experts: int,
443
+ ) -> bool:
444
+ """Handle actual weight tensors for expert layers (gate_up_proj and down_proj).
445
+
446
+ Args:
447
+ name: Parameter name (should contain gate_up_proj or down_proj)
448
+ loaded_weight: Weight tensor(s) to be loaded
449
+ params_dict: Dictionary of model parameters
450
+ num_experts: Total number of experts for tensor distribution
451
+
452
+ Returns:
453
+ bool: True (always handles weight parameters)
454
+ """
455
+ # Transform name and get shard info
456
+ transformed_name, _, shard_id_list = self._transform_expert_name(
457
+ name, is_weight=True
458
+ )
459
+
460
+ if ".gate_up_proj" in name:
461
+ loaded_weight_list = loaded_weight.chunk(2, dim=-1)
462
+ else: # down_proj
463
+ loaded_weight_list = [loaded_weight]
464
+
465
+ for param_name, weight_chunk, shard_id in zip(
466
+ [transformed_name] * len(shard_id_list), loaded_weight_list, shard_id_list
467
+ ):
468
+ if param_name not in params_dict:
469
+ continue
470
+
471
+ param = params_dict[param_name]
472
+ weight_loader = param.weight_loader
473
+
474
+ # Handle the case where loaded_weight might be a single tensor for all experts
475
+ if weight_chunk.dim() == 2:
476
+ # Single tensor case - load for all experts
477
+ for expert_id in range(num_experts):
478
+ weight_loader(
479
+ param,
480
+ weight_chunk.T,
481
+ param_name,
482
+ shard_id=shard_id,
483
+ expert_id=expert_id,
484
+ )
166
485
  else:
167
- if ".experts" in name:
168
- # NOTE: llama4 fp8 has different weight format for experts
169
- if (
170
- "experts.gate_up_proj" not in name
171
- and "experts.down_proj" not in name
172
- ):
173
- for mapping in expert_params_mapping:
174
- param_name, weight_name, expert_id, shard_id = mapping
175
- if weight_name not in name:
176
- continue
177
- name = name.replace(weight_name, param_name)
178
- param = params_dict[name]
179
- weight_loader = param.weight_loader
180
- weight_loader(
181
- param,
182
- loaded_weight,
183
- name,
184
- shard_id=shard_id,
185
- expert_id=expert_id,
186
- )
187
- break
188
- else:
189
- if ".gate_up_proj" in name:
190
- name_list = [
191
- name.replace(
192
- ".experts.gate_up_proj", ".experts.w13_weight"
193
- )
194
- ] * 2
195
- loaded_weight_list = loaded_weight.chunk(2, dim=-1)
196
- shard_id_list = ["w1", "w3"]
197
- else:
198
- name_list = [
199
- name.replace(".experts.down_proj", ".experts.w2_weight")
200
- ]
201
- shard_id_list = ["w2"]
202
- loaded_weight_list = [loaded_weight]
203
- for name, loaded_weight, shard_id in zip(
204
- name_list, loaded_weight_list, shard_id_list
205
- ):
206
- param = params_dict[name]
207
- weight_loader = param.weight_loader
208
- for expert_id in range(num_experts):
209
- weight_loader(
210
- param,
211
- loaded_weight[expert_id].T,
212
- name,
213
- shard_id=shard_id,
214
- expert_id=expert_id,
215
- )
216
- else:
217
- # Skip loading extra bias for GPTQ models.
218
- if name.endswith(".bias") and name not in params_dict:
219
- continue
220
- param = params_dict[name]
221
- weight_loader = getattr(
222
- param, "weight_loader", default_weight_loader
486
+ # Multiple experts case - load each expert's weights
487
+ for expert_id in range(num_experts):
488
+ weight_loader(
489
+ param,
490
+ weight_chunk[expert_id].T,
491
+ param_name,
492
+ shard_id=shard_id,
493
+ expert_id=expert_id,
223
494
  )
224
- weight_loader(param, loaded_weight)
495
+
496
+ return True
497
+
498
+ def _handle_default_weight(
499
+ self, name: str, loaded_weight: torch.Tensor, params_dict: dict
500
+ ):
501
+ """Handle default weight loading."""
502
+ # Skip loading extra bias for GPTQ models
503
+ if name.endswith(".bias") and name not in params_dict:
504
+ return
505
+
506
+ param = params_dict[name]
507
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
508
+ weight_loader(param, loaded_weight)
509
+
510
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
511
+ if hasattr(self.language_model, "set_eagle3_layers_to_capture"):
512
+ self.language_model.set_eagle3_layers_to_capture(layer_ids)
513
+
514
+ def get_embed_and_head(self):
515
+ # For EAGLE3, we delegate to the language model which should have this method
516
+ # If the language model doesn't have lm_head (like EAGLE3), we return None for head
517
+ embed = self.language_model.get_embed()
518
+ if hasattr(self.language_model, "get_embed_and_head"):
519
+ return self.language_model.get_embed_and_head()
520
+ elif hasattr(self.language_model, "lm_head"):
521
+ return embed, self.language_model.lm_head.weight
522
+ else:
523
+ # For EAGLE3, head might not be needed
524
+ return embed, None
525
+
526
+ def set_embed_and_head(self, embed, head):
527
+ if hasattr(self.language_model, "set_embed_and_head"):
528
+ return self.language_model.set_embed_and_head(embed, head)
529
+ else:
530
+ # For EAGLE3, only set embed
531
+ return self.language_model.set_embed(embed)
532
+
533
+ def get_embed(self):
534
+ return self.language_model.get_embed()
535
+
536
+ def set_embed(self, embed):
537
+ return self.language_model.set_embed(embed)
225
538
 
226
539
 
227
540
  EntryClass = Llama4ForConditionalGeneration
@@ -446,9 +446,7 @@ class Phi4MMForCausalLM(nn.Module):
446
446
  return hidden_states
447
447
 
448
448
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
449
- # Get all special token IDs
450
- im_token_id: int = mm_inputs.im_token_id
451
- pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
449
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
452
450
  return pattern.pad_input_tokens(input_ids, mm_inputs)
453
451
 
454
452
  def should_apply_lora(self, module_name: str) -> bool:
@@ -268,15 +268,14 @@ class PixtralHFVisionModel(nn.Module):
268
268
 
269
269
  DEFAULT_IMAGE_TOKEN_ID = 10
270
270
 
271
- def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
272
- return self.input_padder.pad_input_tokens(input_ids, image_inputs)
271
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
272
+ return self.input_padder.pad_input_tokens(input_ids, mm_inputs)
273
273
 
274
274
  def __init__(
275
275
  self,
276
276
  config: PixtralVisionConfig,
277
277
  quant_config: Optional[QuantizationConfig] = None,
278
278
  *,
279
- image_token_id: int = DEFAULT_IMAGE_TOKEN_ID,
280
279
  num_hidden_layers_override: Optional[int] = None,
281
280
  prefix: str = "",
282
281
  ) -> None:
@@ -314,11 +313,8 @@ class PixtralHFVisionModel(nn.Module):
314
313
  )
315
314
 
316
315
  # Initialize patch position embedding
317
- self.image_token_id = image_token_id
318
316
  self.patch_positional_embedding = PixtralRotaryEmbedding(config)
319
- self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens(
320
- [self.image_token_id]
321
- )
317
+ self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens()
322
318
 
323
319
  @property
324
320
  def dtype(self):