sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/lang/chat_template.py +24 -0
  4. sglang/srt/configs/model_config.py +40 -4
  5. sglang/srt/constrained/base_grammar_backend.py +26 -5
  6. sglang/srt/constrained/llguidance_backend.py +1 -0
  7. sglang/srt/constrained/outlines_backend.py +1 -0
  8. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  9. sglang/srt/constrained/xgrammar_backend.py +1 -0
  10. sglang/srt/conversation.py +29 -4
  11. sglang/srt/disaggregation/base/__init__.py +8 -0
  12. sglang/srt/disaggregation/base/conn.py +113 -0
  13. sglang/srt/disaggregation/decode.py +18 -5
  14. sglang/srt/disaggregation/mini_lb.py +53 -122
  15. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  16. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  17. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  18. sglang/srt/disaggregation/prefill.py +43 -19
  19. sglang/srt/disaggregation/utils.py +31 -0
  20. sglang/srt/entrypoints/EngineBase.py +53 -0
  21. sglang/srt/entrypoints/engine.py +36 -8
  22. sglang/srt/entrypoints/http_server.py +37 -8
  23. sglang/srt/entrypoints/http_server_engine.py +142 -0
  24. sglang/srt/entrypoints/verl_engine.py +37 -10
  25. sglang/srt/hf_transformers_utils.py +4 -0
  26. sglang/srt/layers/attention/flashattention_backend.py +609 -202
  27. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  28. sglang/srt/layers/attention/vision.py +1 -1
  29. sglang/srt/layers/dp_attention.py +2 -4
  30. sglang/srt/layers/elementwise.py +15 -2
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  33. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  49. sglang/srt/layers/moe/router.py +7 -1
  50. sglang/srt/layers/moe/topk.py +37 -16
  51. sglang/srt/layers/quantization/__init__.py +13 -5
  52. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  53. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  54. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  55. sglang/srt/layers/quantization/fp8.py +28 -14
  56. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  57. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  58. sglang/srt/layers/quantization/kv_cache.py +43 -52
  59. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  62. sglang/srt/layers/quantization/w8a8_int8.py +3 -0
  63. sglang/srt/layers/radix_attention.py +14 -0
  64. sglang/srt/layers/rotary_embedding.py +75 -1
  65. sglang/srt/managers/io_struct.py +254 -97
  66. sglang/srt/managers/mm_utils.py +3 -2
  67. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  68. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  69. sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
  70. sglang/srt/managers/schedule_batch.py +62 -21
  71. sglang/srt/managers/scheduler.py +71 -14
  72. sglang/srt/managers/tokenizer_manager.py +17 -3
  73. sglang/srt/managers/tp_worker.py +1 -0
  74. sglang/srt/mem_cache/memory_pool.py +14 -1
  75. sglang/srt/metrics/collector.py +9 -0
  76. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  77. sglang/srt/model_executor/forward_batch_info.py +234 -15
  78. sglang/srt/model_executor/model_runner.py +49 -9
  79. sglang/srt/model_loader/loader.py +31 -4
  80. sglang/srt/model_loader/weight_utils.py +4 -2
  81. sglang/srt/models/baichuan.py +2 -0
  82. sglang/srt/models/chatglm.py +1 -0
  83. sglang/srt/models/commandr.py +1 -0
  84. sglang/srt/models/dbrx.py +1 -0
  85. sglang/srt/models/deepseek.py +1 -0
  86. sglang/srt/models/deepseek_v2.py +248 -61
  87. sglang/srt/models/exaone.py +1 -0
  88. sglang/srt/models/gemma.py +1 -0
  89. sglang/srt/models/gemma2.py +1 -0
  90. sglang/srt/models/gemma3_causal.py +1 -0
  91. sglang/srt/models/gpt2.py +1 -0
  92. sglang/srt/models/gpt_bigcode.py +1 -0
  93. sglang/srt/models/granite.py +1 -0
  94. sglang/srt/models/grok.py +1 -0
  95. sglang/srt/models/internlm2.py +1 -0
  96. sglang/srt/models/llama.py +13 -4
  97. sglang/srt/models/llama4.py +487 -0
  98. sglang/srt/models/minicpm.py +1 -0
  99. sglang/srt/models/minicpm3.py +2 -0
  100. sglang/srt/models/mixtral.py +1 -0
  101. sglang/srt/models/mixtral_quant.py +1 -0
  102. sglang/srt/models/mllama.py +51 -8
  103. sglang/srt/models/mllama4.py +227 -0
  104. sglang/srt/models/olmo.py +1 -0
  105. sglang/srt/models/olmo2.py +1 -0
  106. sglang/srt/models/olmoe.py +1 -0
  107. sglang/srt/models/phi3_small.py +1 -0
  108. sglang/srt/models/qwen.py +1 -0
  109. sglang/srt/models/qwen2.py +1 -0
  110. sglang/srt/models/qwen2_5_vl.py +35 -70
  111. sglang/srt/models/qwen2_moe.py +1 -0
  112. sglang/srt/models/qwen2_vl.py +27 -25
  113. sglang/srt/models/stablelm.py +1 -0
  114. sglang/srt/models/xverse.py +1 -0
  115. sglang/srt/models/xverse_moe.py +1 -0
  116. sglang/srt/openai_api/adapter.py +4 -1
  117. sglang/srt/patch_torch.py +11 -0
  118. sglang/srt/server_args.py +34 -0
  119. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  120. sglang/srt/speculative/eagle_utils.py +1 -11
  121. sglang/srt/speculative/eagle_worker.py +6 -2
  122. sglang/srt/utils.py +120 -9
  123. sglang/test/attention/test_flashattn_backend.py +259 -221
  124. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  125. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  126. sglang/test/test_block_fp8.py +57 -0
  127. sglang/test/test_utils.py +19 -8
  128. sglang/version.py +1 -1
  129. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  130. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
  131. sglang/srt/disaggregation/conn.py +0 -81
  132. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  133. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  134. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,7 @@ from sglang.srt.layers.quantization.base_config import (
8
8
  QuantizationConfig,
9
9
  QuantizeMethodBase,
10
10
  )
