sglang 0.4.8__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/srt/configs/model_config.py +1 -0
  2. sglang/srt/conversation.py +1 -0
  3. sglang/srt/custom_op.py +7 -1
  4. sglang/srt/disaggregation/base/conn.py +2 -0
  5. sglang/srt/disaggregation/decode.py +1 -1
  6. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  7. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  8. sglang/srt/disaggregation/nixl/conn.py +94 -46
  9. sglang/srt/disaggregation/prefill.py +3 -2
  10. sglang/srt/disaggregation/utils.py +12 -11
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/openai/protocol.py +47 -4
  13. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  14. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  15. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  16. sglang/srt/layers/activation.py +7 -0
  17. sglang/srt/layers/attention/flashattention_backend.py +24 -14
  18. sglang/srt/layers/layernorm.py +15 -0
  19. sglang/srt/layers/linear.py +18 -1
  20. sglang/srt/layers/logits_processor.py +12 -3
  21. sglang/srt/layers/moe/ep_moe/layer.py +79 -12
  22. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  23. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  24. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -2
  25. sglang/srt/layers/moe/fused_moe_triton/layer.py +73 -14
  26. sglang/srt/layers/moe/topk.py +26 -0
  27. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  28. sglang/srt/layers/rotary_embedding.py +103 -11
  29. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  30. sglang/srt/managers/expert_distribution.py +21 -0
  31. sglang/srt/managers/io_struct.py +10 -2
  32. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  33. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  34. sglang/srt/managers/schedule_batch.py +9 -1
  35. sglang/srt/managers/scheduler.py +42 -6
  36. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  37. sglang/srt/model_executor/model_runner.py +5 -2
  38. sglang/srt/model_loader/loader.py +45 -10
  39. sglang/srt/model_loader/weight_utils.py +89 -0
  40. sglang/srt/models/deepseek_nextn.py +7 -4
  41. sglang/srt/models/deepseek_v2.py +147 -4
  42. sglang/srt/models/gemma3n_audio.py +949 -0
  43. sglang/srt/models/gemma3n_causal.py +1009 -0
  44. sglang/srt/models/gemma3n_mm.py +511 -0
  45. sglang/srt/models/hunyuan.py +771 -0
  46. sglang/srt/server_args.py +16 -2
  47. sglang/srt/two_batch_overlap.py +4 -1
  48. sglang/srt/utils.py +71 -0
  49. sglang/version.py +1 -1
  50. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +1 -1
  51. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +54 -49
  52. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  53. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  54. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ from sglang.srt.entrypoints.openai.protocol import (
22
22
  ErrorResponse,
23
23
  FunctionResponse,
24
24
  LogProbs,
25
+ MessageProcessingResult,
25
26
  ToolCall,
26
27
  TopLogprob,
27
28
  )
@@ -62,120 +63,81 @@ class OpenAIServingChat(OpenAIServingBase):
62
63
  is_multimodal = self.tokenizer_manager.model_config.is_multimodal
63
64
 
64
65
  # Process messages and apply chat template
65
- (
66
- prompt,
67
- prompt_ids,
68
- image_data,
69
- audio_data,
70
- modalities,
71
- stop,
72
- tool_call_constraint,
73
- ) = self._process_messages(request, is_multimodal)
66
+ processed_messages = self._process_messages(request, is_multimodal)
74
67
 
75
68
  # Build sampling parameters
76
69
  sampling_params = self._build_sampling_params(
77
- request, stop, tool_call_constraint
70
+ request, processed_messages.stop, processed_messages.tool_call_constraint
78
71
  )
79
72
 
80
73
  # Handle single vs multiple requests
81
74
  if is_multimodal:
82
- prompt_kwargs = {"text": prompt}
75
+ prompt_kwargs = {"text": processed_messages.prompt}
83
76
  else:
84
- if isinstance(prompt_ids, str):
85
- prompt_kwargs = {"text": prompt_ids}
77
+ if isinstance(processed_messages.prompt_ids, str):
78
+ prompt_kwargs = {"text": processed_messages.prompt_ids}
86
79
  else:
