sglang 0.5.0rc1__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/disaggregation/decode.py +0 -1
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/http_server.py +64 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -0
- sglang/srt/entrypoints/openai/serving_chat.py +1 -0
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/layers/attention/flashinfer_backend.py +3 -0
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +10 -3
- sglang/srt/layers/communicator.py +7 -7
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +5 -32
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +52 -30
- sglang/srt/layers/quantization/mxfp4.py +16 -2
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/managers/cache_controller.py +4 -1
- sglang/srt/managers/io_struct.py +14 -0
- sglang/srt/managers/schedule_batch.py +18 -39
- sglang/srt/managers/scheduler.py +3 -4
- sglang/srt/managers/tokenizer_manager.py +28 -18
- sglang/srt/mem_cache/allocator.py +8 -157
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +8 -21
- sglang/srt/model_executor/forward_batch_info.py +8 -10
- sglang/srt/model_executor/model_runner.py +57 -53
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +5 -3
- sglang/srt/models/glm4_moe.py +2 -2
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/gpt_oss.py +7 -2
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -5
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +33 -7
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/two_batch_overlap.py +4 -8
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +5 -5
- {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +75 -63
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -55,13 +55,7 @@ if is_mxfp_supported:
|
|
55
55
|
from sglang.srt.layers.quantization.fp4 import MxFp4Config
|
56
56
|
|
57
57
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
58
|
-
from sglang.srt.layers.quantization.gptq import
|
59
|
-
GPTQConfig,
|
60
|
-
GPTQLinearMethod,
|
61
|
-
GPTQMarlinConfig,
|
62
|
-
GPTQMarlinLinearMethod,
|
63
|
-
GPTQMarlinMoEMethod,
|
64
|
-
)
|
58
|
+
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
65
59
|
from sglang.srt.layers.quantization.modelopt_quant import (
|
66
60
|
ModelOptFp4Config,
|
67
61
|
ModelOptFp8Config,
|
@@ -70,7 +64,6 @@ from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
|
70
64
|
from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
|
71
65
|
from sglang.srt.layers.quantization.petit import PetitNvFp4Config
|
72
66
|
from sglang.srt.layers.quantization.qoq import QoQConfig
|
73
|
-
from sglang.srt.layers.quantization.utils import get_linear_quant_method
|
74
67
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
75
68
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
76
69
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
@@ -86,6 +79,10 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
86
79
|
"modelopt_fp4": ModelOptFp4Config,
|
87
80
|
"w8a8_int8": W8A8Int8Config,
|
88
81
|
"w8a8_fp8": W8A8Fp8Config,
|
82
|
+
"awq": AWQConfig,
|
83
|
+
"awq_marlin": AWQMarlinConfig,
|
84
|
+
"gptq": GPTQConfig,
|
85
|
+
"gptq_marlin": GPTQMarlinConfig,
|
89
86
|
"moe_wna16": MoeWNA16Config,
|
90
87
|
"compressed-tensors": CompressedTensorsConfig,
|
91
88
|
"qoq": QoQConfig,
|
@@ -111,19 +108,15 @@ elif is_mxfp_supported and is_hip():
|
|
111
108
|
# VLLM-dependent quantization methods
|
112
109
|
VLLM_QUANTIZATION_METHODS = {
|
113
110
|
"aqlm": AQLMConfig,
|
114
|
-
"awq": AWQConfig,
|
115
111
|
"deepspeedfp": DeepSpeedFPConfig,
|
116
112
|
"tpu_int8": Int8TpuConfig,
|
117
113
|
"fbgemm_fp8": FBGEMMFp8Config,
|
118
114
|
"marlin": MarlinConfig,
|
119
115
|
"gguf": GGUFConfig,
|
120
116
|
"gptq_marlin_24": GPTQMarlin24Config,
|
121
|
-
"awq_marlin": AWQMarlinConfig,
|
122
117
|
"bitsandbytes": BitsAndBytesConfig,
|
123
118
|
"qqq": QQQConfig,
|
124
119
|
"experts_int8": ExpertsInt8Config,
|
125
|
-
"gptq_marlin": GPTQMarlinConfig,
|
126
|
-
"gptq": GPTQConfig,
|
127
120
|
}
|
128
121
|
|
129
122
|
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
|
@@ -145,23 +138,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
145
138
|
return QUANTIZATION_METHODS[quantization]
|
146
139
|
|
147
140
|
|
148
|
-
def gptq_get_quant_method(self, layer, prefix):
|
149
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
150
|
-
|
151
|
-
if isinstance(layer, FusedMoE):
|
152
|
-
return GPTQMarlinMoEMethod(self)
|
153
|
-
|
154
|
-
if isinstance(self, GPTQConfig):
|
155
|
-
return get_linear_quant_method(
|
156
|
-
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
157
|
-
)
|
158
|
-
elif isinstance(self, GPTQMarlinConfig):
|
159
|
-
return get_linear_quant_method(
|
160
|
-
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
161
|
-
)
|
162
|
-
return None
|
163
|
-
|
164
|
-
|
165
141
|
original_isinstance = builtins.isinstance
|
166
142
|
|
167
143
|
|
@@ -239,10 +215,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|
239
215
|
|
240
216
|
def monkey_patch_quant_configs():
|
241
217
|
"""Apply all monkey patches in one place."""
|
242
|
-
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
243
|
-
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
244
218
|
|
245
|
-
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
246
219
|
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
247
220
|
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
248
221
|
|
@@ -29,29 +29,25 @@ from sglang.srt.layers.quantization.marlin_utils import (
|
|
29
29
|
verify_marlin_supported,
|
30
30
|
verify_marlin_supports_shape,
|
31
31
|
)
|
32
|
-
from sglang.srt.layers.quantization.scalar_type import scalar_types
|
33
32
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
34
|
-
from sglang.srt.layers.quantization.utils import replace_parameter
|
33
|
+
from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
|
35
34
|
|
36
35
|
if TYPE_CHECKING:
|
37
36
|
from sglang.srt.layers.moe.topk import TopKOutput
|
38
37
|
|
39
|
-
try:
|
40
|
-
from vllm import _custom_ops as ops
|
41
|
-
|
42
|
-
warnings.warn(
|
43
|
-
f"Using kernels directly from vllm. This might lead to performance degradation or "
|
44
|
-
f"missing functionalities as certain kernels may not be optimized. "
|
45
|
-
)
|
46
|
-
except ImportError:
|
47
|
-
ops = None
|
48
|
-
|
49
38
|
from sglang.srt.utils import is_cuda, is_hip
|
50
39
|
|
51
40
|
_is_cuda = is_cuda()
|
52
41
|
_is_hip = is_hip()
|
53
42
|
if _is_cuda:
|
54
|
-
from sgl_kernel import
|
43
|
+
from sgl_kernel import (
|
44
|
+
awq_dequantize,
|
45
|
+
awq_marlin_moe_repack,
|
46
|
+
awq_marlin_repack,
|
47
|
+
fused_marlin_moe,
|
48
|
+
)
|
49
|
+
|
50
|
+
|
55
51
|
elif _is_hip:
|
56
52
|
from sglang.srt.layers.quantization.awq_triton import (
|
57
53
|
awq_dequantize_triton as awq_dequantize,
|
@@ -64,6 +60,9 @@ else:
|
|
64
60
|
logger = logging.getLogger(__name__)
|
65
61
|
|
66
62
|
|
63
|
+
ScalarType, scalar_types = get_scalar_types()
|
64
|
+
|
65
|
+
|
67
66
|
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
|
68
67
|
return any(module_name in prefix for module_name in modules_to_not_convert)
|
69
68
|
|
@@ -516,7 +515,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
|
516
515
|
layer.workspace = marlin_make_workspace(device)
|
517
516
|
|
518
517
|
# Repack weights from AWQ format to marlin format.
|
519
|
-
marlin_qweight =
|
518
|
+
marlin_qweight = awq_marlin_repack(
|
520
519
|
layer.qweight,
|
521
520
|
size_k=layer.input_size_per_partition,
|
522
521
|
size_n=layer.output_size_per_partition,
|
@@ -684,7 +683,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|
684
683
|
requires_grad=False,
|
685
684
|
)
|
686
685
|
|
687
|
-
marlin_w13_qweight =
|
686
|
+
marlin_w13_qweight = awq_marlin_moe_repack(
|
688
687
|
layer.w13_qweight,
|
689
688
|
layer.w13_g_idx_sort_indices,
|
690
689
|
size_k=layer.w13_qweight.shape[1],
|
@@ -693,7 +692,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|
693
692
|
)
|
694
693
|
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
695
694
|
|
696
|
-
marlin_w2_qweight =
|
695
|
+
marlin_w2_qweight = awq_marlin_moe_repack(
|
697
696
|
layer.w2_qweight,
|
698
697
|
layer.w2_g_idx_sort_indices,
|
699
698
|
size_k=layer.w2_qweight.shape[1],
|
@@ -16,7 +16,6 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_qu
|
|
16
16
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
17
17
|
from sglang.srt.layers.quantization.utils import (
|
18
18
|
all_close_1d,
|
19
|
-
cpu_has_amx_support,
|
20
19
|
per_tensor_dequantize,
|
21
20
|
replace_parameter,
|
22
21
|
)
|
@@ -36,9 +36,9 @@ from sglang.srt.layers.quantization.marlin_utils import (
|
|
36
36
|
marlin_zero_points,
|
37
37
|
verify_marlin_supported,
|
38
38
|
)
|
39
|
-
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
40
39
|
from sglang.srt.layers.quantization.utils import (
|
41
40
|
get_linear_quant_method,
|
41
|
+
get_scalar_types,
|
42
42
|
replace_parameter,
|
43
43
|
unpack_cols,
|
44
44
|
)
|
@@ -46,20 +46,16 @@ from sglang.srt.layers.quantization.utils import (
|
|
46
46
|
if TYPE_CHECKING:
|
47
47
|
from sglang.srt.layers.moe.topk import TopKOutput
|
48
48
|
|
49
|
-
try:
|
50
|
-
from vllm import _custom_ops as ops
|
51
|
-
except ImportError:
|
52
|
-
ops = None
|
53
|
-
|
54
49
|
from sglang.srt.utils import is_cuda
|
55
50
|
|
56
51
|
_is_cuda = is_cuda()
|
57
52
|
|
58
53
|
if _is_cuda:
|
59
|
-
from sgl_kernel import fused_marlin_moe
|
54
|
+
from sgl_kernel import fused_marlin_moe, gptq_gemm, gptq_marlin_repack, gptq_shuffle
|
60
55
|
|
61
56
|
|
62
57
|
logger = logging.getLogger(__name__)
|
58
|
+
ScalarType, scalar_types = get_scalar_types()
|
63
59
|
|
64
60
|
|
65
61
|
def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
|
@@ -85,9 +81,7 @@ def gptq_marlin_moe_repack(
|
|
85
81
|
dtype=b_q_weight.dtype,
|
86
82
|
)
|
87
83
|
for e in range(num_experts):
|
88
|
-
output[e] =
|
89
|
-
b_q_weight[e], perm[e], size_k, size_n, num_bits
|
90
|
-
)
|
84
|
+
output[e] = gptq_marlin_repack(b_q_weight[e], perm[e], size_k, size_n, num_bits)
|
91
85
|
return output
|
92
86
|
|
93
87
|
|
@@ -204,11 +198,12 @@ class GPTQConfig(QuantizationConfig):
|
|
204
198
|
from sglang.srt.layers.linear import LinearBase
|
205
199
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
206
200
|
|
207
|
-
if isinstance(layer,
|
208
|
-
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
209
|
-
elif isinstance(layer, FusedMoE):
|
201
|
+
if isinstance(layer, FusedMoE):
|
210
202
|
raise TypeError("GPTQ Method does not support MoE, please use gptq_marlin")
|
211
|
-
|
203
|
+
else:
|
204
|
+
return get_linear_quant_method(
|
205
|
+
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
206
|
+
)
|
212
207
|
|
213
208
|
|
214
209
|
class GPTQMarlinConfig(QuantizationConfig):
|
@@ -530,7 +525,7 @@ class GPTQLinearMethod(LinearMethodBase):
|
|
530
525
|
layer.g_idx.data = torch.empty(
|
531
526
|
(0,), dtype=torch.int, device=layer.g_idx.device
|
532
527
|
)
|
533
|
-
|
528
|
+
gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
|
534
529
|
|
535
530
|
def apply(
|
536
531
|
self,
|
@@ -541,7 +536,7 @@ class GPTQLinearMethod(LinearMethodBase):
|
|
541
536
|
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
|
542
537
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
543
538
|
|
544
|
-
output =
|
539
|
+
output = gptq_gemm(
|
545
540
|
reshaped_x,
|
546
541
|
layer.qweight,
|
547
542
|
layer.qzeros,
|
@@ -726,7 +721,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
|
726
721
|
def transform_w_q(x):
|
727
722
|
assert isinstance(x, BasevLLMParameter)
|
728
723
|
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
729
|
-
x.data =
|
724
|
+
x.data = gptq_marlin_repack(
|
730
725
|
x.data.contiguous(),
|
731
726
|
perm=layer.g_idx_sort_indices,
|
732
727
|
size_k=c.partition_weight_shape[0],
|
@@ -19,9 +19,12 @@ from sglang.srt.layers.quantization.base_config import (
|
|
19
19
|
LinearMethodBase,
|
20
20
|
QuantizationConfig,
|
21
21
|
)
|
22
|
-
from sglang.srt.layers.quantization.
|
23
|
-
|
24
|
-
|
22
|
+
from sglang.srt.layers.quantization.utils import (
|
23
|
+
get_scalar_types,
|
24
|
+
pack_cols,
|
25
|
+
unpack_cols,
|
26
|
+
)
|
27
|
+
from sglang.srt.utils import get_device_capability, is_cuda
|
25
28
|
|
26
29
|
if TYPE_CHECKING:
|
27
30
|
from sglang.srt.layers.linear import LinearBase
|
@@ -31,8 +34,15 @@ try:
|
|
31
34
|
except ImportError:
|
32
35
|
ops = None
|
33
36
|
|
37
|
+
_is_cuda = is_cuda()
|
38
|
+
|
39
|
+
if _is_cuda:
|
40
|
+
from sgl_kernel import gptq_marlin_gemm
|
41
|
+
|
34
42
|
logger = logging.getLogger(__name__)
|
35
43
|
|
44
|
+
ScalarType, scalar_types = get_scalar_types()
|
45
|
+
|
36
46
|
GPTQ_MARLIN_TILE = 16
|
37
47
|
GPTQ_MARLIN_MIN_THREAD_N = 64
|
38
48
|
GPTQ_MARLIN_MIN_THREAD_K = 128
|
@@ -453,7 +463,7 @@ def apply_gptq_marlin_linear(
|
|
453
463
|
dtype=input.dtype,
|
454
464
|
)
|
455
465
|
|
456
|
-
output =
|
466
|
+
output = gptq_marlin_gemm(
|
457
467
|
reshaped_x,
|
458
468
|
None,
|
459
469
|
weight,
|
@@ -504,7 +514,7 @@ def apply_awq_marlin_linear(
|
|
504
514
|
dtype=input.dtype,
|
505
515
|
)
|
506
516
|
|
507
|
-
output =
|
517
|
+
output = gptq_marlin_gemm(
|
508
518
|
reshaped_x,
|
509
519
|
None,
|
510
520
|
weight,
|
@@ -737,6 +737,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
737
737
|
" above."
|
738
738
|
)
|
739
739
|
self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
740
|
+
self._cache_permute_indices = {}
|
740
741
|
|
741
742
|
@property
|
742
743
|
def enable_flashinfer_cutlass_moe(self) -> bool:
|
@@ -900,10 +901,15 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
900
901
|
e2m1_and_ufp8sf_scale_to_float,
|
901
902
|
fp4_quantize,
|
902
903
|
next_positive_power_of_2,
|
904
|
+
nvfp4_block_scale_interleave,
|
903
905
|
reorder_rows_for_gated_act_gemm,
|
904
906
|
shuffle_matrix_a,
|
905
907
|
shuffle_matrix_sf_a,
|
906
908
|
)
|
909
|
+
from flashinfer.fused_moe.core import (
|
910
|
+
_maybe_get_cached_w2_permute_indices,
|
911
|
+
_maybe_get_cached_w3_w1_permute_indices,
|
912
|
+
)
|
907
913
|
|
908
914
|
"""Prepare quantized weights for kernel (done offline with weights)."""
|
909
915
|
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
@@ -927,50 +933,66 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
927
933
|
num_experts, hidden_size, intermediate_size // 16
|
928
934
|
) # fp8 scaling factors
|
929
935
|
|
930
|
-
# Reorder rows of W1 and scales for fused gated activation
|
931
|
-
gemm1_weights_fp4_interleaved = []
|
932
|
-
gemm1_scales_fp4_interleaved = []
|
933
|
-
for i in range(num_experts):
|
934
|
-
gemm1_weights_fp4_interleaved.append(
|
935
|
-
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
|
936
|
-
)
|
937
|
-
gemm1_scales_fp4_interleaved.append(
|
938
|
-
reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
|
939
|
-
)
|
940
|
-
|
941
|
-
# Stack weights and scales for all experts
|
942
|
-
gemm1_weights_fp4_interleaved = torch.stack(
|
943
|
-
gemm1_weights_fp4_interleaved
|
944
|
-
).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
|
945
|
-
gemm1_scales_fp4_interleaved = torch.stack(
|
946
|
-
gemm1_scales_fp4_interleaved
|
947
|
-
).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
|
948
|
-
|
949
|
-
# Shuffle weights and scaling factors for transposed mma output
|
950
936
|
gemm1_weights_fp4_shuffled = []
|
951
937
|
gemm1_scales_fp4_shuffled = []
|
952
938
|
gemm2_weights_fp4_shuffled = []
|
953
939
|
gemm2_scales_fp4_shuffled = []
|
954
940
|
for i in range(num_experts):
|
941
|
+
# Calculate the permute indices for the following:
|
942
|
+
# 1. Reorder rows of W1 and scales for fused gated activation
|
943
|
+
# 2. Shuffle weights and scaling factors for transposed mma output
|
944
|
+
# for both w3_w1 and w2 weights and scale factors
|
945
|
+
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
|
946
|
+
self._cache_permute_indices,
|
947
|
+
gemm1_weights_fp4[i].view(torch.uint8),
|
948
|
+
epilogue_tile_m,
|
949
|
+
)
|
955
950
|
gemm1_weights_fp4_shuffled.append(
|
956
|
-
|
957
|
-
|
958
|
-
)
|
951
|
+
gemm1_weights_fp4[i]
|
952
|
+
.view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
|
953
|
+
.contiguous()
|
954
|
+
)
|
955
|
+
|
956
|
+
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
|
957
|
+
self._cache_permute_indices,
|
958
|
+
gemm1_scales_linear_fp4[i].view(torch.uint8),
|
959
|
+
epilogue_tile_m,
|
960
|
+
num_elts_per_sf=16,
|
959
961
|
)
|
960
962
|
gemm1_scales_fp4_shuffled.append(
|
961
|
-
|
962
|
-
|
963
|
+
nvfp4_block_scale_interleave(
|
964
|
+
gemm1_scales_linear_fp4[i]
|
965
|
+
.view(torch.uint8)[
|
966
|
+
permute_sf_indices.to(gemm1_scales_linear_fp4.device)
|
967
|
+
]
|
968
|
+
.contiguous()
|
963
969
|
)
|
964
970
|
)
|
965
971
|
|
972
|
+
permute_indices = _maybe_get_cached_w2_permute_indices(
|
973
|
+
self._cache_permute_indices,
|
974
|
+
gemm2_weights_fp4[i].view(torch.uint8),
|
975
|
+
epilogue_tile_m,
|
976
|
+
)
|
966
977
|
gemm2_weights_fp4_shuffled.append(
|
967
|
-
|
968
|
-
|
969
|
-
)
|
978
|
+
gemm2_weights_fp4[i]
|
979
|
+
.view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
|
980
|
+
.contiguous()
|
981
|
+
)
|
982
|
+
|
983
|
+
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
984
|
+
self._cache_permute_indices,
|
985
|
+
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
986
|
+
epilogue_tile_m,
|
987
|
+
num_elts_per_sf=16,
|
970
988
|
)
|
971
989
|
gemm2_scales_fp4_shuffled.append(
|
972
|
-
|
973
|
-
gemm2_scales_linear_fp4[i]
|
990
|
+
nvfp4_block_scale_interleave(
|
991
|
+
gemm2_scales_linear_fp4[i]
|
992
|
+
.view(torch.uint8)[
|
993
|
+
permute_sf_indices.to(gemm2_scales_linear_fp4.device)
|
994
|
+
]
|
995
|
+
.contiguous()
|
974
996
|
)
|
975
997
|
)
|
976
998
|
|
@@ -1,5 +1,18 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
1
2
|
# SPDX-License-Identifier: Apache-2.0
|
2
|
-
#
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/mxfp4.py
|
3
16
|
|
4
17
|
from __future__ import annotations
|
5
18
|
|
@@ -209,6 +222,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
209
222
|
|
210
223
|
super().__init__()
|
211
224
|
|
225
|
+
self.prefix = prefix
|
212
226
|
self.topk_indices_dtype = None
|
213
227
|
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
|
214
228
|
self.with_bias = False
|
@@ -332,7 +346,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
332
346
|
if self.use_flashinfer:
|
333
347
|
log_info_on_rank0(
|
334
348
|
logger,
|
335
|
-
"Shuffling MoE weights for FlashInfer MXFP4 moe kernel, it might take a while...",
|
349
|
+
f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...",
|
336
350
|
)
|
337
351
|
layer.gemm1_alpha = Parameter(
|
338
352
|
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
@@ -11,13 +11,39 @@ import numpy
|
|
11
11
|
import torch
|
12
12
|
|
13
13
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
14
|
-
from sglang.srt.
|
15
|
-
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
|
14
|
+
from sglang.srt.utils import is_cuda
|
16
15
|
|
17
16
|
if TYPE_CHECKING:
|
18
17
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
19
18
|
|
20
19
|
|
20
|
+
def get_scalar_types():
|
21
|
+
"""
|
22
|
+
Returns:
|
23
|
+
tuple: (ScalarType, scalar_types)
|
24
|
+
"""
|
25
|
+
try:
|
26
|
+
from sgl_kernel.scalar_type import ScalarType, scalar_types
|
27
|
+
|
28
|
+
return ScalarType, scalar_types
|
29
|
+
except ImportError:
|
30
|
+
|
31
|
+
class MockScalarType:
|
32
|
+
pass
|
33
|
+
|
34
|
+
class MockScalarTypes:
|
35
|
+
uint4b8 = "uint4b8"
|
36
|
+
uint8b128 = "uint8b128"
|
37
|
+
|
38
|
+
def __getattr__(self, name):
|
39
|
+
return f"mock_{name}"
|
40
|
+
|
41
|
+
return MockScalarType, MockScalarTypes()
|
42
|
+
|
43
|
+
|
44
|
+
ScalarType, scalar_types = get_scalar_types()
|
45
|
+
|
46
|
+
|
21
47
|
def is_layer_skipped(
|
22
48
|
prefix: str,
|
23
49
|
ignored_layers: List[str],
|
@@ -295,6 +321,30 @@ def pack_cols(
|
|
295
321
|
return q_res
|
296
322
|
|
297
323
|
|
324
|
+
def pack_rows(
|
325
|
+
q_w: torch.Tensor,
|
326
|
+
num_bits: int,
|
327
|
+
size_k: int,
|
328
|
+
size_n: int,
|
329
|
+
):
|
330
|
+
assert q_w.shape == (size_k, size_n)
|
331
|
+
|
332
|
+
pack_factor = get_pack_factor(num_bits)
|
333
|
+
assert size_k % pack_factor == 0
|
334
|
+
|
335
|
+
orig_device = q_w.device
|
336
|
+
|
337
|
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
338
|
+
|
339
|
+
q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
|
340
|
+
|
341
|
+
for i in range(pack_factor):
|
342
|
+
q_res |= q_w[i::pack_factor, :] << num_bits * i
|
343
|
+
|
344
|
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
345
|
+
return q_res
|
346
|
+
|
347
|
+
|
298
348
|
def unpack_cols(
|
299
349
|
packed_q_w: torch.Tensor,
|
300
350
|
num_bits: int,
|
sglang/srt/layers/sampler.py
CHANGED
@@ -6,7 +6,10 @@ import torch.distributed as dist
|
|
6
6
|
from torch import nn
|
7
7
|
|
8
8
|
from sglang.srt.distributed import get_tp_group
|
9
|
-
from sglang.srt.layers.dp_attention import
|
9
|
+
from sglang.srt.layers.dp_attention import (
|
10
|
+
get_attention_tp_group,
|
11
|
+
is_dp_attention_enabled,
|
12
|
+
)
|
10
13
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
11
14
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
12
15
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
@@ -32,7 +35,7 @@ class Sampler(nn.Module):
|
|
32
35
|
self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
|
33
36
|
self.tp_sync_group = get_tp_group().device_group
|
34
37
|
|
35
|
-
if
|
38
|
+
if is_dp_attention_enabled():
|
36
39
|
self.tp_sync_group = get_attention_tp_group().device_group
|
37
40
|
|
38
41
|
def forward(
|
sglang/srt/lora/layers.py
CHANGED
@@ -253,7 +253,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
253
253
|
)
|
254
254
|
return lora_output
|
255
255
|
|
256
|
-
def forward(self, input_: torch.Tensor):
|
256
|
+
def forward(self, input_: torch.Tensor, skip_all_reduce=False):
|
257
257
|
# duplicate the logic in RowParallelLinear
|
258
258
|
if self.base_layer.input_is_parallel:
|
259
259
|
input_parallel = input_
|
@@ -270,7 +270,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
270
270
|
if self.set_lora:
|
271
271
|
output_parallel = self.apply_lora(output_parallel, input_parallel)
|
272
272
|
|
273
|
-
if
|
273
|
+
if (
|
274
|
+
self.base_layer.reduce_results
|
275
|
+
and self.base_layer.tp_size > 1
|
276
|
+
and not skip_all_reduce
|
277
|
+
):
|
274
278
|
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
275
279
|
else:
|
276
280
|
output_ = output_parallel
|
@@ -296,6 +296,9 @@ class HiCacheController:
|
|
296
296
|
self.prefetch_tp_group = torch.distributed.new_group(
|
297
297
|
group_ranks, backend="gloo"
|
298
298
|
)
|
299
|
+
self.prefetch_io_tp_group = torch.distributed.new_group(
|
300
|
+
group_ranks, backend="gloo"
|
301
|
+
)
|
299
302
|
self.backup_tp_group = torch.distributed.new_group(
|
300
303
|
group_ranks, backend="gloo"
|
301
304
|
)
|
@@ -602,7 +605,7 @@ class HiCacheController:
|
|
602
605
|
|
603
606
|
if self.tp_world_size > 1:
|
604
607
|
# to ensure all TP workers release the host memory at the same time
|
605
|
-
torch.distributed.barrier(group=self.
|
608
|
+
torch.distributed.barrier(group=self.prefetch_io_tp_group)
|
606
609
|
# operation terminated by controller, release pre-allocated memory
|
607
610
|
self.mem_pool_host.free(
|
608
611
|
operation.host_indices[operation.completed_tokens :]
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -798,6 +798,8 @@ class UpdateWeightFromDiskReqInput:
|
|
798
798
|
load_format: Optional[str] = None
|
799
799
|
# Whether to abort all requests before updating weights
|
800
800
|
abort_all_requests: bool = False
|
801
|
+
# Optional: Update weight version along with weights
|
802
|
+
weight_version: Optional[str] = None
|
801
803
|
|
802
804
|
|
803
805
|
@dataclass
|
@@ -819,6 +821,8 @@ class UpdateWeightsFromDistributedReqInput:
|
|
819
821
|
flush_cache: bool = True
|
820
822
|
# Whether to abort all requests before updating weights
|
821
823
|
abort_all_requests: bool = False
|
824
|
+
# Optional: Update weight version along with weights
|
825
|
+
weight_version: Optional[str] = None
|
822
826
|
|
823
827
|
|
824
828
|
@dataclass
|
@@ -842,6 +846,8 @@ class UpdateWeightsFromTensorReqInput:
|
|
842
846
|
flush_cache: bool = True
|
843
847
|
# Whether to abort all requests before updating weights
|
844
848
|
abort_all_requests: bool = False
|
849
|
+
# Optional: Update weight version along with weights
|
850
|
+
weight_version: Optional[str] = None
|
845
851
|
|
846
852
|
|
847
853
|
@dataclass
|
@@ -872,6 +878,14 @@ class InitWeightsUpdateGroupReqOutput:
|
|
872
878
|
message: str
|
873
879
|
|
874
880
|
|
881
|
+
@dataclass
|
882
|
+
class UpdateWeightVersionReqInput:
|
883
|
+
# The new weight version
|
884
|
+
new_version: str
|
885
|
+
# Whether to abort all running requests before updating
|
886
|
+
abort_all_requests: bool = True
|
887
|
+
|
888
|
+
|
875
889
|
@dataclass
|
876
890
|
class GetWeightsByNameReqInput:
|
877
891
|
name: str
|