11
+ from sglang.srt.layers.radix_attention import RadixAttention
11
12
  from sglang.srt.utils import is_hip
12
13
 
13
14
  _is_hip = is_hip()
@@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
17
18
 
18
19
  class BaseKVCacheMethod(QuantizeMethodBase):
19
20
  """
20
- Quant method that adds `_k_scale` and `_v_scale` attributes to the
21
+ Quant method that adds `k_scale` and `v_scale` attributes to the
21
22
  Attention layer to support loading those scaling factors from checkpoints.
22
23
  The k/v_scale will be used to:
23
24
  - quantize k/v_cache entries before saving them to the cache
@@ -36,8 +37,12 @@ class BaseKVCacheMethod(QuantizeMethodBase):
36
37
  # Initialize the KV cache scales to -1.0, which is an invalid value.
37
38
  # If the k/v_scale appears in the checkpoint, it will be
38
39
  # overwritten when loading weights.
39
- layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
40
- layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
40
+ layer.k_scale = torch.nn.Parameter(
41
+ torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
42
+ )
43
+ layer.v_scale = torch.nn.Parameter(
44
+ torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
45
+ )
41
46
 
42
47
  @classmethod
43
48
  def is_fp8_fnuz(cls) -> bool:
@@ -47,52 +52,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
47
52
  def apply(self, layer: torch.nn.Module) -> torch.Tensor:
48
53
  raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
49
54
 
50
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
51
- # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
52
- # regardless whether the kv-scale is available in the checkpoint.
53
- # No need to process kv scales after loading if we are going to
54
- # calculate them on the fly.
55
- if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
56
- if layer.k_scale > 0.0 and layer.v_scale > 0.0:
57
- # We prefer to use separate k_scale and v_scale if present
58
- k_scale = layer.k_scale.to("cpu").tolist()
59
- v_scale = layer.v_scale.to("cpu").tolist()
60
- if _is_hip and self.is_fp8_fnuz():
61
- k_scale *= 2
62
- v_scale *= 2
63
- elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
64
- # If no scales were loaded (both scales are invalid negative
65
- # values), use the default value of 1.0
66
- k_scale = 1.0
67
- v_scale = 1.0
68
- else:
69
- # If we find a single kv_scale in the checkpoint, we remap
70
- # kv_scale to k_scale during weight loading, and duplicate
71
- # k_scale to v_scale here
72
- assert layer.k_scale > 0.0
73
- scale_to_duplicate = max(layer.k_scale, layer.v_scale)
74
- k_scale = scale_to_duplicate.to("cpu").tolist()
75
- v_scale = scale_to_duplicate.to("cpu").tolist()
76
- if _is_hip and self.is_fp8_fnuz():
77
- k_scale *= 2
78
- v_scale *= 2
79
-
80
- if not isinstance(k_scale, float) or not isinstance(v_scale, float):
81
- raise ValueError(
82
- "Only support per-tensor scaling factor " "for fp8 KV cache"
83
- )
84
-
85
- # These are used in the final Attention.forward()
86
- layer._k_scale.copy_(k_scale)
87
- layer._v_scale.copy_(v_scale)
88
- layer._k_scale_float = k_scale
89
- layer._v_scale_float = v_scale
90
- if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
91
- logger.warning(
92
- "Using KV cache scaling factor 1.0 for fp8_e4m3. This "
93
- "may cause accuracy issues. Please make sure k/v_scale "
94
- "scaling factors are available in the fp8 checkpoint."
95
- )
96
-
97
- del layer.k_scale
98
- del layer.v_scale
55
+ def process_weights_after_loading(self, layer: RadixAttention) -> None:
56
+ if layer.k_scale > 0.0 and layer.v_scale > 0.0:
57
+ # We prefer to use separate k_scale and v_scale if present
58
+ k_scale = layer.k_scale.to("cpu").tolist()
59
+ v_scale = layer.v_scale.to("cpu").tolist()
60
+ if _is_hip and self.is_fp8_fnuz():
61
+ k_scale *= 2
62
+ v_scale *= 2
63
+ elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
64
+ # If no scales were loaded (both scales are invalid negative
65
+ # values), use the default value of 1.0
66
+ k_scale = 1.0
67
+ v_scale = 1.0
68
+ else:
69
+ # If we find a single kv_scale in the checkpoint, we remap
70
+ # kv_scale to k_scale during weight loading, and duplicate
71
+ # k_scale to v_scale here
72
+ assert layer.k_scale > 0.0
73
+ scale_to_duplicate = max(layer.k_scale, layer.v_scale)
74
+ k_scale = scale_to_duplicate.to("cpu").tolist()
75
+ v_scale = scale_to_duplicate.to("cpu").tolist()
76
+ if _is_hip and self.is_fp8_fnuz():
77
+ k_scale *= 2
78
+ v_scale *= 2
79
+
80
+ if not isinstance(k_scale, float) or not isinstance(v_scale, float):
81
+ raise ValueError(
82
+ "Only support per-tensor scaling factor " "for fp8 KV cache"
83
+ )
84
+
85
+ # These are used in the final Attention.forward()
86
+ layer.k_scale.copy_(k_scale)
87
+ layer.v_scale.copy_(v_scale)
88
+ layer.k_scale_float = k_scale
89
+ layer.v_scale_float = v_scale
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional
6
6
  import torch
7
7
  from torch.nn.parameter import Parameter
8
8
 
9
- from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
10
9
  from sglang.srt.layers.linear import LinearBase, LinearMethodBase
11
10
  from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
12
11
  from sglang.srt.layers.quantization.base_config import (
@@ -22,6 +21,11 @@ from sglang.srt.layers.quantization.utils import (
22
21
  convert_to_channelwise,
23
22
  requantize_with_max_scale,
24
23
  )
24
+ from sglang.srt.layers.radix_attention import RadixAttention
25
+ from sglang.srt.utils import is_cuda_available
26
+
27
+ if is_cuda_available():
28
+ from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
25
29
 
26
30
  # Initialize logger for the module
27
31
  logger = logging.getLogger(__name__)
@@ -33,12 +37,19 @@ ACTIVATION_SCHEMES = ["static"]
33
37
  class ModelOptFp8Config(QuantizationConfig):
34
38
  """Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