87
- prompt_kwargs = {"input_ids": prompt_ids}
80
+ prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
88
81
 
89
82
  adapted_request = GenerateReqInput(
90
83
  **prompt_kwargs,
91
- image_data=image_data,
92
- audio_data=audio_data,
84
+ image_data=processed_messages.image_data,
85
+ audio_data=processed_messages.audio_data,
93
86
  sampling_params=sampling_params,
94
87
  return_logprob=request.logprobs,
95
88
  logprob_start_len=-1,
96
89
  top_logprobs_num=request.top_logprobs or 0,
97
90
  stream=request.stream,
98
91
  return_text_in_logprobs=True,
99
- modalities=modalities,
92
+ modalities=processed_messages.modalities,
100
93
  lora_path=request.lora_path,
101
94
  bootstrap_host=request.bootstrap_host,
102
95
  bootstrap_port=request.bootstrap_port,
103
96
  bootstrap_room=request.bootstrap_room,
104
97
  return_hidden_states=request.return_hidden_states,
98
+ rid=request.rid,
105
99
  )
106
100
 
107
101
  return adapted_request, request
108
102
 
109
103
  def _process_messages(
110
104
  self, request: ChatCompletionRequest, is_multimodal: bool
111
- ) -> tuple[
112
- str,
113
- Union[str, List[int]],
114
- Optional[Any],
115
- Optional[Any],
116
- List[str],
117
- List[str],
118
- Optional[Any],
119
- ]:
105
+ ) -> MessageProcessingResult:
120
106
  """Process chat messages and apply chat template"""
121
107
  tool_call_constraint = None
122
- prompt = ""
123
- prompt_ids = []
124
108
 
125
- if not isinstance(request.messages, str):
126
- # Apply chat template and its stop strings
127
- tools = None
128
- if request.tools and request.tool_choice != "none":
129
- request.skip_special_tokens = False
130
- if not isinstance(request.tool_choice, str):
131
- tools = [
132
- item.function.model_dump()
133
- for item in request.tools
134
- if item.function.name == request.tool_choice.function.name
135
- ]
136
- else:
137
- tools = [item.function.model_dump() for item in request.tools]
109
+ # Apply chat template and its stop strings
110
+ tools = None
111
+ if request.tools and request.tool_choice != "none":
112
+ request.skip_special_tokens = False
113
+ if not isinstance(request.tool_choice, str):
114
+ tools = [
115
+ item.function.model_dump()
116
+ for item in request.tools
117
+ if item.function.name == request.tool_choice.function.name
118
+ ]
119
+ else:
120
+ tools = [item.function.model_dump() for item in request.tools]
138
121
 
139
- tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
140
- parser = FunctionCallParser(request.tools, tool_call_parser)
141
- tool_call_constraint = parser.get_structure_constraint(
142
- request.tool_choice
143
- )
122
+ tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
123
+ parser = FunctionCallParser(request.tools, tool_call_parser)
124
+ tool_call_constraint = parser.get_structure_constraint(request.tool_choice)
144
125
 
145
- # Use chat template
146
- if self.template_manager.chat_template_name is None:
147
- prompt, prompt_ids, image_data, audio_data, modalities, stop = (
148
- self._apply_jinja_template(request, tools, is_multimodal)
149
- )
150
- else:
151
- prompt, prompt_ids, image_data, audio_data, modalities, stop = (
152
- self._apply_conversation_template(request, is_multimodal)
153
- )
126
+ # Use chat template
127
+ if self.template_manager.chat_template_name is None:
128
+ result = self._apply_jinja_template(request, tools, is_multimodal)
154
129
  else:
