sglang 0.4.9.post6__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +8 -0
  3. sglang/srt/configs/model_config.py +3 -0
  4. sglang/srt/configs/step3_vl.py +172 -0
  5. sglang/srt/conversation.py +23 -0
  6. sglang/srt/disaggregation/decode.py +2 -8
  7. sglang/srt/disaggregation/prefill.py +2 -6
  8. sglang/srt/distributed/parallel_state.py +86 -1
  9. sglang/srt/entrypoints/engine.py +14 -18
  10. sglang/srt/entrypoints/http_server.py +10 -2
  11. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  12. sglang/srt/eplb/expert_distribution.py +5 -0
  13. sglang/srt/eplb/expert_location.py +17 -6
  14. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  15. sglang/srt/eplb/expert_location_updater.py +2 -0
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/step3_detector.py +436 -0
  18. sglang/srt/hf_transformers_utils.py +2 -0
  19. sglang/srt/jinja_template_utils.py +4 -1
  20. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  21. sglang/srt/layers/moe/ep_moe/layer.py +20 -640
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
  24. sglang/srt/layers/quantization/fp8.py +0 -18
  25. sglang/srt/layers/quantization/unquant.py +0 -8
  26. sglang/srt/layers/quantization/w4afp8.py +1 -0
  27. sglang/srt/managers/cache_controller.py +143 -45
  28. sglang/srt/managers/data_parallel_controller.py +2 -0
  29. sglang/srt/managers/io_struct.py +0 -2
  30. sglang/srt/managers/scheduler.py +89 -671
  31. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  32. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  33. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  34. sglang/srt/managers/template_manager.py +62 -19
  35. sglang/srt/managers/tokenizer_manager.py +123 -74
  36. sglang/srt/managers/tp_worker.py +4 -0
  37. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  38. sglang/srt/mem_cache/hicache_storage.py +45 -11
  39. sglang/srt/mem_cache/hiradix_cache.py +15 -4
  40. sglang/srt/mem_cache/memory_pool_host.py +73 -1
  41. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  42. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  43. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
  44. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  45. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  46. sglang/srt/model_executor/model_runner.py +5 -0
  47. sglang/srt/models/arcee.py +532 -0
  48. sglang/srt/models/deepseek_v2.py +2 -0
  49. sglang/srt/models/glm4_moe.py +3 -1
  50. sglang/srt/models/granitemoe.py +3 -0
  51. sglang/srt/models/grok.py +3 -0
  52. sglang/srt/models/hunyuan.py +1 -0
  53. sglang/srt/models/llama4.py +3 -0
  54. sglang/srt/models/mixtral.py +3 -0
  55. sglang/srt/models/olmoe.py +3 -0
  56. sglang/srt/models/phimoe.py +1 -0
  57. sglang/srt/models/step3_vl.py +994 -0
  58. sglang/srt/multimodal/processors/base_processor.py +15 -16
  59. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  60. sglang/srt/reasoning_parser.py +2 -1
  61. sglang/srt/server_args.py +10 -13
  62. sglang/srt/speculative/eagle_worker.py +2 -0
  63. sglang/utils.py +0 -11
  64. sglang/version.py +1 -1
  65. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/METADATA +3 -4
  66. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/RECORD +69 -56
  67. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
  68. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
  69. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,532 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """Inference-only Arcee Foundational Model (AFM) compatible with HuggingFace weights."""