35
39
 
36
- def __init__(self, is_checkpoint_fp8_serialized: bool = False) -> None:
40
+ def __init__(
41
+ self,
42
+ is_checkpoint_fp8_serialized: bool = False,
43
+ kv_cache_quant_method: Optional[str] = None,
44
+ exclude_modules: Optional[List[str]] = None,
45
+ ) -> None:
37
46
  """
38
47
  Args:
39
48
  is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
40
49
  """
41
50
  self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
51
+ self.kv_cache_quant_method = kv_cache_quant_method
52
+ self.exclude_modules = exclude_modules
42
53
  if is_checkpoint_fp8_serialized:
43
54
  logger.warning(
44
55
  "Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
@@ -63,6 +74,12 @@ class ModelOptFp8Config(QuantizationConfig):
63
74
  @classmethod
64
75
  def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
65
76
  quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
77
+ kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
78
+ "kv_cache_quant_algo"
79
+ )
80
+ exclude_modules = cls.get_from_keys(config, ["quantization"]).get(
81
+ "exclude_modules"
82
+ )
66
83
 
67
84
  if "FP8" not in quant_method:
68
85
  raise ValueError(
@@ -70,15 +87,23 @@ class ModelOptFp8Config(QuantizationConfig):
70
87
  "Check the `hf_quant_config.json` file for your model's configuration."
71
88
  )