155
- # Use raw prompt
156
- prompt_ids = request.messages
157
- stop = request.stop or []
158
- image_data = None
159
- audio_data = None
160
- modalities = []
161
- prompt = request.messages
162
-
163
- return (
164
- prompt,
165
- prompt_ids,
166
- image_data,
167
- audio_data,
168
- modalities,
169
- stop,
170
- tool_call_constraint,
171
- )
130
+ result = self._apply_conversation_template(request, is_multimodal)
131
+
132
+ result.tool_call_constraint = tool_call_constraint
133
+ return result
172
134
 
173
135
  def _apply_jinja_template(
174
136
  self,
175
137
  request: ChatCompletionRequest,
176
138
  tools: Optional[List[Dict]],
177
139
  is_multimodal: bool,
178
- ) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]:
140
+ ) -> MessageProcessingResult:
179
141
  """Apply Jinja chat template"""
180
142
  prompt = ""
181
143
  prompt_ids = []
@@ -253,13 +215,20 @@ class OpenAIServingChat(OpenAIServingBase):
253
215
  image_data = image_data if image_data else None
254
216
  audio_data = audio_data if audio_data else None
255
217
  modalities = modalities if modalities else []
256
- return prompt, prompt_ids, image_data, audio_data, modalities, stop
218
+ return MessageProcessingResult(
219
+ prompt=prompt,
220
+ prompt_ids=prompt_ids,
221
+ image_data=image_data,
222
+ audio_data=audio_data,
223
+ modalities=modalities,
224
+ stop=stop,
225
+ )
257
226
 
258
227
  def _apply_conversation_template(
259
228
  self,
260
229
  request: ChatCompletionRequest,
261
230
  is_multimodal: bool,
262
- ) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str], List[str]]:
231
+ ) -> MessageProcessingResult:
263
232
  """Apply conversation template"""
264
233
  prompt = ""
265
234
  prompt_ids = []
@@ -304,7 +273,14 @@ class OpenAIServingChat(OpenAIServingBase):
304
273
  if not is_multimodal:
305
274
  prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
306
275
 
307
- return prompt, prompt_ids, image_data, audio_data, modalities, stop
276
+ return MessageProcessingResult(
277
+ prompt=prompt,
278
+ prompt_ids=prompt_ids,
279
+ image_data=image_data,
280
+ audio_data=audio_data,
281
+ modalities=modalities,
282
+ stop=stop,
283
+ )
308
284
 
309
285
  def _build_sampling_params(
310
286
  self,
@@ -87,6 +87,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
87
87
  bootstrap_port=request.bootstrap_port,
88
88
  bootstrap_room=request.bootstrap_room,
89
89
  return_hidden_states=request.return_hidden_states,
90
+ rid=request.rid,
90
91
  )
91
92
 
92
93
  return adapted_request, request
@@ -119,6 +119,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
119
119
 
120
120
  adapted_request = EmbeddingReqInput(
121
121
  **prompt_kwargs,
122
+ rid=request.rid,
122
123
  )
123
124
 
124
125
  return adapted_request, request
@@ -48,6 +48,9 @@ if _is_cuda:
48
48
 
49
49
  logger = logging.getLogger(__name__)
50
50
 
51
+ if is_npu():
52
+ import torch_npu
53
+
51
54
 
52
55
  class SiluAndMul(CustomOp):
53
56
  def forward_native(self, x: torch.Tensor) -> torch.Tensor:
@@ -70,6 +73,10 @@ class SiluAndMul(CustomOp):
70
73
  else:
71
74
  return self.forward_native(x)
72
75
 
76
+ def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
77
+ out = torch_npu.npu_swiglu(x)
78
+ return out
79
+
73
80
 
74
81
  class GeluAndMul(CustomOp):
75
82
  def __init__(self, approximate="tanh"):
@@ -657,12 +657,16 @@ class FlashAttentionBackend(AttentionBackend):
657
657
  )
658
658
  k_descale, v_descale = None, None
659
659
  # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
