sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -117
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +3 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +22 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +8 -5
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +106 -15
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +55 -13
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
- sglang/srt/layers/attention/vision.py +40 -15
- sglang/srt/layers/communicator.py +35 -8
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +9 -8
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +87 -107
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +59 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +8 -7
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +15 -4
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +10 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +61 -32
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +21 -4
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +30 -8
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +170 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +59 -22
- sglang/srt/managers/tokenizer_manager.py +137 -67
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -21
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +48 -17
- sglang/srt/model_executor/model_runner.py +24 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +95 -50
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +102 -27
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/qwen3_moe.py +39 -14
- sglang/srt/models/step3_vl.py +10 -1
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +218 -23
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +163 -9
- sglang/srt/utils.py +41 -26
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/runners.py +4 -4
- sglang/test/test_utils.py +4 -4
- sglang/version.py +1 -1
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -1,41 +1,24 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from typing import TYPE_CHECKING,
|
4
|
+
from typing import TYPE_CHECKING, Optional
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
8
|
-
from sglang.srt.distributed import
|
9
|
-
get_tensor_model_parallel_rank,
|
10
|
-
get_tensor_model_parallel_world_size,
|
11
|
-
)
|
12
|
-
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
8
|
+
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
13
9
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
14
10
|
ep_gather,
|
15
11
|
ep_scatter,
|
16
|
-
gelu_and_mul_triton_kernel,
|
17
|
-
grouped_gemm_triton,
|
18
12
|
moe_ep_deepgemm_preprocess,
|
19
13
|
post_reorder_triton_kernel,
|
20
|
-
pre_reorder_triton_kernel,
|
21
|
-
pre_reorder_triton_kernel_for_cutlass_moe,
|
22
|
-
run_cutlass_moe_ep_preproess,
|
23
|
-
run_moe_ep_preproess,
|
24
14
|
silu_and_mul_masked_post_quant_fwd,
|
25
|
-
silu_and_mul_triton_kernel,
|
26
15
|
tma_align_input_scale,
|
27
16
|
)
|
28
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import
|
29
|
-
FlashInferFusedMoE,
|
30
|
-
FusedMoE,
|
31
|
-
should_use_flashinfer_trtllm_moe,
|
32
|
-
)
|
17
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
|
33
18
|
from sglang.srt.layers.moe.topk import TopKOutput
|
19
|
+
from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
|
34
20
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
35
|
-
from sglang.srt.layers.quantization.base_config import
|
36
|
-
QuantizationConfig,
|
37
|
-
QuantizeMethodBase,
|
38
|
-
)
|
21
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
39
22
|
from sglang.srt.layers.quantization.fp8 import (
|
40
23
|
Fp8Config,
|
41
24
|
Fp8MoEMethod,
|
@@ -44,23 +27,13 @@ from sglang.srt.layers.quantization.fp8 import (
|
|
44
27
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
45
28
|
is_fp8_fnuz,
|
46
29
|
sglang_per_token_group_quant_fp8,
|
47
|
-
sglang_per_token_quant_fp8,
|
48
30
|
)
|
49
|
-
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
50
|
-
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
51
31
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
52
32
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
53
|
-
from sglang.srt.utils import
|
54
|
-
DeepEPMode,
|
55
|
-
ceil_div,
|
56
|
-
dispose_tensor,
|
57
|
-
get_bool_env_var,
|
58
|
-
is_hip,
|
59
|
-
is_npu,
|
60
|
-
)
|
33
|
+
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
61
34
|
|
62
35
|
if TYPE_CHECKING:
|
63
|
-
from sglang.srt.layers.moe.
|
36
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
64
37
|
DeepEPLLOutput,
|
65
38
|
DeepEPNormalOutput,
|
66
39
|
DispatchOutput,
|
@@ -71,7 +44,6 @@ _is_npu = is_npu()
|
|
71
44
|
_is_fp8_fnuz = is_fp8_fnuz()
|
72
45
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
73
46
|
|
74
|
-
|
75
47
|
if not (_is_npu or _is_hip):
|
76
48
|
from sgl_kernel import silu_and_mul
|
77
49
|
|
@@ -83,6 +55,22 @@ if _use_aiter:
|
|
83
55
|
logger = logging.getLogger(__name__)
|
84
56
|
|
85
57
|
|
58
|
+
# TODO(kaixih@nvidia): ideally we should merge this logic into
|
59
|
+
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
|
60
|
+
@torch.compile
|
61
|
+
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
|
62
|
+
temp = x.to(torch.float32).view(torch.int32)
|
63
|
+
exp = torch.bitwise_right_shift(temp, 23)
|
64
|
+
mant = torch.bitwise_and(temp, 0x7FFFFF)
|
65
|
+
is_ru = torch.logical_and(
|
66
|
+
torch.logical_and((mant > 0), (exp != 0xFE)),
|
67
|
+
~torch.logical_and((exp == 0), (mant <= 0x400000)),
|
68
|
+
)
|
69
|
+
exp = torch.where(is_ru, exp + 1, exp)
|
70
|
+
new_x = exp.to(torch.uint8).view(torch.int)
|
71
|
+
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
|
72
|
+
|
73
|
+
|
86
74
|
class EPMoE(FusedMoE):
|
87
75
|
"""
|
88
76
|
MoE Expert Parallel Impl
|
@@ -104,6 +92,9 @@ class EPMoE(FusedMoE):
|
|
104
92
|
prefix: str = "",
|
105
93
|
activation: str = "silu",
|
106
94
|
routed_scaling_factor: Optional[float] = None,
|
95
|
+
activation_alpha: Optional[float] = None,
|
96
|
+
swiglu_limit: Optional[float] = None,
|
97
|
+
with_bias: bool = False,
|
107
98
|
):
|
108
99
|
super().__init__(
|
109
100
|
num_experts=num_experts,
|
@@ -119,7 +110,9 @@ class EPMoE(FusedMoE):
|
|
119
110
|
activation=activation,
|
120
111
|
# apply_router_weight_on_input=apply_router_weight_on_input,
|
121
112
|
routed_scaling_factor=routed_scaling_factor,
|
122
|
-
|
113
|
+
activation_alpha=activation_alpha,
|
114
|
+
swiglu_limit=swiglu_limit,
|
115
|
+
with_bias=with_bias,
|
123
116
|
)
|
124
117
|
|
125
118
|
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
@@ -227,10 +220,22 @@ class EPMoE(FusedMoE):
|
|
227
220
|
|
228
221
|
dispose_tensor(hidden_states)
|
229
222
|
|
223
|
+
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
224
|
+
b, s_mn, s_k = gateup_input_scale.shape
|
225
|
+
assert (
|
226
|
+
s_mn % 4 == 0 and s_k % 4 == 0
|
227
|
+
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
|
228
|
+
|
230
229
|
# GroupGemm-0
|
231
230
|
gateup_input_fp8 = (
|
232
231
|
gateup_input,
|
233
|
-
|
232
|
+
(
|
233
|
+
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
|
234
|
+
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
235
|
+
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
|
236
|
+
gateup_input_scale
|
237
|
+
)
|
238
|
+
),
|
234
239
|
)
|
235
240
|
num_groups, m, k = gateup_input_fp8[0].size()
|
236
241
|
n = self.w13_weight.size(1)
|
@@ -238,7 +243,12 @@ class EPMoE(FusedMoE):
|
|
238
243
|
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
239
244
|
)
|
240
245
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
241
|
-
gateup_input_fp8,
|
246
|
+
gateup_input_fp8,
|
247
|
+
self.w13_weight_fp8,
|
248
|
+
gateup_output,
|
249
|
+
masked_m,
|
250
|
+
expected_m,
|
251
|
+
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
242
252
|
)
|
243
253
|
del gateup_input
|
244
254
|
del gateup_input_fp8
|
@@ -269,6 +279,7 @@ class EPMoE(FusedMoE):
|
|
269
279
|
down_input_scale,
|
270
280
|
scale_block_size,
|
271
281
|
masked_m,
|
282
|
+
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
272
283
|
)
|
273
284
|
del gateup_output
|
274
285
|
|
@@ -276,13 +287,24 @@ class EPMoE(FusedMoE):
|
|
276
287
|
n = self.w2_weight.size(1)
|
277
288
|
down_input_fp8 = (
|
278
289
|
down_input,
|
279
|
-
|
290
|
+
(
|
291
|
+
down_input_scale
|
292
|
+
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
293
|
+
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
|
294
|
+
down_input_scale
|
295
|
+
)
|
296
|
+
),
|
280
297
|
)
|
281
298
|
down_output = torch.empty(
|
282
299
|
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
283
300
|
)
|
284
301
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
285
|
-
down_input_fp8,
|
302
|
+
down_input_fp8,
|
303
|
+
self.w2_weight_fp8,
|
304
|
+
down_output,
|
305
|
+
masked_m,
|
306
|
+
expected_m,
|
307
|
+
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
286
308
|
)
|
287
309
|
del down_input
|
288
310
|
del down_input_fp8
|
@@ -304,6 +326,8 @@ class EPMoE(FusedMoE):
|
|
304
326
|
m_max * self.start_expert_id,
|
305
327
|
BLOCK_SIZE=512,
|
306
328
|
)
|
329
|
+
if self.routed_scaling_factor is not None:
|
330
|
+
output *= self.routed_scaling_factor
|
307
331
|
return output
|
308
332
|
|
309
333
|
|
@@ -328,7 +352,7 @@ class DeepEPMoE(EPMoE):
|
|
328
352
|
prefix: str = "",
|
329
353
|
activation: str = "silu",
|
330
354
|
routed_scaling_factor: Optional[float] = None,
|
331
|
-
deepep_mode: DeepEPMode = DeepEPMode.
|
355
|
+
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
|
332
356
|
):
|
333
357
|
super().__init__(
|
334
358
|
num_experts=num_experts,
|
@@ -348,7 +372,6 @@ class DeepEPMoE(EPMoE):
|
|
348
372
|
|
349
373
|
# TODO: move to the beginning of the file
|
350
374
|
from sglang.srt.distributed.parallel_state import get_tp_group
|
351
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
352
375
|
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
353
376
|
|
354
377
|
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
@@ -701,72 +724,29 @@ class DeepEPMoE(EPMoE):
|
|
701
724
|
return down_output
|
702
725
|
|
703
726
|
|
704
|
-
class FlashInferEPMoE(EPMoE):
|
705
|
-
def __init__(self, *args, **kwargs):
|
706
|
-
renormalize = kwargs.pop("renormalize", True)
|
707
|
-
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
708
|
-
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
709
|
-
num_expert_group = kwargs.pop("num_expert_group", None)
|
710
|
-
topk_group = kwargs.pop("topk_group", None)
|
711
|
-
correction_bias = kwargs.pop("correction_bias", None)
|
712
|
-
super().__init__(*args, **kwargs)
|
713
|
-
self.renormalize = renormalize
|
714
|
-
self.num_fused_shared_experts = num_fused_shared_experts
|
715
|
-
self.use_grouped_topk = use_grouped_topk
|
716
|
-
if self.use_grouped_topk:
|
717
|
-
assert num_expert_group is not None and topk_group is not None
|
718
|
-
self.num_expert_group = num_expert_group
|
719
|
-
self.topk_group = topk_group
|
720
|
-
self.correction_bias = correction_bias
|
721
|
-
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
722
|
-
|
723
|
-
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
724
|
-
assert self.use_flashinfer_trtllm_moe
|
725
|
-
assert (
|
726
|
-
self.activation == "silu"
|
727
|
-
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
728
|
-
assert (
|
729
|
-
self.renormalize
|
730
|
-
), "Renormalize is required for flashinfer blockscale fp8 moe"
|
731
|
-
assert (
|
732
|
-
self.num_fused_shared_experts == 0
|
733
|
-
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
734
|
-
a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
|
735
|
-
# NOTE: scales of hidden states have to be transposed!
|
736
|
-
a_sf_t = a_sf.t().contiguous()
|
737
|
-
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
738
|
-
|
739
|
-
return trtllm_fp8_block_scale_moe(
|
740
|
-
routing_logits=router_logits.to(torch.float32),
|
741
|
-
routing_bias=self.correction_bias.to(hidden_states.dtype),
|
742
|
-
hidden_states=a_q,
|
743
|
-
hidden_states_scale=a_sf_t,
|
744
|
-
gemm1_weights=self.w13_weight,
|
745
|
-
gemm1_weights_scale=self.w13_weight_scale_inv,
|
746
|
-
gemm2_weights=self.w2_weight,
|
747
|
-
gemm2_weights_scale=self.w2_weight_scale_inv,
|
748
|
-
num_experts=self.num_experts,
|
749
|
-
top_k=self.top_k,
|
750
|
-
n_group=self.num_expert_group,
|
751
|
-
topk_group=self.topk_group,
|
752
|
-
intermediate_size=self.w2_weight.shape[2],
|
753
|
-
local_expert_offset=self.start_expert_id,
|
754
|
-
local_num_experts=self.num_local_experts,
|
755
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
756
|
-
tile_tokens_dim=get_tile_tokens_dim(
|
757
|
-
hidden_states.shape[0], self.top_k, self.num_experts
|
758
|
-
),
|
759
|
-
routing_method_type=2, # DeepSeek-styled routing method
|
760
|
-
use_shuffled_weight=False,
|
761
|
-
)
|
762
|
-
|
763
|
-
|
764
727
|
def get_moe_impl_class():
|
765
|
-
if global_server_args_dict["
|
728
|
+
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
766
729
|
return DeepEPMoE
|
730
|
+
|
731
|
+
# NEW: Direct FP4 detection (bypasses EP requirements)
|
732
|
+
# Check for FP4 quantization with TRTLLM flag, regardless of EP
|
733
|
+
if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False):
|
734
|
+
try:
|
735
|
+
# Check the quantization argument directly
|
736
|
+
quantization = global_server_args_dict.get("quantization")
|
737
|
+
if quantization == "modelopt_fp4":
|
738
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
739
|
+
FlashInferFP4MoE,
|
740
|
+
)
|
741
|
+
|
742
|
+
return FlashInferFP4MoE
|
743
|
+
except:
|
744
|
+
pass
|
745
|
+
|
746
|
+
if should_use_flashinfer_trtllm_moe():
|
747
|
+
return FlashInferFusedMoE
|
767
748
|
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
|
768
|
-
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
769
749
|
return FusedMoE
|
770
|
-
if
|
771
|
-
return
|
772
|
-
return
|
750
|
+
if get_moe_expert_parallel_world_size() > 1:
|
751
|
+
return EPMoE
|
752
|
+
return FusedMoE
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 2
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 32,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 2
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 32,
|
21
|
+
"BLOCK_SIZE_K": 256,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 2
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 256,
|
30
|
+
"GROUP_SIZE_M": 32,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 32,
|
36
|
+
"BLOCK_SIZE_N": 64,
|
37
|
+
"BLOCK_SIZE_K": 256,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 8,
|
40
|
+
"num_stages": 2
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 256,
|
46
|
+
"GROUP_SIZE_M": 32,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 32,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 16,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 4
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 32,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 16,
|
63
|
+
"num_warps": 8,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 32,
|
69
|
+
"BLOCK_SIZE_K": 256,
|
70
|
+
"GROUP_SIZE_M": 16,
|
71
|
+
"num_warps": 8,
|
72
|
+
"num_stages": 2
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 64,
|
77
|
+
"BLOCK_SIZE_K": 256,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 256,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 2
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 32,
|
92
|
+
"BLOCK_SIZE_N": 64,
|
93
|
+
"BLOCK_SIZE_K": 256,
|
94
|
+
"GROUP_SIZE_M": 32,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 2
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 32,
|
100
|
+
"BLOCK_SIZE_N": 64,
|
101
|
+
"BLOCK_SIZE_K": 256,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 2
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 256,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 2
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 64,
|
117
|
+
"BLOCK_SIZE_K": 256,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 2
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 256,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 8,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 32,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|