sglang 0.4.8__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +3 -2
- sglang/srt/disaggregation/utils.py +12 -11
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/openai/protocol.py +47 -4
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/attention/flashattention_backend.py +24 -14
- sglang/srt/layers/layernorm.py +15 -0
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +12 -3
- sglang/srt/layers/moe/ep_moe/layer.py +79 -12
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +73 -14
- sglang/srt/layers/moe/topk.py +26 -0
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/rotary_embedding.py +103 -11
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +10 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +9 -1
- sglang/srt/managers/scheduler.py +42 -6
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -2
- sglang/srt/model_loader/loader.py +45 -10
- sglang/srt/model_loader/weight_utils.py +89 -0
- sglang/srt/models/deepseek_nextn.py +7 -4
- sglang/srt/models/deepseek_v2.py +147 -4
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/server_args.py +16 -2
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +71 -0
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +54 -49
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ from sglang.srt.entrypoints.openai.protocol import (
|
|
22
22
|
ErrorResponse,
|
23
23
|
FunctionResponse,
|
24
24
|
LogProbs,
|
25
|
+
MessageProcessingResult,
|
25
26
|
ToolCall,
|
26
27
|
TopLogprob,
|
27
28
|
)
|
@@ -62,120 +63,81 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
62
63
|
is_multimodal = self.tokenizer_manager.model_config.is_multimodal
|
63
64
|
|
64
65
|
# Process messages and apply chat template
|
65
|
-
(
|
66
|
-
prompt,
|
67
|
-
prompt_ids,
|
68
|
-
image_data,
|
69
|
-
audio_data,
|
70
|
-
modalities,
|
71
|
-
stop,
|
72
|
-
tool_call_constraint,
|
73
|
-
) = self._process_messages(request, is_multimodal)
|
66
|
+
processed_messages = self._process_messages(request, is_multimodal)
|
74
67
|
|
75
68
|
# Build sampling parameters
|
76
69
|
sampling_params = self._build_sampling_params(
|
77
|
-
request, stop, tool_call_constraint
|
70
|
+
request, processed_messages.stop, processed_messages.tool_call_constraint
|
78
71
|
)
|
79
72
|
|
80
73
|
# Handle single vs multiple requests
|
81
74
|
if is_multimodal:
|
82
|
-
prompt_kwargs = {"text": prompt}
|
75
|
+
prompt_kwargs = {"text": processed_messages.prompt}
|
83
76
|
else:
|
84
|
-
if isinstance(prompt_ids, str):
|
85
|
-
prompt_kwargs = {"text": prompt_ids}
|
77
|
+
if isinstance(processed_messages.prompt_ids, str):
|
78
|
+
prompt_kwargs = {"text": processed_messages.prompt_ids}
|
86
79
|
else:
|
87
|
-
prompt_kwargs = {"input_ids": prompt_ids}
|
80
|
+
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
|
88
81
|
|
89
82
|
adapted_request = GenerateReqInput(
|
90
83
|
**prompt_kwargs,
|
91
|
-
image_data=image_data,
|
92
|
-
audio_data=audio_data,
|
84
|
+
image_data=processed_messages.image_data,
|
85
|
+
audio_data=processed_messages.audio_data,
|
93
86
|
sampling_params=sampling_params,
|
94
87
|
return_logprob=request.logprobs,
|
95
88
|
logprob_start_len=-1,
|
96
89
|
top_logprobs_num=request.top_logprobs or 0,
|
97
90
|
stream=request.stream,
|
98
91
|
return_text_in_logprobs=True,
|
99
|
-
modalities=modalities,
|
92
|
+
modalities=processed_messages.modalities,
|
100
93
|
lora_path=request.lora_path,
|
101
94
|
bootstrap_host=request.bootstrap_host,
|
102
95
|
bootstrap_port=request.bootstrap_port,
|
103
96
|
bootstrap_room=request.bootstrap_room,
|
104
97
|
return_hidden_states=request.return_hidden_states,
|
98
|
+
rid=request.rid,
|
105
99
|
)
|
106
100
|
|
107
101
|
return adapted_request, request
|
108
102
|
|
109
103
|
def _process_messages(
|
110
104
|
self, request: ChatCompletionRequest, is_multimodal: bool
|
111
|
-
) ->
|
112
|
-
str,
|
113
|
-
Union[str, List[int]],
|
114
|
-
Optional[Any],
|
115
|
-
Optional[Any],
|
116
|
-
List[str],
|
117
|
-
List[str],
|
118
|
-
Optional[Any],
|
119
|
-
]:
|
105
|
+
) -> MessageProcessingResult:
|
120
106
|
"""Process chat messages and apply chat template"""
|
121
107
|
tool_call_constraint = None
|
122
|
-
prompt = ""
|
123
|
-
prompt_ids = []
|
124
108
|
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
tools = [item.function.model_dump() for item in request.tools]
|
109
|
+
# Apply chat template and its stop strings
|
110
|
+
tools = None
|
111
|
+
if request.tools and request.tool_choice != "none":
|
112
|
+
request.skip_special_tokens = False
|
113
|
+
if not isinstance(request.tool_choice, str):
|
114
|
+
tools = [
|
115
|
+
item.function.model_dump()
|
116
|
+
for item in request.tools
|
117
|
+
if item.function.name == request.tool_choice.function.name
|
118
|
+
]
|
119
|
+
else:
|
120
|
+
tools = [item.function.model_dump() for item in request.tools]
|
138
121
|
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
request.tool_choice
|
143
|
-
)
|
122
|
+
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
123
|
+
parser = FunctionCallParser(request.tools, tool_call_parser)
|
124
|
+
tool_call_constraint = parser.get_structure_constraint(request.tool_choice)
|
144
125
|
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
self._apply_jinja_template(request, tools, is_multimodal)
|
149
|
-
)
|
150
|
-
else:
|
151
|
-
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
|
152
|
-
self._apply_conversation_template(request, is_multimodal)
|
153
|
-
)
|
126
|
+
# Use chat template
|
127
|
+
if self.template_manager.chat_template_name is None:
|
128
|
+
result = self._apply_jinja_template(request, tools, is_multimodal)
|
154
129
|
else:
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
audio_data = None
|
160
|
-
modalities = []
|
161
|
-
prompt = request.messages
|
162
|
-
|
163
|
-
return (
|
164
|
-
prompt,
|
165
|
-
prompt_ids,
|
166
|
-
image_data,
|
167
|
-
audio_data,
|
168
|
-
modalities,
|
169
|
-
stop,
|
170
|
-
tool_call_constraint,
|
171
|
-
)
|
130
|
+
result = self._apply_conversation_template(request, is_multimodal)
|
131
|
+
|
132
|
+
result.tool_call_constraint = tool_call_constraint
|
133
|
+
return result
|
172
134
|
|
173
135
|
def _apply_jinja_template(
|
174
136
|
self,
|
175
137
|
request: ChatCompletionRequest,
|
176
138
|
tools: Optional[List[Dict]],
|
177
139
|
is_multimodal: bool,
|
178
|
-
) ->
|
140
|
+
) -> MessageProcessingResult:
|
179
141
|
"""Apply Jinja chat template"""
|
180
142
|
prompt = ""
|
181
143
|
prompt_ids = []
|
@@ -253,13 +215,20 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
253
215
|
image_data = image_data if image_data else None
|
254
216
|
audio_data = audio_data if audio_data else None
|
255
217
|
modalities = modalities if modalities else []
|
256
|
-
return
|
218
|
+
return MessageProcessingResult(
|
219
|
+
prompt=prompt,
|
220
|
+
prompt_ids=prompt_ids,
|
221
|
+
image_data=image_data,
|
222
|
+
audio_data=audio_data,
|
223
|
+
modalities=modalities,
|
224
|
+
stop=stop,
|
225
|
+
)
|
257
226
|
|
258
227
|
def _apply_conversation_template(
|
259
228
|
self,
|
260
229
|
request: ChatCompletionRequest,
|
261
230
|
is_multimodal: bool,
|
262
|
-
) ->
|
231
|
+
) -> MessageProcessingResult:
|
263
232
|
"""Apply conversation template"""
|
264
233
|
prompt = ""
|
265
234
|
prompt_ids = []
|
@@ -304,7 +273,14 @@ class OpenAIServingChat(OpenAIServingBase):
|
|
304
273
|
if not is_multimodal:
|
305
274
|
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
|
306
275
|
|
307
|
-
return
|
276
|
+
return MessageProcessingResult(
|
277
|
+
prompt=prompt,
|
278
|
+
prompt_ids=prompt_ids,
|
279
|
+
image_data=image_data,
|
280
|
+
audio_data=audio_data,
|
281
|
+
modalities=modalities,
|
282
|
+
stop=stop,
|
283
|
+
)
|
308
284
|
|
309
285
|
def _build_sampling_params(
|
310
286
|
self,
|
sglang/srt/layers/activation.py
CHANGED
@@ -48,6 +48,9 @@ if _is_cuda:
|
|
48
48
|
|
49
49
|
logger = logging.getLogger(__name__)
|
50
50
|
|
51
|
+
if is_npu():
|
52
|
+
import torch_npu
|
53
|
+
|
51
54
|
|
52
55
|
class SiluAndMul(CustomOp):
|
53
56
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -70,6 +73,10 @@ class SiluAndMul(CustomOp):
|
|
70
73
|
else:
|
71
74
|
return self.forward_native(x)
|
72
75
|
|
76
|
+
def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
|
77
|
+
out = torch_npu.npu_swiglu(x)
|
78
|
+
return out
|
79
|
+
|
73
80
|
|
74
81
|
class GeluAndMul(CustomOp):
|
75
82
|
def __init__(self, approximate="tanh"):
|
@@ -657,12 +657,16 @@ class FlashAttentionBackend(AttentionBackend):
|
|
657
657
|
)
|
658
658
|
k_descale, v_descale = None, None
|
659
659
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
660
|
-
# has corresponding quantization method so that layer.k_scale is not None
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
660
|
+
# has corresponding quantization method so that layer.k_scale is not None,
|
661
|
+
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
|
662
|
+
if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
|
663
|
+
if layer.k_scale is not None:
|
664
|
+
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
665
|
+
k_descale = layer.k_scale.expand(descale_shape)
|
666
|
+
v_descale = layer.v_scale.expand(descale_shape)
|
665
667
|
q = q.to(self.kv_cache_dtype)
|
668
|
+
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
|
669
|
+
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
|
666
670
|
causal = not layer.is_cross_attention
|
667
671
|
|
668
672
|
# Check if we should use local attention
|
@@ -776,8 +780,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
776
780
|
|
777
781
|
output, lse, *rest = flash_attn_varlen_func(
|
778
782
|
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
779
|
-
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
780
|
-
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
783
|
+
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
784
|
+
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
781
785
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
782
786
|
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
783
787
|
max_seqlen_q=metadata.max_seq_len_q,
|
@@ -790,8 +794,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
790
794
|
# MHA for extend part of sequence without attending prefix kv cache
|
791
795
|
output, lse, *rest = flash_attn_varlen_func(
|
792
796
|
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
793
|
-
k=k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
794
|
-
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
|
797
|
+
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
798
|
+
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
795
799
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
796
800
|
cu_seqlens_k=metadata.cu_seqlens_q,
|
797
801
|
max_seqlen_q=metadata.max_seq_len_q,
|
@@ -803,7 +807,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
803
807
|
return output, lse
|
804
808
|
else:
|
805
809
|
# Do absorbed multi-latent attention
|
806
|
-
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
810
|
+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
811
|
+
layer.layer_id
|
812
|
+
).to(q.dtype)
|
807
813
|
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
808
814
|
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
809
815
|
k_rope_cache = k_rope.view(
|
@@ -933,14 +939,16 @@ class FlashAttentionBackend(AttentionBackend):
|
|
933
939
|
|
934
940
|
k_descale, v_descale = None, None
|
935
941
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
936
|
-
# has corresponding quantization method so that layer.k_scale is not None
|
937
|
-
|
942
|
+
# has corresponding quantization method so that layer.k_scale is not None,
|
943
|
+
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
|
944
|
+
if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
|
938
945
|
if layer.k_scale is not None:
|
939
946
|
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
940
947
|
k_descale = layer.k_scale.expand(descale_shape)
|
941
948
|
v_descale = layer.v_scale.expand(descale_shape)
|
942
949
|
q = q.to(self.kv_cache_dtype)
|
943
|
-
|
950
|
+
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
|
951
|
+
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
|
944
952
|
if not self.use_mla:
|
945
953
|
# Do multi-head attention
|
946
954
|
|
@@ -1048,7 +1056,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1048
1056
|
o = result
|
1049
1057
|
else:
|
1050
1058
|
# Do absorbed multi-latent attention
|
1051
|
-
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
1059
|
+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
|
1060
|
+
q.dtype
|
1061
|
+
)
|
1052
1062
|
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
1053
1063
|
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
1054
1064
|
k_rope_cache = k_rope.view(
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -52,6 +52,9 @@ elif _is_hip:
|
|
52
52
|
|
53
53
|
logger = logging.getLogger(__name__)
|
54
54
|
|
55
|
+
if is_npu():
|
56
|
+
import torch_npu
|
57
|
+
|
55
58
|
|
56
59
|
class RMSNorm(CustomOp):
|
57
60
|
def __init__(
|
@@ -76,6 +79,18 @@ class RMSNorm(CustomOp):
|
|
76
79
|
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
77
80
|
return out
|
78
81
|
|
82
|
+
def forward_npu(
|
83
|
+
self,
|
84
|
+
x: torch.Tensor,
|
85
|
+
residual: Optional[torch.Tensor] = None,
|
86
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
87
|
+
if residual is not None:
|
88
|
+
out, _, residual_out = torch_npu.npu_add_rms_norm(
|
89
|
+
residual, x, self.weight.data, self.variance_epsilon
|
90
|
+
)
|
91
|
+
return out, residual_out
|
92
|
+
return torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
|
93
|
+
|
79
94
|
def forward_aiter(
|
80
95
|
self,
|
81
96
|
x: torch.Tensor,
|
sglang/srt/layers/linear.py
CHANGED
@@ -30,7 +30,12 @@ from sglang.srt.layers.quantization.base_config import (
|
|
30
30
|
QuantizationConfig,
|
31
31
|
QuantizeMethodBase,
|
32
32
|
)
|
33
|
-
from sglang.srt.utils import
|
33
|
+
from sglang.srt.utils import (
|
34
|
+
_process_weight_after_loading,
|
35
|
+
cpu_has_amx_support,
|
36
|
+
is_cpu,
|
37
|
+
set_weight_attrs,
|
38
|
+
)
|
34
39
|
|
35
40
|
logger = logging.getLogger(__name__)
|
36
41
|
|
@@ -52,6 +57,9 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|
52
57
|
"IPEXAWQLinearMethod",
|
53
58
|
]
|
54
59
|
|
60
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
61
|
+
_is_cpu = is_cpu()
|
62
|
+
|
55
63
|
|
56
64
|
def adjust_marlin_shard(param, shard_size, shard_offset):
|
57
65
|
marlin_tile_size = getattr(param, "marlin_tile_size", None)
|
@@ -165,6 +173,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|
165
173
|
layer.register_parameter("weight", weight)
|
166
174
|
set_weight_attrs(weight, extra_weight_attrs)
|
167
175
|
|
176
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
177
|
+
if _is_cpu and _is_cpu_amx_available:
|
178
|
+
_process_weight_after_loading(layer, ["weight"])
|
179
|
+
|
168
180
|
def apply(
|
169
181
|
self,
|
170
182
|
layer: torch.nn.Module,
|
@@ -172,6 +184,11 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|
172
184
|
bias: Optional[torch.Tensor] = None,
|
173
185
|
) -> torch.Tensor:
|
174
186
|
|
187
|
+
if getattr(layer, "use_intel_amx_backend", False):
|
188
|
+
return torch.ops.sgl_kernel.weight_packed_linear(
|
189
|
+
x, layer.weight, bias, True # is_vnni
|
190
|
+
)
|
191
|
+
|
175
192
|
return F.linear(x, layer.weight, bias)
|
176
193
|
|
177
194
|
|
@@ -442,11 +442,20 @@ class LogitsProcessor(nn.Module):
|
|
442
442
|
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
443
443
|
|
444
444
|
if hasattr(lm_head, "weight"):
|
445
|
-
|
446
|
-
|
447
|
-
|
445
|
+
if getattr(lm_head, "use_intel_amx_backend", False):
|
446
|
+
logits = torch.ops.sgl_kernel.weight_packed_linear(
|
447
|
+
hidden_states.to(lm_head.weight.dtype),
|
448
|
+
lm_head.weight,
|
449
|
+
None, # bias
|
450
|
+
True, # is_vnni
|
451
|
+
)
|
452
|
+
else:
|
453
|
+
logits = torch.matmul(
|
454
|
+
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
|
455
|
+
)
|
448
456
|
else:
|
449
457
|
# GGUF models
|
458
|
+
# TODO: use weight_packed_linear for GGUF models
|
450
459
|
logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
|
451
460
|
|
452
461
|
if self.logit_scale is not None:
|
@@ -54,10 +54,16 @@ from sglang.srt.utils import (
|
|
54
54
|
|
55
55
|
_is_hip = is_hip()
|
56
56
|
_is_fp8_fnuz = is_fp8_fnuz()
|
57
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
57
58
|
|
58
59
|
if _is_hip:
|
59
60
|
from vllm._custom_ops import scaled_fp8_quant
|
60
61
|
|
62
|
+
if _use_aiter:
|
63
|
+
from aiter import ActivationType, QuantType
|
64
|
+
from aiter.fused_moe import fused_moe
|
65
|
+
from aiter.ops.shuffle import shuffle_weight
|
66
|
+
|
61
67
|
logger = logging.getLogger(__name__)
|
62
68
|
|
63
69
|
|
@@ -1046,6 +1052,15 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
1046
1052
|
w2_weight_scale, requires_grad=False
|
1047
1053
|
)
|
1048
1054
|
layer.w2_input_scale = None
|
1055
|
+
if _use_aiter:
|
1056
|
+
layer.w13_weight = torch.nn.Parameter(
|
1057
|
+
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
1058
|
+
requires_grad=False,
|
1059
|
+
)
|
1060
|
+
layer.w2_weight = torch.nn.Parameter(
|
1061
|
+
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
1062
|
+
requires_grad=False,
|
1063
|
+
)
|
1049
1064
|
return
|
1050
1065
|
|
1051
1066
|
def apply(
|
@@ -1117,18 +1132,36 @@ class DeepEPMoE(EPMoE):
|
|
1117
1132
|
assert (
|
1118
1133
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
1119
1134
|
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
1120
|
-
|
1121
|
-
self.
|
1122
|
-
(
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1135
|
+
if _use_aiter:
|
1136
|
+
# expert_mask is of size (self.num_experts_per_partition + 1),
|
1137
|
+
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
|
1138
|
+
# for instance, if we have 4 experts on this rank, we would have a expert_mask like:
|
1139
|
+
# self.expert_mask = [1, 1, 1, 1, 0]
|
1140
|
+
# idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
|
1141
|
+
self.expert_mask = torch.zeros(
|
1142
|
+
(self.num_experts_per_partition + 1),
|
1143
|
+
device=torch.cuda.current_device(),
|
1144
|
+
dtype=torch.int,
|
1145
|
+
)
|
1146
|
+
# the last one is invalid rank_id
|
1147
|
+
self.expert_mask[:-1] = 1
|
1148
|
+
else:
|
1149
|
+
self.w13_weight_fp8 = (
|
1150
|
+
self.w13_weight,
|
1151
|
+
(
|
1152
|
+
self.w13_weight_scale_inv
|
1153
|
+
if self.use_block_quant
|
1154
|
+
else self.w13_weight_scale
|
1155
|
+
),
|
1156
|
+
)
|
1157
|
+
self.w2_weight_fp8 = (
|
1158
|
+
self.w2_weight,
|
1159
|
+
(
|
1160
|
+
self.w2_weight_scale_inv
|
1161
|
+
if self.use_block_quant
|
1162
|
+
else self.w2_weight_scale
|
1163
|
+
),
|
1164
|
+
)
|
1132
1165
|
|
1133
1166
|
def forward(
|
1134
1167
|
self,
|
@@ -1142,6 +1175,9 @@ class DeepEPMoE(EPMoE):
|
|
1142
1175
|
num_recv_tokens_per_expert: List[int],
|
1143
1176
|
forward_mode: ForwardMode,
|
1144
1177
|
):
|
1178
|
+
if _use_aiter:
|
1179
|
+
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
1180
|
+
return self.forward_aiter(hidden_states, topk_idx, topk_weights)
|
1145
1181
|
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
1146
1182
|
if resolved_deepep_mode == DeepEPMode.normal:
|
1147
1183
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
@@ -1274,6 +1310,37 @@ class DeepEPMoE(EPMoE):
|
|
1274
1310
|
)
|
1275
1311
|
return down_output
|
1276
1312
|
|
1313
|
+
def forward_aiter(
|
1314
|
+
self,
|
1315
|
+
hidden_states: torch.Tensor,
|
1316
|
+
topk_idx: torch.Tensor,
|
1317
|
+
topk_weights: torch.Tensor,
|
1318
|
+
):
|
1319
|
+
if hidden_states.shape[0] == 0:
|
1320
|
+
return hidden_states
|
1321
|
+
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
1322
|
+
# aiter does not accept -1, we use a expert mask to make these idx invalid
|
1323
|
+
# (idx == num_experts_per_partition) meaning not used in aiter fused_moe
|
1324
|
+
topk_idx_copy = topk_idx.to(torch.int32)
|
1325
|
+
topk_idx_copy[topk_idx_copy == -1] = self.num_experts_per_partition
|
1326
|
+
|
1327
|
+
return fused_moe(
|
1328
|
+
hidden_states,
|
1329
|
+
self.w13_weight,
|
1330
|
+
self.w2_weight,
|
1331
|
+
topk_weights,
|
1332
|
+
topk_idx_copy,
|
1333
|
+
w1_scale=self.w13_weight_scale_inv,
|
1334
|
+
w2_scale=self.w2_weight_scale_inv,
|
1335
|
+
quant_type=QuantType.per_128x128,
|
1336
|
+
activation=(
|
1337
|
+
ActivationType.Silu
|
1338
|
+
if self.activation == "silu"
|
1339
|
+
else ActivationType.Gelu
|
1340
|
+
),
|
1341
|
+
expert_mask=self.expert_mask,
|
1342
|
+
)
|
1343
|
+
|
1277
1344
|
def forward_deepgemm_contiguous(
|
1278
1345
|
self,
|
1279
1346
|
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
@@ -6,7 +6,13 @@ from sglang.srt.managers.expert_distribution import (
|
|
6
6
|
get_global_expert_distribution_recorder,
|
7
7
|
)
|
8
8
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
|
-
from sglang.srt.utils import
|
9
|
+
from sglang.srt.utils import (
|
10
|
+
DeepEPMode,
|
11
|
+
get_bool_env_var,
|
12
|
+
get_int_env_var,
|
13
|
+
is_hip,
|
14
|
+
load_json_config,
|
15
|
+
)
|
10
16
|
|
11
17
|
try:
|
12
18
|
from deep_ep import Buffer, Config
|
@@ -32,6 +38,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
32
38
|
)
|
33
39
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
34
40
|
|
41
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
42
|
+
|
35
43
|
logger = logging.getLogger(__name__)
|
36
44
|
|
37
45
|
|
@@ -376,6 +384,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
376
384
|
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
377
385
|
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
|
378
386
|
"""
|
387
|
+
if _use_aiter:
|
388
|
+
# skip permutation here as aiter fused_moe has fused inside
|
389
|
+
reorder_topk_ids = torch.empty(
|
390
|
+
(0,), device=hidden_states.device, dtype=torch.int64
|
391
|
+
)
|
392
|
+
seg_indptr = torch.zeros(
|
393
|
+
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
394
|
+
)
|
395
|
+
return reorder_topk_ids, seg_indptr, hidden_states
|
379
396
|
|
380
397
|
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
381
398
|
topk_idx, self.num_experts
|
@@ -409,7 +426,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
409
426
|
topk_idx: torch.Tensor,
|
410
427
|
topk_weights: torch.Tensor,
|
411
428
|
):
|
412
|
-
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
429
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
|
413
430
|
output = hidden_states
|
414
431
|
else:
|
415
432
|
if hidden_states.shape[0] > 0:
|
@@ -77,8 +77,15 @@ def moe_forward_native(
|
|
77
77
|
custom_routing_function: Optional[Callable] = None,
|
78
78
|
correction_bias: Optional[torch.Tensor] = None,
|
79
79
|
activation: str = "silu",
|
80
|
+
apply_router_weight_on_input: bool = False,
|
81
|
+
inplace: bool = True,
|
82
|
+
no_combine: bool = False,
|
80
83
|
routed_scaling_factor: Optional[float] = None,
|
81
84
|
) -> torch.Tensor:
|
85
|
+
|
86
|
+
if apply_router_weight_on_input:
|
87
|
+
raise NotImplementedError()
|
88
|
+
|
82
89
|
topk_weights, topk_ids = select_experts(
|
83
90
|
hidden_states=x,
|
84
91
|
router_logits=router_logits,
|
@@ -750,9 +750,11 @@ def moe_align_block_size(
|
|
750
750
|
by block_size for proper block matrix operations.
|
751
751
|
"""
|
752
752
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
753
|
-
sorted_ids
|
754
|
-
max_num_tokens_padded,
|
753
|
+
sorted_ids = torch.empty(
|
754
|
+
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
755
755
|
)
|
756
|
+
sorted_ids.fill_(topk_ids.numel())
|
757
|
+
|
756
758
|
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
757
759
|
expert_ids = torch.empty(
|
758
760
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
@@ -768,6 +770,9 @@ def moe_align_block_size(
|
|
768
770
|
num_tokens_post_pad,
|
769
771
|
)
|
770
772
|
else:
|
773
|
+
cumsum_buffer = torch.empty(
|
774
|
+
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
|
775
|
+
)
|
771
776
|
token_cnts_buffer = torch.empty(
|
772
777
|
(num_experts + 1) * num_experts,
|
773
778
|
dtype=torch.int32,
|