sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/lang/chat_template.py +24 -0
- sglang/srt/configs/model_config.py +40 -4
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +29 -4
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +609 -202
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +28 -14
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +3 -0
- sglang/srt/layers/radix_attention.py +14 -0
- sglang/srt/layers/rotary_embedding.py +75 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +49 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +13 -4
- sglang/srt/models/llama4.py +487 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +227 -0
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
8
8
|
QuantizationConfig,
|
9
9
|
QuantizeMethodBase,
|
10
10
|
)
|
11
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
11
12
|
from sglang.srt.utils import is_hip
|
12
13
|
|
13
14
|
_is_hip = is_hip()
|
@@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
|
|
17
18
|
|
18
19
|
class BaseKVCacheMethod(QuantizeMethodBase):
|
19
20
|
"""
|
20
|
-
Quant method that adds `
|
21
|
+
Quant method that adds `k_scale` and `v_scale` attributes to the
|
21
22
|
Attention layer to support loading those scaling factors from checkpoints.
|
22
23
|
The k/v_scale will be used to:
|
23
24
|
- quantize k/v_cache entries before saving them to the cache
|
@@ -36,8 +37,12 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|
36
37
|
# Initialize the KV cache scales to -1.0, which is an invalid value.
|
37
38
|
# If the k/v_scale appears in the checkpoint, it will be
|
38
39
|
# overwritten when loading weights.
|
39
|
-
layer.k_scale = torch.nn.Parameter(
|
40
|
-
|
40
|
+
layer.k_scale = torch.nn.Parameter(
|
41
|
+
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
|
42
|
+
)
|
43
|
+
layer.v_scale = torch.nn.Parameter(
|
44
|
+
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
|
45
|
+
)
|
41
46
|
|
42
47
|
@classmethod
|
43
48
|
def is_fp8_fnuz(cls) -> bool:
|
@@ -47,52 +52,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|
47
52
|
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
48
53
|
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
49
54
|
|
50
|
-
def process_weights_after_loading(self, layer:
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
# These are used in the final Attention.forward()
|
86
|
-
layer._k_scale.copy_(k_scale)
|
87
|
-
layer._v_scale.copy_(v_scale)
|
88
|
-
layer._k_scale_float = k_scale
|
89
|
-
layer._v_scale_float = v_scale
|
90
|
-
if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
|
91
|
-
logger.warning(
|
92
|
-
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
|
93
|
-
"may cause accuracy issues. Please make sure k/v_scale "
|
94
|
-
"scaling factors are available in the fp8 checkpoint."
|
95
|
-
)
|
96
|
-
|
97
|
-
del layer.k_scale
|
98
|
-
del layer.v_scale
|
55
|
+
def process_weights_after_loading(self, layer: RadixAttention) -> None:
|
56
|
+
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
57
|
+
# We prefer to use separate k_scale and v_scale if present
|
58
|
+
k_scale = layer.k_scale.to("cpu").tolist()
|
59
|
+
v_scale = layer.v_scale.to("cpu").tolist()
|
60
|
+
if _is_hip and self.is_fp8_fnuz():
|
61
|
+
k_scale *= 2
|
62
|
+
v_scale *= 2
|
63
|
+
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
64
|
+
# If no scales were loaded (both scales are invalid negative
|
65
|
+
# values), use the default value of 1.0
|
66
|
+
k_scale = 1.0
|
67
|
+
v_scale = 1.0
|
68
|
+
else:
|
69
|
+
# If we find a single kv_scale in the checkpoint, we remap
|
70
|
+
# kv_scale to k_scale during weight loading, and duplicate
|
71
|
+
# k_scale to v_scale here
|
72
|
+
assert layer.k_scale > 0.0
|
73
|
+
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
74
|
+
k_scale = scale_to_duplicate.to("cpu").tolist()
|
75
|
+
v_scale = scale_to_duplicate.to("cpu").tolist()
|
76
|
+
if _is_hip and self.is_fp8_fnuz():
|
77
|
+
k_scale *= 2
|
78
|
+
v_scale *= 2
|
79
|
+
|
80
|
+
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
|
81
|
+
raise ValueError(
|
82
|
+
"Only support per-tensor scaling factor " "for fp8 KV cache"
|
83
|
+
)
|
84
|
+
|
85
|
+
# These are used in the final Attention.forward()
|
86
|
+
layer.k_scale.copy_(k_scale)
|
87
|
+
layer.v_scale.copy_(v_scale)
|
88
|
+
layer.k_scale_float = k_scale
|
89
|
+
layer.v_scale_float = v_scale
|
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional
|
|
6
6
|
import torch
|
7
7
|
from torch.nn.parameter import Parameter
|
8
8
|
|
9
|
-
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
10
9
|
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
11
10
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
12
11
|
from sglang.srt.layers.quantization.base_config import (
|
@@ -22,6 +21,11 @@ from sglang.srt.layers.quantization.utils import (
|
|
22
21
|
convert_to_channelwise,
|
23
22
|
requantize_with_max_scale,
|
24
23
|
)
|
24
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
25
|
+
from sglang.srt.utils import is_cuda_available
|
26
|
+
|
27
|
+
if is_cuda_available():
|
28
|
+
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
25
29
|
|
26
30
|
# Initialize logger for the module
|
27
31
|
logger = logging.getLogger(__name__)
|
@@ -33,12 +37,19 @@ ACTIVATION_SCHEMES = ["static"]
|
|
33
37
|
class ModelOptFp8Config(QuantizationConfig):
|
34
38
|
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
|
35
39
|
|
36
|
-
def __init__(
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
is_checkpoint_fp8_serialized: bool = False,
|
43
|
+
kv_cache_quant_method: Optional[str] = None,
|
44
|
+
exclude_modules: Optional[List[str]] = None,
|
45
|
+
) -> None:
|
37
46
|
"""
|
38
47
|
Args:
|
39
48
|
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
|
40
49
|
"""
|
41
50
|
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
51
|
+
self.kv_cache_quant_method = kv_cache_quant_method
|
52
|
+
self.exclude_modules = exclude_modules
|
42
53
|
if is_checkpoint_fp8_serialized:
|
43
54
|
logger.warning(
|
44
55
|
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
|
@@ -63,6 +74,12 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
63
74
|
@classmethod
|
64
75
|
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
|
65
76
|
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
|
77
|
+
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
|
78
|
+
"kv_cache_quant_algo"
|
79
|
+
)
|
80
|
+
exclude_modules = cls.get_from_keys(config, ["quantization"]).get(
|
81
|
+
"exclude_modules"
|
82
|
+
)
|
66
83
|
|
67
84
|
if "FP8" not in quant_method:
|
68
85
|
raise ValueError(
|
@@ -70,15 +87,23 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
70
87
|
"Check the `hf_quant_config.json` file for your model's configuration."
|
71
88
|
)
|
72
89
|
|
73
|
-
return cls(
|
90
|
+
return cls(
|
91
|
+
is_checkpoint_fp8_serialized=True,
|
92
|
+
kv_cache_quant_method=kv_cache_quant_method,
|
93
|
+
exclude_modules=exclude_modules,
|
94
|
+
)
|
74
95
|
|
75
96
|
def get_quant_method(
|
76
97
|
self, layer: torch.nn.Module, prefix: str
|
77
98
|
) -> Optional["QuantizeMethodBase"]:
|
99
|
+
if self.exclude_modules and any(
|
100
|
+
module in prefix for module in self.exclude_modules
|
101
|
+
):
|
102
|
+
return None
|
78
103
|
|
79
104
|
if isinstance(layer, LinearBase):
|
80
105
|
return ModelOptFp8LinearMethod(self)
|
81
|
-
if isinstance(layer,
|
106
|
+
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
|
82
107
|
return ModelOptFp8KVCacheMethod(self)
|
83
108
|
|
84
109
|
return None
|
@@ -194,3 +219,245 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
|
194
219
|
|
195
220
|
def __init__(self, quant_config: ModelOptFp8Config):
|
196
221
|
super().__init__(quant_config)
|
222
|
+
|
223
|
+
|
224
|
+
class ModelOptFp4Config(QuantizationConfig):
|
225
|
+
"""Config class for FP4."""
|
226
|
+
|
227
|
+
def __init__(
|
228
|
+
self,
|
229
|
+
is_checkpoint_nvfp4_serialized: bool = False,
|
230
|
+
kv_cache_quant_algo: str = None,
|
231
|
+
group_size: int = None,
|
232
|
+
exclude_modules: List[str] = None,
|
233
|
+
) -> None:
|
234
|
+
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
235
|
+
if is_checkpoint_nvfp4_serialized:
|
236
|
+
logger.warning(
|
237
|
+
"Detected nvfp4 checkpoint. Please note that the "
|
238
|
+
"format is experimental and subject to change."
|
239
|
+
)
|
240
|
+
self.group_size = group_size
|
241
|
+
self.kv_cache_quant_algo = kv_cache_quant_algo
|
242
|
+
self.exclude_modules = exclude_modules
|
243
|
+
|
244
|
+
@classmethod
|
245
|
+
def get_name(cls) -> str:
|
246
|
+
return "modelopt_fp4"
|
247
|
+
|
248
|
+
@classmethod
|
249
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
250
|
+
return [torch.bfloat16, torch.half, torch.float8_e4m3fn]
|
251
|
+
|
252
|
+
@classmethod
|
253
|
+
def get_min_capability(cls) -> int:
|
254
|
+
return 100
|
255
|
+
|
256
|
+
@classmethod
|
257
|
+
def get_config_filenames(cls) -> List[str]:
|
258
|
+
return ["hf_quant_config.json"]
|
259
|
+
|
260
|
+
@classmethod
|
261
|
+
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp4Config":
|
262
|
+
quant_config = cls.get_from_keys(config, ["quantization"])
|
263
|
+
quant_method = quant_config["quant_algo"]
|
264
|
+
if not quant_method in ["FP8", "NVFP4"]:
|
265
|
+
raise ValueError(
|
266
|
+
f"ModelOpt currently only supports: FP8, NVFP4"
|
267
|
+
" quantizations in sglang. Please check the "
|
268
|
+
"`hf_quant_config.json` file for your model's "
|
269
|
+
"quant configuration."
|
270
|
+
)
|
271
|
+
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
272
|
+
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
|
273
|
+
group_size = quant_config["group_size"]
|
274
|
+
exclude_modules = quant_config["exclude_modules"]
|
275
|
+
if not (group_size and kv_cache_quant_algo and exclude_modules):
|
276
|
+
raise ValueError(
|
277
|
+
"NVFP4 quantization requires group size and "
|
278
|
+
"kv_cache_quant_algo specified in "
|
279
|
+
"hf_quant_config.json"
|
280
|
+
)
|
281
|
+
return cls(
|
282
|
+
is_checkpoint_nvfp4_serialized,
|
283
|
+
kv_cache_quant_algo,
|
284
|
+
group_size,
|
285
|
+
exclude_modules,
|
286
|
+
)
|
287
|
+
|
288
|
+
def get_quant_method(
|
289
|
+
self, layer: torch.nn.Module, prefix: str
|
290
|
+
) -> Optional["QuantizeMethodBase"]:
|
291
|
+
if self.exclude_modules and any(
|
292
|
+
module in prefix for module in self.exclude_modules
|
293
|
+
):
|
294
|
+
return None
|
295
|
+
|
296
|
+
if isinstance(layer, LinearBase):
|
297
|
+
return ModelOptFp4LinearMethod(self)
|
298
|
+
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
|
299
|
+
return ModelOptFp8KVCacheMethod(self)
|
300
|
+
|
301
|
+
return None
|
302
|
+
|
303
|
+
def get_scaled_act_names(self) -> List[str]:
|
304
|
+
return []
|
305
|
+
|
306
|
+
|
307
|
+
class ModelOptFp4LinearMethod(LinearMethodBase):
|
308
|
+
"""Linear method for NVFP4.
|
309
|
+
Supports loading NVFP4 checkpoints with the following structure:
|
310
|
+
|
311
|
+
|Tensor Name | datatype | shape |
|
312
|
+
|----------------------------------------------------|
|
313
|
+
|input_scale | torch.float32 | scalar |
|
314
|
+
|weight | NVFP4(SE2M1) | [1, X, y/2] |
|
315
|
+
|weight_scale | FP8-E4M3 | [X, Y] |
|
316
|
+
|weight_scale_2 | torch.float32 | scalar |
|
317
|
+
|
318
|
+
The weights are quantized per block of 16 elements.
|
319
|
+
Args: quant_config: The ModelOpt quantization config.
|
320
|
+
"""
|
321
|
+
|
322
|
+
def __init__(self, quant_config: ModelOptFp4Config):
|
323
|
+
self.quant_config = quant_config
|
324
|
+
|
325
|
+
def create_weights(
|
326
|
+
self,
|
327
|
+
layer: torch.nn.Module,
|
328
|
+
input_size_per_partition: int,
|
329
|
+
output_partition_sizes: List[int],
|
330
|
+
input_size: int,
|
331
|
+
output_size: int,
|
332
|
+
params_dtype: torch.dtype,
|
333
|
+
**extra_weight_attrs,
|
334
|
+
):
|
335
|
+
del input_size, output_size
|
336
|
+
if not self.quant_config.is_checkpoint_nvfp4_serialized:
|
337
|
+
raise ValueError(
|
338
|
+
"NVFP4 quantization was selected, "
|
339
|
+
" dynamic quantization is not supported."
|
340
|
+
)
|
341
|
+
|
342
|
+
output_size_per_partition = sum(output_partition_sizes)
|
343
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
344
|
+
|
345
|
+
layer.logical_widths = output_partition_sizes
|
346
|
+
|
347
|
+
layer.input_size_per_partition = input_size_per_partition
|
348
|
+
layer.output_size_per_partition = output_size_per_partition
|
349
|
+
if input_size_per_partition % 16 != 0:
|
350
|
+
raise ValueError(
|
351
|
+
"Unsupported model when in features size is " "not multiple of 16"
|
352
|
+
)
|
353
|
+
|
354
|
+
weight_dtype = (
|
355
|
+
torch.float8_e4m3fn
|
356
|
+
if self.quant_config.is_checkpoint_nvfp4_serialized
|
357
|
+
else params_dtype
|
358
|
+
)
|
359
|
+
|
360
|
+
weight = ModelWeightParameter(
|
361
|
+
data=torch.empty(
|
362
|
+
# 2 fp4 data is packed in one uint8 in the input dimension
|
363
|
+
output_size_per_partition,
|
364
|
+
input_size_per_partition // 2,
|
365
|
+
dtype=torch.uint8,
|
366
|
+
),
|
367
|
+
input_dim=1,
|
368
|
+
output_dim=0,
|
369
|
+
weight_loader=weight_loader,
|
370
|
+
)
|
371
|
+
layer.register_parameter("weight", weight)
|
372
|
+
|
373
|
+
input_scale = PerTensorScaleParameter(
|
374
|
+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
375
|
+
weight_loader=weight_loader,
|
376
|
+
)
|
377
|
+
|
378
|
+
layer.register_parameter("input_scale", input_scale)
|
379
|
+
|
380
|
+
weight_scale_2 = PerTensorScaleParameter(
|
381
|
+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
382
|
+
weight_loader=weight_loader,
|
383
|
+
)
|
384
|
+
layer.register_parameter("weight_scale_2", weight_scale_2)
|
385
|
+
|
386
|
+
weight_scale = ModelWeightParameter(
|
387
|
+
data=torch.empty(
|
388
|
+
output_size_per_partition,
|
389
|
+
input_size_per_partition // self.quant_config.group_size,
|
390
|
+
dtype=weight_dtype,
|
391
|
+
),
|
392
|
+
input_dim=1,
|
393
|
+
output_dim=0,
|
394
|
+
weight_loader=weight_loader,
|
395
|
+
)
|
396
|
+
|
397
|
+
layer.register_parameter("weight_scale", weight_scale)
|
398
|
+
|
399
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
400
|
+
input_scale_2 = layer.input_scale.max().to(torch.float32)
|
401
|
+
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
|
402
|
+
layer.input_scale = Parameter(input_scale_2, requires_grad=False)
|
403
|
+
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
|
404
|
+
layer.alpha = Parameter(
|
405
|
+
layer.input_scale * layer.weight_scale_2, requires_grad=False
|
406
|
+
)
|
407
|
+
|
408
|
+
# Pad and blockwise interleave weight_scale
|
409
|
+
scales = layer.weight_scale
|
410
|
+
scale_ndim = scales.ndim
|
411
|
+
if scale_ndim == 2:
|
412
|
+
scales = scales.unsqueeze(0)
|
413
|
+
assert scales.ndim == 3
|
414
|
+
B, M, K = scales.shape
|
415
|
+
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
416
|
+
M_padded = round_up_multiple(M, 128)
|
417
|
+
K_padded = round_up_multiple(K, 4)
|
418
|
+
padded_scales = torch.zeros((B, M_padded, K_padded), dtype=scales.dtype)
|
419
|
+
padded_scales[:B, :M, :K] = scales
|
420
|
+
batches, rows, cols = padded_scales.shape
|
421
|
+
assert rows % 128 == 0
|
422
|
+
assert cols % 4 == 0
|
423
|
+
padded_scales = padded_scales.reshape(batches, rows // 128, 4, 32, cols // 4, 4)
|
424
|
+
padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5))
|
425
|
+
padded_scales = padded_scales.contiguous().cuda()
|
426
|
+
padded_scales = (
|
427
|
+
padded_scales.reshape(M, K)
|
428
|
+
if scale_ndim == 2
|
429
|
+
else padded_scales.reshape(B, M, K)
|
430
|
+
)
|
431
|
+
layer.weight_scale_interleaved = Parameter(padded_scales, requires_grad=False)
|
432
|
+
|
433
|
+
def apply(
|
434
|
+
self,
|
435
|
+
layer: torch.nn.Module,
|
436
|
+
x: torch.Tensor,
|
437
|
+
bias: Optional[torch.Tensor] = None,
|
438
|
+
) -> torch.Tensor:
|
439
|
+
output_dtype = x.dtype
|
440
|
+
x_m, _ = x.shape
|
441
|
+
w_n, _ = layer.weight.shape
|
442
|
+
output_shape = [x_m, w_n]
|
443
|
+
|
444
|
+
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
445
|
+
x_fp4, x_scale_interleaved = scaled_fp4_quant(x, 1 / layer.input_scale)
|
446
|
+
|
447
|
+
assert x_fp4.dtype == torch.uint8
|
448
|
+
assert x_scale_interleaved.dtype == torch.float8_e4m3fn
|
449
|
+
assert layer.weight.dtype == torch.uint8
|
450
|
+
assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn
|
451
|
+
assert layer.alpha.dtype == torch.float32
|
452
|
+
|
453
|
+
out = cutlass_scaled_fp4_mm(
|
454
|
+
x_fp4,
|
455
|
+
layer.weight,
|
456
|
+
x_scale_interleaved,
|
457
|
+
layer.weight_scale_interleaved,
|
458
|
+
layer.alpha,
|
459
|
+
output_dtype,
|
460
|
+
)
|
461
|
+
if bias is not None:
|
462
|
+
out = out + bias
|
463
|
+
return out.view(*output_shape)
|
@@ -344,6 +344,7 @@ class MoeWNA16Method:
|
|
344
344
|
custom_routing_function: Optional[Callable] = None,
|
345
345
|
correction_bias: Optional[torch.Tensor] = None,
|
346
346
|
activation: str = "silu",
|
347
|
+
apply_router_weight_on_input: bool = False,
|
347
348
|
inplace: bool = True,
|
348
349
|
no_combine: bool = False,
|
349
350
|
) -> torch.Tensor:
|
@@ -374,6 +375,7 @@ class MoeWNA16Method:
|
|
374
375
|
topk_weights=topk_weights,
|
375
376
|
topk_ids=topk_ids,
|
376
377
|
inplace=inplace,
|
378
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
377
379
|
use_int4_w4a16=weight_bits == 4,
|
378
380
|
use_int8_w8a16=weight_bits == 8,
|
379
381
|
w1_scale=layer.w13_scales,
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any, Dict, List, Optional
|
1
|
+
from typing import Any, Callable, Dict, List, Optional
|
2
2
|
|
3
3
|
import torch
|
4
4
|
from torch.nn.parameter import Parameter
|
@@ -16,7 +16,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
16
16
|
input_to_float8,
|
17
17
|
normalize_e4m3fn_to_e4m3fnuz,
|
18
18
|
)
|
19
|
-
from sglang.srt.utils import is_hip
|
19
|
+
from sglang.srt.utils import is_hip, set_weight_attrs
|
20
20
|
|
21
21
|
_is_hip = is_hip()
|
22
22
|
|
@@ -62,7 +62,9 @@ class W8A8Fp8Config(QuantizationConfig):
|
|
62
62
|
@classmethod
|
63
63
|
def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
|
64
64
|
quant_method = cls.get_from_keys(config, ["quant_method"])
|
65
|
-
is_checkpoint_fp8_serialized =
|
65
|
+
is_checkpoint_fp8_serialized = (
|
66
|
+
"compressed-tensors" in quant_method or "w8a8_fp8" in quant_method
|
67
|
+
)
|
66
68
|
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized)
|
67
69
|
|
68
70
|
def get_quant_method(
|
@@ -71,9 +73,12 @@ class W8A8Fp8Config(QuantizationConfig):
|
|
71
73
|
prefix: str,
|
72
74
|
) -> Optional["QuantizeMethodBase"]:
|
73
75
|
from sglang.srt.layers.linear import LinearBase
|
76
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
74
77
|
|
75
78
|
if isinstance(layer, LinearBase):
|
76
79
|
return W8A8Fp8LinearMethod(self)
|
80
|
+
elif isinstance(layer, FusedMoE):
|
81
|
+
return W8A8FP8MoEMethod(self)
|
77
82
|
return None
|
78
83
|
|
79
84
|
def get_scaled_act_names(self) -> List[str]:
|
@@ -131,7 +136,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
131
136
|
input_size: int,
|
132
137
|
output_size: int,
|
133
138
|
params_dtype: torch.dtype,
|
134
|
-
**extra_weight_attrs
|
139
|
+
**extra_weight_attrs,
|
135
140
|
):
|
136
141
|
weight_dtype = (
|
137
142
|
torch.float8_e4m3fn
|
@@ -177,3 +182,148 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
177
182
|
bias=bias,
|
178
183
|
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
179
184
|
)
|
185
|
+
|
186
|
+
|
187
|
+
class W8A8FP8MoEMethod:
|
188
|
+
"""MoE method for FP8.
|
189
|
+
Supports loading FP8 checkpoints with static weight scale and
|
190
|
+
dynamic/static activation scale.
|
191
|
+
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
192
|
+
activation scaling. The weight scaling factor will be initialized after
|
193
|
+
the model weights are loaded.
|
194
|
+
Args:
|
195
|
+
quant_config: The quantization config.
|
196
|
+
"""
|
197
|
+
|
198
|
+
def __new__(cls, *args, **kwargs):
|
199
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
200
|
+
|
201
|
+
if not hasattr(cls, "_initialized"):
|
202
|
+
original_init = cls.__init__
|
203
|
+
new_cls = type(
|
204
|
+
cls.__name__,
|
205
|
+
(FusedMoEMethodBase,),
|
206
|
+
{
|
207
|
+
"__init__": original_init,
|
208
|
+
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
209
|
+
},
|
210
|
+
)
|
211
|
+
obj = super(new_cls, new_cls).__new__(new_cls)
|
212
|
+
obj.__init__(*args, **kwargs)
|
213
|
+
return obj
|
214
|
+
return super().__new__(cls)
|
215
|
+
|
216
|
+
def __init__(self, quant_config):
|
217
|
+
self.quant_config = quant_config
|
218
|
+
|
219
|
+
def create_weights(
|
220
|
+
self,
|
221
|
+
layer: torch.nn.Module,
|
222
|
+
num_experts: int,
|
223
|
+
hidden_size: int,
|
224
|
+
intermediate_size: int,
|
225
|
+
params_dtype: torch.dtype,
|
226
|
+
**extra_weight_attrs,
|
227
|
+
):
|
228
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
229
|
+
|
230
|
+
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
231
|
+
# WEIGHTS
|
232
|
+
w13_weight = torch.nn.Parameter(
|
233
|
+
torch.empty(
|
234
|
+
num_experts, 2 * intermediate_size, hidden_size, dtype=fp8_dtype
|
235
|
+
),
|
236
|
+
requires_grad=False,
|
237
|
+
)
|
238
|
+
layer.register_parameter("w13_weight", w13_weight)
|
239
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
240
|
+
|
241
|
+
w2_weight = torch.nn.Parameter(
|
242
|
+
torch.empty(num_experts, hidden_size, intermediate_size, dtype=fp8_dtype),
|
243
|
+
requires_grad=False,
|
244
|
+
)
|
245
|
+
layer.register_parameter("w2_weight", w2_weight)
|
246
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
247
|
+
|
248
|
+
w13_weight_scale = torch.nn.Parameter(
|
249
|
+
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
250
|
+
requires_grad=False,
|
251
|
+
)
|
252
|
+
w2_weight_scale = torch.nn.Parameter(
|
253
|
+
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
|
254
|
+
requires_grad=False,
|
255
|
+
)
|
256
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
257
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
258
|
+
|
259
|
+
extra_weight_attrs.update(
|
260
|
+
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
261
|
+
)
|
262
|
+
|
263
|
+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
264
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
265
|
+
|
266
|
+
w13_input_scale = None
|
267
|
+
layer.register_parameter("w13_input_scale", w13_input_scale)
|
268
|
+
|
269
|
+
w2_input_scale = None
|
270
|
+
layer.register_parameter("w2_input_scale", w2_input_scale)
|
271
|
+
|
272
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
273
|
+
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
274
|
+
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
275
|
+
layer.w13_weight_scale = Parameter(
|
276
|
+
layer.w13_weight_scale.data, requires_grad=False
|
277
|
+
)
|
278
|
+
layer.w2_weight_scale = Parameter(
|
279
|
+
layer.w2_weight_scale.data, requires_grad=False
|
280
|
+
)
|
281
|
+
|
282
|
+
def apply(
|
283
|
+
self,
|
284
|
+
layer: torch.nn.Module,
|
285
|
+
x: torch.Tensor,
|
286
|
+
router_logits: torch.Tensor,
|
287
|
+
top_k: int,
|
288
|
+
renormalize: bool,
|
289
|
+
use_grouped_topk: bool,
|
290
|
+
topk_group: Optional[int] = None,
|
291
|
+
num_expert_group: Optional[int] = None,
|
292
|
+
custom_routing_function: Optional[Callable] = None,
|
293
|
+
correction_bias: Optional[torch.Tensor] = None,
|
294
|
+
activation: str = "silu",
|
295
|
+
inplace: bool = True,
|
296
|
+
no_combine: bool = False,
|
297
|
+
) -> torch.Tensor:
|
298
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
299
|
+
from sglang.srt.layers.moe.topk import select_experts
|
300
|
+
|
301
|
+
# Expert selection
|
302
|
+
topk_weights, topk_ids = select_experts(
|
303
|
+
hidden_states=x,
|
304
|
+
router_logits=router_logits,
|
305
|
+
use_grouped_topk=use_grouped_topk,
|
306
|
+
top_k=top_k,
|
307
|
+
renormalize=renormalize,
|
308
|
+
topk_group=topk_group,
|
309
|
+
num_expert_group=num_expert_group,
|
310
|
+
custom_routing_function=custom_routing_function,
|
311
|
+
correction_bias=correction_bias,
|
312
|
+
)
|
313
|
+
|
314
|
+
return fused_experts(
|
315
|
+
x,
|
316
|
+
layer.w13_weight,
|
317
|
+
layer.w2_weight,
|
318
|
+
topk_weights=topk_weights,
|
319
|
+
topk_ids=topk_ids,
|
320
|
+
inplace=inplace,
|
321
|
+
activation=activation,
|
322
|
+
use_fp8_w8a8=True,
|
323
|
+
per_channel_quant=True,
|
324
|
+
w1_scale=(layer.w13_weight_scale),
|
325
|
+
w2_scale=(layer.w2_weight_scale),
|
326
|
+
a1_scale=layer.w13_input_scale,
|
327
|
+
a2_scale=layer.w2_input_scale,
|
328
|
+
no_combine=no_combine,
|
329
|
+
)
|
@@ -230,6 +230,7 @@ class W8A8Int8MoEMethod:
|
|
230
230
|
custom_routing_function: Optional[Callable] = None,
|
231
231
|
correction_bias: Optional[torch.Tensor] = None,
|
232
232
|
activation: str = "silu",
|
233
|
+
apply_router_weight_on_input: bool = False,
|
233
234
|
inplace: bool = True,
|
234
235
|
no_combine: bool = False,
|
235
236
|
) -> torch.Tensor:
|
@@ -257,7 +258,9 @@ class W8A8Int8MoEMethod:
|
|
257
258
|
topk_ids=topk_ids,
|
258
259
|
inplace=inplace,
|
259
260
|
activation=activation,
|
261
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
260
262
|
use_int8_w8a8=True,
|
263
|
+
per_channel_quant=True,
|
261
264
|
w1_scale=(layer.w13_weight_scale),
|
262
265
|
w2_scale=(layer.w2_weight_scale),
|
263
266
|
a1_scale=layer.w13_input_scale,
|