72
89
 
73
- return cls(is_checkpoint_fp8_serialized=True)
90
+ return cls(
91
+ is_checkpoint_fp8_serialized=True,
92
+ kv_cache_quant_method=kv_cache_quant_method,
93
+ exclude_modules=exclude_modules,
94
+ )
74
95
 
75
96
  def get_quant_method(
76
97
  self, layer: torch.nn.Module, prefix: str
77
98
  ) -> Optional["QuantizeMethodBase"]:
99
+ if self.exclude_modules and any(
100
+ module in prefix for module in self.exclude_modules
101
+ ):
102
+ return None
78
103
 
79
104
  if isinstance(layer, LinearBase):
80
105
  return ModelOptFp8LinearMethod(self)
81
- if isinstance(layer, AttentionBackend):
106
+ if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
82
107
  return ModelOptFp8KVCacheMethod(self)
83
108
 
84
109
  return None
@@ -194,3 +219,245 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
194
219
 
195
220
  def __init__(self, quant_config: ModelOptFp8Config):
196
221
  super().__init__(quant_config)
222
+
223
+
224
+ class ModelOptFp4Config(QuantizationConfig):
225
+ """Config class for FP4."""
226
+
227
+ def __init__(
228
+ self,
229
+ is_checkpoint_nvfp4_serialized: bool = False,
230
+ kv_cache_quant_algo: str = None,
231
+ group_size: int = None,
232
+ exclude_modules: List[str] = None,
233
+ ) -> None:
234
+ self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
235
+ if is_checkpoint_nvfp4_serialized:
236
+ logger.warning(
237
+ "Detected nvfp4 checkpoint. Please note that the "
238
+ "format is experimental and subject to change."
239
+ )
240
+ self.group_size = group_size
241
+ self.kv_cache_quant_algo = kv_cache_quant_algo
242
+ self.exclude_modules = exclude_modules
243
+
244
+ @classmethod
245
+ def get_name(cls) -> str:
246
+ return "modelopt_fp4"
247
+
248
+ @classmethod
249
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
250
+ return [torch.bfloat16, torch.half, torch.float8_e4m3fn]
251
+
252
+ @classmethod
253
+ def get_min_capability(cls) -> int:
254
+ return 100
255
+
256
+ @classmethod
257
+ def get_config_filenames(cls) -> List[str]:
258
+ return ["hf_quant_config.json"]
259
+
260
+ @classmethod
261
+ def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp4Config":
262
+ quant_config = cls.get_from_keys(config, ["quantization"])
263
+ quant_method = quant_config["quant_algo"]
264
+ if not quant_method in ["FP8", "NVFP4"]:
265
+ raise ValueError(
266
+ f"ModelOpt currently only supports: FP8, NVFP4"
267
+ " quantizations in sglang. Please check the "
268
+ "`hf_quant_config.json` file for your model's "
269
+ "quant configuration."
270
+ )
271
+ is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
272
+ kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
273
+ group_size = quant_config["group_size"]
274
+ exclude_modules = quant_config["exclude_modules"]
275
+ if not (group_size and kv_cache_quant_algo and exclude_modules):
276
+ raise ValueError(
277
+ "NVFP4 quantization requires group size and "
278
+ "kv_cache_quant_algo specified in "
279
+ "hf_quant_config.json"
280
+ )
281
+ return cls(
282
+ is_checkpoint_nvfp4_serialized,
283
+ kv_cache_quant_algo,
284
+ group_size,
285
+ exclude_modules,
286
+ )
287
+
288
+ def get_quant_method(
289
+ self, layer: torch.nn.Module, prefix: str
290
+ ) -> Optional["QuantizeMethodBase"]:
291
+ if self.exclude_modules and any(
292
+ module in prefix for module in self.exclude_modules
293
+ ):
294
+ return None
295
+
296
+ if isinstance(layer, LinearBase):
297
+ return ModelOptFp4LinearMethod(self)
298
+ if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
299
+ return ModelOptFp8KVCacheMethod(self)
300
+
301
+ return None
302
+
303
+ def get_scaled_act_names(self) -> List[str]:
304
+ return []
305
+
306
+
307
+ class ModelOptFp4LinearMethod(LinearMethodBase):
308
+ """Linear method for NVFP4.
309
+ Supports loading NVFP4 checkpoints with the following structure:
310
+
311
+ |Tensor Name | datatype | shape |
312
+ |----------------------------------------------------|
313
+ |input_scale | torch.float32 | scalar |
314
+ |weight | NVFP4(SE2M1) | [1, X, y/2] |
315
+ |weight_scale | FP8-E4M3 | [X, Y] |
316
+ |weight_scale_2 | torch.float32 | scalar |
317
+
318
+ The weights are quantized per block of 16 elements.
319
+ Args: quant_config: The ModelOpt quantization config.
320
+ """
321
+
322
+ def __init__(self, quant_config: ModelOptFp4Config):
323
+ self.quant_config = quant_config
324
+
325
+ def create_weights(
326
+ self,
327
+ layer: torch.nn.Module,
328
+ input_size_per_partition: int,
329
+ output_partition_sizes: List[int],
330
+ input_size: int,
331
+ output_size: int,
332
+ params_dtype: torch.dtype,
333
+ **extra_weight_attrs,
334
+ ):
335
+ del input_size, output_size
336
+ if not self.quant_config.is_checkpoint_nvfp4_serialized:
337
+ raise ValueError(
338
+ "NVFP4 quantization was selected, "
339
+ " dynamic quantization is not supported."
340
+ )
341
+
342
+ output_size_per_partition = sum(output_partition_sizes)
343
+ weight_loader = extra_weight_attrs.get("weight_loader")
344
+
345
+ layer.logical_widths = output_partition_sizes
346
+
347
+ layer.input_size_per_partition = input_size_per_partition
348
+ layer.output_size_per_partition = output_size_per_partition
349
+ if input_size_per_partition % 16 != 0:
350
+ raise ValueError(
351
+ "Unsupported model when in features size is " "not multiple of 16"
352
+ )
353
+
354
+ weight_dtype = (
355
+ torch.float8_e4m3fn
356
+ if self.quant_config.is_checkpoint_nvfp4_serialized
357
+ else params_dtype
358
+ )
359
+
360
+ weight = ModelWeightParameter(
361
+ data=torch.empty(
362
+ # 2 fp4 data is packed in one uint8 in the input dimension
363
+ output_size_per_partition,
364
+ input_size_per_partition // 2,
365
+ dtype=torch.uint8,
366
+ ),
367
+ input_dim=1,
368
+ output_dim=0,
369
+ weight_loader=weight_loader,
370
+ )
371
+ layer.register_parameter("weight", weight)
372
+
373
+ input_scale = PerTensorScaleParameter(
374
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
375
+ weight_loader=weight_loader,
376
+ )
377
+
378
+ layer.register_parameter("input_scale", input_scale)
379
+
380
+ weight_scale_2 = PerTensorScaleParameter(
381
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
382
+ weight_loader=weight_loader,
383
+ )
384
+ layer.register_parameter("weight_scale_2", weight_scale_2)
385
+
386
+ weight_scale = ModelWeightParameter(
387
+ data=torch.empty(
388
+ output_size_per_partition,
389
+ input_size_per_partition // self.quant_config.group_size,
390
+ dtype=weight_dtype,
391
+ ),
392
+ input_dim=1,
393
+ output_dim=0,
394
+ weight_loader=weight_loader,
395
+ )
396
+
397
+ layer.register_parameter("weight_scale", weight_scale)
398
+
399
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
400
+ input_scale_2 = layer.input_scale.max().to(torch.float32)
401
+ weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
402
+ layer.input_scale = Parameter(input_scale_2, requires_grad=False)
403
+ layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
404
+ layer.alpha = Parameter(
405
+ layer.input_scale * layer.weight_scale_2, requires_grad=False
406
+ )
407
+
408
+ # Pad and blockwise interleave weight_scale
409
+ scales = layer.weight_scale
410
+ scale_ndim = scales.ndim
411
+ if scale_ndim == 2:
412
+ scales = scales.unsqueeze(0)
413
+ assert scales.ndim == 3
414
+ B, M, K = scales.shape
415
+ round_up_multiple = lambda x, m: (x + m - 1) // m * m
416
+ M_padded = round_up_multiple(M, 128)
417
+ K_padded = round_up_multiple(K, 4)
418
+ padded_scales = torch.zeros((B, M_padded, K_padded), dtype=scales.dtype)
419
+ padded_scales[:B, :M, :K] = scales
420
+ batches, rows, cols = padded_scales.shape
421
+ assert rows % 128 == 0
422
+ assert cols % 4 == 0
423
+ padded_scales = padded_scales.reshape(batches, rows // 128, 4, 32, cols // 4, 4)
424
+ padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5))
425
+ padded_scales = padded_scales.contiguous().cuda()
426
+ padded_scales = (
427
+ padded_scales.reshape(M, K)
428
+ if scale_ndim == 2
429
+ else padded_scales.reshape(B, M, K)
430
+ )
431
+ layer.weight_scale_interleaved = Parameter(padded_scales, requires_grad=False)
432
+
433
+ def apply(
434
+ self,
435
+ layer: torch.nn.Module,
436
+ x: torch.Tensor,
437
+ bias: Optional[torch.Tensor] = None,
438
+ ) -> torch.Tensor:
439
+ output_dtype = x.dtype
440
+ x_m, _ = x.shape
441
+ w_n, _ = layer.weight.shape
442
+ output_shape = [x_m, w_n]
443
+
444
+ # Quantize BF16 or FP16 to (FP4 and interleaved block scale)
445
+ x_fp4, x_scale_interleaved = scaled_fp4_quant(x, 1 / layer.input_scale)
446
+
447
+ assert x_fp4.dtype == torch.uint8
448
+ assert x_scale_interleaved.dtype == torch.float8_e4m3fn
449
+ assert layer.weight.dtype == torch.uint8
450
+ assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn
451
+ assert layer.alpha.dtype == torch.float32
452
+
453
+ out = cutlass_scaled_fp4_mm(
454
+ x_fp4,
455
+ layer.weight,
456
+ x_scale_interleaved,
457
+ layer.weight_scale_interleaved,
458
+ layer.alpha,
459
+ output_dtype,
460
+ )
461
+ if bias is not None:
462
+ out = out + bias
463
+ return out.view(*output_shape)
@@ -344,6 +344,7 @@ class MoeWNA16Method:
344
344
  custom_routing_function: Optional[Callable] = None,
