sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/bench_one_batch.py +3 -0
  3. sglang/srt/configs/__init__.py +8 -0
  4. sglang/srt/configs/model_config.py +4 -0
  5. sglang/srt/configs/step3_vl.py +172 -0
  6. sglang/srt/conversation.py +23 -0
  7. sglang/srt/disaggregation/decode.py +2 -8
  8. sglang/srt/disaggregation/launch_lb.py +5 -20
  9. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  10. sglang/srt/disaggregation/prefill.py +2 -6
  11. sglang/srt/distributed/parallel_state.py +86 -1
  12. sglang/srt/entrypoints/engine.py +14 -18
  13. sglang/srt/entrypoints/http_server.py +10 -2
  14. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  15. sglang/srt/eplb/expert_distribution.py +5 -0
  16. sglang/srt/eplb/expert_location.py +17 -6
  17. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  18. sglang/srt/eplb/expert_location_updater.py +2 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/step3_detector.py +436 -0
  21. sglang/srt/hf_transformers_utils.py +2 -0
  22. sglang/srt/jinja_template_utils.py +4 -1
  23. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  24. sglang/srt/layers/attention/utils.py +6 -1
  25. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +39 -674
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
  29. sglang/srt/layers/quantization/fp8.py +52 -18
  30. sglang/srt/layers/quantization/unquant.py +0 -8
  31. sglang/srt/layers/quantization/w4afp8.py +1 -0
  32. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  33. sglang/srt/managers/cache_controller.py +165 -67
  34. sglang/srt/managers/data_parallel_controller.py +2 -0
  35. sglang/srt/managers/io_struct.py +0 -2
  36. sglang/srt/managers/scheduler.py +90 -671
  37. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  38. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  39. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  40. sglang/srt/managers/template_manager.py +62 -19
  41. sglang/srt/managers/tokenizer_manager.py +123 -74
  42. sglang/srt/managers/tp_worker.py +4 -0
  43. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  44. sglang/srt/mem_cache/hicache_storage.py +60 -17
  45. sglang/srt/mem_cache/hiradix_cache.py +36 -8
  46. sglang/srt/mem_cache/memory_pool.py +15 -118
  47. sglang/srt/mem_cache/memory_pool_host.py +418 -29
  48. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  49. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  50. sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
  51. sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
  52. sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
  53. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
  54. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  55. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  56. sglang/srt/model_executor/cuda_graph_runner.py +25 -1
  57. sglang/srt/model_executor/model_runner.py +13 -1
  58. sglang/srt/model_loader/weight_utils.py +2 -0
  59. sglang/srt/models/arcee.py +532 -0
  60. sglang/srt/models/deepseek_v2.py +7 -6
  61. sglang/srt/models/glm4_moe.py +6 -4
  62. sglang/srt/models/granitemoe.py +3 -0
  63. sglang/srt/models/grok.py +3 -0
  64. sglang/srt/models/hunyuan.py +1 -0
  65. sglang/srt/models/llama4.py +3 -0
  66. sglang/srt/models/mixtral.py +3 -0
  67. sglang/srt/models/olmoe.py +3 -0
  68. sglang/srt/models/phimoe.py +1 -0
  69. sglang/srt/models/step3_vl.py +991 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/reasoning_parser.py +2 -1
  73. sglang/srt/server_args.py +49 -18
  74. sglang/srt/speculative/eagle_worker.py +2 -0
  75. sglang/srt/utils.py +1 -0
  76. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  77. sglang/utils.py +0 -11
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
  80. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
  81. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
  82. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
  83. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,532 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """Inference-only Arcee Foundational Model (AFM) compatible with HuggingFace weights."""
