sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +11 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +4 -3
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +71 -0
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/vision.py +13 -5
- sglang/srt/layers/communicator.py +21 -4
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +2 -7
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +77 -73
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +3 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +55 -30
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +28 -7
- sglang/srt/managers/scheduler.py +26 -12
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +24 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +7 -6
- sglang/srt/model_executor/forward_batch_info.py +35 -14
- sglang/srt/model_executor/model_runner.py +19 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +72 -33
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +24 -12
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +142 -7
- sglang/srt/two_batch_overlap.py +157 -5
- sglang/srt/utils.py +38 -2
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -14,13 +14,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
14
14
|
silu_and_mul_masked_post_quant_fwd,
|
15
15
|
tma_align_input_scale,
|
16
16
|
)
|
17
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import
|
18
|
-
FlashInferFusedMoE,
|
19
|
-
FusedMoE,
|
20
|
-
should_use_flashinfer_trtllm_moe,
|
21
|
-
)
|
17
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
|
22
18
|
from sglang.srt.layers.moe.topk import TopKOutput
|
23
|
-
from sglang.srt.layers.moe.utils import DeepEPMode
|
19
|
+
from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
|
24
20
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
25
21
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
26
22
|
from sglang.srt.layers.quantization.fp8 import (
|
@@ -48,7 +44,6 @@ _is_npu = is_npu()
|
|
48
44
|
_is_fp8_fnuz = is_fp8_fnuz()
|
49
45
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
50
46
|
|
51
|
-
|
52
47
|
if not (_is_npu or _is_hip):
|
53
48
|
from sgl_kernel import silu_and_mul
|
54
49
|
|
@@ -60,6 +55,22 @@ if _use_aiter:
|
|
60
55
|
logger = logging.getLogger(__name__)
|
61
56
|
|
62
57
|
|
58
|
+
# TODO(kaixih@nvidia): ideally we should merge this logic into
|
59
|
+
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
|
60
|
+
@torch.compile
|
61
|
+
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
|
62
|
+
temp = x.to(torch.float32).view(torch.int32)
|
63
|
+
exp = torch.bitwise_right_shift(temp, 23)
|
64
|
+
mant = torch.bitwise_and(temp, 0x7FFFFF)
|
65
|
+
is_ru = torch.logical_and(
|
66
|
+
torch.logical_and((mant > 0), (exp != 0xFE)),
|
67
|
+
~torch.logical_and((exp == 0), (mant <= 0x400000)),
|
68
|
+
)
|
69
|
+
exp = torch.where(is_ru, exp + 1, exp)
|
70
|
+
new_x = exp.to(torch.uint8).view(torch.int)
|
71
|
+
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
|
72
|
+
|
73
|
+
|
63
74
|
class EPMoE(FusedMoE):
|
64
75
|
"""
|
65
76
|
MoE Expert Parallel Impl
|
@@ -81,6 +92,9 @@ class EPMoE(FusedMoE):
|
|
81
92
|
prefix: str = "",
|
82
93
|
activation: str = "silu",
|
83
94
|
routed_scaling_factor: Optional[float] = None,
|
95
|
+
activation_alpha: Optional[float] = None,
|
96
|
+
swiglu_limit: Optional[float] = None,
|
97
|
+
with_bias: bool = False,
|
84
98
|
):
|
85
99
|
super().__init__(
|
86
100
|
num_experts=num_experts,
|
@@ -96,6 +110,9 @@ class EPMoE(FusedMoE):
|
|
96
110
|
activation=activation,
|
97
111
|
# apply_router_weight_on_input=apply_router_weight_on_input,
|
98
112
|
routed_scaling_factor=routed_scaling_factor,
|
113
|
+
activation_alpha=activation_alpha,
|
114
|
+
swiglu_limit=swiglu_limit,
|
115
|
+
with_bias=with_bias,
|
99
116
|
)
|
100
117
|
|
101
118
|
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
@@ -203,10 +220,22 @@ class EPMoE(FusedMoE):
|
|
203
220
|
|
204
221
|
dispose_tensor(hidden_states)
|
205
222
|
|
223
|
+
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
224
|
+
b, s_mn, s_k = gateup_input_scale.shape
|
225
|
+
assert (
|
226
|
+
s_mn % 4 == 0 and s_k % 4 == 0
|
227
|
+
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
|
228
|
+
|
206
229
|
# GroupGemm-0
|
207
230
|
gateup_input_fp8 = (
|
208
231
|
gateup_input,
|
209
|
-
|
232
|
+
(
|
233
|
+
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
|
234
|
+
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
235
|
+
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
|
236
|
+
gateup_input_scale
|
237
|
+
)
|
238
|
+
),
|
210
239
|
)
|
211
240
|
num_groups, m, k = gateup_input_fp8[0].size()
|
212
241
|
n = self.w13_weight.size(1)
|
@@ -214,7 +243,12 @@ class EPMoE(FusedMoE):
|
|
214
243
|
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
215
244
|
)
|
216
245
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
217
|
-
gateup_input_fp8,
|
246
|
+
gateup_input_fp8,
|
247
|
+
self.w13_weight_fp8,
|
248
|
+
gateup_output,
|
249
|
+
masked_m,
|
250
|
+
expected_m,
|
251
|
+
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
218
252
|
)
|
219
253
|
del gateup_input
|
220
254
|
del gateup_input_fp8
|
@@ -245,6 +279,7 @@ class EPMoE(FusedMoE):
|
|
245
279
|
down_input_scale,
|
246
280
|
scale_block_size,
|
247
281
|
masked_m,
|
282
|
+
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
248
283
|
)
|
249
284
|
del gateup_output
|
250
285
|
|
@@ -252,13 +287,24 @@ class EPMoE(FusedMoE):
|
|
252
287
|
n = self.w2_weight.size(1)
|
253
288
|
down_input_fp8 = (
|
254
289
|
down_input,
|
255
|
-
|
290
|
+
(
|
291
|
+
down_input_scale
|
292
|
+
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
293
|
+
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
|
294
|
+
down_input_scale
|
295
|
+
)
|
296
|
+
),
|
256
297
|
)
|
257
298
|
down_output = torch.empty(
|
258
299
|
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
259
300
|
)
|
260
301
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
261
|
-
down_input_fp8,
|
302
|
+
down_input_fp8,
|
303
|
+
self.w2_weight_fp8,
|
304
|
+
down_output,
|
305
|
+
masked_m,
|
306
|
+
expected_m,
|
307
|
+
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
262
308
|
)
|
263
309
|
del down_input
|
264
310
|
del down_input_fp8
|
@@ -678,71 +724,29 @@ class DeepEPMoE(EPMoE):
|
|
678
724
|
return down_output
|
679
725
|
|
680
726
|
|
681
|
-
class FlashInferEPMoE(EPMoE):
|
682
|
-
def __init__(self, *args, **kwargs):
|
683
|
-
renormalize = kwargs.pop("renormalize", True)
|
684
|
-
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
685
|
-
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
686
|
-
num_expert_group = kwargs.pop("num_expert_group", None)
|
687
|
-
topk_group = kwargs.pop("topk_group", None)
|
688
|
-
correction_bias = kwargs.pop("correction_bias", None)
|
689
|
-
super().__init__(*args, **kwargs)
|
690
|
-
self.renormalize = renormalize
|
691
|
-
self.num_fused_shared_experts = num_fused_shared_experts
|
692
|
-
self.use_grouped_topk = use_grouped_topk
|
693
|
-
if self.use_grouped_topk:
|
694
|
-
assert num_expert_group is not None and topk_group is not None
|
695
|
-
self.num_expert_group = num_expert_group
|
696
|
-
self.topk_group = topk_group
|
697
|
-
self.correction_bias = correction_bias
|
698
|
-
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
699
|
-
|
700
|
-
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
701
|
-
assert self.use_flashinfer_trtllm_moe
|
702
|
-
assert (
|
703
|
-
self.activation == "silu"
|
704
|
-
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
705
|
-
assert (
|
706
|
-
self.renormalize
|
707
|
-
), "Renormalize is required for flashinfer blockscale fp8 moe"
|
708
|
-
assert (
|
709
|
-
self.num_fused_shared_experts == 0
|
710
|
-
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
711
|
-
a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
|
712
|
-
# NOTE: scales of hidden states have to be transposed!
|
713
|
-
a_sf_t = a_sf.t().contiguous()
|
714
|
-
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
715
|
-
|
716
|
-
return trtllm_fp8_block_scale_moe(
|
717
|
-
routing_logits=router_logits.to(torch.float32),
|
718
|
-
routing_bias=self.correction_bias.to(hidden_states.dtype),
|
719
|
-
hidden_states=a_q,
|
720
|
-
hidden_states_scale=a_sf_t,
|
721
|
-
gemm1_weights=self.w13_weight,
|
722
|
-
gemm1_weights_scale=self.w13_weight_scale_inv,
|
723
|
-
gemm2_weights=self.w2_weight,
|
724
|
-
gemm2_weights_scale=self.w2_weight_scale_inv,
|
725
|
-
num_experts=self.num_experts,
|
726
|
-
top_k=self.top_k,
|
727
|
-
n_group=self.num_expert_group,
|
728
|
-
topk_group=self.topk_group,
|
729
|
-
intermediate_size=self.w2_weight.shape[2],
|
730
|
-
local_expert_offset=self.start_expert_id,
|
731
|
-
local_num_experts=self.num_local_experts,
|
732
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
733
|
-
tile_tokens_dim=get_tile_tokens_dim(
|
734
|
-
hidden_states.shape[0], self.top_k, self.num_experts
|
735
|
-
),
|
736
|
-
routing_method_type=2, # DeepSeek-styled routing method
|
737
|
-
use_shuffled_weight=False,
|
738
|
-
)
|
739
|
-
|
740
|
-
|
741
727
|
def get_moe_impl_class():
|
742
728
|
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
743
729
|
return DeepEPMoE
|
730
|
+
|
731
|
+
# NEW: Direct FP4 detection (bypasses EP requirements)
|
732
|
+
# Check for FP4 quantization with TRTLLM flag, regardless of EP
|
733
|
+
if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False):
|
734
|
+
try:
|
735
|
+
# Check the quantization argument directly
|
736
|
+
quantization = global_server_args_dict.get("quantization")
|
737
|
+
if quantization == "modelopt_fp4":
|
738
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
739
|
+
FlashInferFP4MoE,
|
740
|
+
)
|
741
|
+
|
742
|
+
return FlashInferFP4MoE
|
743
|
+
except:
|
744
|
+
pass
|
745
|
+
|
746
|
+
if should_use_flashinfer_trtllm_moe():
|
747
|
+
return FlashInferFusedMoE
|
744
748
|
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
|
745
749
|
return FusedMoE
|
746
750
|
if get_moe_expert_parallel_world_size() > 1:
|
747
|
-
return
|
748
|
-
return
|
751
|
+
return EPMoE
|
752
|
+
return FusedMoE
|
@@ -319,6 +319,7 @@ def fused_moe_kernel(
|
|
319
319
|
# Pointers to matrices
|
320
320
|
a_ptr,
|
321
321
|
b_ptr,
|
322
|
+
bias_ptr,
|
322
323
|
c_ptr,
|
323
324
|
a_scale_ptr,
|
324
325
|
b_scale_ptr,
|
@@ -340,6 +341,8 @@ def fused_moe_kernel(
|
|
340
341
|
stride_be,
|
341
342
|
stride_bk,
|
342
343
|
stride_bn,
|
344
|
+
stride_bias_e,
|
345
|
+
stride_bias_n,
|
343
346
|
stride_cm,
|
344
347
|
stride_cn,
|
345
348
|
stride_asm,
|
@@ -449,6 +452,10 @@ def fused_moe_kernel(
|
|
449
452
|
+ off_experts * stride_be
|
450
453
|
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
451
454
|
)
|
455
|
+
if bias_ptr is not None:
|
456
|
+
bias = tl.load(
|
457
|
+
bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n
|
458
|
+
)
|
452
459
|
if use_int8_w8a16:
|
453
460
|
b_scale_ptrs = (
|
454
461
|
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
@@ -526,18 +533,20 @@ def fused_moe_kernel(
|
|
526
533
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
527
534
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
528
535
|
|
529
|
-
if MUL_ROUTED_WEIGHT:
|
530
|
-
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
531
|
-
accumulator = accumulator * moe_weight[:, None]
|
532
536
|
if use_int8_w8a16:
|
533
|
-
accumulator
|
537
|
+
accumulator *= b_scale
|
534
538
|
elif use_fp8_w8a8 or use_int8_w8a8:
|
535
|
-
if group_k
|
536
|
-
accumulator
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
539
|
+
if group_k == 0 or group_n == 0:
|
540
|
+
accumulator *= a_scale * b_scale
|
541
|
+
|
542
|
+
if bias_ptr is not None:
|
543
|
+
accumulator += bias
|
544
|
+
|
545
|
+
if MUL_ROUTED_WEIGHT:
|
546
|
+
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
547
|
+
accumulator *= moe_weight[:, None]
|
548
|
+
|
549
|
+
accumulator = accumulator.to(compute_type)
|
541
550
|
# -----------------------------------------------------------
|
542
551
|
# Write back the block of the output
|
543
552
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
@@ -622,6 +631,7 @@ def moe_align_block_size(
|
|
622
631
|
def invoke_fused_moe_kernel(
|
623
632
|
A: torch.Tensor,
|
624
633
|
B: torch.Tensor,
|
634
|
+
bias: Optional[torch.Tensor],
|
625
635
|
C: torch.Tensor,
|
626
636
|
A_scale: Optional[torch.Tensor],
|
627
637
|
B_scale: Optional[torch.Tensor],
|
@@ -711,6 +721,7 @@ def invoke_fused_moe_kernel(
|
|
711
721
|
):
|
712
722
|
assert B_scale is not None and B_scale.ndim == 3
|
713
723
|
assert B_zp is None or B_zp.ndim == 3
|
724
|
+
assert bias is None
|
714
725
|
fused_moe_kernel_gptq_awq[grid](
|
715
726
|
A,
|
716
727
|
B,
|
@@ -754,6 +765,7 @@ def invoke_fused_moe_kernel(
|
|
754
765
|
fused_moe_kernel[grid](
|
755
766
|
A,
|
756
767
|
B,
|
768
|
+
bias,
|
757
769
|
C,
|
758
770
|
A_scale,
|
759
771
|
B_scale,
|
@@ -770,6 +782,8 @@ def invoke_fused_moe_kernel(
|
|
770
782
|
B.stride(0),
|
771
783
|
B.stride(2),
|
772
784
|
B.stride(1),
|
785
|
+
bias.stride(0) if bias is not None else 0,
|
786
|
+
bias.stride(1) if bias is not None else 0,
|
773
787
|
C.stride(1),
|
774
788
|
C.stride(2),
|
775
789
|
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
@@ -994,6 +1008,8 @@ def inplace_fused_experts(
|
|
994
1008
|
w2: torch.Tensor,
|
995
1009
|
topk_weights: torch.Tensor,
|
996
1010
|
topk_ids: torch.Tensor,
|
1011
|
+
b1: Optional[torch.Tensor] = None,
|
1012
|
+
b2: Optional[torch.Tensor] = None,
|
997
1013
|
activation: str = "silu",
|
998
1014
|
apply_router_weight_on_input: bool = False,
|
999
1015
|
use_fp8_w8a8: bool = False,
|
@@ -1009,6 +1025,8 @@ def inplace_fused_experts(
|
|
1009
1025
|
a2_scale: Optional[torch.Tensor] = None,
|
1010
1026
|
block_shape: Optional[List[int]] = None,
|
1011
1027
|
routed_scaling_factor: Optional[float] = None,
|
1028
|
+
activation_alpha: Optional[float] = None,
|
1029
|
+
swiglu_limit: Optional[float] = None,
|
1012
1030
|
) -> None:
|
1013
1031
|
fused_experts_impl(
|
1014
1032
|
hidden_states,
|
@@ -1016,6 +1034,8 @@ def inplace_fused_experts(
|
|
1016
1034
|
w2,
|
1017
1035
|
topk_weights,
|
1018
1036
|
topk_ids,
|
1037
|
+
b1,
|
1038
|
+
b2,
|
1019
1039
|
True,
|
1020
1040
|
activation,
|
1021
1041
|
apply_router_weight_on_input,
|
@@ -1033,6 +1053,8 @@ def inplace_fused_experts(
|
|
1033
1053
|
block_shape,
|
1034
1054
|
False,
|
1035
1055
|
routed_scaling_factor,
|
1056
|
+
activation_alpha,
|
1057
|
+
swiglu_limit,
|
1036
1058
|
)
|
1037
1059
|
|
1038
1060
|
|
@@ -1042,6 +1064,8 @@ def inplace_fused_experts_fake(
|
|
1042
1064
|
w2: torch.Tensor,
|
1043
1065
|
topk_weights: torch.Tensor,
|
1044
1066
|
topk_ids: torch.Tensor,
|
1067
|
+
b1: Optional[torch.Tensor] = None,
|
1068
|
+
b2: Optional[torch.Tensor] = None,
|
1045
1069
|
activation: str = "silu",
|
1046
1070
|
apply_router_weight_on_input: bool = False,
|
1047
1071
|
use_fp8_w8a8: bool = False,
|
@@ -1057,6 +1081,8 @@ def inplace_fused_experts_fake(
|
|
1057
1081
|
a2_scale: Optional[torch.Tensor] = None,
|
1058
1082
|
block_shape: Optional[List[int]] = None,
|
1059
1083
|
routed_scaling_factor: Optional[float] = None,
|
1084
|
+
activation_alpha: Optional[float] = None,
|
1085
|
+
swiglu_limit: Optional[float] = None,
|
1060
1086
|
) -> None:
|
1061
1087
|
pass
|
1062
1088
|
|
@@ -1075,6 +1101,8 @@ def outplace_fused_experts(
|
|
1075
1101
|
w2: torch.Tensor,
|
1076
1102
|
topk_weights: torch.Tensor,
|
1077
1103
|
topk_ids: torch.Tensor,
|
1104
|
+
b1: Optional[torch.Tensor] = None,
|
1105
|
+
b2: Optional[torch.Tensor] = None,
|
1078
1106
|
activation: str = "silu",
|
1079
1107
|
apply_router_weight_on_input: bool = False,
|
1080
1108
|
use_fp8_w8a8: bool = False,
|
@@ -1091,6 +1119,8 @@ def outplace_fused_experts(
|
|
1091
1119
|
block_shape: Optional[List[int]] = None,
|
1092
1120
|
no_combine: bool = False,
|
1093
1121
|
routed_scaling_factor: Optional[float] = None,
|
1122
|
+
activation_alpha: Optional[float] = None,
|
1123
|
+
swiglu_limit: Optional[float] = None,
|
1094
1124
|
) -> torch.Tensor:
|
1095
1125
|
return fused_experts_impl(
|
1096
1126
|
hidden_states,
|
@@ -1098,6 +1128,8 @@ def outplace_fused_experts(
|
|
1098
1128
|
w2,
|
1099
1129
|
topk_weights,
|
1100
1130
|
topk_ids,
|
1131
|
+
b1,
|
1132
|
+
b2,
|
1101
1133
|
False,
|
1102
1134
|
activation,
|
1103
1135
|
apply_router_weight_on_input,
|
@@ -1115,6 +1147,8 @@ def outplace_fused_experts(
|
|
1115
1147
|
block_shape,
|
1116
1148
|
no_combine=no_combine,
|
1117
1149
|
routed_scaling_factor=routed_scaling_factor,
|
1150
|
+
activation_alpha=activation_alpha,
|
1151
|
+
swiglu_limit=swiglu_limit,
|
1118
1152
|
)
|
1119
1153
|
|
1120
1154
|
|
@@ -1124,6 +1158,8 @@ def outplace_fused_experts_fake(
|
|
1124
1158
|
w2: torch.Tensor,
|
1125
1159
|
topk_weights: torch.Tensor,
|
1126
1160
|
topk_ids: torch.Tensor,
|
1161
|
+
b1: Optional[torch.Tensor] = None,
|
1162
|
+
b2: Optional[torch.Tensor] = None,
|
1127
1163
|
activation: str = "silu",
|
1128
1164
|
apply_router_weight_on_input: bool = False,
|
1129
1165
|
use_fp8_w8a8: bool = False,
|
@@ -1140,6 +1176,8 @@ def outplace_fused_experts_fake(
|
|
1140
1176
|
block_shape: Optional[List[int]] = None,
|
1141
1177
|
no_combine: bool = False,
|
1142
1178
|
routed_scaling_factor: Optional[float] = None,
|
1179
|
+
activation_alpha: Optional[float] = None,
|
1180
|
+
swiglu_limit: Optional[float] = None,
|
1143
1181
|
) -> torch.Tensor:
|
1144
1182
|
return torch.empty_like(hidden_states)
|
1145
1183
|
|
@@ -1157,6 +1195,8 @@ def fused_experts(
|
|
1157
1195
|
w1: torch.Tensor,
|
1158
1196
|
w2: torch.Tensor,
|
1159
1197
|
topk_output: TopKOutput,
|
1198
|
+
b1: Optional[torch.Tensor] = None,
|
1199
|
+
b2: Optional[torch.Tensor] = None,
|
1160
1200
|
inplace: bool = False,
|
1161
1201
|
activation: str = "silu",
|
1162
1202
|
apply_router_weight_on_input: bool = False,
|
@@ -1174,6 +1214,8 @@ def fused_experts(
|
|
1174
1214
|
block_shape: Optional[List[int]] = None,
|
1175
1215
|
no_combine: bool = False,
|
1176
1216
|
routed_scaling_factor: Optional[float] = None,
|
1217
|
+
activation_alpha: Optional[float] = None,
|
1218
|
+
swiglu_limit: Optional[float] = None,
|
1177
1219
|
):
|
1178
1220
|
topk_weights, topk_ids, _ = topk_output
|
1179
1221
|
if inplace:
|
@@ -1184,6 +1226,8 @@ def fused_experts(
|
|
1184
1226
|
w2,
|
1185
1227
|
topk_weights,
|
1186
1228
|
topk_ids,
|
1229
|
+
b1,
|
1230
|
+
b2,
|
1187
1231
|
activation,
|
1188
1232
|
apply_router_weight_on_input,
|
1189
1233
|
use_fp8_w8a8,
|
@@ -1199,6 +1243,8 @@ def fused_experts(
|
|
1199
1243
|
a2_scale,
|
1200
1244
|
block_shape,
|
1201
1245
|
routed_scaling_factor,
|
1246
|
+
activation_alpha,
|
1247
|
+
swiglu_limit,
|
1202
1248
|
)
|
1203
1249
|
return hidden_states
|
1204
1250
|
else:
|
@@ -1208,6 +1254,8 @@ def fused_experts(
|
|
1208
1254
|
w2,
|
1209
1255
|
topk_weights,
|
1210
1256
|
topk_ids,
|
1257
|
+
b1,
|
1258
|
+
b2,
|
1211
1259
|
activation,
|
1212
1260
|
apply_router_weight_on_input,
|
1213
1261
|
use_fp8_w8a8,
|
@@ -1224,6 +1272,8 @@ def fused_experts(
|
|
1224
1272
|
block_shape,
|
1225
1273
|
no_combine=no_combine,
|
1226
1274
|
routed_scaling_factor=routed_scaling_factor,
|
1275
|
+
activation_alpha=activation_alpha,
|
1276
|
+
swiglu_limit=swiglu_limit,
|
1227
1277
|
)
|
1228
1278
|
|
1229
1279
|
|
@@ -1319,12 +1369,22 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
|
|
1319
1369
|
out.mul_(routed_scaling_factor)
|
1320
1370
|
|
1321
1371
|
|
1372
|
+
@torch.compile
|
1373
|
+
def swiglu_with_alpha_and_limit(x, alpha, limit):
|
1374
|
+
gate, up = x[..., ::2], x[..., 1::2]
|
1375
|
+
gate = gate.clamp(min=None, max=limit)
|
1376
|
+
up = up.clamp(min=-limit, max=limit)
|
1377
|
+
return gate * torch.sigmoid(gate * alpha) * (up + 1)
|
1378
|
+
|
1379
|
+
|
1322
1380
|
def fused_experts_impl(
|
1323
1381
|
hidden_states: torch.Tensor,
|
1324
1382
|
w1: torch.Tensor,
|
1325
1383
|
w2: torch.Tensor,
|
1326
1384
|
topk_weights: torch.Tensor,
|
1327
1385
|
topk_ids: torch.Tensor,
|
1386
|
+
b1: Optional[torch.Tensor] = None,
|
1387
|
+
b2: Optional[torch.Tensor] = None,
|
1328
1388
|
inplace: bool = False,
|
1329
1389
|
activation: str = "silu",
|
1330
1390
|
apply_router_weight_on_input: bool = False,
|
@@ -1342,6 +1402,8 @@ def fused_experts_impl(
|
|
1342
1402
|
block_shape: Optional[List[int]] = None,
|
1343
1403
|
no_combine: bool = False,
|
1344
1404
|
routed_scaling_factor: Optional[float] = None,
|
1405
|
+
activation_alpha: Optional[float] = None,
|
1406
|
+
swiglu_limit: Optional[float] = None,
|
1345
1407
|
):
|
1346
1408
|
padded_size = padding_size
|
1347
1409
|
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
|
@@ -1353,7 +1415,7 @@ def fused_experts_impl(
|
|
1353
1415
|
else:
|
1354
1416
|
assert (
|
1355
1417
|
hidden_states.shape[1] == w1.shape[2] - padded_size
|
1356
|
-
), "Hidden size mismatch"
|
1418
|
+
), f"Hidden size mismatch"
|
1357
1419
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
1358
1420
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
1359
1421
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
@@ -1449,6 +1511,7 @@ def fused_experts_impl(
|
|
1449
1511
|
invoke_fused_moe_kernel(
|
1450
1512
|
curr_hidden_states,
|
1451
1513
|
w1,
|
1514
|
+
b1,
|
1452
1515
|
intermediate_cache1,
|
1453
1516
|
a1_scale,
|
1454
1517
|
w1_scale,
|
@@ -1470,13 +1533,24 @@ def fused_experts_impl(
|
|
1470
1533
|
block_shape=block_shape,
|
1471
1534
|
)
|
1472
1535
|
if activation == "silu":
|
1473
|
-
if
|
1536
|
+
if activation_alpha is not None:
|
1537
|
+
assert swiglu_limit is not None
|
1538
|
+
intermediate_cache2 = swiglu_with_alpha_and_limit(
|
1539
|
+
intermediate_cache1.view(-1, N),
|
1540
|
+
activation_alpha,
|
1541
|
+
swiglu_limit,
|
1542
|
+
)
|
1543
|
+
elif _is_cuda:
|
1474
1544
|
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
1475
1545
|
else:
|
1476
1546
|
vllm_ops.silu_and_mul(
|
1477
1547
|
intermediate_cache2, intermediate_cache1.view(-1, N)
|
1478
1548
|
)
|
1479
1549
|
elif activation == "gelu":
|
1550
|
+
assert (
|
1551
|
+
activation_alpha is None
|
1552
|
+
), "activation_alpha is not supported for gelu"
|
1553
|
+
assert swiglu_limit is None, "swiglu_limit is not supported for gelu"
|
1480
1554
|
if _is_cuda:
|
1481
1555
|
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
1482
1556
|
else:
|
@@ -1489,6 +1563,7 @@ def fused_experts_impl(
|
|
1489
1563
|
invoke_fused_moe_kernel(
|
1490
1564
|
intermediate_cache2,
|
1491
1565
|
w2,
|
1566
|
+
b2,
|
1492
1567
|
(
|
1493
1568
|
intermediate_cache3
|
1494
1569
|
if not no_combine and topk_ids.shape[1] != 1
|
@@ -1567,6 +1642,8 @@ def fused_moe(
|
|
1567
1642
|
w1: torch.Tensor,
|
1568
1643
|
w2: torch.Tensor,
|
1569
1644
|
topk_output: TopKOutput,
|
1645
|
+
b1: Optional[torch.Tensor] = None,
|
1646
|
+
b2: Optional[torch.Tensor] = None,
|
1570
1647
|
inplace: bool = False,
|
1571
1648
|
activation: str = "silu",
|
1572
1649
|
apply_router_weight_on_input: bool = False,
|
@@ -1584,6 +1661,8 @@ def fused_moe(
|
|
1584
1661
|
block_shape: Optional[List[int]] = None,
|
1585
1662
|
no_combine: bool = False,
|
1586
1663
|
routed_scaling_factor: Optional[float] = None,
|
1664
|
+
activation_alpha: Optional[float] = None,
|
1665
|
+
swiglu_limit: Optional[float] = None,
|
1587
1666
|
) -> torch.Tensor:
|
1588
1667
|
"""
|
1589
1668
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
@@ -1594,6 +1673,8 @@ def fused_moe(
|
|
1594
1673
|
- w1 (torch.Tensor): The first set of expert weights.
|
1595
1674
|
- w2 (torch.Tensor): The second set of expert weights.
|
1596
1675
|
- topk_output (TopKOutput): The top-k output of the experts.
|
1676
|
+
- b1 (Optional[torch.Tensor]): Optional bias for w1.
|
1677
|
+
- b2 (Optional[torch.Tensor]): Optional bias for w2.
|
1597
1678
|
- inplace (bool): If True, perform the operation in-place.
|
1598
1679
|
Defaults to False.
|
1599
1680
|
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
@@ -1615,6 +1696,10 @@ def fused_moe(
|
|
1615
1696
|
a2.
|
1616
1697
|
- block_shape: (Optional[List[int]]): Optional block size for block-wise
|
1617
1698
|
quantization.
|
1699
|
+
- activation_alpha (Optional[float]): Optional alpha for the activation
|
1700
|
+
function.
|
1701
|
+
- swiglu_limit (Optional[float]): Optional limit for the swiglu activation
|
1702
|
+
function.
|
1618
1703
|
|
1619
1704
|
Returns:
|
1620
1705
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
@@ -1625,6 +1710,8 @@ def fused_moe(
|
|
1625
1710
|
w1,
|
1626
1711
|
w2,
|
1627
1712
|
topk_output,
|
1713
|
+
b1=b1,
|
1714
|
+
b2=b2,
|
1628
1715
|
inplace=inplace,
|
1629
1716
|
activation=activation,
|
1630
1717
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
@@ -1642,4 +1729,6 @@ def fused_moe(
|
|
1642
1729
|
block_shape=block_shape,
|
1643
1730
|
no_combine=no_combine,
|
1644
1731
|
routed_scaling_factor=routed_scaling_factor,
|
1732
|
+
activation_alpha=activation_alpha,
|
1733
|
+
swiglu_limit=swiglu_limit,
|
1645
1734
|
)
|