345
345
  correction_bias: Optional[torch.Tensor] = None,
346
346
  activation: str = "silu",
347
+ apply_router_weight_on_input: bool = False,
347
348
  inplace: bool = True,
348
349
  no_combine: bool = False,
349
350
  ) -> torch.Tensor:
@@ -374,6 +375,7 @@ class MoeWNA16Method:
374
375
  topk_weights=topk_weights,
375
376
  topk_ids=topk_ids,
376
377
  inplace=inplace,
378
+ apply_router_weight_on_input=apply_router_weight_on_input,
377
379
  use_int4_w4a16=weight_bits == 4,
378
380
  use_int8_w8a16=weight_bits == 8,
379
381
  w1_scale=layer.w13_scales,
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Optional
1
+ from typing import Any, Callable, Dict, List, Optional
2
2
 
3
3
  import torch
4
4
  from torch.nn.parameter import Parameter
@@ -16,7 +16,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
16
16
  input_to_float8,
17
17
  normalize_e4m3fn_to_e4m3fnuz,
18
18
  )
19
- from sglang.srt.utils import is_hip
19
+ from sglang.srt.utils import is_hip, set_weight_attrs
20
20
 
21
21
  _is_hip = is_hip()
22
22
 
@@ -62,7 +62,9 @@ class W8A8Fp8Config(QuantizationConfig):
62
62
  @classmethod
