sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/srt/configs/model_config.py +36 -2
- sglang/srt/conversation.py +56 -3
- sglang/srt/disaggregation/ascend/__init__.py +6 -0
- sglang/srt/disaggregation/ascend/conn.py +44 -0
- sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
- sglang/srt/disaggregation/mooncake/conn.py +50 -18
- sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
- sglang/srt/disaggregation/utils.py +25 -3
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/http_server.py +1 -0
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +11 -0
- sglang/srt/entrypoints/openai/serving_chat.py +7 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/kimik2_detector.py +220 -0
- sglang/srt/hf_transformers_utils.py +18 -0
- sglang/srt/jinja_template_utils.py +8 -0
- sglang/srt/layers/communicator.py +20 -5
- sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
- sglang/srt/layers/layernorm.py +2 -2
- sglang/srt/layers/linear.py +12 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
- sglang/srt/layers/moe/ep_moe/layer.py +141 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/topk.py +8 -2
- sglang/srt/layers/parameter.py +19 -3
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +28 -7
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -2
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +738 -14
- sglang/srt/layers/vocab_parallel_embedding.py +9 -3
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/io_struct.py +35 -3
- sglang/srt/managers/mm_utils.py +59 -96
- sglang/srt/managers/schedule_batch.py +17 -6
- sglang/srt/managers/scheduler.py +38 -6
- sglang/srt/managers/tokenizer_manager.py +16 -0
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +176 -101
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/forward_batch_info.py +13 -1
- sglang/srt/model_loader/loader.py +23 -12
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +78 -19
- sglang/srt/models/deepseek_vl2.py +1 -1
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +6 -3
- sglang/srt/models/internvl.py +8 -2
- sglang/srt/models/kimi_vl.py +8 -2
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llava.py +3 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpmo.py +1 -2
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral_quant.py +4 -0
- sglang/srt/models/mllama4.py +372 -82
- sglang/srt/models/phi4mm.py +8 -2
- sglang/srt/models/phimoe.py +553 -0
- sglang/srt/models/qwen2.py +2 -0
- sglang/srt/models/qwen2_5_vl.py +10 -7
- sglang/srt/models/qwen2_vl.py +12 -1
- sglang/srt/models/vila.py +8 -2
- sglang/srt/multimodal/mm_utils.py +2 -2
- sglang/srt/multimodal/processors/base_processor.py +197 -137
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
- sglang/srt/multimodal/processors/gemma3.py +4 -2
- sglang/srt/multimodal/processors/gemma3n.py +1 -1
- sglang/srt/multimodal/processors/internvl.py +1 -1
- sglang/srt/multimodal/processors/janus_pro.py +1 -1
- sglang/srt/multimodal/processors/kimi_vl.py +1 -1
- sglang/srt/multimodal/processors/minicpm.py +4 -3
- sglang/srt/multimodal/processors/mllama4.py +63 -61
- sglang/srt/multimodal/processors/phi4mm.py +1 -1
- sglang/srt/multimodal/processors/pixtral.py +1 -1
- sglang/srt/multimodal/processors/qwen_vl.py +203 -80
- sglang/srt/multimodal/processors/vila.py +1 -1
- sglang/srt/server_args.py +26 -4
- sglang/srt/two_batch_overlap.py +3 -0
- sglang/srt/utils.py +191 -48
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
2
2
|
|
3
|
+
import importlib
|
3
4
|
from abc import abstractmethod
|
4
5
|
from enum import Enum
|
5
6
|
from typing import Callable, List, Optional, Tuple
|
@@ -19,6 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
19
20
|
QuantizationConfig,
|
20
21
|
QuantizeMethodBase,
|
21
22
|
)
|
23
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
22
24
|
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
23
25
|
from sglang.srt.utils import (
|
24
26
|
cpu_has_amx_support,
|
@@ -29,8 +31,15 @@ from sglang.srt.utils import (
|
|
29
31
|
use_intel_amx_backend,
|
30
32
|
)
|
31
33
|
|
34
|
+
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
35
|
+
|
32
36
|
if torch.cuda.is_available():
|
33
37
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
38
|
+
|
39
|
+
if has_triton_kernels:
|
40
|
+
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
41
|
+
triton_kernel_moe_forward,
|
42
|
+
)
|
34
43
|
else:
|
35
44
|
fused_experts = None # type: ignore
|
36
45
|
|
@@ -87,6 +96,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|
87
96
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
88
97
|
"""MoE method without quantization."""
|
89
98
|
|
99
|
+
def __init__(self, use_triton_kernels: bool = False):
|
100
|
+
super().__init__()
|
101
|
+
self.use_triton_kernels = use_triton_kernels
|
102
|
+
|
90
103
|
def create_weights(
|
91
104
|
self,
|
92
105
|
layer: torch.nn.Module,
|
@@ -97,20 +110,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
97
110
|
**extra_weight_attrs,
|
98
111
|
):
|
99
112
|
# Fused gate_up_proj (column parallel)
|
113
|
+
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
|
114
|
+
if self.use_triton_kernels:
|
115
|
+
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
|
100
116
|
w13_weight = torch.nn.Parameter(
|
101
|
-
torch.empty(
|
102
|
-
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
103
|
-
),
|
117
|
+
torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
|
104
118
|
requires_grad=False,
|
105
119
|
)
|
106
120
|
layer.register_parameter("w13_weight", w13_weight)
|
107
121
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
108
122
|
|
109
123
|
# down_proj (row parallel)
|
124
|
+
w2_weight_n, w2_weight_k = (
|
125
|
+
hidden_size,
|
126
|
+
intermediate_size,
|
127
|
+
)
|
128
|
+
if self.use_triton_kernels:
|
129
|
+
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
|
110
130
|
w2_weight = torch.nn.Parameter(
|
111
|
-
torch.empty(
|
112
|
-
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
113
|
-
),
|
131
|
+
torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
|
114
132
|
requires_grad=False,
|
115
133
|
)
|
116
134
|
layer.register_parameter("w2_weight", w2_weight)
|
@@ -192,59 +210,72 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
192
210
|
no_combine: bool = False,
|
193
211
|
routed_scaling_factor: Optional[float] = None,
|
194
212
|
) -> torch.Tensor:
|
195
|
-
topk_weights, topk_ids = select_experts(
|
196
|
-
hidden_states=x,
|
197
|
-
router_logits=router_logits,
|
198
|
-
use_grouped_topk=use_grouped_topk,
|
199
|
-
top_k=top_k,
|
200
|
-
renormalize=renormalize,
|
201
|
-
topk_group=topk_group,
|
202
|
-
num_expert_group=num_expert_group,
|
203
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
204
|
-
custom_routing_function=custom_routing_function,
|
205
|
-
correction_bias=correction_bias,
|
206
|
-
routed_scaling_factor=routed_scaling_factor,
|
207
|
-
)
|
208
213
|
|
209
|
-
if
|
210
|
-
|
211
|
-
if apply_router_weight_on_input:
|
212
|
-
assert (
|
213
|
-
topk_weights.dim() == 2
|
214
|
-
), "`topk_weights` should be in shape (num_tokens, topk)"
|
215
|
-
_, topk = topk_weights.shape
|
216
|
-
assert (
|
217
|
-
topk == 1
|
218
|
-
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
219
|
-
x = x * topk_weights.to(x.dtype)
|
220
|
-
topk_weights = torch.ones_like(
|
221
|
-
topk_weights, dtype=torch.float32
|
222
|
-
) # topk_weights must be FP32 (float32)
|
223
|
-
|
224
|
-
return fused_moe(
|
225
|
-
x,
|
226
|
-
layer.w13_weight,
|
227
|
-
layer.w2_weight,
|
228
|
-
topk_weights,
|
229
|
-
topk_ids,
|
230
|
-
activation=(
|
231
|
-
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
232
|
-
),
|
233
|
-
)
|
234
|
-
else:
|
235
|
-
return fused_experts(
|
214
|
+
if self.use_triton_kernels:
|
215
|
+
return triton_kernel_moe_forward(
|
236
216
|
hidden_states=x,
|
237
217
|
w1=layer.w13_weight,
|
238
218
|
w2=layer.w2_weight,
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
219
|
+
gating_output=router_logits,
|
220
|
+
topk=top_k,
|
221
|
+
renormalize=renormalize,
|
222
|
+
)
|
223
|
+
else:
|
224
|
+
topk_weights, topk_ids = select_experts(
|
225
|
+
hidden_states=x,
|
226
|
+
router_logits=router_logits,
|
227
|
+
use_grouped_topk=use_grouped_topk,
|
228
|
+
top_k=top_k,
|
229
|
+
renormalize=renormalize,
|
230
|
+
topk_group=topk_group,
|
231
|
+
num_expert_group=num_expert_group,
|
232
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
233
|
+
custom_routing_function=custom_routing_function,
|
234
|
+
correction_bias=correction_bias,
|
245
235
|
routed_scaling_factor=routed_scaling_factor,
|
246
236
|
)
|
247
237
|
|
238
|
+
if _use_aiter:
|
239
|
+
assert not no_combine, "unsupported"
|
240
|
+
if apply_router_weight_on_input:
|
241
|
+
assert (
|
242
|
+
topk_weights.dim() == 2
|
243
|
+
), "`topk_weights` should be in shape (num_tokens, topk)"
|
244
|
+
_, topk = topk_weights.shape
|
245
|
+
assert (
|
246
|
+
topk == 1
|
247
|
+
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
248
|
+
x = x * topk_weights.to(x.dtype)
|
249
|
+
topk_weights = torch.ones_like(
|
250
|
+
topk_weights, dtype=torch.float32
|
251
|
+
) # topk_weights must be FP32 (float32)
|
252
|
+
|
253
|
+
return fused_moe(
|
254
|
+
x,
|
255
|
+
layer.w13_weight,
|
256
|
+
layer.w2_weight,
|
257
|
+
topk_weights,
|
258
|
+
topk_ids,
|
259
|
+
activation=(
|
260
|
+
ActivationType.Silu
|
261
|
+
if activation == "silu"
|
262
|
+
else ActivationType.Gelu
|
263
|
+
),
|
264
|
+
)
|
265
|
+
else:
|
266
|
+
return fused_experts(
|
267
|
+
hidden_states=x,
|
268
|
+
w1=layer.w13_weight,
|
269
|
+
w2=layer.w2_weight,
|
270
|
+
topk_weights=topk_weights,
|
271
|
+
topk_ids=topk_ids,
|
272
|
+
inplace=inplace and not no_combine,
|
273
|
+
activation=activation,
|
274
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
275
|
+
no_combine=no_combine,
|
276
|
+
routed_scaling_factor=routed_scaling_factor,
|
277
|
+
)
|
278
|
+
|
248
279
|
def forward_cpu(
|
249
280
|
self,
|
250
281
|
layer: torch.nn.Module,
|
@@ -286,9 +317,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
286
317
|
x,
|
287
318
|
layer.w13_weight,
|
288
319
|
layer.w2_weight,
|
289
|
-
topk_weights
|
290
|
-
torch.float
|
291
|
-
), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
|
320
|
+
topk_weights,
|
292
321
|
topk_ids,
|
293
322
|
False, # inplace # See [Note] inplace should be False in fused_experts.
|
294
323
|
False, # use_int8_w8a8
|
@@ -475,9 +504,13 @@ class FusedMoE(torch.nn.Module):
|
|
475
504
|
self.inplace = inplace
|
476
505
|
self.no_combine = no_combine
|
477
506
|
|
507
|
+
self.use_triton_kernels = (
|
508
|
+
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
|
509
|
+
)
|
510
|
+
|
478
511
|
if quant_config is None:
|
479
|
-
self.quant_method: Optional[QuantizeMethodBase] = (
|
480
|
-
|
512
|
+
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
513
|
+
self.use_triton_kernels
|
481
514
|
)
|
482
515
|
else:
|
483
516
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
@@ -485,6 +518,7 @@ class FusedMoE(torch.nn.Module):
|
|
485
518
|
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
|
486
519
|
assert self.quant_method is not None
|
487
520
|
|
521
|
+
self.quant_config = quant_config
|
488
522
|
self.quant_method.create_weights(
|
489
523
|
layer=self,
|
490
524
|
num_experts=self.local_num_experts,
|
@@ -597,6 +631,8 @@ class FusedMoE(torch.nn.Module):
|
|
597
631
|
)
|
598
632
|
else:
|
599
633
|
if not self.use_presharded_weights:
|
634
|
+
if self.use_triton_kernels:
|
635
|
+
loaded_weight = loaded_weight.transpose(-2, -1)
|
600
636
|
loaded_weight = loaded_weight.narrow(
|
601
637
|
shard_dim, shard_size * tp_rank, shard_size
|
602
638
|
)
|
@@ -612,6 +648,31 @@ class FusedMoE(torch.nn.Module):
|
|
612
648
|
loaded_weight: torch.tensor,
|
613
649
|
tp_rank: int,
|
614
650
|
):
|
651
|
+
"""Load w2 weights for down projection.
|
652
|
+
|
653
|
+
Args:
|
654
|
+
expert_data: The expert data tensor to load into
|
655
|
+
shard_dim: The dimension to shard along
|
656
|
+
shard_id: The shard ID (must be "w2")
|
657
|
+
loaded_weight: The weight tensor to load from
|
658
|
+
tp_rank: The tensor parallel rank
|
659
|
+
"""
|
660
|
+
if not isinstance(expert_data, torch.Tensor) or not isinstance(
|
661
|
+
loaded_weight, torch.Tensor
|
662
|
+
):
|
663
|
+
raise ValueError("expert_data and loaded_weight must be torch.Tensor")
|
664
|
+
|
665
|
+
if (
|
666
|
+
self.quant_config is not None
|
667
|
+
and "modelopt" in self.quant_config.get_name()
|
668
|
+
and (expert_data.dim() != 2 or loaded_weight.dim() != 2)
|
669
|
+
):
|
670
|
+
raise ValueError(
|
671
|
+
f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
|
672
|
+
)
|
673
|
+
|
674
|
+
if shard_id != "w2":
|
675
|
+
raise ValueError(f"shard_id must be 'w2', got {shard_id}")
|
615
676
|
|
616
677
|
# Index the loaded weight for tp sharding.
|
617
678
|
# down_proj: "RowParallel" so tp sharding on input_dim
|
@@ -630,6 +691,12 @@ class FusedMoE(torch.nn.Module):
|
|
630
691
|
)
|
631
692
|
else:
|
632
693
|
if not self.use_presharded_weights:
|
694
|
+
if self.use_triton_kernels:
|
695
|
+
loaded_weight = loaded_weight.transpose(-2, -1)
|
696
|
+
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
|
697
|
+
raise ValueError(
|
698
|
+
f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
|
699
|
+
)
|
633
700
|
loaded_weight = loaded_weight.narrow(
|
634
701
|
shard_dim, shard_size * tp_rank, shard_size
|
635
702
|
)
|
@@ -716,6 +783,8 @@ class FusedMoE(torch.nn.Module):
|
|
716
783
|
# should be whatever dimension intermediate_size is
|
717
784
|
is_transposed = getattr(param, "is_transposed", False)
|
718
785
|
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
786
|
+
if self.use_triton_kernels:
|
787
|
+
is_transposed = True
|
719
788
|
if is_transposed:
|
720
789
|
shard_dim = int(not shard_dim)
|
721
790
|
|
@@ -754,8 +823,21 @@ class FusedMoE(torch.nn.Module):
|
|
754
823
|
tp_rank=tp_rank,
|
755
824
|
)
|
756
825
|
return
|
826
|
+
|
757
827
|
if "ModelOpt" in self.quant_method.__class__.__name__:
|
758
|
-
|
828
|
+
# Determine per-tensor weight scale patterns based on variant
|
829
|
+
is_fp4_variant = (
|
830
|
+
"ModelOptNvFp4FusedMoEMethod" in self.quant_method.__class__.__name__
|
831
|
+
)
|
832
|
+
|
833
|
+
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
|
834
|
+
per_tensor_conditions = (
|
835
|
+
"weight_scale_2" in weight_name
|
836
|
+
if is_fp4_variant
|
837
|
+
else "weight_scale" in weight_name
|
838
|
+
) or "input_scale" in weight_name
|
839
|
+
|
840
|
+
if per_tensor_conditions:
|
759
841
|
self._load_per_tensor_weight_scale(
|
760
842
|
shard_id=shard_id,
|
761
843
|
param=param,
|
@@ -773,7 +855,7 @@ class FusedMoE(torch.nn.Module):
|
|
773
855
|
return
|
774
856
|
|
775
857
|
# Case weight scales and zero_points
|
776
|
-
if "scale" in weight_name or "zero" in weight_name:
|
858
|
+
if "scale" in weight_name or "zero" in weight_name or "offset" in weight_name:
|
777
859
|
# load the weight scales and zp based on the quantization scheme
|
778
860
|
# supported weight scales/zp can be found in
|
779
861
|
# FusedMoeWeightScaleSupported
|
@@ -0,0 +1,176 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from sgl_kernel import gelu_and_mul, silu_and_mul
|
6
|
+
from triton_kernels.matmul_ogs import matmul_ogs
|
7
|
+
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
8
|
+
|
9
|
+
from sglang.srt.utils import direct_register_custom_op
|
10
|
+
|
11
|
+
|
12
|
+
def triton_kernel_moe_forward(
|
13
|
+
hidden_states: torch.Tensor,
|
14
|
+
w1: torch.Tensor,
|
15
|
+
w2: torch.Tensor,
|
16
|
+
gating_output: torch.Tensor,
|
17
|
+
topk: int,
|
18
|
+
renormalize: bool,
|
19
|
+
inplace: bool = False,
|
20
|
+
activation: str = "silu",
|
21
|
+
apply_router_weight_on_input: bool = False,
|
22
|
+
use_fp8_w8a8: bool = False,
|
23
|
+
per_channel_quant: bool = False,
|
24
|
+
global_num_experts: int = -1,
|
25
|
+
expert_map: Optional[torch.Tensor] = None,
|
26
|
+
w1_scale: Optional[torch.Tensor] = None,
|
27
|
+
w2_scale: Optional[torch.Tensor] = None,
|
28
|
+
a1_scale: Optional[torch.Tensor] = None,
|
29
|
+
a2_scale: Optional[torch.Tensor] = None,
|
30
|
+
block_shape: Optional[list[int]] = None,
|
31
|
+
) -> torch.Tensor:
|
32
|
+
|
33
|
+
if not renormalize:
|
34
|
+
gating_output = torch.softmax(gating_output, dim=-1)
|
35
|
+
routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize)
|
36
|
+
|
37
|
+
return triton_kernel_fused_experts(
|
38
|
+
hidden_states,
|
39
|
+
w1,
|
40
|
+
w2,
|
41
|
+
routing_data,
|
42
|
+
gather_idx,
|
43
|
+
scatter_idx,
|
44
|
+
inplace=inplace,
|
45
|
+
activation=activation,
|
46
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
47
|
+
use_fp8_w8a8=use_fp8_w8a8,
|
48
|
+
per_channel_quant=per_channel_quant,
|
49
|
+
global_num_experts=global_num_experts,
|
50
|
+
expert_map=expert_map,
|
51
|
+
w1_scale=w1_scale,
|
52
|
+
w2_scale=w2_scale,
|
53
|
+
a1_scale=a1_scale,
|
54
|
+
a2_scale=a2_scale,
|
55
|
+
block_shape=block_shape,
|
56
|
+
)
|
57
|
+
|
58
|
+
|
59
|
+
# This is a triton implementation of the fused_experts function
|
60
|
+
def triton_kernel_fused_experts(
|
61
|
+
hidden_states: torch.Tensor,
|
62
|
+
w1: torch.Tensor,
|
63
|
+
w2: torch.Tensor,
|
64
|
+
routing_data: RoutingData,
|
65
|
+
gather_indx: GatherIndx,
|
66
|
+
scatter_indx: ScatterIndx,
|
67
|
+
inplace: bool = False,
|
68
|
+
activation: str = "silu",
|
69
|
+
apply_router_weight_on_input: bool = False,
|
70
|
+
use_fp8_w8a8: bool = False,
|
71
|
+
per_channel_quant: bool = False,
|
72
|
+
global_num_experts: int = -1,
|
73
|
+
expert_map: Optional[torch.Tensor] = None,
|
74
|
+
w1_scale: Optional[torch.Tensor] = None,
|
75
|
+
w2_scale: Optional[torch.Tensor] = None,
|
76
|
+
a1_scale: Optional[torch.Tensor] = None,
|
77
|
+
a2_scale: Optional[torch.Tensor] = None,
|
78
|
+
block_shape: Optional[list[int]] = None,
|
79
|
+
) -> torch.Tensor:
|
80
|
+
|
81
|
+
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
|
82
|
+
assert per_channel_quant == False, "per_channel_quant is not supported"
|
83
|
+
assert expert_map == None, "expert_map is not supported"
|
84
|
+
assert w1_scale == None, "w1_scale is not supported"
|
85
|
+
assert w2_scale == None, "w2_scale is not supported"
|
86
|
+
assert a1_scale == None, "a1_scale is not supported"
|
87
|
+
assert a2_scale == None, "a2_scale is not supported"
|
88
|
+
assert block_shape == None, "block_shape is not supported"
|
89
|
+
|
90
|
+
# type check
|
91
|
+
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
|
92
|
+
assert w1.dtype == torch.bfloat16, "w1 must be bfloat16"
|
93
|
+
assert w2.dtype == torch.bfloat16, "w2 must be bfloat16"
|
94
|
+
|
95
|
+
# Shape check
|
96
|
+
assert hidden_states.ndim == 2, "hidden_states must be 2D"
|
97
|
+
assert (
|
98
|
+
hidden_states.shape[-1] == w1.shape[-2]
|
99
|
+
), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
|
100
|
+
assert (
|
101
|
+
w2.shape[-1] == w1.shape[1]
|
102
|
+
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
|
103
|
+
|
104
|
+
# feature check
|
105
|
+
assert inplace == False, "Inplace is not supported in new triton MoE kernel"
|
106
|
+
|
107
|
+
M, K = hidden_states.shape
|
108
|
+
E, _, N = w1.shape
|
109
|
+
n_expts_act = routing_data.n_expts_act
|
110
|
+
dtype = hidden_states.dtype
|
111
|
+
|
112
|
+
if global_num_experts == -1:
|
113
|
+
global_num_experts = E
|
114
|
+
|
115
|
+
# consistent with default implementation
|
116
|
+
intermediate_cache2 = torch.empty(
|
117
|
+
(M * n_expts_act, N // 2), device="cuda", dtype=dtype
|
118
|
+
)
|
119
|
+
|
120
|
+
intermediate_cache1 = matmul_ogs(
|
121
|
+
hidden_states,
|
122
|
+
w1,
|
123
|
+
None,
|
124
|
+
routing_data,
|
125
|
+
gather_indx=gather_indx,
|
126
|
+
gammas=routing_data.gate_scal if apply_router_weight_on_input else None,
|
127
|
+
)
|
128
|
+
|
129
|
+
if activation == "silu":
|
130
|
+
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
131
|
+
elif activation == "gelu":
|
132
|
+
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
133
|
+
else:
|
134
|
+
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
135
|
+
|
136
|
+
intermediate_cache3 = matmul_ogs(
|
137
|
+
intermediate_cache2,
|
138
|
+
w2,
|
139
|
+
None,
|
140
|
+
routing_data,
|
141
|
+
scatter_indx=scatter_indx,
|
142
|
+
gammas=None if apply_router_weight_on_input else routing_data.gate_scal,
|
143
|
+
)
|
144
|
+
|
145
|
+
return intermediate_cache3
|
146
|
+
|
147
|
+
|
148
|
+
def triton_kernel_moe_forward_fake(
|
149
|
+
hidden_states: torch.Tensor,
|
150
|
+
w1: torch.Tensor,
|
151
|
+
w2: torch.Tensor,
|
152
|
+
gating_output: torch.Tensor,
|
153
|
+
topk: int,
|
154
|
+
renormalize: bool,
|
155
|
+
inplace: bool = False,
|
156
|
+
activation: str = "silu",
|
157
|
+
apply_router_weight_on_input: bool = False,
|
158
|
+
use_fp8_w8a8: bool = False,
|
159
|
+
per_channel_quant: bool = False,
|
160
|
+
global_num_experts: int = -1,
|
161
|
+
expert_map: Optional[torch.Tensor] = None,
|
162
|
+
w1_scale: Optional[torch.Tensor] = None,
|
163
|
+
w2_scale: Optional[torch.Tensor] = None,
|
164
|
+
a1_scale: Optional[torch.Tensor] = None,
|
165
|
+
a2_scale: Optional[torch.Tensor] = None,
|
166
|
+
block_shape: Optional[list[int]] = None,
|
167
|
+
) -> torch.Tensor:
|
168
|
+
return torch.empty_like(hidden_states)
|
169
|
+
|
170
|
+
|
171
|
+
direct_register_custom_op(
|
172
|
+
op_name="forward_cuda_triton",
|
173
|
+
op_func=triton_kernel_moe_forward,
|
174
|
+
mutates_args=[],
|
175
|
+
fake_impl=triton_kernel_moe_forward_fake,
|
176
|
+
)
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -83,13 +83,18 @@ def fused_topk_cpu(
|
|
83
83
|
gating_output: torch.Tensor,
|
84
84
|
topk: int,
|
85
85
|
renormalize: bool,
|
86
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
87
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
86
88
|
):
|
87
|
-
|
89
|
+
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
|
88
90
|
hidden_states=hidden_states,
|
89
91
|
gating_output=gating_output,
|
90
92
|
topk=topk,
|
91
93
|
renormalize=renormalize,
|
92
94
|
)
|
95
|
+
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
96
|
+
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
97
|
+
return topk_weights, topk_ids
|
93
98
|
|
94
99
|
|
95
100
|
def fused_topk(
|
@@ -303,7 +308,7 @@ def biased_grouped_topk_gpu(
|
|
303
308
|
renormalize: bool,
|
304
309
|
num_expert_group: int = 0,
|
305
310
|
topk_group: int = 0,
|
306
|
-
compiled: bool =
|
311
|
+
compiled: bool = not _is_npu,
|
307
312
|
num_fused_shared_experts: int = 0,
|
308
313
|
routed_scaling_factor: Optional[float] = None,
|
309
314
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
@@ -411,6 +416,7 @@ if _is_cpu and _is_cpu_amx_available:
|
|
411
416
|
biased_grouped_topk = biased_grouped_topk_cpu
|
412
417
|
grouped_topk = grouped_topk_cpu
|
413
418
|
fused_topk_native = fused_topk_cpu
|
419
|
+
fused_topk = fused_topk_cpu
|
414
420
|
else:
|
415
421
|
biased_grouped_topk = biased_grouped_topk_gpu
|
416
422
|
grouped_topk = grouped_topk_gpu
|
sglang/srt/layers/parameter.py
CHANGED
@@ -187,10 +187,26 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
187
187
|
param_data = self.data
|
188
188
|
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
189
189
|
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
190
|
-
|
191
|
-
|
192
|
-
|
190
|
+
|
191
|
+
if _is_cpu:
|
192
|
+
from sglang.srt.model_loader.weight_utils import (
|
193
|
+
narrow_padded_param_and_loaded_weight,
|
194
|
+
)
|
195
|
+
|
196
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
197
|
+
param_data,
|
198
|
+
loaded_weight,
|
199
|
+
0, # param_data_start
|
200
|
+
shard_id * shard_size,
|
201
|
+
self.output_dim,
|
202
|
+
shard_size,
|
203
|
+
not use_presharded_weights,
|
193
204
|
)
|
205
|
+
else:
|
206
|
+
if not use_presharded_weights:
|
207
|
+
loaded_weight = loaded_weight.narrow(
|
208
|
+
self.output_dim, shard_id * shard_size, shard_size
|
209
|
+
)
|
194
210
|
|
195
211
|
assert (
|
196
212
|
param_data.shape == loaded_weight.shape
|
@@ -68,6 +68,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
|
68
68
|
)
|
69
69
|
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
70
70
|
from sglang.srt.layers.quantization.qoq import QoQConfig
|
71
|
+
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
71
72
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
72
73
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
73
74
|
|
@@ -82,6 +83,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
82
83
|
"moe_wna16": MoeWNA16Config,
|
83
84
|
"compressed-tensors": CompressedTensorsConfig,
|
84
85
|
"qoq": QoQConfig,
|
86
|
+
"w4afp8": W4AFp8Config,
|
85
87
|
}
|
86
88
|
|
87
89
|
# VLLM-dependent quantization methods
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from typing import Any, Callable, Dict, List, Optional
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch.nn.functional as F
|
@@ -88,7 +88,7 @@ _is_fp8_fnuz = is_fp8_fnuz()
|
|
88
88
|
_use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
|
89
89
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
90
90
|
|
91
|
-
if _is_hip:
|
91
|
+
if _is_hip and (_use_aiter or _use_hip_int4):
|
92
92
|
from aiter import ActivationType, QuantType
|
93
93
|
from aiter.fused_moe import fused_moe
|
94
94
|
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
@@ -200,7 +200,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
200
200
|
quant_config: The quantization config.
|
201
201
|
"""
|
202
202
|
|
203
|
-
def __init__(self, quant_config: Fp8Config):
|
203
|
+
def __init__(self, quant_config: Union["Fp8Config", "W4AFp8Config"]):
|
204
204
|
self.quant_config = quant_config
|
205
205
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
206
206
|
|
@@ -286,7 +286,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
286
286
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
287
287
|
# WEIGHT SCALE
|
288
288
|
if self.block_quant:
|
289
|
-
|
289
|
+
if hasattr(self.quant_config, "activation_scheme"):
|
290
|
+
assert self.quant_config.activation_scheme == "dynamic"
|
291
|
+
elif hasattr(self.quant_config, "linear_activation_scheme"):
|
292
|
+
assert self.quant_config.linear_activation_scheme == "dynamic"
|
290
293
|
scale = BlockQuantScaleParameter(
|
291
294
|
data=torch.empty(
|
292
295
|
(output_size_per_partition + block_n - 1) // block_n,
|
@@ -308,7 +311,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
308
311
|
layer.register_parameter("weight_scale", scale)
|
309
312
|
|
310
313
|
# INPUT ACTIVATION SCALE
|
311
|
-
if
|
314
|
+
if (
|
315
|
+
hasattr(self.quant_config, "activation_scheme")
|
316
|
+
and self.quant_config.activation_scheme == "static"
|
317
|
+
) or (
|
318
|
+
hasattr(self.quant_config, "linear_activation_scheme")
|
319
|
+
and self.quant_config.linear_activation_scheme == "static"
|
320
|
+
):
|
312
321
|
scale = PerTensorScaleParameter(
|
313
322
|
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
314
323
|
weight_loader=weight_loader,
|
@@ -371,7 +380,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
371
380
|
layer.weight_scale = torch.nn.Parameter(
|
372
381
|
layer.weight_scale.data, requires_grad=False
|
373
382
|
)
|
374
|
-
if
|
383
|
+
if (
|
384
|
+
hasattr(self.quant_config, "activation_scheme")
|
385
|
+
and self.quant_config.activation_scheme == "static"
|
386
|
+
) or (
|
387
|
+
hasattr(self.quant_config, "linear_activation_scheme")
|
388
|
+
and self.quant_config.linear_activation_scheme == "static"
|
389
|
+
):
|
375
390
|
layer.input_scale = torch.nn.Parameter(
|
376
391
|
layer.input_scale.data, requires_grad=False
|
377
392
|
)
|
@@ -405,7 +420,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
405
420
|
# Update layer with new values.
|
406
421
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
407
422
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
408
|
-
if
|
423
|
+
if (
|
424
|
+
hasattr(self.quant_config, "activation_scheme")
|
425
|
+
and self.quant_config.activation_scheme == "static"
|
426
|
+
) or (
|
427
|
+
hasattr(self.quant_config, "linear_activation_scheme")
|
428
|
+
and self.quant_config.linear_activation_scheme == "static"
|
429
|
+
):
|
409
430
|
layer.input_scale = Parameter(
|
410
431
|
layer.input_scale.max(), requires_grad=False
|
411
432
|
)
|
@@ -160,8 +160,8 @@ def _per_token_group_quant_fp8_colmajor(
|
|
160
160
|
"""
|
161
161
|
# Map the program id to the row of X and Y it should compute.
|
162
162
|
g_id = tl.program_id(0)
|
163
|
-
y_ptr += g_id * group_size
|
164
|
-
y_q_ptr += g_id * group_size
|
163
|
+
y_ptr += g_id.to(tl.int64) * group_size
|
164
|
+
y_q_ptr += g_id.to(tl.int64) * group_size
|
165
165
|
|
166
166
|
# Convert g_id the flattened block coordinate to 2D so we can index
|
167
167
|
# into the output y_scales matrix
|