sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -12,22 +12,21 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
+ from __future__ import annotations
16
+
15
17
  import math
16
- from typing import Callable, Optional
18
+ from typing import TYPE_CHECKING, Callable, NamedTuple, Optional
17
19
 
18
20
  import torch
19
21
  import torch.nn.functional as F
20
22
 
23
+ from sglang.srt.custom_op import CustomOp
21
24
  from sglang.srt.eplb import expert_location_dispatch
22
- from sglang.srt.eplb.expert_distribution import (
23
- ExpertDistributionRecorder,
24
- get_global_expert_distribution_recorder,
25
- )
25
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
26
26
  from sglang.srt.eplb.expert_location_dispatch import (
27
27
  ExpertLocationDispatchInfo,
28
28
  topk_ids_logical_to_physical,
29
29
  )
30
- from sglang.srt.managers.schedule_batch import global_server_args_dict
31
30
  from sglang.srt.utils import (
32
31
  cpu_has_amx_support,
33
32
  get_bool_env_var,
@@ -56,6 +55,168 @@ if _use_aiter:
56
55
  except ImportError:
57
56
  raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
58
57
 
58
+ if _is_npu:
59
+ import torch_npu
60
+
61
+
62
+ class TopKOutput(NamedTuple):
63
+ topk_weights: torch.Tensor
64
+ topk_ids: torch.Tensor
65
+ router_logits: torch.Tensor
66
+
67
+
68
+ class TopK(CustomOp):
69
+
70
+ # TODO(ch-wan): support triton_kernels
71
+
72
+ def __init__(
73
+ self,
74
+ top_k: int,
75
+ *,
76
+ use_grouped_topk: bool = False,
77
+ topk_group: Optional[int] = None,
78
+ num_expert_group: Optional[int] = None,
79
+ renormalize: bool = True,
80
+ num_fused_shared_experts: int = 0,
81
+ custom_routing_function: Optional[Callable] = None,
82
+ scoring_func: str = "softmax",
83
+ correction_bias: Optional[torch.Tensor] = None,
84
+ routed_scaling_factor: Optional[float] = None,
85
+ ):
86
+ # NOTE: scoring_func is not used for now, but we keep it for future use
87
+ # see https://github.com/sgl-project/sglang/pull/4505 for more details
88
+ super().__init__()
89
+ if use_grouped_topk:
90
+ assert num_expert_group is not None and topk_group is not None
91
+ self.top_k = top_k
92
+ self.use_grouped_topk = use_grouped_topk
93
+ self.renormalize = renormalize
94
+ self.topk_group = topk_group
95
+ self.num_expert_group = num_expert_group
96
+ self.num_fused_shared_experts = num_fused_shared_experts
97
+ self.custom_routing_function = custom_routing_function
98
+ self.correction_bias = correction_bias
99
+ self.routed_scaling_factor = routed_scaling_factor
100
+
101
+ def forward_native(
102
+ self,
103
+ hidden_states: torch.Tensor,
104
+ router_logits: torch.Tensor,
105
+ *,
106
+ num_token_non_padded: Optional[torch.Tensor] = None,
107
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
108
+ ) -> TopKOutput:
109
+ torch_native = True
110
+ return select_experts(
111
+ hidden_states=hidden_states,
112
+ router_logits=router_logits,
113
+ top_k=self.top_k,
114
+ use_grouped_topk=self.use_grouped_topk,
115
+ renormalize=self.renormalize,
116
+ topk_group=self.topk_group,
117
+ num_expert_group=self.num_expert_group,
118
+ num_fused_shared_experts=self.num_fused_shared_experts,
119
+ custom_routing_function=self.custom_routing_function,
120
+ correction_bias=self.correction_bias,
121
+ torch_native=torch_native,
122
+ routed_scaling_factor=self.routed_scaling_factor,
123
+ num_token_non_padded=num_token_non_padded,
124
+ expert_location_dispatch_info=expert_location_dispatch_info,
125
+ )
126
+
127
+ def forward_cuda(
128
+ self,
129
+ hidden_states: torch.Tensor,
130
+ router_logits: torch.Tensor,
131
+ *,
132
+ num_token_non_padded: Optional[torch.Tensor] = None,
133
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
134
+ ) -> TopKOutput:
135
+ torch_native = False
136
+ return select_experts(
137
+ hidden_states=hidden_states,
138
+ router_logits=router_logits,
139
+ top_k=self.top_k,
140
+ use_grouped_topk=self.use_grouped_topk,
141
+ renormalize=self.renormalize,
142
+ topk_group=self.topk_group,
143
+ num_expert_group=self.num_expert_group,
144
+ num_fused_shared_experts=self.num_fused_shared_experts,
145
+ custom_routing_function=self.custom_routing_function,
146
+ correction_bias=self.correction_bias,
147
+ torch_native=torch_native,
148
+ routed_scaling_factor=self.routed_scaling_factor,
149
+ num_token_non_padded=num_token_non_padded,
150
+ expert_location_dispatch_info=expert_location_dispatch_info,
151
+ )
152
+
153
+ def forward_cpu(
154
+ self,
155
+ hidden_states: torch.Tensor,
156
+ router_logits: torch.Tensor,
157
+ *,
158
+ num_token_non_padded: Optional[torch.Tensor] = None,
159
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
160
+ ) -> TopKOutput:
161
+ return select_experts(
162
+ hidden_states=hidden_states,
163
+ router_logits=router_logits,
164
+ top_k=self.top_k,
165
+ use_grouped_topk=self.use_grouped_topk,
166
+ renormalize=self.renormalize,
167
+ topk_group=self.topk_group,
168
+ num_expert_group=self.num_expert_group,
169
+ num_fused_shared_experts=self.num_fused_shared_experts,
170
+ custom_routing_function=self.custom_routing_function,
171
+ correction_bias=self.correction_bias,
172
+ routed_scaling_factor=self.routed_scaling_factor,
173
+ num_token_non_padded=num_token_non_padded,
174
+ expert_location_dispatch_info=expert_location_dispatch_info,
175
+ )
176
+
177
+ def forward_npu(
178
+ self,
179
+ hidden_states: torch.Tensor,
180
+ router_logits: torch.Tensor,
181
+ *,
182
+ num_token_non_padded: Optional[torch.Tensor] = None,
183
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
184
+ ) -> TopKOutput:
185
+ global_num_experts = router_logits.shape[-1]
186
+
187
+ # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
188
+ if global_num_experts == 256:
189
+ return torch_npu.npu_moe_gating_top_k(
190
+ router_logits,
191
+ k=self.top_k,
192
+ bias=self.correction_bias,
193
+ k_group=self.topk_group,
194
+ group_count=self.num_expert_group,
195
+ group_select_mode=1,
196
+ renorm=0,
197
+ norm_type=1,
198
+ routed_scaling_factor=1,
199
+ eps=float(1e-20),
200
+ )
201
+ else:
202
+ torch_native = True
203
+ return select_experts(
204
+ hidden_states=hidden_states,
205
+ router_logits=router_logits,
206
+ top_k=self.top_k,
207
+ use_grouped_topk=self.use_grouped_topk,
208
+ renormalize=self.renormalize,
209
+ topk_group=self.topk_group,
210
+ num_expert_group=self.num_expert_group,
211
+ num_fused_shared_experts=self.num_fused_shared_experts,
212
+ custom_routing_function=self.custom_routing_function,
213
+ correction_bias=self.correction_bias,
214
+ torch_native=torch_native,
215
+ routed_scaling_factor=self.routed_scaling_factor,
216
+ num_token_non_padded=num_token_non_padded,
217
+ expert_location_dispatch_info=expert_location_dispatch_info,
218
+ )
219
+
59
220
 
60
221
  def fused_topk_torch_native(
61
222
  hidden_states: torch.Tensor,
@@ -97,6 +258,19 @@ def fused_topk_cpu(
97
258
  return topk_weights, topk_ids
98
259
 
99
260
 
261
+ def apply_topk_weights_cpu(need_apply, topk_weights, inputs):
262
+ if not need_apply:
263
+ return inputs, topk_weights
264
+
265
+ # TODO: fuse below processing in fused_experts_cpu kernel
266
+ inputs = inputs * topk_weights.to(inputs.dtype)
267
+ topk_weights = torch.ones_like(
268
+ topk_weights, dtype=torch.float32
269
+ ) # clear topk_weights as already applied
270
+
271
+ return inputs, topk_weights
272
+
273
+
100
274
  def fused_topk(
101
275
  hidden_states: torch.Tensor,
102
276
  gating_output: torch.Tensor,
@@ -325,7 +499,7 @@ def biased_grouped_topk_gpu(
325
499
  and is_power_of_two(correction_bias.shape[0])
326
500
  ):
327
501
  topk_weights, topk_ids = moe_fused_gate(
328
- gating_output,
502
+ gating_output.to(dtype=torch.float32),
329
503
  correction_bias,
330
504
  num_expert_group,
331
505
  topk_group,
@@ -350,7 +524,7 @@ def biased_grouped_topk_gpu(
350
524
  topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
351
525
  topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
352
526
  aiter_biased_grouped_topk(
353
- gating_output,
527
+ gating_output.to(dtype=torch.float32),
354
528
  correction_bias,
355
529
  topk_weights,
356
530
  topk_ids,
@@ -427,8 +601,9 @@ def select_experts(
427
601
  hidden_states: torch.Tensor,
428
602
  router_logits: torch.Tensor,
429
603
  top_k: int,
430
- use_grouped_topk: bool,
431
- renormalize: bool,
604
+ *,
605
+ use_grouped_topk: bool = False,
606
+ renormalize: bool = False,
432
607
  topk_group: Optional[int] = None,
433
608
  num_expert_group: Optional[int] = None,
434
609
  num_fused_shared_experts: int = 0,
@@ -438,7 +613,7 @@ def select_experts(
438
613
  routed_scaling_factor: Optional[float] = None,
439
614
  num_token_non_padded: Optional[torch.Tensor] = None,
440
615
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
441
- ):
616
+ ) -> TopKOutput:
442
617
  router_logits, correction_bias = (
443
618
  expert_location_dispatch.transform_select_experts_inputs(
444
619
  router_logits=router_logits,
@@ -513,4 +688,4 @@ def select_experts(
513
688
 
514
689
  get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
515
690
 
516
- return topk_weights, topk_ids
691
+ return TopKOutput(topk_weights, topk_ids, router_logits)
@@ -1,18 +1,14 @@
1
1
  # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
2
+ from __future__ import annotations
3
+
2
4
  import builtins
3
5
  import inspect
4
- import re
5
- from copy import deepcopy
6
- from typing import Callable, Dict, Optional, Type, Union
6
+ from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
7
7
 
8
8
  import torch
9
9
 
10
10
  try:
11
11
  from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
12
- from vllm.model_executor.layers.quantization.awq_marlin import (
13
- AWQMarlinConfig,
14
- AWQMoEMethod,
15
- )
16
12
  from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
17
13
  from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
18
14
  CompressedTensorsW8A8Fp8MoEMethod,
@@ -22,10 +18,6 @@ try:
22
18
  from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
23
19
  from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
24
20
  from vllm.model_executor.layers.quantization.gguf import GGUFConfig
25
- from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
26
- from vllm.model_executor.layers.quantization.gptq_marlin import (
27
- GPTQMarlinLinearMethod,
28
- )
29
21
  from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
30
22
  GPTQMarlin24Config,
31
23
  )
@@ -42,15 +34,14 @@ except ImportError:
42
34
  def override_quantization_method(self, *args, **kwargs):
43
35
  return None
44
36
 
45
- AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = (
46
- DeepSpeedFPConfig
47
- ) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = (
48
- MarlinConfig
49
- ) = QQQConfig = Int8TpuConfig = DummyConfig
37
+ AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
38
+ ExpertsInt8Config
39
+ ) = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = (
40
+ Int8TpuConfig
41
+ ) = DummyConfig
50
42
 
51
43
 
52
- from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
53
- from sglang.srt.layers.quantization.awq import AWQConfig
44
+ from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
54
45
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
55
46
  from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
56
47
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
@@ -59,7 +50,9 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
59
50
  from sglang.srt.layers.quantization.fp8 import Fp8Config
60
51
  from sglang.srt.layers.quantization.gptq import (
61
52
  GPTQConfig,
53
+ GPTQLinearMethod,
62
54
  GPTQMarlinConfig,
55
+ GPTQMarlinLinearMethod,
63
56
  GPTQMarlinMoEMethod,
64
57
  )
65
58
  from sglang.srt.layers.quantization.modelopt_quant import (
@@ -67,11 +60,16 @@ from sglang.srt.layers.quantization.modelopt_quant import (
67
60
  ModelOptFp8Config,
68
61
  )
69
62
  from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
63
+ from sglang.srt.layers.quantization.petit import PetitNvFp4Config
70
64
  from sglang.srt.layers.quantization.qoq import QoQConfig
65
+ from sglang.srt.layers.quantization.utils import get_linear_quant_method
71
66
  from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
72
67
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
73
68
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
74
69
 
70
+ if TYPE_CHECKING:
71
+ from sglang.srt.layers.moe.topk import TopKOutput
72
+
75
73
  # Base quantization methods that don't depend on vllm
76
74
  BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
77
75
  "fp8": Fp8Config,
@@ -84,6 +82,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
84
82
  "compressed-tensors": CompressedTensorsConfig,
85
83
  "qoq": QoQConfig,
86
84
  "w4afp8": W4AFp8Config,
85
+ "petit_nvfp4": PetitNvFp4Config,
87
86
  }
88
87
 
89
88
  # VLLM-dependent quantization methods
@@ -122,99 +121,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
122
121
  return QUANTIZATION_METHODS[quantization]
123
122
 
124
123
 
125
- # Match dynamic rules with module name (prefix) and override quantize
126
- # config if module (prefix) matches a rule
127
- def override_config(config: QuantizationConfig, prefix: str):
128
- weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
129
- if isinstance(weight_bits, int):
130
- config.weight_bits = weight_bits
131
- group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
132
- if isinstance(group_size, int):
133
- config.group_size = group_size
134
- desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
135
- if isinstance(desc_act, bool):
136
- config.desc_act = desc_act
137
-
138
- config.pack_factor = 32 // config.weight_bits # packed into int32
139
- if config.get_name() == "gptq_marlin":
140
- is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
141
- if isinstance(is_sym, bool):
142
- config.is_sym = is_sym
143
-
144
- if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
145
- raise ValueError(
146
- "Unsupported quantization config: "
147
- f"bits={config.weight_bits}, sym={config.is_sym}"
148
- )
149
-
150
- config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
151
- elif config.get_name() == "gptq":
152
- if config.weight_bits not in [2, 3, 4, 8]:
153
- raise ValueError(
154
- "Currently, only 2/3/4/8-bit weight quantization is "
155
- f"supported for GPTQ, but got {config.weight_bits} bits."
156
- )
157
-
158
-
159
- def get_dynamic_override(
160
- config: QuantizationConfig,
161
- layer_name: str,
162
- key: Optional[str] = None,
163
- default_value: Union[int, bool, None] = None,
164
- ) -> Union[Dict, int, bool, None]:
165
- for pattern, pattern_dict in config.dynamic.items():
166
- # Negative match: matched modules are excluded from quantized init
167
- if pattern.startswith("-:"):
168
- if re.match(pattern.removeprefix("-:"), layer_name):
169
- return False
170
- # Positive match: matched modules have quant properties overrides
171
- # base quant config
172
- elif re.match(pattern.removeprefix("+:"), layer_name):
173
- if key is None:
174
- return pattern_dict
175
- else:
176
- return pattern_dict.get(key, default_value)
177
- return default_value
178
-
179
-
180
- def get_linear_quant_method(
181
- config: QuantizationConfig,
182
- layer: torch.nn.Module,
183
- prefix: str,
184
- linear_method_cls: type,
185
- ):
186
- # Move import here to avoid circular import. This is only used in monkey patching
187
- # of vllm's QuantizationConfig.
188
- from sglang.srt.layers.vocab_parallel_embedding import (
189
- ParallelLMHead,
190
- UnquantizedEmbeddingMethod,
191
- )
192
-
193
- cloned_config = deepcopy(config)
194
- parallel_lm_head_quantized = (
195
- isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
196
- )
197
-
198
- if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
199
- # False = skip module, None = no override, else = Positive match
200
- if (
201
- get_dynamic_override( # noqa: E712
202
- cloned_config, layer_name=prefix # noqa: E712
203
- )
204
- == False
205
- ): # noqa: E712
206
- if parallel_lm_head_quantized:
207
- return UnquantizedEmbeddingMethod()
208
- return UnquantizedLinearMethod()
209
-
210
- if prefix:
211
- # Dynamic per module/layer rules may override base config
212
- override_config(cloned_config, prefix=prefix)
213
-
214
- return linear_method_cls(cloned_config)
215
- return None
216
-
217
-
218
124
  def gptq_get_quant_method(self, layer, prefix):
219
125
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
220
126
 
@@ -285,15 +191,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
285
191
  self,
286
192
  layer: torch.nn.Module,
287
193
  x: torch.Tensor,
288
- router_logits: torch.Tensor,
289
- top_k: int,
290
- renormalize: bool,
291
- use_grouped_topk: bool,
292
- topk_group: Optional[int] = None,
293
- num_expert_group: Optional[int] = None,
294
- num_fused_shared_experts: int = 0,
295
- custom_routing_function: Optional[Callable] = None,
296
- correction_bias: Optional[torch.Tensor] = None,
194
+ topk_output: TopKOutput,
195
+ *,
297
196
  activation: str = "silu",
298
197
  apply_router_weight_on_input: bool = False,
299
198
  inplace: bool = True,
@@ -307,20 +206,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
307
206
  "self": self,
308
207
  "layer": layer,
309
208
  "x": x,
310
- "router_logits": router_logits,
311
- "top_k": top_k,
312
- "renormalize": renormalize,
313
- "use_grouped_topk": use_grouped_topk,
314
- "topk_group": topk_group,
315
- "num_expert_group": num_expert_group,
316
- "custom_routing_function": custom_routing_function,
209
+ "topk_output": topk_output,
317
210
  }
318
- if correction_bias is not None:
319
- if not has_correction_bias:
320
- raise ValueError(
321
- "Please increase the version of your vllm. Try `pip install vllm==0.9.0.1`"
322
- )
323
- kwargs["e_score_correction_bias"] = correction_bias
324
211
  return original_apply(**kwargs)
325
212
 
326
213
  setattr(class_obj, "apply", new_apply)
@@ -331,7 +218,6 @@ def monkey_patch_quant_configs():
331
218
  setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
332
219
  setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
333
220
 
334
- monkey_patch_moe_apply(AWQMoEMethod)
335
221
  monkey_patch_moe_apply(GPTQMarlinMoEMethod)
336
222
  monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
337
223
  monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)