63
63
  def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
64
64
  quant_method = cls.get_from_keys(config, ["quant_method"])
65
- is_checkpoint_fp8_serialized = "compressed-tensors" in quant_method
65
+ is_checkpoint_fp8_serialized = (
66
+ "compressed-tensors" in quant_method or "w8a8_fp8" in quant_method
67
+ )
66
68
  return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized)
67
69
 
68
70
  def get_quant_method(
@@ -71,9 +73,12 @@ class W8A8Fp8Config(QuantizationConfig):
71
73
  prefix: str,
72
74
  ) -> Optional["QuantizeMethodBase"]:
73
75
  from sglang.srt.layers.linear import LinearBase
76
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
74
77
 
75
78
  if isinstance(layer, LinearBase):
76
79
  return W8A8Fp8LinearMethod(self)
80
+ elif isinstance(layer, FusedMoE):
81
+ return W8A8FP8MoEMethod(self)
77
82
  return None
78
83
 
79
84
  def get_scaled_act_names(self) -> List[str]:
@@ -131,7 +136,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
131
136
  input_size: int,
132
137
  output_size: int,
133
138
  params_dtype: torch.dtype,
134
- **extra_weight_attrs
139
+ **extra_weight_attrs,
135
140
  ):
136
141
  weight_dtype = (
137
142
  torch.float8_e4m3fn
@@ -177,3 +182,148 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
177
182
  bias=bias,
178
183
  cutlass_fp8_supported=self.cutlass_fp8_supported,
179
184
  )
