sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +14 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +301 -64
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +20 -15
- sglang/srt/disaggregation/utils.py +47 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +27 -31
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +897 -0
- sglang/srt/entrypoints/openai/serving_completions.py +425 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +28 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +43 -23
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +44 -2
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +14 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +286 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
- sglang/srt/layers/moe/topk.py +117 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +144 -12
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +19 -14
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +49 -32
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +189 -68
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +77 -46
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +27 -8
- sglang/srt/model_loader/loader.py +50 -8
- sglang/srt/model_loader/weight_utils.py +100 -2
- sglang/srt/models/deepseek_nextn.py +35 -30
- sglang/srt/models/deepseek_v2.py +255 -30
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +51 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +248 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -64,9 +64,12 @@ from sglang.srt.layers.quantization.utils import (
|
|
64
64
|
)
|
65
65
|
from sglang.srt.layers.utils import is_sm100_supported
|
66
66
|
from sglang.srt.utils import (
|
67
|
+
cpu_has_amx_support,
|
67
68
|
get_bool_env_var,
|
69
|
+
is_cpu,
|
68
70
|
is_cuda,
|
69
71
|
is_hip,
|
72
|
+
is_npu,
|
70
73
|
log_info_on_rank0,
|
71
74
|
print_warning_once,
|
72
75
|
set_weight_attrs,
|
@@ -74,6 +77,9 @@ from sglang.srt.utils import (
|
|
74
77
|
|
75
78
|
_is_hip = is_hip()
|
76
79
|
_is_cuda = is_cuda()
|
80
|
+
_is_npu = is_npu()
|
81
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
82
|
+
_is_cpu = is_cpu()
|
77
83
|
|
78
84
|
_is_fp8_fnuz = is_fp8_fnuz()
|
79
85
|
|
@@ -82,10 +88,11 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
82
88
|
|
83
89
|
if _is_hip:
|
84
90
|
from aiter import ActivationType, QuantType
|
91
|
+
from aiter.fused_moe import fused_moe
|
85
92
|
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
86
93
|
from aiter.ops.shuffle import shuffle_weight
|
87
94
|
|
88
|
-
if not _is_cuda:
|
95
|
+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
89
96
|
from vllm._custom_ops import scaled_fp8_quant
|
90
97
|
|
91
98
|
|
@@ -1045,15 +1052,15 @@ class Fp8MoEMethod:
|
|
1045
1052
|
if _use_hip_int4:
|
1046
1053
|
# TODO: add triton kernel and add check _use_aiter
|
1047
1054
|
assert not no_combine, f"{no_combine=} is not supported."
|
1048
|
-
return
|
1055
|
+
return fused_moe(
|
1049
1056
|
x,
|
1050
1057
|
layer.w13_weight,
|
1051
1058
|
layer.w2_weight,
|
1052
1059
|
topk_weights,
|
1053
1060
|
topk_ids,
|
1054
|
-
QuantType.per_Token,
|
1055
|
-
layer.w13_weight_scale1,
|
1056
|
-
layer.w2_weight_scale1,
|
1061
|
+
quant_type=QuantType.per_Token,
|
1062
|
+
w1_scale=layer.w13_weight_scale1,
|
1063
|
+
w2_scale=layer.w2_weight_scale1,
|
1057
1064
|
activation=(
|
1058
1065
|
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
1059
1066
|
),
|
@@ -1062,31 +1069,32 @@ class Fp8MoEMethod:
|
|
1062
1069
|
if _use_aiter:
|
1063
1070
|
assert not no_combine, f"{no_combine=} is not supported."
|
1064
1071
|
if self.block_quant:
|
1065
|
-
|
1066
|
-
assert (
|
1067
|
-
activation == "silu"
|
1068
|
-
), f"_use_aiter: FP8 bloack_quant {activation=} will be supported later, unset _use_aiter"
|
1069
|
-
return asm_moe(
|
1072
|
+
return fused_moe(
|
1070
1073
|
x,
|
1071
1074
|
layer.w13_weight,
|
1072
1075
|
layer.w2_weight,
|
1073
1076
|
topk_weights,
|
1074
1077
|
topk_ids,
|
1075
|
-
layer.w13_weight_scale_inv,
|
1076
|
-
layer.w2_weight_scale_inv,
|
1077
|
-
|
1078
|
+
w1_scale=layer.w13_weight_scale_inv,
|
1079
|
+
w2_scale=layer.w2_weight_scale_inv,
|
1080
|
+
quant_type=QuantType.per_128x128,
|
1081
|
+
activation=(
|
1082
|
+
ActivationType.Silu
|
1083
|
+
if activation == "silu"
|
1084
|
+
else ActivationType.Gelu
|
1085
|
+
),
|
1078
1086
|
expert_mask=None,
|
1079
1087
|
)
|
1080
1088
|
else:
|
1081
|
-
return
|
1089
|
+
return fused_moe(
|
1082
1090
|
x,
|
1083
1091
|
layer.w13_weight,
|
1084
1092
|
layer.w2_weight,
|
1085
1093
|
topk_weights,
|
1086
1094
|
topk_ids,
|
1087
|
-
QuantType.per_Token,
|
1088
|
-
layer.w13_weight_scale1,
|
1089
|
-
layer.w2_weight_scale1,
|
1095
|
+
quant_type=QuantType.per_Token,
|
1096
|
+
w1_scale=layer.w13_weight_scale1,
|
1097
|
+
w2_scale=layer.w2_weight_scale1,
|
1090
1098
|
activation=(
|
1091
1099
|
ActivationType.Silu
|
1092
1100
|
if activation == "silu"
|
@@ -42,7 +42,10 @@ _is_fp8_fnuz = is_fp8_fnuz()
|
|
42
42
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
43
43
|
|
44
44
|
if _use_aiter:
|
45
|
-
|
45
|
+
import aiter
|
46
|
+
from aiter import gemm_a8w8_blockscale_CK, get_hip_quant
|
47
|
+
|
48
|
+
aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)
|
46
49
|
|
47
50
|
if _is_cuda:
|
48
51
|
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
|
@@ -271,9 +274,7 @@ def aiter_w8a8_block_fp8_linear(
|
|
271
274
|
input_2d = input.view(-1, input.shape[-1])
|
272
275
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
273
276
|
|
274
|
-
q_input, x_scale =
|
275
|
-
input_2d, block_size[1], column_major_scales=False
|
276
|
-
)
|
277
|
+
q_input, x_scale = aiter_per1x128_quant(input_2d, quant_dtype=aiter.dtypes.fp8)
|
277
278
|
output = gemm_a8w8_blockscale_CK(
|
278
279
|
q_input, weight, x_scale, weight_scale, dtype=input.dtype
|
279
280
|
)
|
@@ -29,11 +29,17 @@ from sglang.srt.layers.quantization.utils import (
|
|
29
29
|
requantize_with_max_scale,
|
30
30
|
)
|
31
31
|
from sglang.srt.layers.radix_attention import RadixAttention
|
32
|
-
from sglang.srt.utils import is_cuda
|
32
|
+
from sglang.srt.utils import is_cuda, next_power_of_2
|
33
33
|
|
34
34
|
if is_cuda():
|
35
35
|
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
36
36
|
|
37
|
+
try:
|
38
|
+
from flashinfer import fp4_quantize as fp4_quantize
|
39
|
+
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
|
40
|
+
except ImportError:
|
41
|
+
flashinfer_cutlass_fused_moe = None
|
42
|
+
|
37
43
|
# Initialize logger for the module
|
38
44
|
logger = logging.getLogger(__name__)
|
39
45
|
|
@@ -429,6 +435,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|
429
435
|
layer.alpha = Parameter(
|
430
436
|
layer.input_scale * layer.weight_scale_2, requires_grad=False
|
431
437
|
)
|
438
|
+
layer.input_scale_inv = Parameter(
|
439
|
+
(1 / input_scale_2).to(torch.float32), requires_grad=False
|
440
|
+
)
|
432
441
|
|
433
442
|
# Pad and blockwise interleave weight_scale
|
434
443
|
scales = layer.weight_scale
|
@@ -467,7 +476,7 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|
467
476
|
output_shape = [x_m, w_n]
|
468
477
|
|
469
478
|
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
470
|
-
x_fp4, x_scale_interleaved = scaled_fp4_quant(x,
|
479
|
+
x_fp4, x_scale_interleaved = scaled_fp4_quant(x, layer.input_scale_inv)
|
471
480
|
|
472
481
|
assert x_fp4.dtype == torch.uint8
|
473
482
|
assert x_scale_interleaved.dtype == torch.float8_e4m3fn
|
@@ -521,6 +530,7 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
521
530
|
" quantization. Please use Blackwell and"
|
522
531
|
" above."
|
523
532
|
)
|
533
|
+
self.enable_flashinfer_moe = False
|
524
534
|
|
525
535
|
def create_weights(
|
526
536
|
self,
|
@@ -674,7 +684,10 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
674
684
|
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
675
685
|
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
|
676
686
|
|
677
|
-
|
687
|
+
if self.enable_flashinfer_moe:
|
688
|
+
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
689
|
+
else:
|
690
|
+
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
678
691
|
layer.g1_alphas = Parameter(
|
679
692
|
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
680
693
|
requires_grad=False,
|
@@ -700,14 +713,19 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
700
713
|
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
701
714
|
|
702
715
|
# GEMM 2
|
716
|
+
if self.enable_flashinfer_moe:
|
717
|
+
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
718
|
+
else:
|
719
|
+
w2_input_scale = layer.w2_input_scale
|
720
|
+
|
703
721
|
layer.g2_alphas = Parameter(
|
704
|
-
(
|
722
|
+
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
705
723
|
requires_grad=False,
|
706
724
|
)
|
707
725
|
|
708
726
|
# This is for quantization, so we need to invert it.
|
709
727
|
layer.w2_input_scale_quant = Parameter(
|
710
|
-
(1 /
|
728
|
+
(1 / w2_input_scale).to(torch.float32), requires_grad=False
|
711
729
|
)
|
712
730
|
|
713
731
|
assert (
|
@@ -727,11 +745,16 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
727
745
|
layer.cutlass_moe_params = CutlassMoEParams(
|
728
746
|
CutlassMoEType.BlockscaledFP4,
|
729
747
|
device,
|
730
|
-
num_experts=layer.num_experts,
|
748
|
+
num_experts=layer.num_experts, # global num experts
|
731
749
|
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
|
732
750
|
hidden_size=layer.w13_weight.shape[2] * 2,
|
733
751
|
) # k
|
734
752
|
|
753
|
+
@property
|
754
|
+
def load_up_proj_weight_first(self) -> bool:
|
755
|
+
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
756
|
+
return self.enable_flashinfer_moe
|
757
|
+
|
735
758
|
def apply(
|
736
759
|
self,
|
737
760
|
layer: torch.nn.Module,
|
@@ -750,11 +773,13 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
750
773
|
inplace: bool = True,
|
751
774
|
no_combine: bool = False,
|
752
775
|
routed_scaling_factor: Optional[float] = None,
|
776
|
+
ep_rank: Optional[int] = None,
|
777
|
+
ep_size: Optional[int] = None,
|
778
|
+
tp_rank: Optional[int] = None,
|
779
|
+
tp_size: Optional[int] = None,
|
753
780
|
) -> torch.Tensor:
|
754
781
|
|
755
782
|
assert activation == "silu", "Only SiLU activation is supported."
|
756
|
-
|
757
|
-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
758
783
|
from sglang.srt.layers.moe.topk import select_experts
|
759
784
|
|
760
785
|
topk_weights, topk_ids = select_experts(
|
@@ -771,6 +796,35 @@ class ModelOptNvFp4FusedMoEMethod:
|
|
771
796
|
routed_scaling_factor=routed_scaling_factor,
|
772
797
|
)
|
773
798
|
|
799
|
+
if self.enable_flashinfer_moe:
|
800
|
+
assert (
|
801
|
+
not apply_router_weight_on_input
|
802
|
+
), "apply_router_weight_on_input is not supported for Flashinfer"
|
803
|
+
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
|
804
|
+
# and fp4 quantized weights loaded from the checkpoint
|
805
|
+
output = flashinfer_cutlass_fused_moe(
|
806
|
+
x,
|
807
|
+
topk_ids.to(torch.int),
|
808
|
+
topk_weights,
|
809
|
+
layer.w13_weight.view(torch.long),
|
810
|
+
layer.w2_weight.view(torch.long),
|
811
|
+
x.dtype,
|
812
|
+
quant_scales=[
|
813
|
+
layer.w13_input_scale_quant,
|
814
|
+
layer.w13_blockscale_swizzled.view(torch.int32),
|
815
|
+
layer.g1_alphas,
|
816
|
+
layer.w2_input_scale_quant,
|
817
|
+
layer.w2_blockscale_swizzled.view(torch.int32),
|
818
|
+
layer.g2_alphas,
|
819
|
+
],
|
820
|
+
ep_size=ep_size,
|
821
|
+
ep_rank=ep_rank,
|
822
|
+
tp_size=tp_size,
|
823
|
+
tp_rank=tp_rank,
|
824
|
+
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
825
|
+
)
|
826
|
+
return output[0]
|
827
|
+
|
774
828
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
775
829
|
|
776
830
|
return cutlass_moe_fp4(
|
@@ -6,11 +6,14 @@ from typing import List, Mapping, Tuple, Union
|
|
6
6
|
import torch
|
7
7
|
|
8
8
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
9
|
-
from sglang.srt.utils import is_cuda
|
9
|
+
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
|
10
10
|
|
11
11
|
_is_cuda = is_cuda()
|
12
|
+
_is_npu = is_npu()
|
13
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
14
|
+
_is_cpu = is_cpu()
|
12
15
|
|
13
|
-
if not _is_cuda:
|
16
|
+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
14
17
|
from vllm._custom_ops import scaled_fp8_quant
|
15
18
|
|
16
19
|
|
@@ -8,13 +8,29 @@ import torch
|
|
8
8
|
import torch.nn as nn
|
9
9
|
|
10
10
|
from sglang.srt.custom_op import CustomOp
|
11
|
-
from sglang.srt.utils import
|
11
|
+
from sglang.srt.utils import (
|
12
|
+
cpu_has_amx_support,
|
13
|
+
get_bool_env_var,
|
14
|
+
is_cpu,
|
15
|
+
is_cuda,
|
16
|
+
is_hip,
|
17
|
+
is_npu,
|
18
|
+
)
|
12
19
|
|
13
20
|
_is_cuda = is_cuda()
|
14
21
|
_is_hip = is_hip()
|
22
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
23
|
+
_is_npu = is_npu()
|
24
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
25
|
+
_is_cpu = is_cpu()
|
15
26
|
|
16
27
|
if _is_cuda:
|
17
28
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
29
|
+
if _use_aiter:
|
30
|
+
from aiter.rotary_embedding import get_rope as aiter_get_rope
|
31
|
+
|
32
|
+
if is_npu():
|
33
|
+
import torch_npu
|
18
34
|
|
19
35
|
|
20
36
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
@@ -84,7 +100,9 @@ class RotaryEmbedding(CustomOp):
|
|
84
100
|
if not _is_cuda:
|
85
101
|
cache = cache.to(dtype)
|
86
102
|
|
87
|
-
if
|
103
|
+
if (
|
104
|
+
not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]
|
105
|
+
) and not (_is_cpu and _is_cpu_amx_available):
|
88
106
|
from vllm._custom_ops import rotary_embedding
|
89
107
|
|
90
108
|
self.vllm_rotary_embedding = rotary_embedding
|
@@ -147,6 +165,56 @@ class RotaryEmbedding(CustomOp):
|
|
147
165
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
148
166
|
return query, key
|
149
167
|
|
168
|
+
def forward_npu(
|
169
|
+
self,
|
170
|
+
positions: torch.Tensor,
|
171
|
+
query: torch.Tensor,
|
172
|
+
key: torch.Tensor,
|
173
|
+
offsets: Optional[torch.Tensor] = None,
|
174
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
175
|
+
"""A PyTorch-npu implementation of forward()."""
|
176
|
+
import os
|
177
|
+
|
178
|
+
if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
|
179
|
+
return self.forward_native(positions, query, key, offsets)
|
180
|
+
else:
|
181
|
+
rotary_mode = "half"
|
182
|
+
if self.is_neox_style:
|
183
|
+
rotary_mode = "half"
|
184
|
+
else:
|
185
|
+
rotary_mode = "interleave"
|
186
|
+
mrope_section = [0, 0, 0]
|
187
|
+
query_out, key_out = torch_npu.npu_mrope(
|
188
|
+
positions,
|
189
|
+
query,
|
190
|
+
key,
|
191
|
+
self.cos_sin_cache,
|
192
|
+
self.head_size,
|
193
|
+
mrope_section=mrope_section,
|
194
|
+
rotary_mode=rotary_mode,
|
195
|
+
)
|
196
|
+
return query_out, key_out
|
197
|
+
|
198
|
+
def forward_cpu(
|
199
|
+
self,
|
200
|
+
positions: torch.Tensor,
|
201
|
+
query: torch.Tensor,
|
202
|
+
key: torch.Tensor,
|
203
|
+
offsets: Optional[torch.Tensor] = None,
|
204
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
205
|
+
positions = torch.add(positions, offsets) if offsets is not None else positions
|
206
|
+
if _is_cpu_amx_available:
|
207
|
+
return torch.ops.sgl_kernel.rotary_embedding_cpu(
|
208
|
+
positions,
|
209
|
+
query,
|
210
|
+
key,
|
211
|
+
self.head_size,
|
212
|
+
self.cos_sin_cache,
|
213
|
+
self.is_neox_style,
|
214
|
+
)
|
215
|
+
else:
|
216
|
+
return self.forward_native(positions, query, key, offsets)
|
217
|
+
|
150
218
|
def forward_cuda(
|
151
219
|
self,
|
152
220
|
positions: torch.Tensor,
|
@@ -696,6 +764,21 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
696
764
|
key = key_rot
|
697
765
|
return query.to(dtype), key.to(dtype)
|
698
766
|
|
767
|
+
def forward_cpu(
|
768
|
+
self,
|
769
|
+
positions: torch.Tensor,
|
770
|
+
query: torch.Tensor,
|
771
|
+
key: torch.Tensor,
|
772
|
+
offsets: Optional[torch.Tensor] = None,
|
773
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
774
|
+
positions = torch.add(positions, offsets) if offsets is not None else positions
|
775
|
+
if _is_cpu_amx_available:
|
776
|
+
return torch.ops.sgl_kernel.rotary_embedding_cpu(
|
777
|
+
positions, query, key, self.head_size, self.cos_sin_cache, False
|
778
|
+
)
|
779
|
+
else:
|
780
|
+
return self.forward_native(positions, query, key, offsets)
|
781
|
+
|
699
782
|
|
700
783
|
class Llama3RotaryEmbedding(RotaryEmbedding):
|
701
784
|
|
@@ -807,6 +890,43 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
|
807
890
|
return query_out.type_as(query), key_out.type_as(key)
|
808
891
|
|
809
892
|
|
893
|
+
class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
|
894
|
+
"""RotaryEmbedding extended with Dynamic NTK scaling.
|
895
|
+
|
896
|
+
Credits to the Reddit users /u/bloc97 and /u/emozilla
|
897
|
+
"""
|
898
|
+
|
899
|
+
def __init__(
|
900
|
+
self,
|
901
|
+
head_size: int,
|
902
|
+
rotary_dim: int,
|
903
|
+
max_position_embeddings: int,
|
904
|
+
base: int,
|
905
|
+
is_neox_style: bool,
|
906
|
+
scaling_alpha: float,
|
907
|
+
dtype: torch.dtype,
|
908
|
+
) -> None:
|
909
|
+
self.scaling_alpha = scaling_alpha
|
910
|
+
super().__init__(
|
911
|
+
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
912
|
+
)
|
913
|
+
|
914
|
+
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
915
|
+
max_len = self.max_position_embeddings
|
916
|
+
base = self.base * self.scaling_alpha ** (
|
917
|
+
self.rotary_dim / (self.rotary_dim - 2)
|
918
|
+
)
|
919
|
+
|
920
|
+
inv_freq = self._compute_inv_freq(base)
|
921
|
+
t = torch.arange(max_len, dtype=torch.float)
|
922
|
+
|
923
|
+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
924
|
+
cos = freqs.cos()
|
925
|
+
sin = freqs.sin()
|
926
|
+
cache = torch.cat((cos, sin), dim=-1)
|
927
|
+
return cache
|
928
|
+
|
929
|
+
|
810
930
|
class MRotaryEmbedding(RotaryEmbedding):
|
811
931
|
"""Rotary Embedding with Multimodal Sections."""
|
812
932
|
|
@@ -1151,15 +1271,26 @@ def get_rope(
|
|
1151
1271
|
)
|
1152
1272
|
elif scaling_type == "dynamic":
|
1153
1273
|
scaling_factor = rope_scaling["factor"]
|
1154
|
-
|
1155
|
-
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1159
|
-
|
1160
|
-
|
1161
|
-
|
1162
|
-
|
1274
|
+
if "alpha" in rope_scaling:
|
1275
|
+
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
|
1276
|
+
head_size,
|
1277
|
+
rotary_dim,
|
1278
|
+
max_position,
|
1279
|
+
base,
|
1280
|
+
is_neox_style,
|
1281
|
+
rope_scaling["alpha"],
|
1282
|
+
dtype,
|
1283
|
+
)
|
1284
|
+
else:
|
1285
|
+
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
1286
|
+
head_size,
|
1287
|
+
rotary_dim,
|
1288
|
+
max_position,
|
1289
|
+
base,
|
1290
|
+
is_neox_style,
|
1291
|
+
scaling_factor,
|
1292
|
+
dtype,
|
1293
|
+
)
|
1163
1294
|
elif scaling_type == "yarn":
|
1164
1295
|
scaling_factor = rope_scaling["factor"]
|
1165
1296
|
original_max_position = rope_scaling["original_max_position_embeddings"]
|
@@ -1348,7 +1479,8 @@ def get_rope_wrapper(
|
|
1348
1479
|
device: Optional[str] = None,
|
1349
1480
|
):
|
1350
1481
|
if device != "cpu":
|
1351
|
-
|
1482
|
+
wrapper = aiter_get_rope if _use_aiter else get_rope
|
1483
|
+
return wrapper(
|
1352
1484
|
head_size,
|
1353
1485
|
rotary_dim,
|
1354
1486
|
max_position,
|
sglang/srt/layers/sampler.py
CHANGED
@@ -20,10 +20,18 @@ from sglang.srt.layers.quantization.base_config import (
|
|
20
20
|
QuantizeMethodBase,
|
21
21
|
method_has_implemented_embedding,
|
22
22
|
)
|
23
|
-
from sglang.srt.utils import
|
23
|
+
from sglang.srt.utils import (
|
24
|
+
PackWeightMethod,
|
25
|
+
cpu_has_amx_support,
|
26
|
+
is_cpu,
|
27
|
+
set_weight_attrs,
|
28
|
+
)
|
24
29
|
|
25
30
|
DEFAULT_VOCAB_PADDING_SIZE = 64
|
26
31
|
|
32
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
33
|
+
_is_cpu = is_cpu()
|
34
|
+
|
27
35
|
|
28
36
|
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
29
37
|
"""Unquantized method for embeddings."""
|
@@ -549,6 +557,11 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
549
557
|
use_presharded_weights=use_presharded_weights,
|
550
558
|
)
|
551
559
|
self.quant_config = quant_config
|
560
|
+
|
561
|
+
# We only support pack LMHead if it's not quantized. For LMHead with quant_config, the weight_name will be "qweight"
|
562
|
+
if self.quant_config is None and _is_cpu and _is_cpu_amx_available:
|
563
|
+
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
564
|
+
|
552
565
|
if bias:
|
553
566
|
self.bias = Parameter(
|
554
567
|
torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)
|