15
+
16
+ import logging
17
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ from torch import nn
21
+ from transformers import LlamaConfig
22
+
23
+ from sglang.srt.distributed import (
24
+ get_pp_group,
25
+ get_tensor_model_parallel_rank,
26
+ get_tensor_model_parallel_world_size,
27
+ )
28
+ from sglang.srt.layers.activation import get_act_fn
29
+ from sglang.srt.layers.layernorm import RMSNorm
30
+ from sglang.srt.layers.linear import (
31
+ ColumnParallelLinear,
32
+ QKVParallelLinear,
33
+ RowParallelLinear,
34
+ )
35
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
36
+ from sglang.srt.layers.pooler import Pooler, PoolingType
37
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
+ from sglang.srt.layers.radix_attention import RadixAttention
39
+ from sglang.srt.layers.rotary_embedding import get_rope
40
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
41
+ from sglang.srt.layers.vocab_parallel_embedding import (
42
+ ParallelLMHead,
43
+ VocabParallelEmbedding,
44
+ )
45
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
46
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
47
+ from sglang.srt.model_loader.weight_utils import (
48
+ default_weight_loader,
49
+ kv_cache_scales_loader,
50
+ maybe_remap_kv_scale_name,
51
+ )
52
+ from sglang.srt.utils import add_prefix, make_layers
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+
57
+ class ArceeMLP(nn.Module):
58
+ """
59
+ MLP block for the Arcee model, using a ReLU-squared activation function.
60
+ This differs from the Llama SwiGLU activation.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ hidden_size: int,
66
+ intermediate_size: int,
67
+ hidden_act: str,
68
+ quant_config: Optional[QuantizationConfig] = None,
69
+ prefix: str = "",
70
+ reduce_results: bool = True,
71
+ ) -> None:
72
+ super().__init__()
73
+ # Arcee uses a single up-projection, not a merged gate/up projection.
74
+ self.up_proj = ColumnParallelLinear(
75
+ hidden_size,
76
+ intermediate_size,
77
+ bias=False,
78
+ quant_config=quant_config,
79
+ prefix=add_prefix("up_proj", prefix),
80
+ )
81
+ self.down_proj = RowParallelLinear(
82
+ intermediate_size,
83
+ hidden_size,
84
+ bias=False,
85
+ quant_config=quant_config,
86
+ prefix=add_prefix("down_proj", prefix),
87
+ reduce_results=reduce_results,
88
+ )
89
+ if hidden_act != "relu2":
90
+ raise ValueError(
91
+ f"Unsupported activation: {hidden_act}. "
92
+ "Arcee model in SGLang only supports 'relu2'."
93
+ )
94
+ # The activation function is relu(x)^2
95
+ self.act_fn = get_act_fn("relu2")
96
+
97
+ def forward(self, x, forward_batch=None):
98
+ x, _ = self.up_proj(x)
99
+ x = self.act_fn(x)
100
+ x, _ = self.down_proj(x)
101
+ return x
102
+
103
+
104
+ class ArceeAttention(nn.Module):
105
+ def __init__(
106
+ self,
107
+ config: LlamaConfig,
108
+ hidden_size: int,
109
+ num_heads: int,
110
+ num_kv_heads: int,
111
+ layer_id: int = 0,
112
+ rope_theta: float = 10000,
113
+ rope_scaling: Optional[Dict[str, Any]] = None,
114
+ rope_is_neox_style: bool = True,
115
+ max_position_embeddings: int = 8192,
116
+ quant_config: Optional[QuantizationConfig] = None,
117
+ prefix: str = "",
118
+ bias: bool = False,
119
+ ) -> None:
120
+ super().__init__()
121
+ self.hidden_size = hidden_size
122
+ tp_size = get_tensor_model_parallel_world_size()
123
+ self.total_num_heads = num_heads
124
+ assert self.total_num_heads % tp_size == 0
125
+ self.num_heads = self.total_num_heads // tp_size
126
+ self.total_num_kv_heads = num_kv_heads
127
+ if self.total_num_kv_heads >= tp_size:
128
+ assert self.total_num_kv_heads % tp_size == 0
129
+ else:
130
+ assert tp_size % self.total_num_kv_heads == 0
131
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
132
+ self.head_dim = getattr(config, "head_dim", None)
133
+ if self.head_dim is None:
134
+ self.head_dim = self.hidden_size // self.total_num_heads
135
+ self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
136
+ self.rotary_dim = int(self.partial_rotary_factor * self.head_dim)
137
+ self.q_size = self.num_heads * self.head_dim
138
+ self.kv_size = self.num_kv_heads * self.head_dim
139
+ self.scaling = self.head_dim**-0.5
140
+ self.rope_theta = rope_theta
141
+ self.max_position_embeddings = max_position_embeddings
142
+
143
+ self.qkv_proj = QKVParallelLinear(
144
+ hidden_size,
145
+ self.head_dim,
146
+ self.total_num_heads,
147
+ self.total_num_kv_heads,
148
+ bias=bias,
149
+ quant_config=quant_config,
150
+ prefix=add_prefix("qkv_proj", prefix),
151
+ )
152
+ self.o_proj = RowParallelLinear(
153
+ self.total_num_heads * self.head_dim,
154
+ hidden_size,
155
+ bias=bias,
156
+ quant_config=quant_config,
157
+ prefix=add_prefix("o_proj", prefix),
158
+ )
159
+
160
+ self.rotary_emb = get_rope(
161
+ self.head_dim,
162
+ rotary_dim=self.rotary_dim,
163
+ max_position=max_position_embeddings,
164
+ base=rope_theta,
165
+ rope_scaling=rope_scaling,
166
+ is_neox_style=rope_is_neox_style,
167
+ )
168
+ self.attn = RadixAttention(
169
+ self.num_heads,
170
+ self.head_dim,
171
+ self.scaling,
172
+ num_kv_heads=self.num_kv_heads,
173
+ layer_id=layer_id,
174
+ quant_config=quant_config,
175
+ prefix=add_prefix("attn", prefix),
176
+ )
177
+
178
+ def forward(
179
+ self,
180
+ positions: torch.Tensor,
181
+ hidden_states: torch.Tensor,
182
+ forward_batch: ForwardBatch,
183
+ ) -> torch.Tensor:
184
+ qkv, _ = self.qkv_proj(hidden_states)
185
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
186
+ q, k = self.rotary_emb(positions, q, k)
187
+ attn_output = self.attn(q, k, v, forward_batch)
188
+ output, _ = self.o_proj(attn_output)
189
+ return output
190
+
191
+
192
+ class ArceeDecoderLayer(nn.Module):
193
+ def __init__(
194
+ self,
195
+ config: LlamaConfig,
196
+ layer_id: int = 0,
197
+ quant_config: Optional[QuantizationConfig] = None,
198
+ prefix: str = "",
199
+ ) -> None:
200
+ super().__init__()
201
+ self.hidden_size = config.hidden_size
202
+ rope_theta = getattr(config, "rope_theta", 10000)
203
+ rope_scaling = getattr(config, "rope_scaling", None)
204
+ if rope_scaling is not None and getattr(
205
+ config, "original_max_position_embeddings", None
206
+ ):
207
+ rope_scaling["original_max_position_embeddings"] = (
208
+ config.original_max_position_embeddings
209
+ )
210
+ rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
211
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
212
+ attention_bias = getattr(config, "attention_bias", False) or getattr(
213
+ config, "bias", False
214
+ )
215
+ self.self_attn = ArceeAttention(
216
+ config=config,
217
+ hidden_size=self.hidden_size,
218
+ num_heads=config.num_attention_heads,
219
+ num_kv_heads=config.num_key_value_heads,
220
+ layer_id=layer_id,
221
+ rope_theta=rope_theta,
222
+ rope_scaling=rope_scaling,
223
+ rope_is_neox_style=rope_is_neox_style,
224
+ max_position_embeddings=max_position_embeddings,
225
+ quant_config=quant_config,
226
+ prefix=add_prefix("self_attn", prefix),
227
+ bias=attention_bias,
228
+ )
229
+ self.mlp = ArceeMLP(
230
+ hidden_size=self.hidden_size,
231
+ intermediate_size=config.intermediate_size,
232
+ hidden_act=config.hidden_act,
233
+ quant_config=quant_config,
234
+ prefix=add_prefix("mlp", prefix),
235
+ )
236
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
237
+ self.post_attention_layernorm = RMSNorm(
238
+ config.hidden_size, eps=config.rms_norm_eps
239
+ )
240
+
241
+ def forward(
242
+ self,
243
+ positions: torch.Tensor,
244
+ hidden_states: torch.Tensor,
245
+ forward_batch: ForwardBatch,
246
+ residual: Optional[torch.Tensor],
247
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
248
+ # Self Attention
249
+ if residual is None:
250
+ residual = hidden_states
251
+ hidden_states = self.input_layernorm(hidden_states)
252
+ else:
253
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
254
+ hidden_states = self.self_attn(
255
+ positions=positions,
256
+ hidden_states=hidden_states,
257
+ forward_batch=forward_batch,
258
+ )
259
+
260
+ # Fully Connected
261
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
262
+ hidden_states = self.mlp(hidden_states)
263
+ return hidden_states, residual
264
+
265
+
266
+ class ArceeModel(nn.Module):
267
+ def __init__(
268
+ self,
269
+ config: LlamaConfig,
270
+ quant_config: Optional[QuantizationConfig] = None,
271
+ prefix: str = "",
272
+ ) -> None:
273
+ super().__init__()
274
+ self.config = config
275
+ self.padding_idx = config.pad_token_id
276
+ self.vocab_size = config.vocab_size
277
+ self.pp_group = get_pp_group()
278
+ if self.pp_group.is_first_rank:
279
+ self.embed_tokens = VocabParallelEmbedding(
280
+ config.vocab_size,
281
+ config.hidden_size,
282
+ quant_config=quant_config,
283
+ prefix=add_prefix("embed_tokens", prefix),
284
+ )
285
+ else:
286
+ self.embed_tokens = PPMissingLayer()
287
+
288
+ self.layers, self.start_layer, self.end_layer = make_layers(
289
+ config.num_hidden_layers,
290
+ lambda idx, prefix: ArceeDecoderLayer(
291
+ config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
292
+ ),
293
+ pp_rank=self.pp_group.rank_in_group,
294
+ pp_size=self.pp_group.world_size,
295
+ prefix="model.layers",
296
+ )
297
+
298
+ if self.pp_group.is_last_rank:
299
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
300
+ else:
301
+ self.norm = PPMissingLayer(return_tuple=True)
302
+ self.layers_to_capture = []
303
+
304
+ def forward(
305
+ self,
306
+ input_ids: torch.Tensor,
307
+ positions: torch.Tensor,
308
+ forward_batch: ForwardBatch,
309
+ input_embeds: torch.Tensor = None,
310
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
311
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]], PPProxyTensors]:
312
+ if self.pp_group.is_first_rank:
313
+ if input_embeds is None:
314
+ hidden_states = self.embed_tokens(input_ids)
315
+ else:
316
+ hidden_states = input_embeds
317
+ residual = None
318
+ else:
319
+ assert pp_proxy_tensors is not None
320
+ hidden_states = pp_proxy_tensors["hidden_states"]
321
+ residual = pp_proxy_tensors["residual"]
322
+
323
+ aux_hidden_states = []
324
+ for i in range(self.start_layer, self.end_layer):
325
+ if i in self.layers_to_capture:
326
+ aux_hidden_states.append(hidden_states + residual)
327
+ layer = self.layers[i]
328
+ hidden_states, residual = layer(
329
+ positions,
330
+ hidden_states,
331
+ forward_batch,
332
+ residual,
333
+ )
334
+
335
+ if not self.pp_group.is_last_rank:
336
+ return PPProxyTensors(
337
+ {
338
+ "hidden_states": hidden_states,
339
+ "residual": residual,
340
+ }
341
+ )
342
+ else:
343
+ hidden_states, _ = self.norm(hidden_states, residual)
344
+
345
+ if len(aux_hidden_states) == 0:
346
+ return hidden_states
347
+
348
+ return hidden_states, aux_hidden_states
349
+
350
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
351
+ tp_size = get_tensor_model_parallel_world_size()
352
+ tp_rank = get_tensor_model_parallel_rank()
353
+ for layer_idx, scaling_factor in kv_cache_scales_loader(
354
+ quantization_param_path,
355
+ tp_rank,
356
+ tp_size,
357
+ self.config.num_hidden_layers,
358
+ self.config.__class__.model_type,
359
+ ):
360
+ if not isinstance(self.layers[layer_idx], nn.Identity):
361
+ layer_self_attn = self.layers[layer_idx].self_attn
362
+
363
+ if hasattr(layer_self_attn.attn, "k_scale"):
364
+ layer_self_attn.attn.k_scale = scaling_factor
365
+ layer_self_attn.attn.v_scale = scaling_factor
366
+ else:
367
+ raise RuntimeError(
368
+ "Self attention has no KV cache scaling factor attribute!"
369
+ )
370
+
371
+
372
+ class ArceeForCausalLM(nn.Module):
373
+ # BitandBytes specific attributes
374
+ default_bitsandbytes_target_modules = [
375
+ # Note: gate_proj is removed compared to Llama
376
+ ".down_proj.",
377
+ ".up_proj.",
378
+ ".q_proj.",
379
+ ".k_proj.",
380
+ ".v_proj.",
381
+ ".o_proj.",
382
+ ]
383
+ # in TP, these weights are partitioned along the column dimension (dim=-1)
384
+ column_parallel_weights_modules = [".down_proj.", ".o_proj."]
385
+ bitsandbytes_stacked_params_mapping = {
386
+ # shard_name, weight_name, index
387
+ # Note: gate_proj and up_proj are removed as they are not stacked in ArceeMLP
388
+ ".q_proj": (".qkv_proj", 0),
389
+ ".k_proj": (".qkv_proj", 1),
390
+ ".v_proj": (".qkv_proj", 2),
391
+ }
392
+
393
+ def __init__(
394
+ self,
395
+ config: LlamaConfig,
396
+ quant_config: Optional[QuantizationConfig] = None,
397
+ prefix: str = "",
398
+ ) -> None:
399
+ super().__init__()
400
+ self.pp_group = get_pp_group()
401
+ self.config = config
402
+ self.quant_config = quant_config
403
+ self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
404
+ # Arcee does not tie word embeddings
405
+ self.lm_head = ParallelLMHead(
406
+ config.vocab_size,
407
+ config.hidden_size,
408
+ quant_config=quant_config,
409
+ prefix=add_prefix("lm_head", prefix),
410
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
411
+ )
412
+ self.logits_processor = LogitsProcessor(config)
413
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
414
+ # Parameters that are stacked in a single tensor in this model
415
+ self.stacked_params_mapping = [
416
+ # (param_name, shard_name, shard_id)
417
+ (".qkv_proj", ".q_proj", "q"),
418
+ (".qkv_proj", ".k_proj", "k"),
419
+ (".qkv_proj", ".v_proj", "v"),
420
+ ]
421
+ self.capture_aux_hidden_states = False
422
+
423
+ def _init_model(
424
+ self,
425
+ config: LlamaConfig,
426
+ quant_config: Optional[QuantizationConfig] = None,
427
+ prefix: str = "",
428
+ ):
429
+ return ArceeModel(config, quant_config=quant_config, prefix=prefix)
430
+
431
+ @torch.no_grad()
432
+ def forward(
433
+ self,
434
+ input_ids: torch.Tensor,
435
+ positions: torch.Tensor,
436
+ forward_batch: ForwardBatch,
437
+ input_embeds: torch.Tensor = None,
438
+ get_embedding: bool = False,
439
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
440
+ ) -> LogitsProcessorOutput:
441
+ hidden_states = self.model(
442
+ input_ids,
443
+ positions,
444
+ forward_batch,
445
+ input_embeds,
446
+ pp_proxy_tensors=pp_proxy_tensors,
447
+ )
448
+
449
+ aux_hidden_states = None
450
+ if self.capture_aux_hidden_states:
451
+ hidden_states, aux_hidden_states = hidden_states
452
+
453
+ if self.pp_group.is_last_rank:
454
+ if not get_embedding:
455
+ return self.logits_processor(
456
+ input_ids,
457
+ hidden_states,
458
+ self.lm_head,
459
+ forward_batch,
460
+ aux_hidden_states,
461
+ )
462
+ else:
463
+ return self.pooler(hidden_states, forward_batch)
464
+ else:
465
+ return hidden_states
466
+
467
+ @property
468
+ def start_layer(self):
469
+ return self.model.start_layer
470
+
471
+ @property
472
+ def end_layer(self):
473
+ return self.model.end_layer
474
+
475
+ def get_input_embeddings(self) -> nn.Embedding:
476
+ return self.model.embed_tokens
477
+
478
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
479
+ params_dict = dict(self.named_parameters())
480
+
481
+ for name, loaded_weight in weights:
482
+ layer_id = get_layer_id(name)
483
+ if (
484
+ layer_id is not None
485
+ and hasattr(self.model, "start_layer")
486
+ and (
487
+ layer_id < self.model.start_layer
488
+ or layer_id >= self.model.end_layer
489
+ )
490
+ ):
491
+ continue
492
+ if "rotary_emb.inv_freq" in name or "projector" in name:
493
+ continue
494
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
495
+ continue
496
+
497
+ # Handle FP8 kv-scale remapping
498
+ if "scale" in name:
499
+ name = maybe_remap_kv_scale_name(name, params_dict)
500
+ if name is None:
501
+ continue
502
+
503
+ is_stacked = False
504
+ for param_name, weight_name, shard_id in self.stacked_params_mapping:
505
+ if weight_name not in name:
506
+ continue
507
+
508
+ name = name.replace(weight_name, param_name)
509
+ if name not in params_dict:
510
+ continue
511
+
512
+ param = params_dict[name]
513
+ weight_loader = param.weight_loader
514
+ weight_loader(param, loaded_weight, shard_id)
515
+ is_stacked = True
516
+ break
517
+
518
+ if not is_stacked:
519
+ if name in params_dict:
520
+ param = params_dict[name]
521
+ weight_loader = getattr(
522
+ param, "weight_loader", default_weight_loader
523
+ )
524
+ weight_loader(param, loaded_weight)
525
+ else:
526
+ logger.warning(f"Parameter {name} not found in model.")
527
+
528
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
529
+ self.model.load_kv_cache_scales(quantization_param_path)
530
+
531
+
532
+ EntryClass = [ArceeForCausalLM]
@@ -325,6 +325,7 @@ class DeepseekV2MoE(nn.Module):
325
325
  num_experts=config.n_routed_experts
326
326
  + self.num_fused_shared_experts
327
327
  + global_server_args_dict["ep_num_redundant_experts"],
328
+ num_fused_shared_experts=self.num_fused_shared_experts,
328
329
  top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
329
330
  hidden_size=config.hidden_size,
330
331
  intermediate_size=config.moe_intermediate_size,
@@ -2112,6 +2113,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2112
2113
 
2113
2114
  if disable_reason is not None:
2114
2115
  global_server_args_dict["disable_shared_experts_fusion"] = True
2116
+ self.num_fused_shared_experts = 0
2115
2117
  log_info_on_rank0(
2116
2118
  logger,
2117
2119
  f"{disable_reason} Shared experts fusion optimization is disabled.",
@@ -434,6 +434,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
434
434
  num_experts=config.n_routed_experts
435
435
  + self.num_fused_shared_experts
436
436
  + global_server_args_dict["ep_num_redundant_experts"],
437
+ num_fused_shared_experts=self.num_fused_shared_experts,
437
438
  top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
438
439
  hidden_size=config.hidden_size,
439
440
  intermediate_size=config.moe_intermediate_size,
@@ -740,10 +741,11 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
740
741
  global_server_args_dict["enable_deepep_moe"]
741
742
  or global_server_args_dict["enable_ep_moe"]
742
743
  ):
743
- disable_reason = "Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
744
+ disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
744
745
 
745
746
  if disable_reason is not None:
746
747
  global_server_args_dict["disable_shared_experts_fusion"] = True
748
+ self.num_fused_shared_experts = 0
747
749
  log_info_on_rank0(
748
750
  logger,
749
751
  f"{disable_reason} Shared experts fusion optimization is disabled.",
@@ -43,6 +43,7 @@ class GraniteMoeMoE(nn.Module):
43
43
  top_k: int,
44
44
  hidden_size: int,
45
45
  intermediate_size: int,
46
+ layer_id: int,
46
47
  params_dtype: Optional[torch.dtype] = None,
47
48
  quant_config: Optional[QuantizationConfig] = None,
48
49
  tp_size: Optional[int] = None,
@@ -71,6 +72,7 @@ class GraniteMoeMoE(nn.Module):
71
72
  top_k=top_k,
72
73
  hidden_size=hidden_size,
73
74
  intermediate_size=intermediate_size,
75
+ layer_id=layer_id,
74
76
  params_dtype=params_dtype,
75
77
  reduce_results=True,
76
78
  quant_config=quant_config,
@@ -203,6 +205,7 @@ class GraniteMoeDecoderLayer(nn.Module):
203
205
  top_k=config.num_experts_per_tok,
204
206
  hidden_size=config.hidden_size,
205
207
  intermediate_size=config.intermediate_size,
208
+ layer_id=layer_id,
206
209
  quant_config=quant_config,
207
210
  prefix=f"{prefix}.block_sparse_moe",
208
211
  )
sglang/srt/models/grok.py CHANGED
@@ -78,6 +78,7 @@ class Grok1MoE(nn.Module):
78
78
  def __init__(
79
79
  self,
80
80
  config: PretrainedConfig,
81
+ layer_id: int,
81
82
  num_experts: int,
82
83
  top_k: int,
83
84
  hidden_size: int,
@@ -128,6 +129,7 @@ class Grok1MoE(nn.Module):
128
129
  self.experts = MoEImpl(
129
130
  num_experts=num_experts,
130
131
  top_k=top_k,
132
+ layer_id=layer_id,
131
133
  hidden_size=hidden_size,
132
134
  intermediate_size=intermediate_size,
133
135
  params_dtype=params_dtype,
@@ -331,6 +333,7 @@ class Grok1DecoderLayer(nn.Module):
331
333
  )
332
334
  self.block_sparse_moe = Grok1MoE(
333
335
  config=config,
336
+ layer_id=layer_id,
334
337
  num_experts=config.num_local_experts,
335
338
  top_k=config.num_experts_per_tok,
336
339
  hidden_size=config.hidden_size,
@@ -163,6 +163,7 @@ class HunYuanSparseMoeBlock(nn.Module):
163
163
  hidden_size=config.hidden_size,
164
164
  intermediate_size=intermediate_size,
165
165
  reduce_results=False,
166
+ layer_id=layer_id,
166
167
  quant_config=quant_config,
167
168
  )
168
169
 
@@ -87,6 +87,7 @@ class Llama4MoE(nn.Module):
87
87
  def __init__(
88
88
  self,
89
89
  config: Llama4TextConfig,
90
+ layer_id: int,
90
91
  quant_config: Optional[QuantizationConfig] = None,
91
92
  prefix: str = "",
92
93
  ):
@@ -114,6 +115,7 @@ class Llama4MoE(nn.Module):
114
115
  num_experts=config.num_local_experts,
115
116
  hidden_size=config.hidden_size,
116
117
  intermediate_size=intermediate_size_moe,
118
+ layer_id=layer_id,
117
119
  reduce_results=False,
118
120
  quant_config=quant_config,
119
121
  apply_router_weight_on_input=True,
@@ -373,6 +375,7 @@ class Llama4DecoderLayer(nn.Module):
373
375
  if is_moe_layer:
374
376
  self.feed_forward = Llama4MoE(
375
377
  config=config,
378
+ layer_id=layer_id,
376
379
  quant_config=quant_config,
377
380
  prefix=add_prefix("feed_forward", prefix),
378
381
  )
@@ -69,6 +69,7 @@ class MixtralMoE(nn.Module):
69
69
  top_k: int,
70
70
  hidden_size: int,
71
71
  intermediate_size: int,
72
+ layer_id: int,
72
73
  params_dtype: Optional[torch.dtype] = None,
73
74
  quant_config: Optional[QuantizationConfig] = None,
74
75
  tp_size: Optional[int] = None,
@@ -97,6 +98,7 @@ class MixtralMoE(nn.Module):
97
98
  self.experts = MoEImpl(
98
99
  num_experts=num_experts,
99
100
  top_k=top_k,
101
+ layer_id=layer_id,
100
102
  hidden_size=hidden_size,
101
103
  intermediate_size=intermediate_size,
102
104
  params_dtype=params_dtype,
@@ -226,6 +228,7 @@ class MixtralDecoderLayer(nn.Module):
226
228
  top_k=config.num_experts_per_tok,
227
229
  hidden_size=config.hidden_size,
228
230
  intermediate_size=config.intermediate_size,
231
+ layer_id=layer_id,
229
232
  quant_config=quant_config,
230
233
  prefix=add_prefix("block_sparse_moe", prefix),
231
234
  )
@@ -63,6 +63,7 @@ class OlmoeMoE(nn.Module):
63
63
  params_dtype: Optional[torch.dtype] = None,
64
64
  quant_config: Optional[QuantizationConfig] = None,
65
65
  tp_size: Optional[int] = None,
66
+ layer_id: int = 0,
66
67
  prefix: str = "",
67
68
  ):
68
69
  super().__init__()
@@ -89,6 +90,7 @@ class OlmoeMoE(nn.Module):
89
90
  reduce_results=True,
90
91
  quant_config=quant_config,
91
92
  tp_size=tp_size,
93
+ layer_id=layer_id,
92
94
  prefix=add_prefix("experts", prefix),
93
95
  )
94
96
 
@@ -224,6 +226,7 @@ class OlmoeDecoderLayer(nn.Module):
224
226
  top_k=config.num_experts_per_tok,
225
227
  hidden_size=config.hidden_size,
226
228
  intermediate_size=config.intermediate_size,
229
+ layer_id=layer_id,
227
230
  quant_config=quant_config,
228
231
  prefix=add_prefix("mlp", prefix),
229
232
  )