185
+
186
+
187
+ class W8A8FP8MoEMethod:
188
+ """MoE method for FP8.
189
+ Supports loading FP8 checkpoints with static weight scale and
190
+ dynamic/static activation scale.
191
+ Also supports loading quantized FP16/BF16 model checkpoints with dynamic
192
+ activation scaling. The weight scaling factor will be initialized after
193
+ the model weights are loaded.
194
+ Args:
195
+ quant_config: The quantization config.
196
+ """
197
+
198
+ def __new__(cls, *args, **kwargs):
199
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
200
+
201
+ if not hasattr(cls, "_initialized"):
202
+ original_init = cls.__init__
203
+ new_cls = type(
204
+ cls.__name__,
205
+ (FusedMoEMethodBase,),
206
+ {
207
+ "__init__": original_init,
208
+ **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
209
+ },
210
+ )
211
+ obj = super(new_cls, new_cls).__new__(new_cls)
212
+ obj.__init__(*args, **kwargs)
213
+ return obj
214
+ return super().__new__(cls)
215
+
216
+ def __init__(self, quant_config):
217
+ self.quant_config = quant_config
218
+
219
+ def create_weights(
220
+ self,
221
+ layer: torch.nn.Module,
222
+ num_experts: int,
223
+ hidden_size: int,
224
+ intermediate_size: int,
225
+ params_dtype: torch.dtype,
226
+ **extra_weight_attrs,
227
+ ):
228
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
229
+
230
+ fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
231
+ # WEIGHTS
232
+ w13_weight = torch.nn.Parameter(
233
+ torch.empty(
234
+ num_experts, 2 * intermediate_size, hidden_size, dtype=fp8_dtype
235
+ ),
236
+ requires_grad=False,
237
+ )
238
+ layer.register_parameter("w13_weight", w13_weight)
239
+ set_weight_attrs(w13_weight, extra_weight_attrs)
240
+
241
+ w2_weight = torch.nn.Parameter(
242
+ torch.empty(num_experts, hidden_size, intermediate_size, dtype=fp8_dtype),
243
+ requires_grad=False,
244
+ )
245
+ layer.register_parameter("w2_weight", w2_weight)
246
+ set_weight_attrs(w2_weight, extra_weight_attrs)
247
+
248
+ w13_weight_scale = torch.nn.Parameter(
249
+ torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
250
+ requires_grad=False,
251
+ )
252
+ w2_weight_scale = torch.nn.Parameter(
253
+ torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
254
+ requires_grad=False,
255
+ )
256
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
257
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
258
+
259
+ extra_weight_attrs.update(
260
+ {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
261
+ )
262
+
263
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
264
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
265
+
266
+ w13_input_scale = None
267
+ layer.register_parameter("w13_input_scale", w13_input_scale)
268
+
269
+ w2_input_scale = None
270
+ layer.register_parameter("w2_input_scale", w2_input_scale)
271
+
272
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
273
+ layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
274
+ layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
275
+ layer.w13_weight_scale = Parameter(
276
+ layer.w13_weight_scale.data, requires_grad=False
277
+ )
278
+ layer.w2_weight_scale = Parameter(
279
+ layer.w2_weight_scale.data, requires_grad=False
280
+ )
281
+
282
+ def apply(
283
+ self,
284
+ layer: torch.nn.Module,
285
+ x: torch.Tensor,
286
+ router_logits: torch.Tensor,
287
+ top_k: int,
288
+ renormalize: bool,
289
+ use_grouped_topk: bool,
290
+ topk_group: Optional[int] = None,
291
+ num_expert_group: Optional[int] = None,
292
+ custom_routing_function: Optional[Callable] = None,
293
+ correction_bias: Optional[torch.Tensor] = None,
294
+ activation: str = "silu",
295
+ inplace: bool = True,
296
+ no_combine: bool = False,
297
+ ) -> torch.Tensor:
298
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
299
+ from sglang.srt.layers.moe.topk import select_experts
300
+
301
+ # Expert selection
302
+ topk_weights, topk_ids = select_experts(
303
+ hidden_states=x,
304
+ router_logits=router_logits,
305
+ use_grouped_topk=use_grouped_topk,
306
+ top_k=top_k,
307
+ renormalize=renormalize,
308
+ topk_group=topk_group,
309
+ num_expert_group=num_expert_group,
310
+ custom_routing_function=custom_routing_function,
311
+ correction_bias=correction_bias,
312
+ )
313
+
314
+ return fused_experts(
315
+ x,
316
+ layer.w13_weight,
317
+ layer.w2_weight,
318
+ topk_weights=topk_weights,
319
+ topk_ids=topk_ids,
320
+ inplace=inplace,
321
+ activation=activation,
322
+ use_fp8_w8a8=True,
323
+ per_channel_quant=True,
324
+ w1_scale=(layer.w13_weight_scale),
325
+ w2_scale=(layer.w2_weight_scale),
326
+ a1_scale=layer.w13_input_scale,
327
+ a2_scale=layer.w2_input_scale,
328
+ no_combine=no_combine,
329
+ )
@@ -230,6 +230,7 @@ class W8A8Int8MoEMethod:
230
230
  custom_routing_function: Optional[Callable] = None,
231
231
  correction_bias: Optional[torch.Tensor] = None,
232
232
  activation: str = "silu",
233
+ apply_router_weight_on_input: bool = False,
233
234
  inplace: bool = True,
234
235
  no_combine: bool = False,
235
236
  ) -> torch.Tensor:
@@ -257,7 +258,9 @@ class W8A8Int8MoEMethod:
257
258
  topk_ids=topk_ids,
258
259
  inplace=inplace,
259
260
  activation=activation,
261
+ apply_router_weight_on_input=apply_router_weight_on_input,
260
262
  use_int8_w8a8=True,
263
+ per_channel_quant=True,
261
264
  w1_scale=(layer.w13_weight_scale),
262
265
  w2_scale=(layer.w2_weight_scale),
263
266
  a1_scale=layer.w13_input_scale,