sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
- import os
15
+ import math
16
16
  from typing import Callable, Optional
17
17
 
18
18
  import torch
@@ -25,6 +25,12 @@ from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
25
25
  _is_cuda = is_cuda()
26
26
  _is_hip = is_hip()
27
27
 
28
+ if _is_cuda:
29
+ from sgl_kernel import moe_fused_gate
30
+
31
+ if _is_cuda or _is_hip:
32
+ from sgl_kernel import topk_softmax
33
+
28
34
 
29
35
  expert_distribution_recorder = ExpertDistributionRecorder()
30
36
 
@@ -56,11 +62,6 @@ def fused_topk(
56
62
  topk: int,
57
63
  renormalize: bool,
58
64
  ):
59
- if _is_cuda or _is_hip:
60
- from sgl_kernel import topk_softmax
61
- else:
62
- from vllm import _custom_ops as vllm_ops
63
-
64
65
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
65
66
 
66
67
  M, _ = hidden_states.shape
@@ -73,20 +74,12 @@ def fused_topk(
73
74
  M, topk, dtype=torch.int32, device=hidden_states.device
74
75
  )
75
76
 
76
- if _is_cuda or _is_hip:
77
- topk_softmax(
78
- topk_weights,
79
- topk_ids,
80
- token_expert_indicies,
81
- gating_output.float(),
82
- )
83
- else:
84
- vllm_ops.topk_softmax(
85
- topk_weights,
86
- topk_ids,
87
- token_expert_indicies,
88
- gating_output.float(),
89
- )
77
+ topk_softmax(
78
+ topk_weights,
79
+ topk_ids,
80
+ token_expert_indicies,
81
+ gating_output.float(),
82
+ )
90
83
  del token_expert_indicies
91
84
 
92
85
  if renormalize:
@@ -105,6 +98,7 @@ def grouped_topk(
105
98
  num_expert_group: int = 0,
106
99
  topk_group: int = 0,
107
100
  n_share_experts_fusion: int = 0,
101
+ routed_scaling_factor: Optional[float] = None,
108
102
  ):
109
103
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
110
104
 
@@ -134,9 +128,7 @@ def grouped_topk(
134
128
  dtype=topk_ids.dtype,
135
129
  device=topk_ids.device,
136
130
  )
137
- topk_weights[:, -1] = (
138
- topk_weights[:, :-1].sum(dim=-1) / 2.5
139
- ) # 2.5 is the routed_scaling_factor.
131
+ topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
140
132
 
141
133
  if renormalize:
