sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +13 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +12 -16
- sglang/srt/disaggregation/prefill.py +17 -13
- sglang/srt/disaggregation/utils.py +46 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +22 -28
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +21 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +19 -9
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +207 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +91 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/io_struct.py +9 -12
- sglang/srt/managers/schedule_batch.py +40 -31
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +147 -62
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +76 -45
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +22 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +108 -26
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +36 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/utils.py +177 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
sglang/srt/layers/moe/topk.py
CHANGED
@@ -28,10 +28,18 @@ from sglang.srt.managers.expert_location_dispatch import (
|
|
28
28
|
topk_ids_logical_to_physical,
|
29
29
|
)
|
30
30
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
31
|
-
from sglang.srt.utils import
|
31
|
+
from sglang.srt.utils import (
|
32
|
+
cpu_has_amx_support,
|
33
|
+
get_compiler_backend,
|
34
|
+
is_cpu,
|
35
|
+
is_cuda,
|
36
|
+
is_hip,
|
37
|
+
)
|
32
38
|
|
33
39
|
_is_cuda = is_cuda()
|
34
40
|
_is_hip = is_hip()
|
41
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
42
|
+
_is_cpu = is_cpu()
|
35
43
|
|
36
44
|
if _is_cuda:
|
37
45
|
from sgl_kernel import moe_fused_gate
|
@@ -40,7 +48,7 @@ if _is_cuda or _is_hip:
|
|
40
48
|
from sgl_kernel import topk_softmax
|
41
49
|
|
42
50
|
|
43
|
-
def
|
51
|
+
def fused_topk_torch_native(
|
44
52
|
hidden_states: torch.Tensor,
|
45
53
|
gating_output: torch.Tensor,
|
46
54
|
topk: int,
|
@@ -61,6 +69,20 @@ def fused_topk_native(
|
|
61
69
|
return topk_weights, topk_ids
|
62
70
|
|
63
71
|
|
72
|
+
def fused_topk_cpu(
|
73
|
+
hidden_states: torch.Tensor,
|
74
|
+
gating_output: torch.Tensor,
|
75
|
+
topk: int,
|
76
|
+
renormalize: bool,
|
77
|
+
):
|
78
|
+
return torch.ops.sgl_kernel.topk_softmax_cpu(
|
79
|
+
hidden_states=hidden_states,
|
80
|
+
gating_output=gating_output,
|
81
|
+
topk=topk,
|
82
|
+
renormalize=renormalize,
|
83
|
+
)
|
84
|
+
|
85
|
+
|
64
86
|
def fused_topk(
|
65
87
|
hidden_states: torch.Tensor,
|
66
88
|
gating_output: torch.Tensor,
|
@@ -115,7 +137,7 @@ def _fused_topk_postprocess(
|
|
115
137
|
|
116
138
|
# This is used by the Deepseek V2/V3/R1 series models
|
117
139
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
118
|
-
def
|
140
|
+
def grouped_topk_gpu(
|
119
141
|
hidden_states: torch.Tensor,
|
120
142
|
gating_output: torch.Tensor,
|
121
143
|
topk: int,
|
@@ -171,6 +193,32 @@ def grouped_topk(
|
|
171
193
|
return topk_weights, topk_ids
|
172
194
|
|
173
195
|
|
196
|
+
def grouped_topk_cpu(
|
197
|
+
hidden_states: torch.Tensor,
|
198
|
+
gating_output: torch.Tensor,
|
199
|
+
topk: int,
|
200
|
+
renormalize: bool,
|
201
|
+
num_expert_group: int = 0,
|
202
|
+
topk_group: int = 0,
|
203
|
+
num_fused_shared_experts: int = 0,
|
204
|
+
routed_scaling_factor: Optional[float] = None,
|
205
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
206
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
207
|
+
):
|
208
|
+
assert expert_location_dispatch_info is None
|
209
|
+
return torch.ops.sgl_kernel.grouped_topk_cpu(
|
210
|
+
hidden_states,
|
211
|
+
gating_output,
|
212
|
+
topk,
|
213
|
+
renormalize,
|
214
|
+
num_expert_group,
|
215
|
+
topk_group,
|
216
|
+
num_fused_shared_experts,
|
217
|
+
routed_scaling_factor,
|
218
|
+
num_token_non_padded,
|
219
|
+
)
|
220
|
+
|
221
|
+
|
174
222
|
def biased_grouped_topk_impl(
|
175
223
|
hidden_states: torch.Tensor,
|
176
224
|
gating_output: torch.Tensor,
|
@@ -258,7 +306,7 @@ def _biased_grouped_topk_postprocess(
|
|
258
306
|
return topk_ids
|
259
307
|
|
260
308
|
|
261
|
-
def
|
309
|
+
def biased_grouped_topk_gpu(
|
262
310
|
hidden_states: torch.Tensor,
|
263
311
|
gating_output: torch.Tensor,
|
264
312
|
correction_bias: torch.Tensor,
|
@@ -322,6 +370,45 @@ def biased_grouped_topk(
|
|
322
370
|
)
|
323
371
|
|
324
372
|
|
373
|
+
def biased_grouped_topk_cpu(
|
374
|
+
hidden_states: torch.Tensor,
|
375
|
+
gating_output: torch.Tensor,
|
376
|
+
correction_bias: torch.Tensor,
|
377
|
+
topk: int,
|
378
|
+
renormalize: bool,
|
379
|
+
num_expert_group: int = 0,
|
380
|
+
topk_group: int = 0,
|
381
|
+
compiled: bool = True,
|
382
|
+
num_fused_shared_experts: int = 0,
|
383
|
+
routed_scaling_factor: Optional[float] = None,
|
384
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
385
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
386
|
+
):
|
387
|
+
assert expert_location_dispatch_info is None
|
388
|
+
return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
|
389
|
+
hidden_states,
|
390
|
+
gating_output,
|
391
|
+
correction_bias,
|
392
|
+
topk,
|
393
|
+
renormalize,
|
394
|
+
num_expert_group,
|
395
|
+
topk_group,
|
396
|
+
num_fused_shared_experts,
|
397
|
+
routed_scaling_factor,
|
398
|
+
num_token_non_padded,
|
399
|
+
)
|
400
|
+
|
401
|
+
|
402
|
+
if _is_cpu and _is_cpu_amx_available:
|
403
|
+
biased_grouped_topk = biased_grouped_topk_cpu
|
404
|
+
grouped_topk = grouped_topk_cpu
|
405
|
+
fused_topk_native = fused_topk_cpu
|
406
|
+
else:
|
407
|
+
biased_grouped_topk = biased_grouped_topk_gpu
|
408
|
+
grouped_topk = grouped_topk_gpu
|
409
|
+
fused_topk_native = fused_topk_torch_native
|
410
|
+
|
411
|
+
|
325
412
|
def select_experts(
|
326
413
|
hidden_states: torch.Tensor,
|
327
414
|
router_logits: torch.Tensor,
|
@@ -14,14 +14,18 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_qu
|
|
14
14
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
15
15
|
from sglang.srt.layers.quantization.utils import (
|
16
16
|
all_close_1d,
|
17
|
+
cpu_has_amx_support,
|
17
18
|
per_tensor_dequantize,
|
18
19
|
replace_parameter,
|
19
20
|
)
|
20
|
-
from sglang.srt.utils import is_cuda, set_weight_attrs
|
21
|
+
from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
|
21
22
|
|
22
23
|
_is_cuda = is_cuda()
|
24
|
+
_is_npu = is_npu()
|
25
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
26
|
+
_is_cpu = is_cpu()
|
23
27
|
|
24
|
-
if not _is_cuda:
|
28
|
+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
25
29
|
from vllm import _custom_ops as vllm_ops
|
26
30
|
from vllm._custom_ops import scaled_fp8_quant
|
27
31
|
|
@@ -64,9 +64,12 @@ from sglang.srt.layers.quantization.utils import (
|
|
64
64
|
)
|
65
65
|
from sglang.srt.layers.utils import is_sm100_supported
|
66
66
|
from sglang.srt.utils import (
|
67
|
+
cpu_has_amx_support,
|
67
68
|
get_bool_env_var,
|
69
|
+
is_cpu,
|
68
70
|
is_cuda,
|
69
71
|
is_hip,
|
72
|
+
is_npu,
|
70
73
|
log_info_on_rank0,
|
71
74
|
print_warning_once,
|
72
75
|
set_weight_attrs,
|
@@ -74,6 +77,9 @@ from sglang.srt.utils import (
|
|
74
77
|
|
75
78
|
_is_hip = is_hip()
|
76
79
|
_is_cuda = is_cuda()
|
80
|
+
_is_npu = is_npu()
|
81
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
82
|
+
_is_cpu = is_cpu()
|
77
83
|
|
78
84
|
_is_fp8_fnuz = is_fp8_fnuz()
|
79
85
|
|
@@ -82,10 +88,11 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
82
88
|
|
83
89
|
if _is_hip:
|
84
90
|
from aiter import ActivationType, QuantType
|
91
|
+
from aiter.fused_moe import fused_moe
|
85
92
|
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
86
93
|
from aiter.ops.shuffle import shuffle_weight
|
87
94
|
|
88
|
-
if not _is_cuda:
|
95
|
+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
89
96
|
from vllm._custom_ops import scaled_fp8_quant
|
90
97
|
|
91
98
|
|
@@ -1045,15 +1052,15 @@ class Fp8MoEMethod:
|
|
1045
1052
|
if _use_hip_int4:
|
1046
1053
|
# TODO: add triton kernel and add check _use_aiter
|
1047
1054
|
assert not no_combine, f"{no_combine=} is not supported."
|
1048
|
-
return
|
1055
|
+
return fused_moe(
|
1049
1056
|
x,
|
1050
1057
|
layer.w13_weight,
|
1051
1058
|
layer.w2_weight,
|
1052
1059
|
topk_weights,
|
1053
1060
|
topk_ids,
|
1054
|
-
QuantType.per_Token,
|
1055
|
-
layer.w13_weight_scale1,
|
1056
|
-
layer.w2_weight_scale1,
|
1061
|
+
quant_type=QuantType.per_Token,
|
1062
|
+
w1_scale=layer.w13_weight_scale1,
|
1063
|
+
w2_scale=layer.w2_weight_scale1,
|
1057
1064
|
activation=(
|
1058
1065
|
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
1059
1066
|
),
|
@@ -1062,31 +1069,32 @@ class Fp8MoEMethod:
|
|
1062
1069
|
if _use_aiter:
|
1063
1070
|
assert not no_combine, f"{no_combine=} is not supported."
|
1064
1071
|
if self.block_quant:
|
1065
|
-
|
1066
|
-
assert (
|
1067
|
-
activation == "silu"
|
1068
|
-
), f"_use_aiter: FP8 bloack_quant {activation=} will be supported later, unset _use_aiter"
|
1069
|
-
return asm_moe(
|
1072
|
+
return fused_moe(
|
1070
1073
|
x,
|
1071
1074
|
layer.w13_weight,
|
1072
1075
|
layer.w2_weight,
|
1073
1076
|
topk_weights,
|
1074
1077
|
topk_ids,
|
1075
|
-
layer.w13_weight_scale_inv,
|
1076
|
-
layer.w2_weight_scale_inv,
|
1077
|
-
|
1078
|
+
w1_scale=layer.w13_weight_scale_inv,
|
1079
|
+
w2_scale=layer.w2_weight_scale_inv,
|
1080
|
+
quant_type=QuantType.per_128x128,
|
1081
|
+
activation=(
|
1082
|
+
ActivationType.Silu
|
1083
|
+
if activation == "silu"
|
1084
|
+
else ActivationType.Gelu
|
1085
|
+
),
|
1078
1086
|
expert_mask=None,
|
1079
1087
|
)
|
1080
1088
|
else:
|
1081
|
-
return
|
1089
|
+
return fused_moe(
|
1082
1090
|
x,
|
1083
1091
|
layer.w13_weight,
|
1084
1092
|
layer.w2_weight,
|
1085
1093
|
topk_weights,
|
1086
1094
|
topk_ids,
|
1087
|
-
QuantType.per_Token,
|
1088
|
-
layer.w13_weight_scale1,
|
1089
|
-
layer.w2_weight_scale1,
|
1095
|
+
quant_type=QuantType.per_Token,
|
1096
|
+
w1_scale=layer.w13_weight_scale1,
|
1097
|
+
w2_scale=layer.w2_weight_scale1,
|
1090
1098
|
activation=(
|
1091
1099
|
ActivationType.Silu
|
1092
1100
|
if activation == "silu"
|
@@ -29,11 +29,17 @@ from sglang.srt.layers.quantization.utils import (
|
|
29
29
|
requantize_with_max_scale,
|
30
30
|
)
|
31
31
|
from sglang.srt.layers.radix_attention import RadixAttention
|
32
|
-
from sglang.srt.utils import is_cuda
|
32
|
+
from sglang.srt.utils import is_cuda, next_power_of_2
|
33
33
|
|
34
34
|
if is_cuda():
|
35
35
|
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
36
36
|
|
37
|
+
try:
|
38
|
+
from flashinfer import fp4_quantize as fp4_quantize
|
39
|
+
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
|
40
|
+
except ImportError:
|
41
|
+
flashinfer_cutlass_fused_moe = None
|
42
|
+
|
37
43
|
# Initialize logger for the module
|
38
44
|
logger = logging.getLogger(__name__)
|
39
45
|
|
@@ -429,6 +435,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|
429
435
|
layer.alpha = Parameter(
|
430
436
|
layer.input_scale * layer.weight_scale_2, requires_grad=False
|
431
437
|
)
|
438
|
+
layer.input_scale_inv = Parameter(
|
439
|
+
(1 / input_scale_2).to(torch.float32), requires_grad=False
|
440
|
+
)
|
432
441
|
|
433
442
|
# Pad and blockwise interleave weight_scale
|
434
443
|
scales = layer.weight_scale
|
@@ -467,7 +476,7 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|
467
476
|
output_shape = [x_m, w_n]
|
468
477
|
|
469
478
|
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
470
|
-
x_fp4, x_scale_interleaved = scaled_fp4_quant(x,
|
479
|
+
x_fp4, x_scale_interleaved = scaled_fp4_quant(x, layer.input_scale_inv)
|
471
480
|
|
472
481
|
assert x_fp4.dtype == torch.uint8
|
473
482
|
assert x_scale_interleaved.dtype == torch.float8_e4m3fn
|
@@ -521,6 +530,7 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
521
530
|
" quantization. Please use Blackwell and"
|
522
531
|
" above."
|
523
532
|
)
|
533
|
+
self.enable_flashinfer_moe = False
|
524
534
|
|
525
535
|
def create_weights(
|
526
536
|
self,
|
@@ -674,7 +684,10 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
674
684
|
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
675
685
|
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
|
676
686
|
|
677
|
-
|
687
|
+
if self.enable_flashinfer_moe:
|
688
|
+
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
689
|
+
else:
|
690
|
+
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
678
691
|
layer.g1_alphas = Parameter(
|
679
692
|
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
680
693
|
requires_grad=False,
|
@@ -700,14 +713,19 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
700
713
|
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
701
714
|
|
702
715
|
# GEMM 2
|
716
|
+
if self.enable_flashinfer_moe:
|
717
|
+
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
718
|
+
else:
|
719
|
+
w2_input_scale = layer.w2_input_scale
|
720
|
+
|
703
721
|
layer.g2_alphas = Parameter(
|
704
|
-
(
|
722
|
+
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
705
723
|
requires_grad=False,
|
706
724
|
)
|
707
725
|
|
708
726
|
# This is for quantization, so we need to invert it.
|
709
727
|
layer.w2_input_scale_quant = Parameter(
|
710
|
-
(1 /
|
728
|
+
(1 / w2_input_scale).to(torch.float32), requires_grad=False
|
711
729
|
)
|
712
730
|
|
713
731
|
assert (
|
@@ -727,11 +745,16 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
727
745
|
layer.cutlass_moe_params = CutlassMoEParams(
|
728
746
|
CutlassMoEType.BlockscaledFP4,
|
729
747
|
device,
|
730
|
-
num_experts=layer.num_experts,
|
748
|
+
num_experts=layer.num_experts, # global num experts
|
731
749
|
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
|
732
750
|
hidden_size=layer.w13_weight.shape[2] * 2,
|
733
751
|
) # k
|
734
752
|
|
753
|
+
@property
|
754
|
+
def load_up_proj_weight_first(self) -> bool:
|
755
|
+
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
756
|
+
return self.enable_flashinfer_moe
|
757
|
+
|
735
758
|
def apply(
|
736
759
|
self,
|
737
760
|
layer: torch.nn.Module,
|
@@ -750,11 +773,13 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
750
773
|
inplace: bool = True,
|
751
774
|
no_combine: bool = False,
|
752
775
|
routed_scaling_factor: Optional[float] = None,
|
776
|
+
ep_rank: Optional[int] = None,
|
777
|
+
ep_size: Optional[int] = None,
|
778
|
+
tp_rank: Optional[int] = None,
|
779
|
+
tp_size: Optional[int] = None,
|
753
780
|
) -> torch.Tensor:
|
754
781
|
|
755
782
|
assert activation == "silu", "Only SiLU activation is supported."
|
756
|
-
|
757
|
-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
758
783
|
from sglang.srt.layers.moe.topk import select_experts
|
759
784
|
|
760
785
|
topk_weights, topk_ids = select_experts(
|
@@ -771,6 +796,35 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
771
796
|
routed_scaling_factor=routed_scaling_factor,
|
772
797
|
)
|
773
798
|
|
799
|
+
if self.enable_flashinfer_moe:
|
800
|
+
assert (
|
801
|
+
not apply_router_weight_on_input
|
802
|
+
), "apply_router_weight_on_input is not supported for Flashinfer"
|
803
|
+
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
|
804
|
+
# and fp4 quantized weights loaded from the checkpoint
|
805
|
+
output = flashinfer_cutlass_fused_moe(
|
806
|
+
x,
|
807
|
+
topk_ids.to(torch.int),
|
808
|
+
topk_weights,
|
809
|
+
layer.w13_weight.view(torch.long),
|
810
|
+
layer.w2_weight.view(torch.long),
|
811
|
+
x.dtype,
|
812
|
+
quant_scales=[
|
813
|
+
layer.w13_input_scale_quant,
|
814
|
+
layer.w13_blockscale_swizzled.view(torch.int32),
|
815
|
+
layer.g1_alphas,
|
816
|
+
layer.w2_input_scale_quant,
|
817
|
+
layer.w2_blockscale_swizzled.view(torch.int32),
|
818
|
+
layer.g2_alphas,
|
819
|
+
],
|
820
|
+
ep_size=ep_size,
|
821
|
+
ep_rank=ep_rank,
|
822
|
+
tp_size=tp_size,
|
823
|
+
tp_rank=tp_rank,
|
824
|
+
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
825
|
+
)
|
826
|
+
return output[0]
|
827
|
+
|
774
828
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
775
829
|
|
776
830
|
return cutlass_moe_fp4(
|
@@ -6,11 +6,14 @@ from typing import List, Mapping, Tuple, Union
|
|
6
6
|
import torch
|
7
7
|
|
8
8
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
9
|
-
from sglang.srt.utils import is_cuda
|
9
|
+
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
|
10
10
|
|
11
11
|
_is_cuda = is_cuda()
|
12
|
+
_is_npu = is_npu()
|
13
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
14
|
+
_is_cpu = is_cpu()
|
12
15
|
|
13
|
-
if not _is_cuda:
|
16
|
+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
14
17
|
from vllm._custom_ops import scaled_fp8_quant
|
15
18
|
|
16
19
|
|
@@ -8,10 +8,13 @@ import torch
|
|
8
8
|
import torch.nn as nn
|
9
9
|
|
10
10
|
from sglang.srt.custom_op import CustomOp
|
11
|
-
from sglang.srt.utils import is_cuda, is_hip
|
11
|
+
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
|
12
12
|
|
13
13
|
_is_cuda = is_cuda()
|
14
14
|
_is_hip = is_hip()
|
15
|
+
_is_npu = is_npu()
|
16
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
17
|
+
_is_cpu = is_cpu()
|
15
18
|
|
16
19
|
if _is_cuda:
|
17
20
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
@@ -84,7 +87,9 @@ class RotaryEmbedding(CustomOp):
|
|
84
87
|
if not _is_cuda:
|
85
88
|
cache = cache.to(dtype)
|
86
89
|
|
87
|
-
if
|
90
|
+
if (
|
91
|
+
not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]
|
92
|
+
) and not (_is_cpu and _is_cpu_amx_available):
|
88
93
|
from vllm._custom_ops import rotary_embedding
|
89
94
|
|
90
95
|
self.vllm_rotary_embedding = rotary_embedding
|
@@ -147,6 +152,26 @@ class RotaryEmbedding(CustomOp):
|
|
147
152
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
148
153
|
return query, key
|
149
154
|
|
155
|
+
def forward_cpu(
|
156
|
+
self,
|
157
|
+
positions: torch.Tensor,
|
158
|
+
query: torch.Tensor,
|
159
|
+
key: torch.Tensor,
|
160
|
+
offsets: Optional[torch.Tensor] = None,
|
161
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
162
|
+
positions = torch.add(positions, offsets) if offsets is not None else positions
|
163
|
+
if _is_cpu_amx_available:
|
164
|
+
return torch.ops.sgl_kernel.rotary_embedding_cpu(
|
165
|
+
positions,
|
166
|
+
query,
|
167
|
+
key,
|
168
|
+
self.head_size,
|
169
|
+
self.cos_sin_cache,
|
170
|
+
self.is_neox_style,
|
171
|
+
)
|
172
|
+
else:
|
173
|
+
return self.forward_native(positions, query, key, offsets)
|
174
|
+
|
150
175
|
def forward_cuda(
|
151
176
|
self,
|
152
177
|
positions: torch.Tensor,
|
@@ -696,6 +721,21 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
696
721
|
key = key_rot
|
697
722
|
return query.to(dtype), key.to(dtype)
|
698
723
|
|
724
|
+
def forward_cpu(
|
725
|
+
self,
|
726
|
+
positions: torch.Tensor,
|
727
|
+
query: torch.Tensor,
|
728
|
+
key: torch.Tensor,
|
729
|
+
offsets: Optional[torch.Tensor] = None,
|
730
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
731
|
+
positions = torch.add(positions, offsets) if offsets is not None else positions
|
732
|
+
if _is_cpu_amx_available:
|
733
|
+
return torch.ops.sgl_kernel.rotary_embedding_cpu(
|
734
|
+
positions, query, key, self.head_size, self.cos_sin_cache, False
|
735
|
+
)
|
736
|
+
else:
|
737
|
+
return self.forward_native(positions, query, key, offsets)
|
738
|
+
|
699
739
|
|
700
740
|
class Llama3RotaryEmbedding(RotaryEmbedding):
|
701
741
|
|