sglang 0.4.5__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +330 -200
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +38 -21
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +12 -5
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +25 -13
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +1 -0
- sglang/srt/layers/radix_attention.py +13 -1
- sglang/srt/layers/rotary_embedding.py +12 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +48 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +1 -0
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +120 -106
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional
|
|
6
6
|
import torch
|
7
7
|
from torch.nn.parameter import Parameter
|
8
8
|
|
9
|
-
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
10
9
|
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
11
10
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
12
11
|
from sglang.srt.layers.quantization.base_config import (
|
@@ -22,6 +21,11 @@ from sglang.srt.layers.quantization.utils import (
|
|
22
21
|
convert_to_channelwise,
|
23
22
|
requantize_with_max_scale,
|
24
23
|
)
|
24
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
25
|
+
from sglang.srt.utils import is_cuda_available
|
26
|
+
|
27
|
+
if is_cuda_available():
|
28
|
+
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
25
29
|
|
26
30
|
# Initialize logger for the module
|
27
31
|
logger = logging.getLogger(__name__)
|
@@ -33,12 +37,19 @@ ACTIVATION_SCHEMES = ["static"]
|
|
33
37
|
class ModelOptFp8Config(QuantizationConfig):
|
34
38
|
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
|
35
39
|
|
36
|
-
def __init__(
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
is_checkpoint_fp8_serialized: bool = False,
|
43
|
+
kv_cache_quant_method: Optional[str] = None,
|
44
|
+
exclude_modules: Optional[List[str]] = None,
|
45
|
+
) -> None:
|
37
46
|
"""
|
38
47
|
Args:
|
39
48
|
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
|
40
49
|
"""
|
41
50
|
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
51
|
+
self.kv_cache_quant_method = kv_cache_quant_method
|
52
|
+
self.exclude_modules = exclude_modules
|
42
53
|
if is_checkpoint_fp8_serialized:
|
43
54
|
logger.warning(
|
44
55
|
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
|
@@ -63,6 +74,12 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
63
74
|
@classmethod
|
64
75
|
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
|
65
76
|
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
|
77
|
+
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
|
78
|
+
"kv_cache_quant_algo"
|
79
|
+
)
|
80
|
+
exclude_modules = cls.get_from_keys(config, ["quantization"]).get(
|
81
|
+
"exclude_modules"
|
82
|
+
)
|
66
83
|
|
67
84
|
if "FP8" not in quant_method:
|
68
85
|
raise ValueError(
|
@@ -70,15 +87,23 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
70
87
|
"Check the `hf_quant_config.json` file for your model's configuration."
|
71
88
|
)
|
72
89
|
|
73
|
-
return cls(
|
90
|
+
return cls(
|
91
|
+
is_checkpoint_fp8_serialized=True,
|
92
|
+
kv_cache_quant_method=kv_cache_quant_method,
|
93
|
+
exclude_modules=exclude_modules,
|
94
|
+
)
|
74
95
|
|
75
96
|
def get_quant_method(
|
76
97
|
self, layer: torch.nn.Module, prefix: str
|
77
98
|
) -> Optional["QuantizeMethodBase"]:
|
99
|
+
if self.exclude_modules and any(
|
100
|
+
module in prefix for module in self.exclude_modules
|
101
|
+
):
|
102
|
+
return None
|
78
103
|
|
79
104
|
if isinstance(layer, LinearBase):
|
80
105
|
return ModelOptFp8LinearMethod(self)
|
81
|
-
if isinstance(layer,
|
106
|
+
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
|
82
107
|
return ModelOptFp8KVCacheMethod(self)
|
83
108
|
|
84
109
|
return None
|
@@ -194,3 +219,245 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
|
194
219
|
|
195
220
|
def __init__(self, quant_config: ModelOptFp8Config):
|
196
221
|
super().__init__(quant_config)
|
222
|
+
|
223
|
+
|
224
|
+
class ModelOptFp4Config(QuantizationConfig):
|
225
|
+
"""Config class for FP4."""
|
226
|
+
|
227
|
+
def __init__(
|
228
|
+
self,
|
229
|
+
is_checkpoint_nvfp4_serialized: bool = False,
|
230
|
+
kv_cache_quant_algo: str = None,
|
231
|
+
group_size: int = None,
|
232
|
+
exclude_modules: List[str] = None,
|
233
|
+
) -> None:
|
234
|
+
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
235
|
+
if is_checkpoint_nvfp4_serialized:
|
236
|
+
logger.warning(
|
237
|
+
"Detected nvfp4 checkpoint. Please note that the "
|
238
|
+
"format is experimental and subject to change."
|
239
|
+
)
|
240
|
+
self.group_size = group_size
|
241
|
+
self.kv_cache_quant_algo = kv_cache_quant_algo
|
242
|
+
self.exclude_modules = exclude_modules
|
243
|
+
|
244
|
+
@classmethod
|
245
|
+
def get_name(cls) -> str:
|
246
|
+
return "modelopt_fp4"
|
247
|
+
|
248
|
+
@classmethod
|
249
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
250
|
+
return [torch.bfloat16, torch.half, torch.float8_e4m3fn]
|
251
|
+
|
252
|
+
@classmethod
|
253
|
+
def get_min_capability(cls) -> int:
|
254
|
+
return 100
|
255
|
+
|
256
|
+
@classmethod
|
257
|
+
def get_config_filenames(cls) -> List[str]:
|
258
|
+
return ["hf_quant_config.json"]
|
259
|
+
|
260
|
+
@classmethod
|
261
|
+
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp4Config":
|
262
|
+
quant_config = cls.get_from_keys(config, ["quantization"])
|
263
|
+
quant_method = quant_config["quant_algo"]
|
264
|
+
if not quant_method in ["FP8", "NVFP4"]:
|
265
|
+
raise ValueError(
|
266
|
+
f"ModelOpt currently only supports: FP8, NVFP4"
|
267
|
+
" quantizations in sglang. Please check the "
|
268
|
+
"`hf_quant_config.json` file for your model's "
|
269
|
+
"quant configuration."
|
270
|
+
)
|
271
|
+
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
272
|
+
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
|
273
|
+
group_size = quant_config["group_size"]
|
274
|
+
exclude_modules = quant_config["exclude_modules"]
|
275
|
+
if not (group_size and kv_cache_quant_algo and exclude_modules):
|
276
|
+
raise ValueError(
|
277
|
+
"NVFP4 quantization requires group size and "
|
278
|
+
"kv_cache_quant_algo specified in "
|
279
|
+
"hf_quant_config.json"
|
280
|
+
)
|
281
|
+
return cls(
|
282
|
+
is_checkpoint_nvfp4_serialized,
|
283
|
+
kv_cache_quant_algo,
|
284
|
+
group_size,
|
285
|
+
exclude_modules,
|
286
|
+
)
|
287
|
+
|
288
|
+
def get_quant_method(
|
289
|
+
self, layer: torch.nn.Module, prefix: str
|
290
|
+
) -> Optional["QuantizeMethodBase"]:
|
291
|
+
if self.exclude_modules and any(
|
292
|
+
module in prefix for module in self.exclude_modules
|
293
|
+
):
|
294
|
+
return None
|
295
|
+
|
296
|
+
if isinstance(layer, LinearBase):
|
297
|
+
return ModelOptFp4LinearMethod(self)
|
298
|
+
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
|
299
|
+
return ModelOptFp8KVCacheMethod(self)
|
300
|
+
|
301
|
+
return None
|
302
|
+
|
303
|
+
def get_scaled_act_names(self) -> List[str]:
|
304
|
+
return []
|
305
|
+
|
306
|
+
|
307
|
+
class ModelOptFp4LinearMethod(LinearMethodBase):
|
308
|
+
"""Linear method for NVFP4.
|
309
|
+
Supports loading NVFP4 checkpoints with the following structure:
|
310
|
+
|
311
|
+
|Tensor Name | datatype | shape |
|
312
|
+
|----------------------------------------------------|
|
313
|
+
|input_scale | torch.float32 | scalar |
|
314
|
+
|weight | NVFP4(SE2M1) | [1, X, y/2] |
|
315
|
+
|weight_scale | FP8-E4M3 | [X, Y] |
|
316
|
+
|weight_scale_2 | torch.float32 | scalar |
|
317
|
+
|
318
|
+
The weights are quantized per block of 16 elements.
|
319
|
+
Args: quant_config: The ModelOpt quantization config.
|
320
|
+
"""
|
321
|
+
|
322
|
+
def __init__(self, quant_config: ModelOptFp4Config):
|
323
|
+
self.quant_config = quant_config
|
324
|
+
|
325
|
+
def create_weights(
|
326
|
+
self,
|
327
|
+
layer: torch.nn.Module,
|
328
|
+
input_size_per_partition: int,
|
329
|
+
output_partition_sizes: List[int],
|
330
|
+
input_size: int,
|
331
|
+
output_size: int,
|
332
|
+
params_dtype: torch.dtype,
|
333
|
+
**extra_weight_attrs,
|
334
|
+
):
|
335
|
+
del input_size, output_size
|
336
|
+
if not self.quant_config.is_checkpoint_nvfp4_serialized:
|
337
|
+
raise ValueError(
|
338
|
+
"NVFP4 quantization was selected, "
|
339
|
+
" dynamic quantization is not supported."
|
340
|
+
)
|
341
|
+
|
342
|
+
output_size_per_partition = sum(output_partition_sizes)
|
343
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
344
|
+
|
345
|
+
layer.logical_widths = output_partition_sizes
|
346
|
+
|
347
|
+
layer.input_size_per_partition = input_size_per_partition
|
348
|
+
layer.output_size_per_partition = output_size_per_partition
|
349
|
+
if input_size_per_partition % 16 != 0:
|
350
|
+
raise ValueError(
|
351
|
+
"Unsupported model when in features size is " "not multiple of 16"
|
352
|
+
)
|
353
|
+
|
354
|
+
weight_dtype = (
|
355
|
+
torch.float8_e4m3fn
|
356
|
+
if self.quant_config.is_checkpoint_nvfp4_serialized
|
357
|
+
else params_dtype
|
358
|
+
)
|
359
|
+
|
360
|
+
weight = ModelWeightParameter(
|
361
|
+
data=torch.empty(
|
362
|
+
# 2 fp4 data is packed in one uint8 in the input dimension
|
363
|
+
output_size_per_partition,
|
364
|
+
input_size_per_partition // 2,
|
365
|
+
dtype=torch.uint8,
|
366
|
+
),
|
367
|
+
input_dim=1,
|
368
|
+
output_dim=0,
|
369
|
+
weight_loader=weight_loader,
|
370
|
+
)
|
371
|
+
layer.register_parameter("weight", weight)
|
372
|
+
|
373
|
+
input_scale = PerTensorScaleParameter(
|
374
|
+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
375
|
+
weight_loader=weight_loader,
|
376
|
+
)
|
377
|
+
|
378
|
+
layer.register_parameter("input_scale", input_scale)
|
379
|
+
|
380
|
+
weight_scale_2 = PerTensorScaleParameter(
|
381
|
+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
382
|
+
weight_loader=weight_loader,
|
383
|
+
)
|
384
|
+
layer.register_parameter("weight_scale_2", weight_scale_2)
|
385
|
+
|
386
|
+
weight_scale = ModelWeightParameter(
|
387
|
+
data=torch.empty(
|
388
|
+
output_size_per_partition,
|
389
|
+
input_size_per_partition // self.quant_config.group_size,
|
390
|
+
dtype=weight_dtype,
|
391
|
+
),
|
392
|
+
input_dim=1,
|
393
|
+
output_dim=0,
|
394
|
+
weight_loader=weight_loader,
|
395
|
+
)
|
396
|
+
|
397
|
+
layer.register_parameter("weight_scale", weight_scale)
|
398
|
+
|
399
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
400
|
+
input_scale_2 = layer.input_scale.max().to(torch.float32)
|
401
|
+
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
|
402
|
+
layer.input_scale = Parameter(input_scale_2, requires_grad=False)
|
403
|
+
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
|
404
|
+
layer.alpha = Parameter(
|
405
|
+
layer.input_scale * layer.weight_scale_2, requires_grad=False
|
406
|
+
)
|
407
|
+
|
408
|
+
# Pad and blockwise interleave weight_scale
|
409
|
+
scales = layer.weight_scale
|
410
|
+
scale_ndim = scales.ndim
|
411
|
+
if scale_ndim == 2:
|
412
|
+
scales = scales.unsqueeze(0)
|
413
|
+
assert scales.ndim == 3
|
414
|
+
B, M, K = scales.shape
|
415
|
+
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
416
|
+
M_padded = round_up_multiple(M, 128)
|
417
|
+
K_padded = round_up_multiple(K, 4)
|
418
|
+
padded_scales = torch.zeros((B, M_padded, K_padded), dtype=scales.dtype)
|
419
|
+
padded_scales[:B, :M, :K] = scales
|
420
|
+
batches, rows, cols = padded_scales.shape
|
421
|
+
assert rows % 128 == 0
|
422
|
+
assert cols % 4 == 0
|
423
|
+
padded_scales = padded_scales.reshape(batches, rows // 128, 4, 32, cols // 4, 4)
|
424
|
+
padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5))
|
425
|
+
padded_scales = padded_scales.contiguous().cuda()
|
426
|
+
padded_scales = (
|
427
|
+
padded_scales.reshape(M, K)
|
428
|
+
if scale_ndim == 2
|
429
|
+
else padded_scales.reshape(B, M, K)
|
430
|
+
)
|
431
|
+
layer.weight_scale_interleaved = Parameter(padded_scales, requires_grad=False)
|
432
|
+
|
433
|
+
def apply(
|
434
|
+
self,
|
435
|
+
layer: torch.nn.Module,
|
436
|
+
x: torch.Tensor,
|
437
|
+
bias: Optional[torch.Tensor] = None,
|
438
|
+
) -> torch.Tensor:
|
439
|
+
output_dtype = x.dtype
|
440
|
+
x_m, _ = x.shape
|
441
|
+
w_n, _ = layer.weight.shape
|
442
|
+
output_shape = [x_m, w_n]
|
443
|
+
|
444
|
+
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
445
|
+
x_fp4, x_scale_interleaved = scaled_fp4_quant(x, 1 / layer.input_scale)
|
446
|
+
|
447
|
+
assert x_fp4.dtype == torch.uint8
|
448
|
+
assert x_scale_interleaved.dtype == torch.float8_e4m3fn
|
449
|
+
assert layer.weight.dtype == torch.uint8
|
450
|
+
assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn
|
451
|
+
assert layer.alpha.dtype == torch.float32
|
452
|
+
|
453
|
+
out = cutlass_scaled_fp4_mm(
|
454
|
+
x_fp4,
|
455
|
+
layer.weight,
|
456
|
+
x_scale_interleaved,
|
457
|
+
layer.weight_scale_interleaved,
|
458
|
+
layer.alpha,
|
459
|
+
output_dtype,
|
460
|
+
)
|
461
|
+
if bias is not None:
|
462
|
+
out = out + bias
|
463
|
+
return out.view(*output_shape)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any, Dict, List, Optional
|
1
|
+
from typing import Any, Callable, Dict, List, Optional
|
2
2
|
|
3
3
|
import torch
|
4
4
|
from torch.nn.parameter import Parameter
|
@@ -16,7 +16,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
16
16
|
input_to_float8,
|
17
17
|
normalize_e4m3fn_to_e4m3fnuz,
|
18
18
|
)
|
19
|
-
from sglang.srt.utils import is_hip
|
19
|
+
from sglang.srt.utils import is_hip, set_weight_attrs
|
20
20
|
|
21
21
|
_is_hip = is_hip()
|
22
22
|
|
@@ -62,7 +62,9 @@ class W8A8Fp8Config(QuantizationConfig):
|
|
62
62
|
@classmethod
|
63
63
|
def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
|
64
64
|
quant_method = cls.get_from_keys(config, ["quant_method"])
|
65
|
-
is_checkpoint_fp8_serialized =
|
65
|
+
is_checkpoint_fp8_serialized = (
|
66
|
+
"compressed-tensors" in quant_method or "w8a8_fp8" in quant_method
|
67
|
+
)
|
66
68
|
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized)
|
67
69
|
|
68
70
|
def get_quant_method(
|
@@ -71,9 +73,12 @@ class W8A8Fp8Config(QuantizationConfig):
|
|
71
73
|
prefix: str,
|
72
74
|
) -> Optional["QuantizeMethodBase"]:
|
73
75
|
from sglang.srt.layers.linear import LinearBase
|
76
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
74
77
|
|
75
78
|
if isinstance(layer, LinearBase):
|
76
79
|
return W8A8Fp8LinearMethod(self)
|
80
|
+
elif isinstance(layer, FusedMoE):
|
81
|
+
return W8A8FP8MoEMethod(self)
|
77
82
|
return None
|
78
83
|
|
79
84
|
def get_scaled_act_names(self) -> List[str]:
|
@@ -131,7 +136,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
131
136
|
input_size: int,
|
132
137
|
output_size: int,
|
133
138
|
params_dtype: torch.dtype,
|
134
|
-
**extra_weight_attrs
|
139
|
+
**extra_weight_attrs,
|
135
140
|
):
|
136
141
|
weight_dtype = (
|
137
142
|
torch.float8_e4m3fn
|
@@ -177,3 +182,148 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
177
182
|
bias=bias,
|
178
183
|
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
179
184
|
)
|
185
|
+
|
186
|
+
|
187
|
+
class W8A8FP8MoEMethod:
|
188
|
+
"""MoE method for FP8.
|
189
|
+
Supports loading FP8 checkpoints with static weight scale and
|
190
|
+
dynamic/static activation scale.
|
191
|
+
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
192
|
+
activation scaling. The weight scaling factor will be initialized after
|
193
|
+
the model weights are loaded.
|
194
|
+
Args:
|
195
|
+
quant_config: The quantization config.
|
196
|
+
"""
|
197
|
+
|
198
|
+
def __new__(cls, *args, **kwargs):
|
199
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
200
|
+
|
201
|
+
if not hasattr(cls, "_initialized"):
|
202
|
+
original_init = cls.__init__
|
203
|
+
new_cls = type(
|
204
|
+
cls.__name__,
|
205
|
+
(FusedMoEMethodBase,),
|
206
|
+
{
|
207
|
+
"__init__": original_init,
|
208
|
+
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
209
|
+
},
|
210
|
+
)
|
211
|
+
obj = super(new_cls, new_cls).__new__(new_cls)
|
212
|
+
obj.__init__(*args, **kwargs)
|
213
|
+
return obj
|
214
|
+
return super().__new__(cls)
|
215
|
+
|
216
|
+
def __init__(self, quant_config):
|
217
|
+
self.quant_config = quant_config
|
218
|
+
|
219
|
+
def create_weights(
|
220
|
+
self,
|
221
|
+
layer: torch.nn.Module,
|
222
|
+
num_experts: int,
|
223
|
+
hidden_size: int,
|
224
|
+
intermediate_size: int,
|
225
|
+
params_dtype: torch.dtype,
|
226
|
+
**extra_weight_attrs,
|
227
|
+
):
|
228
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
229
|
+
|
230
|
+
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
231
|
+
# WEIGHTS
|
232
|
+
w13_weight = torch.nn.Parameter(
|
233
|
+
torch.empty(
|
234
|
+
num_experts, 2 * intermediate_size, hidden_size, dtype=fp8_dtype
|
235
|
+
),
|
236
|
+
requires_grad=False,
|
237
|
+
)
|
238
|
+
layer.register_parameter("w13_weight", w13_weight)
|
239
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
240
|
+
|
241
|
+
w2_weight = torch.nn.Parameter(
|
242
|
+
torch.empty(num_experts, hidden_size, intermediate_size, dtype=fp8_dtype),
|
243
|
+
requires_grad=False,
|
244
|
+
)
|
245
|
+
layer.register_parameter("w2_weight", w2_weight)
|
246
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
247
|
+
|
248
|
+
w13_weight_scale = torch.nn.Parameter(
|
249
|
+
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
250
|
+
requires_grad=False,
|
251
|
+
)
|
252
|
+
w2_weight_scale = torch.nn.Parameter(
|
253
|
+
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
|
254
|
+
requires_grad=False,
|
255
|
+
)
|
256
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
257
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
258
|
+
|
259
|
+
extra_weight_attrs.update(
|
260
|
+
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
261
|
+
)
|
262
|
+
|
263
|
+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
264
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
265
|
+
|
266
|
+
w13_input_scale = None
|
267
|
+
layer.register_parameter("w13_input_scale", w13_input_scale)
|
268
|
+
|
269
|
+
w2_input_scale = None
|
270
|
+
layer.register_parameter("w2_input_scale", w2_input_scale)
|
271
|
+
|
272
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
273
|
+
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
274
|
+
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
275
|
+
layer.w13_weight_scale = Parameter(
|
276
|
+
layer.w13_weight_scale.data, requires_grad=False
|
277
|
+
)
|
278
|
+
layer.w2_weight_scale = Parameter(
|
279
|
+
layer.w2_weight_scale.data, requires_grad=False
|
280
|
+
)
|
281
|
+
|
282
|
+
def apply(
|
283
|
+
self,
|
284
|
+
layer: torch.nn.Module,
|
285
|
+
x: torch.Tensor,
|
286
|
+
router_logits: torch.Tensor,
|
287
|
+
top_k: int,
|
288
|
+
renormalize: bool,
|
289
|
+
use_grouped_topk: bool,
|
290
|
+
topk_group: Optional[int] = None,
|
291
|
+
num_expert_group: Optional[int] = None,
|
292
|
+
custom_routing_function: Optional[Callable] = None,
|
293
|
+
correction_bias: Optional[torch.Tensor] = None,
|
294
|
+
activation: str = "silu",
|
295
|
+
inplace: bool = True,
|
296
|
+
no_combine: bool = False,
|
297
|
+
) -> torch.Tensor:
|
298
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
299
|
+
from sglang.srt.layers.moe.topk import select_experts
|
300
|
+
|
301
|
+
# Expert selection
|
302
|
+
topk_weights, topk_ids = select_experts(
|
303
|
+
hidden_states=x,
|
304
|
+
router_logits=router_logits,
|
305
|
+
use_grouped_topk=use_grouped_topk,
|
306
|
+
top_k=top_k,
|
307
|
+
renormalize=renormalize,
|
308
|
+
topk_group=topk_group,
|
309
|
+
num_expert_group=num_expert_group,
|
310
|
+
custom_routing_function=custom_routing_function,
|
311
|
+
correction_bias=correction_bias,
|
312
|
+
)
|
313
|
+
|
314
|
+
return fused_experts(
|
315
|
+
x,
|
316
|
+
layer.w13_weight,
|
317
|
+
layer.w2_weight,
|
318
|
+
topk_weights=topk_weights,
|
319
|
+
topk_ids=topk_ids,
|
320
|
+
inplace=inplace,
|
321
|
+
activation=activation,
|
322
|
+
use_fp8_w8a8=True,
|
323
|
+
per_channel_quant=True,
|
324
|
+
w1_scale=(layer.w13_weight_scale),
|
325
|
+
w2_scale=(layer.w2_weight_scale),
|
326
|
+
a1_scale=layer.w13_input_scale,
|
327
|
+
a2_scale=layer.w2_input_scale,
|
328
|
+
no_combine=no_combine,
|
329
|
+
)
|
@@ -260,6 +260,7 @@ class W8A8Int8MoEMethod:
|
|
260
260
|
activation=activation,
|
261
261
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
262
262
|
use_int8_w8a8=True,
|
263
|
+
per_channel_quant=True,
|
263
264
|
w1_scale=(layer.w13_weight_scale),
|
264
265
|
w2_scale=(layer.w2_weight_scale),
|
265
266
|
a1_scale=layer.w13_input_scale,
|
@@ -13,8 +13,12 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""Radix attention."""
|
15
15
|
|
16
|
+
from typing import Optional
|
17
|
+
|
16
18
|
from torch import nn
|
17
19
|
|
20
|
+
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
21
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
18
22
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
19
23
|
|
20
24
|
|
@@ -34,6 +38,7 @@ class RadixAttention(nn.Module):
|
|
34
38
|
v_head_dim: int = -1,
|
35
39
|
sliding_window_size: int = -1,
|
36
40
|
is_cross_attention: bool = False,
|
41
|
+
quant_config: Optional[QuantizationConfig] = None,
|
37
42
|
prefix: str = "",
|
38
43
|
use_irope: bool = False,
|
39
44
|
):
|
@@ -49,9 +54,16 @@ class RadixAttention(nn.Module):
|
|
49
54
|
self.logit_cap = logit_cap
|
50
55
|
self.sliding_window_size = sliding_window_size or -1
|
51
56
|
self.is_cross_attention = is_cross_attention
|
57
|
+
self.use_irope = use_irope
|
52
58
|
self.k_scale = None
|
53
59
|
self.v_scale = None
|
54
|
-
self.
|
60
|
+
self.k_scale_float = None
|
61
|
+
self.v_scale_float = None
|
62
|
+
self.quant_method = None
|
63
|
+
if quant_config is not None:
|
64
|
+
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
65
|
+
if self.quant_method is not None:
|
66
|
+
self.quant_method.create_weights(self)
|
55
67
|
|
56
68
|
def forward(
|
57
69
|
self,
|
@@ -645,7 +645,18 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
645
645
|
cache = torch.cat((cos, sin), dim=-1)
|
646
646
|
return cache
|
647
647
|
|
648
|
-
def
|
648
|
+
def forward_hip(self, *args, **kwargs):
|
649
|
+
return self.forward_native(*args, **kwargs)
|
650
|
+
|
651
|
+
def forward(self, *args, **kwargs):
|
652
|
+
if torch.compiler.is_compiling():
|
653
|
+
return self.forward_native(*args, **kwargs)
|
654
|
+
if _is_cuda_available:
|
655
|
+
return self.forward_cuda(*args, **kwargs)
|
656
|
+
else:
|
657
|
+
return self.forward_native(*args, **kwargs)
|
658
|
+
|
659
|
+
def forward_native(
|
649
660
|
self,
|
650
661
|
positions: torch.Tensor,
|
651
662
|
query: torch.Tensor,
|