sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -117
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +3 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +22 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +8 -5
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +106 -15
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +55 -13
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
- sglang/srt/layers/attention/vision.py +40 -15
- sglang/srt/layers/communicator.py +35 -8
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +9 -8
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +87 -107
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +59 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +8 -7
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +15 -4
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +10 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +61 -32
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +21 -4
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +30 -8
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +170 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +59 -22
- sglang/srt/managers/tokenizer_manager.py +137 -67
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -21
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +48 -17
- sglang/srt/model_executor/model_runner.py +24 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +95 -50
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +102 -27
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/qwen3_moe.py +39 -14
- sglang/srt/models/step3_vl.py +10 -1
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +218 -23
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +163 -9
- sglang/srt/utils.py +41 -26
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/runners.py +4 -4
- sglang/test/test_utils.py +4 -4
- sglang/version.py +1 -1
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -6,15 +6,50 @@ from typing import TYPE_CHECKING, Optional
|
|
6
6
|
|
7
7
|
import torch
|
8
8
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
9
|
-
from triton_kernels.matmul_ogs import
|
9
|
+
from triton_kernels.matmul_ogs import (
|
10
|
+
FlexCtx,
|
11
|
+
FnSpecs,
|
12
|
+
FusedActivation,
|
13
|
+
PrecisionConfig,
|
14
|
+
matmul_ogs,
|
15
|
+
)
|
16
|
+
from triton_kernels.numerics import InFlexData
|
10
17
|
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
11
|
-
|
12
|
-
from sglang.srt.utils import direct_register_custom_op
|
18
|
+
from triton_kernels.swiglu import swiglu_fn
|
13
19
|
|
14
20
|
if TYPE_CHECKING:
|
15
21
|
from sglang.srt.layers.moe.topk import TopKOutput
|
16
22
|
|
17
23
|
|
24
|
+
def quantize(w, dtype, dev, **opt):
|
25
|
+
if dtype == "bf16":
|
26
|
+
return w.to(torch.bfloat16), InFlexData()
|
27
|
+
elif dtype == "fp8":
|
28
|
+
wq = w.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous().transpose(-1, -2)
|
29
|
+
return (
|
30
|
+
wq,
|
31
|
+
InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)),
|
32
|
+
MicroscalingCtx(),
|
33
|
+
)
|
34
|
+
else:
|
35
|
+
assert dtype == "mx4", f"{dtype=}"
|
36
|
+
swizzle_mx_scale = opt["swizzle_mx_scale"]
|
37
|
+
swizzle_axis = 2 if swizzle_mx_scale else None
|
38
|
+
w = w.to(torch.bfloat16)
|
39
|
+
w, mx_scales, weight_scale_shape = downcast_to_mxfp(
|
40
|
+
w, torch.uint8, axis=1, swizzle_axis=swizzle_axis
|
41
|
+
)
|
42
|
+
return (
|
43
|
+
w,
|
44
|
+
InFlexData(),
|
45
|
+
MicroscalingCtx(
|
46
|
+
weight_scale=mx_scales,
|
47
|
+
swizzle_mx=swizzle_mx_scale,
|
48
|
+
actual_weight_scale_shape=weight_scale_shape,
|
49
|
+
),
|
50
|
+
)
|
51
|
+
|
52
|
+
|
18
53
|
def triton_kernel_moe_forward(
|
19
54
|
hidden_states: torch.Tensor,
|
20
55
|
w1: torch.Tensor,
|
@@ -148,16 +183,17 @@ def triton_kernel_fused_experts(
|
|
148
183
|
return intermediate_cache3
|
149
184
|
|
150
185
|
|
151
|
-
def
|
186
|
+
def triton_kernel_moe_with_bias_forward(
|
152
187
|
hidden_states: torch.Tensor,
|
153
188
|
w1: torch.Tensor,
|
189
|
+
w1_pcg,
|
190
|
+
b1: torch.Tensor,
|
154
191
|
w2: torch.Tensor,
|
155
|
-
|
156
|
-
|
157
|
-
|
192
|
+
w2_pcg,
|
193
|
+
b2: torch.Tensor,
|
194
|
+
topk_output: TopKOutput,
|
158
195
|
inplace: bool = False,
|
159
196
|
activation: str = "silu",
|
160
|
-
apply_router_weight_on_input: bool = False,
|
161
197
|
use_fp8_w8a8: bool = False,
|
162
198
|
per_channel_quant: bool = False,
|
163
199
|
global_num_experts: int = -1,
|
@@ -167,13 +203,131 @@ def triton_kernel_moe_forward_fake(
|
|
167
203
|
a1_scale: Optional[torch.Tensor] = None,
|
168
204
|
a2_scale: Optional[torch.Tensor] = None,
|
169
205
|
block_shape: Optional[list[int]] = None,
|
206
|
+
activation_alpha: Optional[float] = None,
|
207
|
+
swiglu_limit: Optional[int] = None,
|
170
208
|
) -> torch.Tensor:
|
171
|
-
|
209
|
+
assert topk_output.format.is_triton_kernel()
|
210
|
+
routing_data, gather_idx, scatter_idx = topk_output
|
211
|
+
|
212
|
+
return triton_kernel_fused_experts_with_bias(
|
213
|
+
hidden_states,
|
214
|
+
w1=w1,
|
215
|
+
w1_pcg=w1_pcg,
|
216
|
+
b1=b1,
|
217
|
+
w2=w2,
|
218
|
+
w2_pcg=w2_pcg,
|
219
|
+
b2=b2,
|
220
|
+
routing_data=routing_data,
|
221
|
+
gather_indx=gather_idx,
|
222
|
+
scatter_indx=scatter_idx,
|
223
|
+
inplace=inplace,
|
224
|
+
activation=activation,
|
225
|
+
use_fp8_w8a8=use_fp8_w8a8,
|
226
|
+
per_channel_quant=per_channel_quant,
|
227
|
+
global_num_experts=global_num_experts,
|
228
|
+
expert_map=expert_map,
|
229
|
+
w1_scale=w1_scale,
|
230
|
+
w2_scale=w2_scale,
|
231
|
+
a1_scale=a1_scale,
|
232
|
+
a2_scale=a2_scale,
|
233
|
+
block_shape=block_shape,
|
234
|
+
activation_alpha=activation_alpha,
|
235
|
+
swiglu_limit=swiglu_limit,
|
236
|
+
)
|
172
237
|
|
173
238
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
239
|
+
def triton_kernel_fused_experts_with_bias(
|
240
|
+
hidden_states: torch.Tensor,
|
241
|
+
w1: torch.Tensor,
|
242
|
+
w1_pcg,
|
243
|
+
b1: torch.Tensor,
|
244
|
+
w2: torch.Tensor,
|
245
|
+
w2_pcg,
|
246
|
+
b2: torch.Tensor,
|
247
|
+
routing_data: RoutingData,
|
248
|
+
gather_indx: GatherIndx,
|
249
|
+
scatter_indx: ScatterIndx,
|
250
|
+
inplace: bool = False,
|
251
|
+
activation: str = "silu",
|
252
|
+
use_fp8_w8a8: bool = False,
|
253
|
+
per_channel_quant: bool = False,
|
254
|
+
global_num_experts: int = -1,
|
255
|
+
expert_map: Optional[torch.Tensor] = None,
|
256
|
+
w1_scale: Optional[torch.Tensor] = None,
|
257
|
+
w2_scale: Optional[torch.Tensor] = None,
|
258
|
+
a1_scale: Optional[torch.Tensor] = None,
|
259
|
+
a2_scale: Optional[torch.Tensor] = None,
|
260
|
+
block_shape: Optional[list[int]] = None,
|
261
|
+
activation_alpha: Optional[float] = None,
|
262
|
+
swiglu_limit: Optional[int] = None,
|
263
|
+
) -> torch.Tensor:
|
264
|
+
# print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
|
265
|
+
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
|
266
|
+
assert per_channel_quant == False, "per_channel_quant is not supported"
|
267
|
+
assert expert_map == None, "expert_map is not supported"
|
268
|
+
assert w1_scale == None, "w1_scale is not supported"
|
269
|
+
assert w2_scale == None, "w2_scale is not supported"
|
270
|
+
assert a1_scale == None, "a1_scale is not supported"
|
271
|
+
assert a2_scale == None, "a2_scale is not supported"
|
272
|
+
assert block_shape == None, "block_shape is not supported"
|
273
|
+
|
274
|
+
# type check
|
275
|
+
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
|
276
|
+
for w in (w1, w2):
|
277
|
+
# TODO assert bf16 or mxfp4
|
278
|
+
# assert (w.dtype == torch.bfloat16) or check-is-mxfp4, f"w must be bfloat16 or mxfp4 {w1.dtype=}"
|
279
|
+
pass
|
280
|
+
|
281
|
+
# Shape check
|
282
|
+
assert hidden_states.ndim == 2, "hidden_states must be 2D"
|
283
|
+
assert (
|
284
|
+
hidden_states.shape[-1] == w1.shape[-2]
|
285
|
+
), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
|
286
|
+
assert (
|
287
|
+
w2.shape[-1] == w1.shape[1]
|
288
|
+
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
|
289
|
+
|
290
|
+
# feature check
|
291
|
+
assert inplace == False, "Inplace is not supported in new triton MoE kernel"
|
292
|
+
|
293
|
+
E, _, _ = w1.shape
|
294
|
+
|
295
|
+
if global_num_experts == -1:
|
296
|
+
global_num_experts = E
|
297
|
+
|
298
|
+
# TODO maybe completely remove this branch
|
299
|
+
if w1.dtype == torch.bfloat16:
|
300
|
+
device = "cuda"
|
301
|
+
optg = dict()
|
302
|
+
w1, w1_flex = quantize(w1, "bf16", device, **optg)
|
303
|
+
w1_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex))
|
304
|
+
|
305
|
+
w2, w2_flex = quantize(w2, "bf16", device, **optg)
|
306
|
+
w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex))
|
307
|
+
|
308
|
+
act = FusedActivation(
|
309
|
+
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
|
310
|
+
(activation_alpha, swiglu_limit),
|
311
|
+
2,
|
312
|
+
)
|
313
|
+
|
314
|
+
intermediate_cache = matmul_ogs(
|
315
|
+
hidden_states,
|
316
|
+
w1,
|
317
|
+
b1,
|
318
|
+
routing_data,
|
319
|
+
gather_indx=gather_indx,
|
320
|
+
precision_config=w1_pcg,
|
321
|
+
gammas=None,
|
322
|
+
fused_activation=act,
|
323
|
+
)
|
324
|
+
|
325
|
+
return matmul_ogs(
|
326
|
+
intermediate_cache,
|
327
|
+
w2,
|
328
|
+
b2,
|
329
|
+
routing_data,
|
330
|
+
scatter_indx=scatter_indx,
|
331
|
+
precision_config=w2_pcg,
|
332
|
+
gammas=routing_data.gate_scal,
|
333
|
+
)
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
2
|
+
BaseDispatcher,
|
3
|
+
BaseDispatcherConfig,
|
4
|
+
DispatchOutput,
|
5
|
+
DispatchOutputFormat,
|
6
|
+
)
|
7
|
+
from sglang.srt.layers.moe.token_dispatcher.deepep import (
|
8
|
+
DeepEPConfig,
|
9
|
+
DeepEPDispatcher,
|
10
|
+
DeepEPLLOutput,
|
11
|
+
DeepEPNormalOutput,
|
12
|
+
)
|
13
|
+
|
14
|
+
__all__ = [
|
15
|
+
"BaseDispatcher",
|
16
|
+
"BaseDispatcherConfig",
|
17
|
+
"DispatchOutput",
|
18
|
+
"DispatchOutputFormat",
|
19
|
+
"DeepEPConfig",
|
20
|
+
"DeepEPDispatcher",
|
21
|
+
"DeepEPNormalOutput",
|
22
|
+
"DeepEPLLOutput",
|
23
|
+
]
|
@@ -2,11 +2,22 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from enum import Enum, auto
|
5
|
-
from typing import
|
5
|
+
from typing import Protocol, runtime_checkable
|
6
6
|
|
7
7
|
import torch
|
8
8
|
|
9
9
|
|
10
|
+
class MoEA2ABackend(Enum):
|
11
|
+
none = "none"
|
12
|
+
deepep = "deepep"
|
13
|
+
|
14
|
+
def is_none(self):
|
15
|
+
return self == MoEA2ABackend.none
|
16
|
+
|
17
|
+
def is_deepep(self):
|
18
|
+
return self == MoEA2ABackend.deepep
|
19
|
+
|
20
|
+
|
10
21
|
class DispatchOutputFormat(Enum):
|
11
22
|
standard = auto()
|
12
23
|
deepep_normal = auto()
|
@@ -1,5 +1,3 @@
|
|
1
|
-
# TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
|
2
|
-
|
3
1
|
from __future__ import annotations
|
4
2
|
|
5
3
|
import logging
|
@@ -22,15 +20,10 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
|
22
20
|
DispatchOutput,
|
23
21
|
DispatchOutputFormat,
|
24
22
|
)
|
23
|
+
from sglang.srt.layers.moe.utils import DeepEPMode
|
25
24
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
26
25
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
27
|
-
from sglang.srt.utils import
|
28
|
-
DeepEPMode,
|
29
|
-
get_bool_env_var,
|
30
|
-
get_int_env_var,
|
31
|
-
is_hip,
|
32
|
-
load_json_config,
|
33
|
-
)
|
26
|
+
from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config
|
34
27
|
|
35
28
|
try:
|
36
29
|
from deep_ep import Buffer, Config
|
@@ -150,9 +143,9 @@ class DeepEPBuffer:
|
|
150
143
|
num_rdma_bytes,
|
151
144
|
)
|
152
145
|
|
153
|
-
if deepep_mode == DeepEPMode.
|
146
|
+
if deepep_mode == DeepEPMode.NORMAL:
|
154
147
|
num_qps_per_rank = DeepEPConfig.get_instance().num_sms // 2
|
155
|
-
elif deepep_mode in [DeepEPMode.
|
148
|
+
elif deepep_mode in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]:
|
156
149
|
num_qps_per_rank = num_experts // group.size()
|
157
150
|
else:
|
158
151
|
raise NotImplementedError
|
@@ -161,7 +154,7 @@ class DeepEPBuffer:
|
|
161
154
|
device="cuda"
|
162
155
|
).multi_processor_count
|
163
156
|
if (
|
164
|
-
(deepep_mode != DeepEPMode.
|
157
|
+
(deepep_mode != DeepEPMode.LOW_LATENCY)
|
165
158
|
and not global_server_args_dict["enable_two_batch_overlap"]
|
166
159
|
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
|
167
160
|
):
|
@@ -611,7 +604,7 @@ class DeepEPDispatcher(BaseDispatcher):
|
|
611
604
|
num_local_experts: int = None,
|
612
605
|
hidden_size: int = None,
|
613
606
|
params_dtype: torch.dtype = None,
|
614
|
-
deepep_mode: DeepEPMode = DeepEPMode.
|
607
|
+
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
|
615
608
|
async_finish: bool = False,
|
616
609
|
return_recv_hook: bool = False,
|
617
610
|
):
|
@@ -697,9 +690,9 @@ class DeepEPDispatcher(BaseDispatcher):
|
|
697
690
|
resolved_deepep_mode = self.deepep_mode.resolve(
|
698
691
|
forward_batch.is_extend_in_batch
|
699
692
|
)
|
700
|
-
if resolved_deepep_mode == DeepEPMode.
|
693
|
+
if resolved_deepep_mode == DeepEPMode.NORMAL:
|
701
694
|
return self._normal_dispatcher
|
702
|
-
elif resolved_deepep_mode == DeepEPMode.
|
695
|
+
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
|
703
696
|
return self._low_latency_dispatcher
|
704
697
|
else:
|
705
698
|
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -185,8 +185,9 @@ class TopK(CustomOp):
|
|
185
185
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
186
186
|
) -> TopKOutput:
|
187
187
|
if self.use_triton_kernels:
|
188
|
+
# renormalize=True is equivalent to sm_first=False
|
188
189
|
routing_data, gather_idx, scatter_idx = routing(
|
189
|
-
router_logits, self.top_k, self.renormalize
|
190
|
+
router_logits, self.top_k, sm_first=not self.renormalize
|
190
191
|
)
|
191
192
|
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
192
193
|
else:
|
@@ -397,8 +398,12 @@ def grouped_topk_gpu(
|
|
397
398
|
.reshape(num_token, -1)
|
398
399
|
) # [n, e]
|
399
400
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
401
|
+
# TODO: NPU can't support directly evaluating a comparison for now
|
400
402
|
topk_weights, topk_ids = torch.topk(
|
401
|
-
tmp_scores,
|
403
|
+
tmp_scores,
|
404
|
+
k=topk,
|
405
|
+
dim=-1,
|
406
|
+
sorted=(True if num_fused_shared_experts > 0 else False),
|
402
407
|
)
|
403
408
|
if num_fused_shared_experts:
|
404
409
|
topk_ids[:, -1] = torch.randint(
|
@@ -488,8 +493,12 @@ def biased_grouped_topk_impl(
|
|
488
493
|
tmp_scores = scores_for_choice.masked_fill(
|
489
494
|
~score_mask.bool(), float("-inf")
|
490
495
|
) # [n, e]
|
496
|
+
# TODO: NPU can't support directly evaluating a comparison for now
|
491
497
|
_, topk_ids = torch.topk(
|
492
|
-
tmp_scores,
|
498
|
+
tmp_scores,
|
499
|
+
k=topk,
|
500
|
+
dim=-1,
|
501
|
+
sorted=(True if num_fused_shared_experts > 0 else False),
|
493
502
|
)
|
494
503
|
topk_weights = scores.gather(1, topk_ids)
|
495
504
|
|
@@ -0,0 +1,59 @@
|
|
1
|
+
import importlib.util
|
2
|
+
from enum import Enum
|
3
|
+
from functools import lru_cache
|
4
|
+
|
5
|
+
from packaging import version as pkg_version
|
6
|
+
|
7
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
8
|
+
|
9
|
+
|
10
|
+
@lru_cache(maxsize=1)
|
11
|
+
def should_use_flashinfer_trtllm_moe():
|
12
|
+
result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
|
13
|
+
not importlib.util.find_spec("flashinfer")
|
14
|
+
or pkg_version.parse(__import__("flashinfer").__version__)
|
15
|
+
>= pkg_version.parse("0.2.9rc1")
|
16
|
+
)
|
17
|
+
return result
|
18
|
+
|
19
|
+
|
20
|
+
class MoeA2ABackend(Enum):
|
21
|
+
|
22
|
+
STANDARD = ("standard", "none")
|
23
|
+
DEEPEP = "deepep"
|
24
|
+
|
25
|
+
@classmethod
|
26
|
+
def _missing_(cls, value):
|
27
|
+
if value is None:
|
28
|
+
return cls.STANDARD
|
29
|
+
for member in cls:
|
30
|
+
if value in member.value:
|
31
|
+
return member
|
32
|
+
raise ValueError(f"No {cls.__name__} member for value {value}")
|
33
|
+
|
34
|
+
def is_deepep(self):
|
35
|
+
return self == MoeA2ABackend.DEEPEP
|
36
|
+
|
37
|
+
def is_standard(self):
|
38
|
+
return self == MoeA2ABackend.STANDARD
|
39
|
+
|
40
|
+
|
41
|
+
class DeepEPMode(Enum):
|
42
|
+
NORMAL = "normal"
|
43
|
+
LOW_LATENCY = "low_latency"
|
44
|
+
AUTO = "auto"
|
45
|
+
|
46
|
+
def enable_normal(self):
|
47
|
+
return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
|
48
|
+
|
49
|
+
def enable_low_latency(self):
|
50
|
+
return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
|
51
|
+
|
52
|
+
def resolve(self, is_extend_in_batch: bool):
|
53
|
+
if self != DeepEPMode.AUTO:
|
54
|
+
return self
|
55
|
+
|
56
|
+
if is_extend_in_batch:
|
57
|
+
return DeepEPMode.NORMAL
|
58
|
+
else:
|
59
|
+
return DeepEPMode.LOW_LATENCY
|
@@ -47,6 +47,12 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
|
47
47
|
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
48
48
|
CompressedTensorsConfig,
|
49
49
|
)
|
50
|
+
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
|
51
|
+
|
52
|
+
is_mxfp_supported = mxfp_supported()
|
53
|
+
if is_mxfp_supported:
|
54
|
+
from sglang.srt.layers.quantization.fp4 import MxFp4Config
|
55
|
+
|
50
56
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
51
57
|
from sglang.srt.layers.quantization.gptq import (
|
52
58
|
GPTQConfig,
|
@@ -60,6 +66,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
|
60
66
|
ModelOptFp8Config,
|
61
67
|
)
|
62
68
|
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
69
|
+
from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
|
63
70
|
from sglang.srt.layers.quantization.petit import PetitNvFp4Config
|
64
71
|
from sglang.srt.layers.quantization.qoq import QoQConfig
|
65
72
|
from sglang.srt.layers.quantization.utils import get_linear_quant_method
|
@@ -85,6 +92,21 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
85
92
|
"petit_nvfp4": PetitNvFp4Config,
|
86
93
|
}
|
87
94
|
|
95
|
+
|
96
|
+
if is_cuda():
|
97
|
+
BASE_QUANTIZATION_METHODS.update(
|
98
|
+
{
|
99
|
+
"quark": Mxfp4Config,
|
100
|
+
"mxfp4": Mxfp4Config,
|
101
|
+
}
|
102
|
+
)
|
103
|
+
elif is_mxfp_supported and is_hip():
|
104
|
+
BASE_QUANTIZATION_METHODS.update(
|
105
|
+
{
|
106
|
+
"quark": MxFp4Config,
|
107
|
+
"mxfp4": MxFp4Config,
|
108
|
+
}
|
109
|
+
)
|
88
110
|
# VLLM-dependent quantization methods
|
89
111
|
VLLM_QUANTIZATION_METHODS = {
|
90
112
|
"aqlm": AQLMConfig,
|
@@ -23,6 +23,7 @@ from sglang.srt.layers.quantization.utils import (
|
|
23
23
|
from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
|
24
24
|
|
25
25
|
if TYPE_CHECKING:
|
26
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
26
27
|
from sglang.srt.layers.moe.topk import TopKOutput
|
27
28
|
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
28
29
|
CompressedTensorsConfig,
|
@@ -189,7 +190,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
189
190
|
layer.w13_input_scale = None
|
190
191
|
layer.w2_input_scale = None
|
191
192
|
|
192
|
-
def process_weights_after_loading(self, layer:
|
193
|
+
def process_weights_after_loading(self, layer: FusedMoE) -> None:
|
193
194
|
# Fp8 moe kernels require a single activation scale.
|
194
195
|
# We take the max of all the scales in case they differ.
|
195
196
|
if self.static_input_scales:
|
@@ -246,7 +247,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
246
247
|
assert layer.w13_weight_scale is not None
|
247
248
|
shard_size = layer.intermediate_size_per_partition
|
248
249
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
249
|
-
for expert_id in range(layer.
|
250
|
+
for expert_id in range(layer.num_local_experts):
|
250
251
|
start = 0
|
251
252
|
for shard_id in range(2):
|
252
253
|
dq_weight = per_tensor_dequantize(
|
@@ -148,7 +148,7 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
|
148
148
|
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
149
149
|
"N": n,
|
150
150
|
"K": k,
|
151
|
-
"NUM_GROUPS":
|
151
|
+
"NUM_GROUPS": num_groups,
|
152
152
|
"BLOCK_M": block_m,
|
153
153
|
"BLOCK_N": block_n,
|
154
154
|
"BLOCK_K": block_k,
|