15
+
16
+ import logging
17
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ from torch import nn
21
+ from transformers import LlamaConfig
22
+
23
+ from sglang.srt.distributed import (
24
+ get_pp_group,
25
+ get_tensor_model_parallel_rank,
26
+ get_tensor_model_parallel_world_size,
27
+ )
28
+ from sglang.srt.layers.activation import get_act_fn
29
+ from sglang.srt.layers.layernorm import RMSNorm
30
+ from sglang.srt.layers.linear import (
31
+ ColumnParallelLinear,
32
+ QKVParallelLinear,
33
+ RowParallelLinear,
34
+ )
35
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
36
+ from sglang.srt.layers.pooler import Pooler, PoolingType
37
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
+ from sglang.srt.layers.radix_attention import RadixAttention
39
+ from sglang.srt.layers.rotary_embedding import get_rope
40
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
41
+ from sglang.srt.layers.vocab_parallel_embedding import (
42
+ ParallelLMHead,
43
+ VocabParallelEmbedding,
44
+ )
45
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
46
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
47
+ from sglang.srt.model_loader.weight_utils import (
48
+ default_weight_loader,
49
+ kv_cache_scales_loader,
50
+ maybe_remap_kv_scale_name,
51
+ )
52
+ from sglang.srt.utils import add_prefix, make_layers
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+
57
+ class ArceeMLP(nn.Module):
58
+ """
59
+ MLP block for the Arcee model, using a ReLU-squared activation function.
60
+ This differs from the Llama SwiGLU activation.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ hidden_size: int,
66
+ intermediate_size: int,
67
+ hidden_act: str,
68
+ quant_config: Optional[QuantizationConfig] = None,
69
+ prefix: str = "",
70
+ reduce_results: bool = True,
71
+ ) -> None:
72
+ super().__init__()
73
+ # Arcee uses a single up-projection, not a merged gate/up projection.
74
+ self.up_proj = ColumnParallelLinear(
75
+ hidden_size,
76
+ intermediate_size,
77
+ bias=False,
78
+ quant_config=quant_config,
79
+ prefix=add_prefix("up_proj", prefix),
80
+ )
81
+ self.down_proj = RowParallelLinear(
82
+ intermediate_size,
83
+ hidden_size,
84
+ bias=False,
85
+ quant_config=quant_config,
86
+ prefix=add_prefix("down_proj", prefix),
87
+ reduce_results=reduce_results,
88
+ )
89
+ if hidden_act != "relu2":
90
+ raise ValueError(
91
+ f"Unsupported activation: {hidden_act}. "
92
+ "Arcee model in SGLang only supports 'relu2'."
93
+ )
94
+ # The activation function is relu(x)^2
95
+ self.act_fn = get_act_fn("relu2")
96
+
97
+ def forward(self, x, forward_batch=None):
98
+ x, _ = self.up_proj(x)
99
+ x = self.act_fn(x)
100
+ x, _ = self.down_proj(x)
101
+ return x
102
+
103
+
104
+ class ArceeAttention(nn.Module):
105
+ def __init__(
106
+ self,
107
+ config: LlamaConfig,
108
+ hidden_size: int,
109
+ num_heads: int,
110
+ num_kv_heads: int,
111
+ layer_id: int = 0,
112
+ rope_theta: float = 10000,
113
+ rope_scaling: Optional[Dict[str, Any]] = None,
114
+ rope_is_neox_style: bool = True,
115
+ max_position_embeddings: int = 8192,
116
+ quant_config: Optional[QuantizationConfig] = None,
117
+ prefix: str = "",
118
+ bias: bool = False,
119
+ ) -> None:
120
+ super().__init__()
121
+ self.hidden_size = hidden_size
122
+ tp_size = get_tensor_model_parallel_world_size()
123
+ self.total_num_heads = num_heads
124
+ assert self.total_num_heads % tp_size == 0
125
+ self.num_heads = self.total_num_heads // tp_size
126
+ self.total_num_kv_heads = num_kv_heads
127
+ if self.total_num_kv_heads >= tp_size:
128
+ assert self.total_num_kv_heads % tp_size == 0
129
+ else:
130
+ assert tp_size % self.total_num_kv_heads == 0
131
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
132
+ self.head_dim = getattr(config, "head_dim", None)
133
+ if self.head_dim is None:
134
+ self.head_dim = self.hidden_size // self.total_num_heads
135
+ self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
136
+ self.rotary_dim = int(self.partial_rotary_factor * self.head_dim)
137
+ self.q_size = self.num_heads * self.head_dim
138
+ self.kv_size = self.num_kv_heads * self.head_dim
139
+ self.scaling = self.head_dim**-0.5
140
+ self.rope_theta = rope_theta
141
+ self.max_position_embeddings = max_position_embeddings
142
+
143
+ self.qkv_proj = QKVParallelLinear(
144
+ hidden_size,
145
+ self.head_dim,
146
+ self.total_num_heads,
147
+ self.total_num_kv_heads,
148
+ bias=bias,
149
+ quant_config=quant_config,
150
+ prefix=add_prefix("qkv_proj", prefix),
151
+ )
152
+ self.o_proj = RowParallelLinear(
153
+ self.total_num_heads * self.head_dim,
154
+ hidden_size,
155
+ bias=bias,
156
+ quant_config=quant_config,
157
+ prefix=add_prefix("o_proj", prefix),
158
+ )
159
+
160
+ self.rotary_emb = get_rope(
161
+ self.head_dim,
162
+ rotary_dim=self.rotary_dim,
163
+ max_position=max_position_embeddings,
164
+ base=rope_theta,
165
+ rope_scaling=rope_scaling,
166
+ is_neox_style=rope_is_neox_style,
167
+ )
168
+ self.attn = RadixAttention(
169
+ self.num_heads,
170
+ self.head_dim,
171
+ self.scaling,
172
+ num_kv_heads=self.num_kv_heads,
173
+ layer_id=layer_id,
174
+ quant_config=quant_config,
175
+ prefix=add_prefix("attn", prefix),
176
+ )
177
+
178
+ def forward(
179
+ self,
180
+ positions: torch.Tensor,
181
+ hidden_states: torch.Tensor,
182
+ forward_batch: ForwardBatch,
183
+ ) -> torch.Tensor:
184
+ qkv, _ = self.qkv_proj(hidden_states)
185
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
186
+ q, k = self.rotary_emb(positions, q, k)
187
+ attn_output = self.attn(q, k, v, forward_batch)
188
+ output, _ = self.o_proj(attn_output)
189
+ return output
190
+
191
+
192
+ class ArceeDecoderLayer(nn.Module):
193
+ def __init__(
194
+ self,
195
+ config: LlamaConfig,
196
+ layer_id: int = 0,
197
+ quant_config: Optional[QuantizationConfig] = None,
198
+ prefix: str = "",
199
+ ) -> None:
200
+ super().__init__()
201
+ self.hidden_size = config.hidden_size
202
+ rope_theta = getattr(config, "rope_theta", 10000)
203
+ rope_scaling = getattr(config, "rope_scaling", None)
204
+ if rope_scaling is not None and getattr(
205
+ config, "original_max_position_embeddings", None
206
+ ):
207
+ rope_scaling["original_max_position_embeddings"] = (
208
+ config.original_max_position_embeddings
209
+ )
210
+ rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
211
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
212
+ attention_bias = getattr(config, "attention_bias", False) or getattr(
213
+ config, "bias", False
214
+ )
215
+ self.self_attn = ArceeAttention(
216
+ config=config,
217
+ hidden_size=self.hidden_size,
218
+ num_heads=config.num_attention_heads,
219
+ num_kv_heads=config.num_key_value_heads,
220
+ layer_id=layer_id,
221
+ rope_theta=rope_theta,
222
+ rope_scaling=rope_scaling,
223
+ rope_is_neox_style=rope_is_neox_style,
224
+ max_position_embeddings=max_position_embeddings,
225
+ quant_config=quant_config,
226
+ prefix=add_prefix("self_attn", prefix),
227
+ bias=attention_bias,
228
+ )
229
+ self.mlp = ArceeMLP(
230
+ hidden_size=self.hidden_size,
231
+ intermediate_size=config.intermediate_size,
232
+ hidden_act=config.hidden_act,
233
+ quant_config=quant_config,
234
+ prefix=add_prefix("mlp", prefix),
235
+ )
236
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
237
+ self.post_attention_layernorm = RMSNorm(
238
+ config.hidden_size, eps=config.rms_norm_eps
239
+ )
240
+
241
+ def forward(
242
+ self,
243
+ positions: torch.Tensor,
244
+ hidden_states: torch.Tensor,
245
+ forward_batch: ForwardBatch,
246
+ residual: Optional[torch.Tensor],
247
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
248
+ # Self Attention
249
+ if residual is None:
250
+ residual = hidden_states
251
+ hidden_states = self.input_layernorm(hidden_states)
252
+ else:
253
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
254
+ hidden_states = self.self_attn(
255
+ positions=positions,
256
+ hidden_states=hidden_states,
257
+ forward_batch=forward_batch,
258
+ )
259
+
260
+ # Fully Connected
261
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
262
+ hidden_states = self.mlp(hidden_states)
263
+ return hidden_states, residual
264
+
265
+
266
+ class ArceeModel(nn.Module):
267
+ def __init__(
268
+ self,
269
+ config: LlamaConfig,
270
+ quant_config: Optional[QuantizationConfig] = None,
271
+ prefix: str = "",
272
+ ) -> None:
273
+ super().__init__()
274
+ self.config = config
275
+ self.padding_idx = config.pad_token_id
276
+ self.vocab_size = config.vocab_size
277
+ self.pp_group = get_pp_group()
278
+ if self.pp_group.is_first_rank:
279
+ self.embed_tokens = VocabParallelEmbedding(
280
+ config.vocab_size,
281
+ config.hidden_size,
282
+ quant_config=quant_config,
283
+ prefix=add_prefix("embed_tokens", prefix),
284
+ )
285
+ else:
286
+ self.embed_tokens = PPMissingLayer()
287
+
288
+ self.layers, self.start_layer, self.end_layer = make_layers(
289
+ config.num_hidden_layers,
290
+ lambda idx, prefix: ArceeDecoderLayer(
291
+ config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
292
+ ),
293
+ pp_rank=self.pp_group.rank_in_group,
294
+ pp_size=self.pp_group.world_size,
295
+ prefix="model.layers",
296
+ )
297
+
298
+ if self.pp_group.is_last_rank:
299
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
300
+ else:
301
+ self.norm = PPMissingLayer(return_tuple=True)
302
+ self.layers_to_capture = []
303
+
304
+ def forward(
305
+ self,
306
+ input_ids: torch.Tensor,
307
+ positions: torch.Tensor,
308
+ forward_batch: ForwardBatch,
309
+ input_embeds: torch.Tensor = None,
310
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
311
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]], PPProxyTensors]:
312
+ if self.pp_group.is_first_rank:
313
+ if input_embeds is None:
314
+ hidden_states = self.embed_tokens(input_ids)
315
+ else:
316
+ hidden_states = input_embeds
317
+ residual = None
318
+ else:
319
+ assert pp_proxy_tensors is not None
320
+ hidden_states = pp_proxy_tensors["hidden_states"]
321
+ residual = pp_proxy_tensors["residual"]
322
+
323
+ aux_hidden_states = []
324
+ for i in range(self.start_layer, self.end_layer):
325
+ if i in self.layers_to_capture:
326
+ aux_hidden_states.append(hidden_states + residual)
327
+ layer = self.layers[i]
328
+ hidden_states, residual = layer(
329
+ positions,
330
+ hidden_states,
331
+ forward_batch,
332
+ residual,
333
+ )
334
+
335
+ if not self.pp_group.is_last_rank:
336
+ return PPProxyTensors(
337
+ {
338
+ "hidden_states": hidden_states,
339
+ "residual": residual,
340
+ }
341
+ )
342
+ else:
343
+ hidden_states, _ = self.norm(hidden_states, residual)
344
+
345
+ if len(aux_hidden_states) == 0:
346
+ return hidden_states
347
+
348
+ return hidden_states, aux_hidden_states
349
+
350
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
351
+ tp_size = get_tensor_model_parallel_world_size()
352
+ tp_rank = get_tensor_model_parallel_rank()
353
+ for layer_idx, scaling_factor in kv_cache_scales_loader(
354
+ quantization_param_path,
355
+ tp_rank,
356
+ tp_size,
357
+ self.config.num_hidden_layers,
358
+ self.config.__class__.model_type,
359
+ ):
360
+ if not isinstance(self.layers[layer_idx], nn.Identity):
361
+ layer_self_attn = self.layers[layer_idx].self_attn
362
+
363
+ if hasattr(layer_self_attn.attn, "k_scale"):
364
+ layer_self_attn.attn.k_scale = scaling_factor
365
+ layer_self_attn.attn.v_scale = scaling_factor
366
+ else:
367
+ raise RuntimeError(
368
+ "Self attention has no KV cache scaling factor attribute!"
369
+ )
370
+
371
+
372
+ class ArceeForCausalLM(nn.Module):
373
+ # BitandBytes specific attributes
374
+ default_bitsandbytes_target_modules = [
375
+ # Note: gate_proj is removed compared to Llama
376
+ ".down_proj.",
377
+ ".up_proj.",
378
+ ".q_proj.",
379
+ ".k_proj.",
380
+ ".v_proj.",
381
+ ".o_proj.",
382
+ ]
383
+ # in TP, these weights are partitioned along the column dimension (dim=-1)
384
+ column_parallel_weights_modules = [".down_proj.", ".o_proj."]
385
+ bitsandbytes_stacked_params_mapping = {
386
+ # shard_name, weight_name, index
387
+ # Note: gate_proj and up_proj are removed as they are not stacked in ArceeMLP
388
+ ".q_proj": (".qkv_proj", 0),
389
+ ".k_proj": (".qkv_proj", 1),
390
+ ".v_proj": (".qkv_proj", 2),
391
+ }
392
+
393
+ def __init__(
394
+ self,
395
+ config: LlamaConfig,
396
+ quant_config: Optional[QuantizationConfig] = None,
397
+ prefix: str = "",
398
+ ) -> None:
399
+ super().__init__()
400
+ self.pp_group = get_pp_group()
401
+ self.config = config
402
+ self.quant_config = quant_config
403
+ self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
404
+ # Arcee does not tie word embeddings
405
+ self.lm_head = ParallelLMHead(
406
+ config.vocab_size,
407
+ config.hidden_size,
408
+ quant_config=quant_config,
409
+ prefix=add_prefix("lm_head", prefix),
410
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
411
+ )
412
+ self.logits_processor = LogitsProcessor(config)
413
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
414
+ # Parameters that are stacked in a single tensor in this model
415
+ self.stacked_params_mapping = [
416
+ # (param_name, shard_name, shard_id)
417
+ (".qkv_proj", ".q_proj", "q"),
418
+ (".qkv_proj", ".k_proj", "k"),
419
+ (".qkv_proj", ".v_proj", "v"),
420
+ ]
421
+ self.capture_aux_hidden_states = False
422
+
423
+ def _init_model(
424
+ self,
425
+ config: LlamaConfig,
426
+ quant_config: Optional[QuantizationConfig] = None,
427
+ prefix: str = "",
428
+ ):
429
+ return ArceeModel(config, quant_config=quant_config, prefix=prefix)
430
+
431
+ @torch.no_grad()
432
+ def forward(
433
+ self,
434
+ input_ids: torch.Tensor,
435
+ positions: torch.Tensor,
436
+ forward_batch: ForwardBatch,
437
+ input_embeds: torch.Tensor = None,
438
+ get_embedding: bool = False,
439
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
440
+ ) -> LogitsProcessorOutput:
441
+ hidden_states = self.model(
442
+ input_ids,
443
+ positions,
444
+ forward_batch,
445
+ input_embeds,
446
+ pp_proxy_tensors=pp_proxy_tensors,
447
+ )
448
+
449
+ aux_hidden_states = None
450
+ if self.capture_aux_hidden_states:
451
+ hidden_states, aux_hidden_states = hidden_states
452
+
453
+ if self.pp_group.is_last_rank:
454
+ if not get_embedding:
455
+ return self.logits_processor(
456
+ input_ids,
457
+ hidden_states,
458
+ self.lm_head,
459
+ forward_batch,
460
+ aux_hidden_states,
461
+ )
462
+ else:
463
+ return self.pooler(hidden_states, forward_batch)
464
+ else:
465
+ return hidden_states
466
+
467
+ @property
468
+ def start_layer(self):
469
+ return self.model.start_layer
470
+
471
+ @property
472
+ def end_layer(self):
473
+ return self.model.end_layer
474
+
475
+ def get_input_embeddings(self) -> nn.Embedding:
476
+ return self.model.embed_tokens
477
+
478
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
479
+ params_dict = dict(self.named_parameters())
480
+
481
+ for name, loaded_weight in weights:
482
+ layer_id = get_layer_id(name)
483
+ if (
484
+ layer_id is not None
485
+ and hasattr(self.model, "start_layer")
486
+ and (
487
+ layer_id < self.model.start_layer
488
+ or layer_id >= self.model.end_layer
489
+ )
490
+ ):
491
+ continue
492
+ if "rotary_emb.inv_freq" in name or "projector" in name:
493
+ continue
494
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
495
+ continue
496
+
497
+ # Handle FP8 kv-scale remapping
498
+ if "scale" in name:
499
+ name = maybe_remap_kv_scale_name(name, params_dict)
500
+ if name is None:
501
+ continue
502
+
503
+ is_stacked = False
504
+ for param_name, weight_name, shard_id in self.stacked_params_mapping:
505
+ if weight_name not in name:
506
+ continue
507
+
508
+ name = name.replace(weight_name, param_name)
509
+ if name not in params_dict:
510
+ continue
511
+
512
+ param = params_dict[name]
513
+ weight_loader = param.weight_loader
514
+ weight_loader(param, loaded_weight, shard_id)
515
+ is_stacked = True
516
+ break
517
+
518
+ if not is_stacked:
519
+ if name in params_dict:
520
+ param = params_dict[name]
521
+ weight_loader = getattr(
522
+ param, "weight_loader", default_weight_loader
523
+ )
524
+ weight_loader(param, loaded_weight)
525
+ else:
526
+ logger.warning(f"Parameter {name} not found in model.")
527
+
528
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
529
+ self.model.load_kv_cache_scales(quantization_param_path)
530
+
531
+
532
+ EntryClass = [ArceeForCausalLM]
@@ -59,7 +59,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
59
59
  from sglang.srt.layers.moe.ep_moe.layer import (
60
60
  DeepEPMoE,
61
61
  get_moe_impl_class,
62
- use_flashinfer_trtllm_moe,
62
+ should_use_flashinfer_trtllm_moe,
63
63
  )
64
64
  from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
65
65
  from sglang.srt.layers.moe.topk import TopK
@@ -252,8 +252,7 @@ class MoEGate(nn.Module):
252
252
  # NOTE: For some unknown reason, router_gemm seems degrade accept length.
253
253
  if (
254
254
  _is_cuda
255
- and not self.is_nextn
256
- and hidden_states.shape[0] < 4
255
+ and hidden_states.shape[0] <= 16
257
256
  and hidden_states.shape[1] == 7168
258
257
  and self.weight.shape[0] == 256
259
258
  and _device_sm >= 90
@@ -317,7 +316,7 @@ class DeepseekV2MoE(nn.Module):
317
316
  correction_bias=self.gate.e_score_correction_bias,
318
317
  routed_scaling_factor=self.routed_scaling_factor,
319
318
  )
320
- if not use_flashinfer_trtllm_moe
319
+ if not should_use_flashinfer_trtllm_moe()
321
320
  else None
322
321
  )
323
322
 
@@ -325,6 +324,7 @@ class DeepseekV2MoE(nn.Module):
325
324
  num_experts=config.n_routed_experts
326
325
  + self.num_fused_shared_experts
327
326
  + global_server_args_dict["ep_num_redundant_experts"],
327
+ num_fused_shared_experts=self.num_fused_shared_experts,
328
328
  top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
329
329
  hidden_size=config.hidden_size,
330
330
  intermediate_size=config.moe_intermediate_size,
@@ -351,11 +351,10 @@ class DeepseekV2MoE(nn.Module):
351
351
  renormalize=config.norm_topk_prob,
352
352
  use_grouped_topk=True,
353
353
  num_expert_group=config.n_group,
354
- num_fused_shared_experts=self.num_fused_shared_experts,
355
354
  topk_group=config.topk_group,
356
355
  correction_bias=self.gate.e_score_correction_bias,
357
356
  )
358
- if use_flashinfer_trtllm_moe
357
+ if should_use_flashinfer_trtllm_moe()
359
358
  else {}
360
359
  ),
361
360
  )
@@ -1258,6 +1257,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1258
1257
  self.current_attention_backend == "fa3"
1259
1258
  or self.current_attention_backend == "flashinfer"
1260
1259
  or self.current_attention_backend == "cutlass_mla"
1260
+ or self.current_attention_backend == "trtllm_mla"
1261
1261
  ):
1262
1262
  attn_output = self.attn_mqa(
1263
1263
  q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
@@ -2112,6 +2112,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2112
2112
 
2113
2113
  if disable_reason is not None:
2114
2114
  global_server_args_dict["disable_shared_experts_fusion"] = True
2115
+ self.num_fused_shared_experts = 0
2115
2116
  log_info_on_rank0(
2116
2117
  logger,
2117
2118
  f"{disable_reason} Shared experts fusion optimization is disabled.",
@@ -52,7 +52,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
52
52
  from sglang.srt.layers.moe.ep_moe.layer import (
53
53
  DeepEPMoE,
54
54
  get_moe_impl_class,
55
- use_flashinfer_trtllm_moe,
55
+ should_use_flashinfer_trtllm_moe,
56
56
  )
57
57
  from sglang.srt.layers.moe.topk import TopK
58
58
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -426,7 +426,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
426
426
  correction_bias=self.gate.e_score_correction_bias,
427
427
  routed_scaling_factor=self.routed_scaling_factor,
428
428
  )
429
- if not use_flashinfer_trtllm_moe
429
+ if not should_use_flashinfer_trtllm_moe()
430
430
  else None
431
431
  )
432
432
 
@@ -434,6 +434,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
434
434
  num_experts=config.n_routed_experts
435
435
  + self.num_fused_shared_experts
436
436
  + global_server_args_dict["ep_num_redundant_experts"],
437
+ num_fused_shared_experts=self.num_fused_shared_experts,
437
438
  top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
438
439
  hidden_size=config.hidden_size,
439
440
  intermediate_size=config.moe_intermediate_size,
@@ -464,7 +465,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
464
465
  topk_group=config.topk_group,
465
466
  correction_bias=self.gate.e_score_correction_bias,
466
467
  )
467
- if use_flashinfer_trtllm_moe
468
+ if should_use_flashinfer_trtllm_moe()
468
469
  else {}
469
470
  ),
470
471
  )
@@ -740,10 +741,11 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
740
741
  global_server_args_dict["enable_deepep_moe"]
741
742
  or global_server_args_dict["enable_ep_moe"]
742
743
  ):
743
- disable_reason = "Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
744
+ disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
744
745
 
745
746
  if disable_reason is not None:
746
747
  global_server_args_dict["disable_shared_experts_fusion"] = True
748
+ self.num_fused_shared_experts = 0
747
749
  log_info_on_rank0(
748
750
  logger,
749
751
  f"{disable_reason} Shared experts fusion optimization is disabled.",
@@ -43,6 +43,7 @@ class GraniteMoeMoE(nn.Module):
43
43
  top_k: int,
44
44
  hidden_size: int,
45
45
  intermediate_size: int,
46
+ layer_id: int,
46
47
  params_dtype: Optional[torch.dtype] = None,
47
48
  quant_config: Optional[QuantizationConfig] = None,
48
49
  tp_size: Optional[int] = None,
@@ -71,6 +72,7 @@ class GraniteMoeMoE(nn.Module):
71
72
  top_k=top_k,
72
73
  hidden_size=hidden_size,
73
74
  intermediate_size=intermediate_size,
75
+ layer_id=layer_id,
74
76
  params_dtype=params_dtype,
75
77
  reduce_results=True,
76
78
  quant_config=quant_config,
@@ -203,6 +205,7 @@ class GraniteMoeDecoderLayer(nn.Module):
203
205
  top_k=config.num_experts_per_tok,
204
206
  hidden_size=config.hidden_size,
205
207
  intermediate_size=config.intermediate_size,
208
+ layer_id=layer_id,
206
209
  quant_config=quant_config,
207
210
  prefix=f"{prefix}.block_sparse_moe",
208
211
  )
sglang/srt/models/grok.py CHANGED
@@ -78,6 +78,7 @@ class Grok1MoE(nn.Module):
78
78
  def __init__(
79
79
  self,
80
80
  config: PretrainedConfig,
81
+ layer_id: int,
81
82
  num_experts: int,
82
83
  top_k: int,
83
84
  hidden_size: int,
@@ -128,6 +129,7 @@ class Grok1MoE(nn.Module):
128
129
  self.experts = MoEImpl(
129
130
  num_experts=num_experts,
130
131
  top_k=top_k,
132
+ layer_id=layer_id,
131
133
  hidden_size=hidden_size,
132
134
  intermediate_size=intermediate_size,
133
135
  params_dtype=params_dtype,
@@ -331,6 +333,7 @@ class Grok1DecoderLayer(nn.Module):
331
333
  )
332
334
  self.block_sparse_moe = Grok1MoE(
333
335
  config=config,
336
+ layer_id=layer_id,
334
337
  num_experts=config.num_local_experts,
335
338
  top_k=config.num_experts_per_tok,
336
339
  hidden_size=config.hidden_size,
@@ -163,6 +163,7 @@ class HunYuanSparseMoeBlock(nn.Module):
163
163
  hidden_size=config.hidden_size,
164
164
  intermediate_size=intermediate_size,
165
165
  reduce_results=False,
166
+ layer_id=layer_id,
166
167
  quant_config=quant_config,
167
168
  )
168
169