sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +0 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +62 -6
- sglang/srt/disaggregation/mini_lb.py +5 -1
- sglang/srt/disaggregation/mooncake/conn.py +32 -62
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/prefill.py +40 -4
- sglang/srt/disaggregation/utils.py +15 -0
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +114 -71
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -57
- sglang/srt/layers/quantization/fp8_utils.py +187 -262
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +3 -2
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +1 -0
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +2 -4
- sglang/srt/managers/scheduler.py +12 -71
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +7 -2
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/model_runner.py +20 -27
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +289 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +29 -201
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +34 -32
- sglang/srt/speculative/eagle_worker.py +4 -7
- sglang/srt/utils.py +16 -1
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -8,15 +8,6 @@ import torch.nn.functional as F
|
|
8
8
|
from torch.nn import Module
|
9
9
|
from torch.nn.parameter import Parameter
|
10
10
|
|
11
|
-
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
12
|
-
from sglang.srt.layers.quantization.utils import (
|
13
|
-
all_close_1d,
|
14
|
-
convert_to_channelwise,
|
15
|
-
is_layer_skipped,
|
16
|
-
per_tensor_dequantize,
|
17
|
-
requantize_with_max_scale,
|
18
|
-
)
|
19
|
-
|
20
11
|
try:
|
21
12
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
22
13
|
apply_fp8_marlin_linear,
|
@@ -27,11 +18,12 @@ try:
|
|
27
18
|
except ImportError:
|
28
19
|
MARLIN_FP8_AVAILABLE = False
|
29
20
|
|
30
|
-
def
|
31
|
-
raise ImportError(
|
21
|
+
def dummy_func(*args, **kwargs):
|
22
|
+
raise ImportError(
|
23
|
+
"marlin FP8 requires some operators from vllm. Please install vllm."
|
24
|
+
)
|
32
25
|
|
33
|
-
|
34
|
-
raise ImportError("vllm is not installed")
|
26
|
+
apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func
|
35
27
|
|
36
28
|
|
37
29
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
@@ -49,7 +41,10 @@ from sglang.srt.layers.quantization.base_config import (
|
|
49
41
|
QuantizationConfig,
|
50
42
|
QuantizeMethodBase,
|
51
43
|
)
|
52
|
-
from sglang.srt.layers.quantization.fp8_kernel import
|
44
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
45
|
+
per_token_group_quant_fp8,
|
46
|
+
scaled_fp8_quant,
|
47
|
+
)
|
53
48
|
from sglang.srt.layers.quantization.fp8_utils import (
|
54
49
|
apply_fp8_linear,
|
55
50
|
apply_w8a8_block_fp8_linear,
|
@@ -57,30 +52,35 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
57
52
|
input_to_float8,
|
58
53
|
normalize_e4m3fn_to_e4m3fnuz,
|
59
54
|
)
|
55
|
+
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
56
|
+
from sglang.srt.layers.quantization.utils import (
|
57
|
+
all_close_1d,
|
58
|
+
convert_to_channelwise,
|
59
|
+
is_layer_skipped,
|
60
|
+
per_tensor_dequantize,
|
61
|
+
requantize_with_max_scale,
|
62
|
+
)
|
60
63
|
from sglang.srt.utils import (
|
61
64
|
get_bool_env_var,
|
62
65
|
is_cuda,
|
63
66
|
is_hip,
|
64
|
-
permute_weight,
|
65
67
|
print_warning_once,
|
66
68
|
set_weight_attrs,
|
67
69
|
)
|
68
70
|
|
69
|
-
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
70
|
-
|
71
71
|
_is_hip = is_hip()
|
72
|
+
_is_cuda = is_cuda()
|
72
73
|
|
73
74
|
if _is_hip:
|
74
75
|
from aiter import ActivationType
|
75
76
|
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
|
76
77
|
from aiter.ops.shuffle import shuffle_weight
|
77
78
|
|
78
|
-
|
79
|
+
if not _is_cuda:
|
80
|
+
from vllm._custom_ops import scaled_fp8_quant
|
79
81
|
|
80
|
-
|
81
|
-
|
82
|
-
else:
|
83
|
-
from vllm import _custom_ops as vllm_ops
|
82
|
+
|
83
|
+
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
84
84
|
|
85
85
|
logger = logging.getLogger(__name__)
|
86
86
|
|
@@ -243,7 +243,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
243
243
|
)
|
244
244
|
|
245
245
|
layer.logical_widths = output_partition_sizes
|
246
|
-
|
247
246
|
layer.input_size_per_partition = input_size_per_partition
|
248
247
|
layer.output_size_per_partition = output_size_per_partition
|
249
248
|
layer.orig_dtype = params_dtype
|
@@ -327,7 +326,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
327
326
|
layer.weight_scale_inv.data, requires_grad=False
|
328
327
|
)
|
329
328
|
return
|
329
|
+
|
330
330
|
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
331
|
+
|
331
332
|
# If checkpoint not serialized fp8, quantize the weights.
|
332
333
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
333
334
|
if self.cutlass_fp8_supported or self.use_marlin:
|
@@ -391,12 +392,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
391
392
|
)
|
392
393
|
|
393
394
|
if self.use_marlin:
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
del layer.input_scale
|
398
|
-
except ImportError:
|
399
|
-
self.use_marlin = False
|
395
|
+
prepare_fp8_layer_for_marlin(layer)
|
396
|
+
# Activations not quantized for marlin.
|
397
|
+
del layer.input_scale
|
400
398
|
|
401
399
|
def apply(
|
402
400
|
self,
|
@@ -406,18 +404,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
406
404
|
) -> torch.Tensor:
|
407
405
|
|
408
406
|
if self.use_marlin:
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
)
|
419
|
-
except ImportError:
|
420
|
-
self.use_marlin = False
|
407
|
+
return apply_fp8_marlin_linear(
|
408
|
+
input=x,
|
409
|
+
weight=layer.weight,
|
410
|
+
weight_scale=layer.weight_scale,
|
411
|
+
workspace=layer.workspace,
|
412
|
+
size_n=layer.output_size_per_partition,
|
413
|
+
size_k=layer.input_size_per_partition,
|
414
|
+
bias=bias,
|
415
|
+
)
|
421
416
|
|
422
417
|
if self.block_quant:
|
423
418
|
return apply_w8a8_block_fp8_linear(
|
@@ -516,7 +511,7 @@ class Fp8MoEMethod:
|
|
516
511
|
)
|
517
512
|
|
518
513
|
# WEIGHTS
|
519
|
-
if get_bool_env_var("USE_INT4_WEIGHT"):
|
514
|
+
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
520
515
|
# INT4 MoE weight - INT32 packed
|
521
516
|
w13_weight = torch.nn.Parameter(
|
522
517
|
torch.empty(
|
@@ -617,7 +612,7 @@ class Fp8MoEMethod:
|
|
617
612
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
618
613
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
619
614
|
|
620
|
-
if get_bool_env_var("USE_INT4_WEIGHT"):
|
615
|
+
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
621
616
|
extra_weight_attrs.update(
|
622
617
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
623
618
|
)
|
@@ -649,7 +644,7 @@ class Fp8MoEMethod:
|
|
649
644
|
layer.w2_input_scale = None
|
650
645
|
|
651
646
|
def process_weights_after_loading(self, layer: Module) -> None:
|
652
|
-
if get_bool_env_var("USE_INT4_WEIGHT"):
|
647
|
+
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
653
648
|
self.process_weights_hip_int4(layer)
|
654
649
|
return
|
655
650
|
|
@@ -706,20 +701,12 @@ class Fp8MoEMethod:
|
|
706
701
|
requires_grad=False,
|
707
702
|
)
|
708
703
|
for expert in range(layer.num_experts):
|
709
|
-
|
710
|
-
w13_weight[expert, :, :]
|
711
|
-
|
712
|
-
|
713
|
-
w2_weight[expert, :, :]
|
714
|
-
|
715
|
-
)
|
716
|
-
else:
|
717
|
-
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
718
|
-
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
719
|
-
)
|
720
|
-
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
721
|
-
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
722
|
-
)
|
704
|
+
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
705
|
+
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
706
|
+
)
|
707
|
+
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
708
|
+
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
709
|
+
)
|
723
710
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
724
711
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
725
712
|
|
@@ -796,18 +783,10 @@ class Fp8MoEMethod:
|
|
796
783
|
layer.w13_weight[expert_id][start : start + shard_size, :],
|
797
784
|
layer.w13_weight_scale[expert_id][shard_id],
|
798
785
|
)
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
804
|
-
else:
|
805
|
-
(
|
806
|
-
layer.w13_weight[expert_id][start : start + shard_size, :],
|
807
|
-
_,
|
808
|
-
) = vllm_ops.scaled_fp8_quant(
|
809
|
-
dq_weight, max_w13_scales[expert_id]
|
810
|
-
)
|
786
|
+
(
|
787
|
+
layer.w13_weight[expert_id][start : start + shard_size, :],
|
788
|
+
_,
|
789
|
+
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
811
790
|
start += shard_size
|
812
791
|
|
813
792
|
layer.w13_weight_scale = torch.nn.Parameter(
|
@@ -913,6 +892,7 @@ class Fp8MoEMethod:
|
|
913
892
|
apply_router_weight_on_input: bool = False,
|
914
893
|
inplace: bool = True,
|
915
894
|
no_combine: bool = False,
|
895
|
+
routed_scaling_factor: Optional[float] = None,
|
916
896
|
) -> torch.Tensor:
|
917
897
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
918
898
|
from sglang.srt.layers.moe.topk import select_experts
|
@@ -928,43 +908,14 @@ class Fp8MoEMethod:
|
|
928
908
|
num_expert_group=num_expert_group,
|
929
909
|
custom_routing_function=custom_routing_function,
|
930
910
|
correction_bias=correction_bias,
|
911
|
+
routed_scaling_factor=routed_scaling_factor,
|
931
912
|
)
|
932
913
|
|
933
|
-
if _is_hip
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
layer.w13_weight,
|
939
|
-
layer.w2_weight,
|
940
|
-
topk_weights,
|
941
|
-
topk_ids,
|
942
|
-
layer.w13_weight_scale1,
|
943
|
-
layer.w2_weight_scale1,
|
944
|
-
activation=(
|
945
|
-
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
946
|
-
),
|
947
|
-
)
|
948
|
-
if _is_hip and get_bool_env_var("CK_MOE"):
|
949
|
-
assert not no_combine, f"{no_combine=} is not supported."
|
950
|
-
if self.block_quant:
|
951
|
-
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
952
|
-
assert (
|
953
|
-
activation == "silu"
|
954
|
-
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
955
|
-
return asm_moe(
|
956
|
-
x,
|
957
|
-
layer.w13_weight,
|
958
|
-
layer.w2_weight,
|
959
|
-
topk_weights,
|
960
|
-
topk_ids,
|
961
|
-
layer.w13_weight_scale_inv,
|
962
|
-
layer.w2_weight_scale_inv,
|
963
|
-
block_shape=tuple(self.quant_config.weight_block_size),
|
964
|
-
expert_mask=None,
|
965
|
-
)
|
966
|
-
else:
|
967
|
-
return ck_moe_2stages(
|
914
|
+
if _is_hip:
|
915
|
+
if get_bool_env_var("USE_INT4_WEIGHT"):
|
916
|
+
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
917
|
+
assert not no_combine, f"{no_combine=} is not supported."
|
918
|
+
return ck_moe_2stages_win4(
|
968
919
|
x,
|
969
920
|
layer.w13_weight,
|
970
921
|
layer.w2_weight,
|
@@ -978,33 +929,65 @@ class Fp8MoEMethod:
|
|
978
929
|
else ActivationType.Gelu
|
979
930
|
),
|
980
931
|
)
|
981
|
-
|
982
|
-
|
983
|
-
|
984
|
-
|
985
|
-
|
986
|
-
|
987
|
-
|
988
|
-
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1003
|
-
|
1004
|
-
|
1005
|
-
|
1006
|
-
|
1007
|
-
|
932
|
+
|
933
|
+
if get_bool_env_var("CK_MOE"):
|
934
|
+
assert not no_combine, f"{no_combine=} is not supported."
|
935
|
+
if self.block_quant:
|
936
|
+
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
937
|
+
assert (
|
938
|
+
activation == "silu"
|
939
|
+
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
940
|
+
return asm_moe(
|
941
|
+
x,
|
942
|
+
layer.w13_weight,
|
943
|
+
layer.w2_weight,
|
944
|
+
topk_weights,
|
945
|
+
topk_ids,
|
946
|
+
layer.w13_weight_scale_inv,
|
947
|
+
layer.w2_weight_scale_inv,
|
948
|
+
block_shape=tuple(self.quant_config.weight_block_size),
|
949
|
+
expert_mask=None,
|
950
|
+
)
|
951
|
+
else:
|
952
|
+
return ck_moe_2stages(
|
953
|
+
x,
|
954
|
+
layer.w13_weight,
|
955
|
+
layer.w2_weight,
|
956
|
+
topk_weights,
|
957
|
+
topk_ids,
|
958
|
+
layer.w13_weight_scale1,
|
959
|
+
layer.w2_weight_scale1,
|
960
|
+
activation=(
|
961
|
+
ActivationType.Silu
|
962
|
+
if activation == "silu"
|
963
|
+
else ActivationType.Gelu
|
964
|
+
),
|
965
|
+
)
|
966
|
+
|
967
|
+
# Expert fusion with FP8 quantization
|
968
|
+
return fused_experts(
|
969
|
+
x,
|
970
|
+
layer.w13_weight,
|
971
|
+
layer.w2_weight,
|
972
|
+
topk_weights=topk_weights,
|
973
|
+
topk_ids=topk_ids,
|
974
|
+
inplace=inplace and not no_combine,
|
975
|
+
activation=activation,
|
976
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
977
|
+
use_fp8_w8a8=True,
|
978
|
+
w1_scale=(
|
979
|
+
layer.w13_weight_scale_inv
|
980
|
+
if self.block_quant
|
981
|
+
else layer.w13_weight_scale
|
982
|
+
),
|
983
|
+
w2_scale=(
|
984
|
+
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
985
|
+
),
|
986
|
+
a1_scale=layer.w13_input_scale,
|
987
|
+
a2_scale=layer.w2_input_scale,
|
988
|
+
block_shape=self.quant_config.weight_block_size,
|
989
|
+
no_combine=no_combine,
|
990
|
+
)
|
1008
991
|
|
1009
992
|
|
1010
993
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|