142
134
  topk_weights_sum = (
@@ -158,6 +150,7 @@ def biased_grouped_topk_impl(
158
150
  num_expert_group: int = 0,
159
151
  topk_group: int = 0,
160
152
  n_share_experts_fusion: int = 0,
153
+ routed_scaling_factor: Optional[float] = None,
161
154
  ):
162
155
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
163
156
 
@@ -194,9 +187,7 @@ def biased_grouped_topk_impl(
194
187
  dtype=topk_ids.dtype,
195
188
  device=topk_ids.device,
196
189
  )
197
- topk_weights[:, -1] = (
198
- topk_weights[:, :-1].sum(dim=-1) / 2.5
199
- ) # 2.5 is the routed_scaling_factor.
190
+ topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
200
191
 
201
192
  if renormalize:
202
193
  topk_weights_sum = (
@@ -209,6 +200,10 @@ def biased_grouped_topk_impl(
209
200
  return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
210
201
 
211
202
 
203
+ def is_power_of_two(n):
204
+ return n > 0 and math.log2(n).is_integer()
205
+
206
+
212
207
  def biased_grouped_topk(
213
208
  hidden_states: torch.Tensor,
214
209
  gating_output: torch.Tensor,
@@ -219,24 +214,46 @@ def biased_grouped_topk(
219
214
  topk_group: int = 0,
220
215
  compiled: bool = True,
221
216
  n_share_experts_fusion: int = 0,
217
+ routed_scaling_factor: Optional[float] = None,
222
218
  ):
223
- biased_grouped_topk_fn = (
224
- torch.compile(
225
- biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
219
+ assert (
220
+ routed_scaling_factor is not None
221
+ ), "routed_scaling_factor is required for biased_grouped_topk"
222
+ # TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
223
+ if (
224
+ _is_cuda
225
+ and gating_output.shape[1] // num_expert_group
226
+ <= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
227
+ and is_power_of_two(correction_bias.shape[0])
228
+ ):
229
+ return moe_fused_gate(
230
+ gating_output,
231
+ correction_bias,
232
+ num_expert_group,
233
+ topk_group,
234
+ topk,
235
+ n_share_experts_fusion,
236
+ routed_scaling_factor,
237
+ )
238
+ else:
239
+ biased_grouped_topk_fn = (
240
+ torch.compile(
241
+ biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
242
+ )
243
+ if compiled
244
+ else biased_grouped_topk_impl
245
+ )
246
+ return biased_grouped_topk_fn(
247
+ hidden_states,
248
+ gating_output,
249
+ correction_bias,
250
+ topk,
251
+ renormalize,
252
+ num_expert_group,
253
+ topk_group,
254
+ n_share_experts_fusion=n_share_experts_fusion,
255
+ routed_scaling_factor=routed_scaling_factor,
226
256
  )
227
- if compiled
228
- else biased_grouped_topk_impl
229
- )
230
- return biased_grouped_topk_fn(
231
- hidden_states,
232
- gating_output,
233
- correction_bias,
234
- topk,
235
- renormalize,
236
- num_expert_group,
237
- topk_group,
238
- n_share_experts_fusion=n_share_experts_fusion,
239
- )
240
257
 
241
258
 
242
259
  def select_experts(
@@ -250,10 +267,9 @@ def select_experts(
250
267
  custom_routing_function: Optional[Callable] = None,
251
268
  correction_bias: Optional[torch.Tensor] = None,
252
269
  torch_native: bool = False,
270
+ routed_scaling_factor: Optional[float] = None,
253
271
  ):
254
- n_share_experts_fusion = 0
255
- if global_server_args_dict["n_share_experts_fusion"] is not None:
256
- n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
272
+ n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
257
273
  # DeekSeek V2/V3/R1 serices models uses grouped_top_k
258
274
  if use_grouped_topk:
259
275
  assert topk_group is not None
@@ -267,6 +283,7 @@ def select_experts(
267
283
  num_expert_group=num_expert_group,
268
284
  topk_group=topk_group,
269
285
  n_share_experts_fusion=n_share_experts_fusion,
286
+ routed_scaling_factor=routed_scaling_factor,
270
287
  )
271
288
  else:
272
289
  topk_weights, topk_ids = biased_grouped_topk(
@@ -278,6 +295,7 @@ def select_experts(
278
295
  num_expert_group=num_expert_group,
279
296
  topk_group=topk_group,
280
297
  n_share_experts_fusion=n_share_experts_fusion,
298
+ routed_scaling_factor=routed_scaling_factor,
281
299
  )
282
300
  elif torch_native and custom_routing_function is None:
283
301
  topk_weights, topk_ids = fused_topk_native(
@@ -7,8 +7,6 @@ from typing import Callable, Optional, Union
7
7
  import torch
8
8
  from torch.nn import Parameter
9
9
 
10
- from sglang.srt.distributed import get_tensor_model_parallel_rank
11
-
12
10
  __all__ = [
13
11
  "BasevLLMParameter",
14
12
  "PackedvLLMParameter",
@@ -59,20 +59,20 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
59
59
  )
60
60
  from sglang.srt.layers.quantization.fp8 import Fp8Config
61
61
  from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
62
- from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
62
+ from sglang.srt.layers.quantization.modelopt_quant import (
63
+ ModelOptFp4Config,
64
+ ModelOptFp8Config,
65
+ )
63
66
  from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
64
67
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
65
68
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
66
- from sglang.srt.layers.vocab_parallel_embedding import (
67
- ParallelLMHead,
68
- UnquantizedEmbeddingMethod,
69
- )
70
69
 
71
70
  # Base quantization methods that don't depend on vllm
72
71
  BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
73
72
  "fp8": Fp8Config,
74
73
  "blockwise_int8": BlockInt8Config,
75
74
  "modelopt": ModelOptFp8Config,
75
+ "modelopt_fp4": ModelOptFp4Config,
76
76
  "w8a8_int8": W8A8Int8Config,
77
77
  "w8a8_fp8": W8A8Fp8Config,
78
78
  "moe_wna16": MoeWNA16Config,
@@ -176,6 +176,13 @@ def get_linear_quant_method(
176
176
  prefix: str,
177
177
  linear_method_cls: type,
178
178
  ):
179
+ # Move import here to avoid circular import. This is only used in monkey patching
180
+ # of vllm's QuantizationConfig.
181
+ from sglang.srt.layers.vocab_parallel_embedding import (
182
+ ParallelLMHead,
183
+ UnquantizedEmbeddingMethod,
184
+ )
185
+
179
186
  cloned_config = deepcopy(config)
180
187
  parallel_lm_head_quantized = (
181
188
  isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
@@ -283,6 +290,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
283
290
  apply_router_weight_on_input: bool = False,
284
291
  inplace: bool = True,
285
292
  no_combine: bool = False,
293
+ routed_scaling_factor: Optional[float] = None,
286
294
  ):
287
295
  assert activation == "silu"
288
296
  assert inplace and not no_combine
@@ -373,6 +373,7 @@ class BlockInt8MoEMethod:
373
373
  apply_router_weight_on_input: bool = False,
374
374
  inplace: bool = True,
375
375
  no_combine: bool = False,
376
+ routed_scaling_factor: Optional[float] = None,
376
377
  ) -> torch.Tensor:
377
378
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
378
379
  from sglang.srt.layers.moe.topk import select_experts
@@ -388,6 +389,7 @@ class BlockInt8MoEMethod:
388
389
  num_expert_group=num_expert_group,
389
390
  custom_routing_function=custom_routing_function,
390
391
  correction_bias=correction_bias,
392
+ routed_scaling_factor=routed_scaling_factor,
391
393
  )
392
394
 
393
395
  # Expert fusion with INT8 quantization
@@ -1,4 +1,4 @@
1
- # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  import logging
@@ -39,7 +39,13 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
39
39
  is_activation_quantization_format,
40
40
  should_ignore_layer,
41
41
  )
42
- from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
42
+
43
+ try:
44
+ import vllm
45
+
46
+ VLLM_AVAILABLE = True
47
+ except ImportError:
48
+ VLLM_AVAILABLE = False
43
49
 
44
50
  logger = logging.getLogger(__name__)
45
51
 
@@ -77,6 +83,7 @@ class CompressedTensorsConfig(QuantizationConfig):
77
83
  sparsity_ignore_list: List[str],
78
84
  kv_cache_scheme: Optional[Dict[str, Any]] = None,
79
85
  config: Optional[Dict[str, Any]] = None,
86
+ packed_modules_mapping: Dict[str, List[str]] = {},
80
87
  ):
81
88
  super().__init__()
82
89
  self.ignore = ignore
@@ -87,6 +94,7 @@ class CompressedTensorsConfig(QuantizationConfig):
87
94
  self.sparsity_scheme_map = sparsity_scheme_map
88
95
  self.sparsity_ignore_list = sparsity_ignore_list
89
96
  self.config = config
97
+ self.packed_modules_mapping = packed_modules_mapping
90
98
 
91
99
  def get_linear_method(self) -> "CompressedTensorsLinearMethod":
92
100
  return CompressedTensorsLinearMethod(self)
@@ -136,6 +144,7 @@ class CompressedTensorsConfig(QuantizationConfig):
136
144
  sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
137
145
  config=config
138
146
  )
147
+ packed_modules_mapping = config.get("packed_modules_mapping", {})
139
148
 
140
149
  return cls(
141
150
  target_scheme_map=target_scheme_map,
@@ -144,6 +153,7 @@ class CompressedTensorsConfig(QuantizationConfig):
144
153
  sparsity_scheme_map=sparsity_scheme_map,
145
154
  sparsity_ignore_list=sparsity_ignore_list,
146
155
  config=config,
156
+ packed_modules_mapping=packed_modules_mapping,
147
157
  )
148
158
 
149
159
  @classmethod
@@ -1,22 +1,16 @@
1
- # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  import enum
5
5
  import logging
6
6
  from enum import Enum
7
- from typing import TYPE_CHECKING, Callable, List, Optional
7
+ from typing import Callable, List, Optional
8
8
 
9
9
  import torch
10
10
  from compressed_tensors import CompressionFormat
11
11
  from compressed_tensors.quantization import QuantizationStrategy
12
12
 
13
- if TYPE_CHECKING:
14
- from sglang.srt.layers.moe.fused_moe_triton import (
15
- FusedMoE,
16
- FusedMoEMethodBase,
17
- FusedMoeWeightScaleSupported,
18
- )
19
-
13
+ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
20
14
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
21
15
  from sglang.srt.layers.quantization.utils import (
22
16
  all_close_1d,
@@ -29,10 +23,9 @@ from sglang.srt.utils import set_weight_attrs
29
23
 
30
24
  _is_cuda = is_cuda()
31
25
 
32
- if _is_cuda:
33
- from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
34
- else:
26
+ if not _is_cuda:
35
27
  from vllm import _custom_ops as vllm_ops
28
+ from vllm._custom_ops import scaled_fp8_quant
36
29
 
37
30
  try:
38
31
  import vllm
@@ -58,8 +51,6 @@ __all__ = [
58
51
 
59
52
  class CompressedTensorsMoEMethod:
60
53
  def __new__(cls, *args, **kwargs):
61
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
62
-
63
54
  if cls is CompressedTensorsMoEMethod:
64
55
  return super().__new__(cls)
65
56
  return super().__new__(cls)
@@ -76,7 +67,7 @@ class CompressedTensorsMoEMethod:
76
67
  if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
77
68
  if not VLLM_AVAILABLE:
78
69
  raise ImportError(
79
- "vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm"
70
+ "vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
80
71
  )
81
72
  return CompressedTensorsWNA16MoEMethod(quant_config)
82
73
  elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
@@ -92,27 +83,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
92
83
  def __init__(
93
84
  self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
94
85
  ):
95
- from sglang.srt.layers.moe.fused_moe_triton import (
96
- FusedMoEMethodBase,
97
- FusedMoeWeightScaleSupported,
98
- )
99
-
100
86
  self.quant_config = quant_config
101
87
  self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
102
88
  self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
103
89
  "input_activations"
104
90
  )
105
91
 
106
- if not (
107
- self.weight_quant.strategy == QuantizationStrategy.TENSOR
108
- and self.input_quant.strategy == QuantizationStrategy.TENSOR
109
- ):
110
- raise ValueError(
111
- "For FP8 Fused MoE layers, only per-tensor scales "
112
- "for weights and activations are supported. Found "
113
- f"{self.weight_quant}, {self.input_quant}"
114
- )
115
-
116
92
  self.static_input_scales = not self.input_quant.dynamic
117
93
 
118
94
  def create_weights(
@@ -154,27 +130,50 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
154
130
  set_weight_attrs(w2_weight, extra_weight_attrs)
155
131
 
156
132
  # WEIGHT_SCALES
157
- # Allocate 2 scales for w1 and w3 respectively.
158
- # They will be combined to a single scale after weight loading.
159
- w13_weight_scale = torch.nn.Parameter(
160
- torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
161
- )
162
- layer.register_parameter("w13_weight_scale", w13_weight_scale)
133
+ # per-tensor quantization
134
+ if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
135
+ # Allocate 2 scales for w1 and w3 respectively.
136
+ # They will be combined to a single scale after weight loading.
137
+ w13_weight_scale = torch.nn.Parameter(
138
+ torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
139
+ )
140
+ w2_weight_scale = torch.nn.Parameter(
141
+ torch.ones(num_experts, dtype=torch.float32), requires_grad=False
142
+ )
143
+ weight_quant_method = FusedMoeWeightScaleSupported.TENSOR.value
144
+ elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
145
+ w13_weight_scale = torch.nn.Parameter(
146
+ torch.ones(
147
+ num_experts,
148
+ 2 * intermediate_size_per_partition,
149
+ 1,
150
+ dtype=torch.float32,
151
+ ),
152
+ requires_grad=False,
153
+ )
154
+ w2_weight_scale = torch.nn.Parameter(
155
+ torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
156
+ requires_grad=False,
157
+ )
158
+ weight_quant_method = FusedMoeWeightScaleSupported.CHANNEL.value
159
+ else:
160
+ raise ValueError(
161
+ f"Unsupported weight quantization strategy: {self.weight_quant.strategy}"
162
+ )
163
163
 
164
- w2_weight_scale = torch.nn.Parameter(
165
- torch.ones(num_experts, dtype=torch.float32), requires_grad=False
166
- )
164
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
167
165
  layer.register_parameter("w2_weight_scale", w2_weight_scale)
168
166
  # Add the quantization method used (per tensor/grouped/channel)
169
167
  # to ensure the weight scales are loaded in properly
170
- extra_weight_attrs.update(
171
- {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
172
- )
168
+ extra_weight_attrs.update({"quant_method": weight_quant_method})
173
169
  set_weight_attrs(w13_weight_scale, extra_weight_attrs)
174
170
  set_weight_attrs(w2_weight_scale, extra_weight_attrs)
175
171
 
176
172
  # INPUT_SCALES
177
173
  if self.static_input_scales:
174
+ assert (
175
+ self.input_quant.strategy == QuantizationStrategy.TENSOR
176
+ ), "Only per-tensor quantization is supported for static input scales"
178
177
  w13_input_scale = torch.nn.Parameter(
179
178
  torch.ones(num_experts, dtype=torch.float32), requires_grad=False
180
179
  )
@@ -241,31 +240,29 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
241
240
  layer.w2_input_scale = torch.nn.Parameter(
242
241
  w2_input_scale, requires_grad=False
243
242
  )
244
-
245
- # Fp8 moe kernel needs single weight scale for w13 per expert.
246
- # We take the max then dequant and requant each expert.
247
- assert layer.w13_weight_scale is not None
248
- shard_size = layer.intermediate_size_per_partition
249
- max_w13_scales = layer.w13_weight_scale.max(dim=1).values
250
- for expert_id in range(layer.local_num_experts):
251
- start = 0
252
- for shard_id in range(2):
253
- dq_weight = per_tensor_dequantize(
254
- layer.w13_weight[expert_id][start : start + shard_size, :],
255
- layer.w13_weight_scale[expert_id][shard_id],
256
- )
257
-
258
- if _is_cuda:
259
- layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
260
- sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
261
- )
262
- else:
263
- layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
264
- vllm_ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
243
+ if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
244
+ # Fp8 moe kernel needs single weight scale for w13 per expert.
245
+ # We take the max then dequant and requant each expert.
246
+ assert layer.w13_weight_scale is not None
247
+ shard_size = layer.intermediate_size_per_partition
248
+ max_w13_scales = layer.w13_weight_scale.max(dim=1).values
249
+ for expert_id in range(layer.local_num_experts):
250
+ start = 0
251
+ for shard_id in range(2):
252
+ dq_weight = per_tensor_dequantize(
253
+ layer.w13_weight[expert_id][start : start + shard_size, :],
254
+ layer.w13_weight_scale[expert_id][shard_id],
265
255
  )
266
- start += shard_size
256
+ (
257
+ layer.w13_weight[expert_id][start : start + shard_size, :],
258
+ _,
259
+ ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
267
260
 
268
- layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
261
+ start += shard_size
262
+
263
+ layer.w13_weight_scale = torch.nn.Parameter(
264
+ max_w13_scales, requires_grad=False
265
+ )
269
266
 
270
267
  def apply(
271
268
  self,
@@ -285,6 +282,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
285
282
  activation: str = "silu",
286
283
  inplace: bool = True,
287
284
  no_combine: bool = False,
285
+ apply_router_weight_on_input: bool = False,
286
+ routed_scaling_factor: Optional[float] = None,
288
287
  ) -> torch.Tensor:
289
288
  from sglang.srt.layers.moe.fused_moe_triton import fused_experts
290
289
  from sglang.srt.layers.moe.topk import select_experts
@@ -299,6 +298,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
299
298
  num_expert_group=num_expert_group,
300
299
  custom_routing_function=custom_routing_function,
301
300
  correction_bias=correction_bias,
301
+ routed_scaling_factor=routed_scaling_factor,
302
302
  )
303
303
 
304
304
  return fused_experts(
@@ -310,10 +310,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
310
310
  inplace=inplace,
311
311
  activation=activation,
312
312
  use_fp8_w8a8=True,
313
+ per_channel_quant=self.weight_quant.strategy
314
+ == QuantizationStrategy.CHANNEL,
313
315
  w1_scale=layer.w13_weight_scale,
314
316
  w2_scale=layer.w2_weight_scale,
315
317
  a1_scale=layer.w13_input_scale,
316
318
  a2_scale=layer.w2_input_scale,
319
+ apply_router_weight_on_input=apply_router_weight_on_input,
317
320
  )
318
321
 
319
322
 
@@ -322,11 +325,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
322
325
  def __init__(
323
326
  self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
324
327
  ):
325
- from sglang.srt.layers.moe.fused_moe_triton import (
326
- FusedMoEMethodBase,
327
- FusedMoeWeightScaleSupported,
328
- )
329
-
330
328
  self.quant_config = quant_config
331
329
  # TODO: @dsikka: refactor this to use schemes as other kernels
332
330
  # are supported + check if the layer is being ignored.
@@ -586,7 +584,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
586
584
  requires_grad=False,
587
585
  )
588
586
 
589
- marlin_w13_qweight = ops.gptq_marlin_moe_repack(
587
+ marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
590
588
  layer.w13_weight_packed,
591
589
  layer.w13_g_idx_sort_indices,
592
590
  layer.w13_weight_packed.shape[1] * self.packed_factor,
@@ -594,7 +592,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
594
592
  self.num_bits,
595
593
  )
596
594
  replace_tensor("w13_weight_packed", marlin_w13_qweight)
597
- marlin_w2_qweight = ops.gptq_marlin_moe_repack(
595
+ marlin_w2_qweight = vllm_ops.gptq_marlin_moe_repack(
598
596
  layer.w2_weight_packed,
599
597
  layer.w2_g_idx_sort_indices,
600
598
  layer.w2_weight_packed.shape[1] * self.packed_factor,
@@ -637,15 +635,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
637
635
  scoring_func: str = "softmax",
638
636
  correction_bias: Optional[torch.Tensor] = None,
639
637
  activation: str = "silu",
638
+ routed_scaling_factor: Optional[float] = None,
640
639
  ) -> torch.Tensor:
641
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
642
640
  from sglang.srt.layers.moe.topk import select_experts
643
641
 
644
642
  assert activation == "silu", "Only SiLU activation is supported."
645
- if not VLLM_AVAILABLE:
646
- raise ImportError(
647
- "vllm is not installed, to use fused_marlin_moe, please install vllm"
648
- )
649
643
  if expert_map is not None:
650
644
  raise NotImplementedError(
651
645
  "Expert Parallelism is not supported for " "fused Marlin MoE method."
@@ -662,6 +656,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
662
656
  custom_routing_function=custom_routing_function,
663
657
  scoring_func=scoring_func,
664
658
  correction_bias=correction_bias,
659
+ routed_scaling_factor=routed_scaling_factor,
665
660
  )
666
661
 
667
662
  return torch.ops.vllm.fused_marlin_moe(
@@ -16,8 +16,7 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
16
16
  CompressedTensorsScheme,
17
17
  )
18
18
  from sglang.srt.layers.quantization.fp8_utils import (
19
- Fp8LinearOp,
20
- maybe_create_device_identity,
19
+ apply_fp8_linear,
21
20
  normalize_e4m3fn_to_e4m3fnuz,
22
21
  )
23
22
  from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
@@ -30,7 +29,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
30
29
  def __init__(self, strategy: str, is_static_input_scheme: bool):
31
30
  self.strategy = strategy
32
31
  self.is_static_input_scheme = is_static_input_scheme
33
- self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
34
32
 
35
33
  @classmethod
36
34
  def get_min_capability(cls) -> int:
@@ -99,8 +97,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
99
97
  weight_loader: Callable,
100
98
  **kwargs,
101
99
  ):
102
- maybe_create_device_identity()
103
-
104
100
  output_size_per_partition = sum(output_partition_sizes)
105
101
  layer.logical_widths = output_partition_sizes
106
102
 
@@ -152,11 +148,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
152
148
  x: torch.Tensor,
153
149
  bias: Optional[torch.Tensor] = None,
154
150
  ) -> torch.Tensor:
155
-
156
- return self.fp8_linear.apply(
151
+ return apply_fp8_linear(
157
152
  input=x,
158
153
  weight=layer.weight,
159
154
  weight_scale=layer.weight_scale,
160
155
  input_scale=layer.input_scale,
161
156
  bias=bias,
157
+ use_per_token_if_dynamic=True,
158
+ compressed_tensor_quant=True,
162
159
  )