660
- # has corresponding quantization method so that layer.k_scale is not None
661
- if self.kv_cache_dtype_str != "auto" and layer.k_scale is not None:
662
- descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
663
- k_descale = layer.k_scale.expand(descale_shape)
664
- v_descale = layer.v_scale.expand(descale_shape)
660
+ # has corresponding quantization method so that layer.k_scale is not None,
661
+ # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
662
+ if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
663
+ if layer.k_scale is not None:
664
+ descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
665
+ k_descale = layer.k_scale.expand(descale_shape)
666
+ v_descale = layer.v_scale.expand(descale_shape)
665
667
  q = q.to(self.kv_cache_dtype)
668
+ q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
669
+ k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
666
670
  causal = not layer.is_cross_attention
667
671
 
668
672
  # Check if we should use local attention
@@ -776,8 +780,8 @@ class FlashAttentionBackend(AttentionBackend):
776
780
 
777
781
  output, lse, *rest = flash_attn_varlen_func(
778
782
  q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
779
- k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
780
- v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
783
+ k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
784
+ v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
781
785
  cu_seqlens_q=metadata.cu_seqlens_q,
782
786
  cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
783
787
  max_seqlen_q=metadata.max_seq_len_q,
@@ -790,8 +794,8 @@ class FlashAttentionBackend(AttentionBackend):
790
794
  # MHA for extend part of sequence without attending prefix kv cache
791
795
  output, lse, *rest = flash_attn_varlen_func(
792
796
  q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
793
- k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
794
- v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
797
+ k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
798
+ v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
795
799
  cu_seqlens_q=metadata.cu_seqlens_q,
796
800
  cu_seqlens_k=metadata.cu_seqlens_q,
797
801
  max_seqlen_q=metadata.max_seq_len_q,
@@ -803,7 +807,9 @@ class FlashAttentionBackend(AttentionBackend):
803
807
  return output, lse
804
808
  else:
805
809
  # Do absorbed multi-latent attention
806
- kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
810
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
811
+ layer.layer_id
812
+ ).to(q.dtype)
807
813
  k_rope = kv_cache[:, :, layer.v_head_dim :]
808
814
  c_kv = kv_cache[:, :, : layer.v_head_dim]
809
815
  k_rope_cache = k_rope.view(
@@ -933,14 +939,16 @@ class FlashAttentionBackend(AttentionBackend):
933
939
 
934
940
  k_descale, v_descale = None, None
935
941
  # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
936
- # has corresponding quantization method so that layer.k_scale is not None
937
- if self.kv_cache_dtype_str != "auto":
942
+ # has corresponding quantization method so that layer.k_scale is not None,
943
+ # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
944
+ if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
938
945
  if layer.k_scale is not None:
939
946
  descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
940
947
  k_descale = layer.k_scale.expand(descale_shape)
941
948
  v_descale = layer.v_scale.expand(descale_shape)
942
949
  q = q.to(self.kv_cache_dtype)
943
-
950
+ q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
951
+ k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
944
952
  if not self.use_mla:
945
953
  # Do multi-head attention
946
954
 
@@ -1048,7 +1056,9 @@ class FlashAttentionBackend(AttentionBackend):
1048
1056
  o = result
1049
1057
  else:
1050
1058
  # Do absorbed multi-latent attention
1051
- kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
1059
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
1060
+ q.dtype
1061
+ )
1052
1062
  k_rope = kv_cache[:, :, layer.v_head_dim :]
1053
1063
  c_kv = kv_cache[:, :, : layer.v_head_dim]
1054
1064
  k_rope_cache = k_rope.view(
@@ -52,6 +52,9 @@ elif _is_hip:
52
52
 
53
53
  logger = logging.getLogger(__name__)
54
54
 
55
+ if is_npu():
56
+ import torch_npu
57
+
55
58
 
56
59
  class RMSNorm(CustomOp):
57
60
  def __init__(
@@ -76,6 +79,18 @@ class RMSNorm(CustomOp):
76
79
  out = rmsnorm(x, self.weight.data, self.variance_epsilon)
77
80
  return out
78
81
 
82
+ def forward_npu(
83
+ self,
84
+ x: torch.Tensor,
85
+ residual: Optional[torch.Tensor] = None,
86
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
87
+ if residual is not None:
88
+ out, _, residual_out = torch_npu.npu_add_rms_norm(
89
+ residual, x, self.weight.data, self.variance_epsilon
90
+ )
91
+ return out, residual_out
92
+ return torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
93
+
79
94
  def forward_aiter(
80
95
  self,
81
96
  x: torch.Tensor,
@@ -30,7 +30,12 @@ from sglang.srt.layers.quantization.base_config import (
30
30
  QuantizationConfig,
31
31
  QuantizeMethodBase,
32
32
  )
33
- from sglang.srt.utils import set_weight_attrs
33
+ from sglang.srt.utils import (
34
+ _process_weight_after_loading,
35
+ cpu_has_amx_support,
36
+ is_cpu,
37
+ set_weight_attrs,
38
+ )
34
39
 
35
40
  logger = logging.getLogger(__name__)
36
41
 
@@ -52,6 +57,9 @@ WEIGHT_LOADER_V2_SUPPORTED = [
52
57
  "IPEXAWQLinearMethod",
53
58
  ]
54
59
 
60
+ _is_cpu_amx_available = cpu_has_amx_support()
61
+ _is_cpu = is_cpu()
62
+
55
63
 
56
64
  def adjust_marlin_shard(param, shard_size, shard_offset):
57
65
  marlin_tile_size = getattr(param, "marlin_tile_size", None)
@@ -165,6 +173,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
165
173
  layer.register_parameter("weight", weight)
166
174
  set_weight_attrs(weight, extra_weight_attrs)
167
175
 
176
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
177
+ if _is_cpu and _is_cpu_amx_available:
178
+ _process_weight_after_loading(layer, ["weight"])
179
+
168
180
  def apply(
169
181
  self,
170
182
  layer: torch.nn.Module,
@@ -172,6 +184,11 @@ class UnquantizedLinearMethod(LinearMethodBase):
172
184
  bias: Optional[torch.Tensor] = None,
173
185
  ) -> torch.Tensor:
174
186
 
187
+ if getattr(layer, "use_intel_amx_backend", False):
188
+ return torch.ops.sgl_kernel.weight_packed_linear(
189
+ x, layer.weight, bias, True # is_vnni
190
+ )
191
+
175
192
  return F.linear(x, layer.weight, bias)
176
193
 
177
194
 
@@ -442,11 +442,20 @@ class LogitsProcessor(nn.Module):
442
442
  dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
443
443
 
444
444
  if hasattr(lm_head, "weight"):
445
- logits = torch.matmul(
446
- hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
447
- )
445
+ if getattr(lm_head, "use_intel_amx_backend", False):
446
+ logits = torch.ops.sgl_kernel.weight_packed_linear(
447
+ hidden_states.to(lm_head.weight.dtype),
448
+ lm_head.weight,
449
+ None, # bias
450
+ True, # is_vnni
451
+ )
452
+ else:
453
+ logits = torch.matmul(
454
+ hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
455
+ )
448
456
  else:
449
457
  # GGUF models
458
+ # TODO: use weight_packed_linear for GGUF models
450
459
  logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
451
460
 
452
461
  if self.logit_scale is not None:
@@ -54,10 +54,16 @@ from sglang.srt.utils import (
54
54
 
55
55
  _is_hip = is_hip()
56
56
  _is_fp8_fnuz = is_fp8_fnuz()
57
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
57
58
 
58
59
  if _is_hip:
59
60
  from vllm._custom_ops import scaled_fp8_quant
60
61
 
62
+ if _use_aiter:
63
+ from aiter import ActivationType, QuantType
64
+ from aiter.fused_moe import fused_moe
65
+ from aiter.ops.shuffle import shuffle_weight
66
+
61
67
  logger = logging.getLogger(__name__)
62
68
 
63
69
 
@@ -1046,6 +1052,15 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
1046
1052
  w2_weight_scale, requires_grad=False
1047
1053
  )
1048
1054
  layer.w2_input_scale = None
1055
+ if _use_aiter:
1056
+ layer.w13_weight = torch.nn.Parameter(
1057
+ shuffle_weight(layer.w13_weight.data, (16, 16)),
1058
+ requires_grad=False,
1059
+ )
1060
+ layer.w2_weight = torch.nn.Parameter(
1061
+ shuffle_weight(layer.w2_weight.data, (16, 16)),
1062
+ requires_grad=False,
1063
+ )
1049
1064
  return
1050
1065
 
1051
1066
  def apply(
@@ -1117,18 +1132,36 @@ class DeepEPMoE(EPMoE):
1117
1132
  assert (
1118
1133
  deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1119
1134
  ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
1120
- self.w13_weight_fp8 = (
1121
- self.w13_weight,
1122
- (
1123
- self.w13_weight_scale_inv
1124
- if self.use_block_quant
1125
- else self.w13_weight_scale
1126
- ),
1127
- )
1128
- self.w2_weight_fp8 = (
1129
- self.w2_weight,
1130
- self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
1131
- )
1135
+ if _use_aiter:
1136
+ # expert_mask is of size (self.num_experts_per_partition + 1),
1137
+ # the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
1138
+ # for instance, if we have 4 experts on this rank, we would have a expert_mask like:
1139
+ # self.expert_mask = [1, 1, 1, 1, 0]
1140
+ # idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
1141
+ self.expert_mask = torch.zeros(
1142
+ (self.num_experts_per_partition + 1),
1143
+ device=torch.cuda.current_device(),
1144
+ dtype=torch.int,
1145
+ )
1146
+ # the last one is invalid rank_id
1147
+ self.expert_mask[:-1] = 1
1148
+ else:
1149
+ self.w13_weight_fp8 = (
1150
+ self.w13_weight,
1151
+ (
1152
+ self.w13_weight_scale_inv
1153
+ if self.use_block_quant
1154
+ else self.w13_weight_scale
1155
+ ),
1156
+ )
1157
+ self.w2_weight_fp8 = (
1158
+ self.w2_weight,
1159
+ (
1160
+ self.w2_weight_scale_inv
1161
+ if self.use_block_quant
1162
+ else self.w2_weight_scale
1163
+ ),
1164
+ )
1132
1165
 
1133
1166
  def forward(
1134
1167
  self,
@@ -1142,6 +1175,9 @@ class DeepEPMoE(EPMoE):
1142
1175
  num_recv_tokens_per_expert: List[int],
1143
1176
  forward_mode: ForwardMode,
1144
1177
  ):
1178
+ if _use_aiter:
1179
+ # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
1180
+ return self.forward_aiter(hidden_states, topk_idx, topk_weights)
1145
1181
  resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
1146
1182
  if resolved_deepep_mode == DeepEPMode.normal:
1147
1183
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
@@ -1274,6 +1310,37 @@ class DeepEPMoE(EPMoE):
1274
1310
  )
1275
1311
  return down_output
1276
1312
 
1313
+ def forward_aiter(
1314
+ self,
1315
+ hidden_states: torch.Tensor,
1316
+ topk_idx: torch.Tensor,
1317
+ topk_weights: torch.Tensor,
1318
+ ):
1319
+ if hidden_states.shape[0] == 0:
1320
+ return hidden_states
1321
+ # in original deepep, idx == -1 meaning invalid and will not be processed.
1322
+ # aiter does not accept -1, we use a expert mask to make these idx invalid
1323
+ # (idx == num_experts_per_partition) meaning not used in aiter fused_moe
1324
+ topk_idx_copy = topk_idx.to(torch.int32)
1325
+ topk_idx_copy[topk_idx_copy == -1] = self.num_experts_per_partition
1326
+
1327
+ return fused_moe(
1328
+ hidden_states,
1329
+ self.w13_weight,
1330
+ self.w2_weight,
1331
+ topk_weights,
1332
+ topk_idx_copy,
1333
+ w1_scale=self.w13_weight_scale_inv,
1334
+ w2_scale=self.w2_weight_scale_inv,
1335
+ quant_type=QuantType.per_128x128,
1336
+ activation=(
1337
+ ActivationType.Silu
1338
+ if self.activation == "silu"
1339
+ else ActivationType.Gelu
1340
+ ),
1341
+ expert_mask=self.expert_mask,
1342
+ )
1343
+
1277
1344
  def forward_deepgemm_contiguous(
1278
1345
  self,
1279
1346
  hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
@@ -6,7 +6,13 @@ from sglang.srt.managers.expert_distribution import (
6
6
  get_global_expert_distribution_recorder,
7
7
  )
8
8
  from sglang.srt.managers.schedule_batch import global_server_args_dict
9
- from sglang.srt.utils import DeepEPMode, get_int_env_var, load_json_config
9
+ from sglang.srt.utils import (
10
+ DeepEPMode,
11
+ get_bool_env_var,
12
+ get_int_env_var,
13
+ is_hip,
14
+ load_json_config,
15
+ )
10
16
 
11
17
  try:
12
18
  from deep_ep import Buffer, Config
@@ -32,6 +38,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
32
38
  )
33
39
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
34
40
 
41
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
42
+
35
43
  logger = logging.getLogger(__name__)
36
44
 
37
45
 
@@ -376,6 +384,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
376
384
  Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
377
385
  https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
378
386
  """
387
+ if _use_aiter:
388
+ # skip permutation here as aiter fused_moe has fused inside
389
+ reorder_topk_ids = torch.empty(
390
+ (0,), device=hidden_states.device, dtype=torch.int64
391
+ )
392
+ seg_indptr = torch.zeros(
393
+ (self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
394
+ )
395
+ return reorder_topk_ids, seg_indptr, hidden_states
379
396
 
380
397
  reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
381
398
  topk_idx, self.num_experts
@@ -409,7 +426,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
409
426
  topk_idx: torch.Tensor,
410
427
  topk_weights: torch.Tensor,
411
428
  ):
412
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
429
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
413
430
  output = hidden_states
414
431
  else:
415
432
  if hidden_states.shape[0] > 0:
@@ -77,8 +77,15 @@ def moe_forward_native(
77
77
  custom_routing_function: Optional[Callable] = None,
78
78
  correction_bias: Optional[torch.Tensor] = None,
79
79
  activation: str = "silu",
80
+ apply_router_weight_on_input: bool = False,
81
+ inplace: bool = True,
82
+ no_combine: bool = False,
80
83
  routed_scaling_factor: Optional[float] = None,
81
84
  ) -> torch.Tensor:
85
+
86
+ if apply_router_weight_on_input:
87
+ raise NotImplementedError()
88
+
82
89
  topk_weights, topk_ids = select_experts(
83
90
  hidden_states=x,
84
91
  router_logits=router_logits,
@@ -750,9 +750,11 @@ def moe_align_block_size(
750
750
  by block_size for proper block matrix operations.
751
751
  """
752
752
  max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
753
- sorted_ids, cumsum_buffer = init_sorted_ids_and_cumsum_buffer(
754
- max_num_tokens_padded, topk_ids.numel(), num_experts, topk_ids.device
753
+ sorted_ids = torch.empty(
754
+ (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
755
755
  )
756
+ sorted_ids.fill_(topk_ids.numel())
757
+
756
758
  max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
757
759
  expert_ids = torch.empty(
758
760
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
@@ -768,6 +770,9 @@ def moe_align_block_size(
768
770
  num_tokens_post_pad,
769
771
  )
770
772
  else:
773
+ cumsum_buffer = torch.empty(
774
+ (num_experts + 1,), dtype=torch.int32, device=topk_ids.device
775
+ )
771
776
  token_cnts_buffer = torch.empty(
772
777
  (num_experts + 1) * num_experts,
773
778
  dtype=torch.int32,