sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -12,22 +12,21 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
+ from __future__ import annotations
16
+
15
17
  import math
16
- from typing import Callable, Optional
18
+ from typing import Callable, NamedTuple, Optional
17
19
 
18
20
  import torch
19
21
  import torch.nn.functional as F
20
22
 
23
+ from sglang.srt.custom_op import CustomOp
21
24
  from sglang.srt.eplb import expert_location_dispatch
22
- from sglang.srt.eplb.expert_distribution import (
23
- ExpertDistributionRecorder,
24
- get_global_expert_distribution_recorder,
25
- )
25
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
26
26
  from sglang.srt.eplb.expert_location_dispatch import (
27
27
  ExpertLocationDispatchInfo,
28
28
  topk_ids_logical_to_physical,
29
29
  )
30
- from sglang.srt.managers.schedule_batch import global_server_args_dict
31
30
  from sglang.srt.utils import (
32
31
  cpu_has_amx_support,
33
32
  get_bool_env_var,
@@ -40,10 +39,10 @@ from sglang.srt.utils import (
40
39
 
41
40
  _is_cuda = is_cuda()
42
41
  _is_hip = is_hip()
43
- _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
44
- _is_cpu_amx_available = cpu_has_amx_support()
45
42
  _is_cpu = is_cpu()
43
+ _is_cpu_amx_available = cpu_has_amx_support()
46
44
  _is_npu = is_npu()
45
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
47
46
 
48
47
  if _is_cuda:
49
48
  from sgl_kernel import moe_fused_gate
@@ -55,6 +54,167 @@ if _use_aiter:
55
54
  from aiter import biased_grouped_topk as aiter_biased_grouped_topk
56
55
  except ImportError:
57
56
  raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
57
+ if _is_npu:
58
+ import torch_npu
59
+
60
+
61
+ class TopKOutput(NamedTuple):
62
+ topk_weights: torch.Tensor
63
+ topk_ids: torch.Tensor
64
+ router_logits: torch.Tensor
65
+
66
+
67
+ class TopK(CustomOp):
68
+
69
+ # TODO(ch-wan): support triton_kernels
70
+
71
+ def __init__(
72
+ self,
73
+ top_k: int,
74
+ *,
75
+ use_grouped_topk: bool = False,
76
+ topk_group: Optional[int] = None,
77
+ num_expert_group: Optional[int] = None,
78
+ renormalize: bool = True,
79
+ num_fused_shared_experts: int = 0,
80
+ custom_routing_function: Optional[Callable] = None,
81
+ scoring_func: str = "softmax",
82
+ correction_bias: Optional[torch.Tensor] = None,
83
+ routed_scaling_factor: Optional[float] = None,
84
+ ):
85
+ # NOTE: scoring_func is not used for now, but we keep it for future use
86
+ # see https://github.com/sgl-project/sglang/pull/4505 for more details
87
+ super().__init__()
88
+ if use_grouped_topk:
89
+ assert num_expert_group is not None and topk_group is not None
90
+ self.top_k = top_k
91
+ self.use_grouped_topk = use_grouped_topk
92
+ self.renormalize = renormalize
93
+ self.topk_group = topk_group
94
+ self.num_expert_group = num_expert_group
95
+ self.num_fused_shared_experts = num_fused_shared_experts
96
+ self.custom_routing_function = custom_routing_function
97
+ self.correction_bias = correction_bias
98
+ self.routed_scaling_factor = routed_scaling_factor
99
+
100
+ def forward_native(
101
+ self,
102
+ hidden_states: torch.Tensor,
103
+ router_logits: torch.Tensor,
104
+ *,
105
+ num_token_non_padded: Optional[torch.Tensor] = None,
106
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
107
+ ) -> TopKOutput:
108
+ torch_native = True
109
+ return select_experts(
110
+ hidden_states=hidden_states,
111
+ router_logits=router_logits,
112
+ top_k=self.top_k,
113
+ use_grouped_topk=self.use_grouped_topk,
114
+ renormalize=self.renormalize,
115
+ topk_group=self.topk_group,
116
+ num_expert_group=self.num_expert_group,
117
+ num_fused_shared_experts=self.num_fused_shared_experts,
118
+ custom_routing_function=self.custom_routing_function,
119
+ correction_bias=self.correction_bias,
120
+ torch_native=torch_native,
121
+ routed_scaling_factor=self.routed_scaling_factor,
122
+ num_token_non_padded=num_token_non_padded,
123
+ expert_location_dispatch_info=expert_location_dispatch_info,
124
+ )
125
+
126
+ def forward_cuda(
127
+ self,
128
+ hidden_states: torch.Tensor,
129
+ router_logits: torch.Tensor,
130
+ *,
131
+ num_token_non_padded: Optional[torch.Tensor] = None,
132
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
133
+ ) -> TopKOutput:
134
+ torch_native = False
135
+ return select_experts(
136
+ hidden_states=hidden_states,
137
+ router_logits=router_logits,
138
+ top_k=self.top_k,
139
+ use_grouped_topk=self.use_grouped_topk,
140
+ renormalize=self.renormalize,
141
+ topk_group=self.topk_group,
142
+ num_expert_group=self.num_expert_group,
143
+ num_fused_shared_experts=self.num_fused_shared_experts,
144
+ custom_routing_function=self.custom_routing_function,
145
+ correction_bias=self.correction_bias,
146
+ torch_native=torch_native,
147
+ routed_scaling_factor=self.routed_scaling_factor,
148
+ num_token_non_padded=num_token_non_padded,
149
+ expert_location_dispatch_info=expert_location_dispatch_info,
150
+ )
151
+
152
+ def forward_cpu(
153
+ self,
154
+ hidden_states: torch.Tensor,
155
+ router_logits: torch.Tensor,
156
+ *,
157
+ num_token_non_padded: Optional[torch.Tensor] = None,
158
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
159
+ ) -> TopKOutput:
160
+ return select_experts(
161
+ hidden_states=hidden_states,
162
+ router_logits=router_logits,
163
+ top_k=self.top_k,
164
+ use_grouped_topk=self.use_grouped_topk,
165
+ renormalize=self.renormalize,
166
+ topk_group=self.topk_group,
167
+ num_expert_group=self.num_expert_group,
168
+ num_fused_shared_experts=self.num_fused_shared_experts,
169
+ custom_routing_function=self.custom_routing_function,
170
+ correction_bias=self.correction_bias,
171
+ routed_scaling_factor=self.routed_scaling_factor,
172
+ num_token_non_padded=num_token_non_padded,
173
+ expert_location_dispatch_info=expert_location_dispatch_info,
174
+ )
175
+
176
+ def forward_npu(
177
+ self,
178
+ hidden_states: torch.Tensor,
179
+ router_logits: torch.Tensor,
180
+ *,
181
+ num_token_non_padded: Optional[torch.Tensor] = None,
182
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
183
+ ) -> TopKOutput:
184
+ global_num_experts = router_logits.shape[-1]
185
+
186
+ # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
187
+ if global_num_experts == 256:
188
+ return torch_npu.npu_moe_gating_top_k(
189
+ router_logits,
190
+ k=self.top_k,
191
+ bias=self.correction_bias,
192
+ k_group=self.topk_group,
193
+ group_count=self.num_expert_group,
194
+ group_select_mode=1,
195
+ renorm=0,
196
+ norm_type=1,
197
+ routed_scaling_factor=1,
198
+ eps=float(1e-20),
199
+ )
200
+ else:
201
+ torch_native = True
202
+ return select_experts(
203
+ hidden_states=hidden_states,
204
+ router_logits=router_logits,
205
+ top_k=self.top_k,
206
+ use_grouped_topk=self.use_grouped_topk,
207
+ renormalize=self.renormalize,
208
+ topk_group=self.topk_group,
209
+ num_expert_group=self.num_expert_group,
210
+ num_fused_shared_experts=self.num_fused_shared_experts,
211
+ custom_routing_function=self.custom_routing_function,
212
+ correction_bias=self.correction_bias,
213
+ torch_native=torch_native,
214
+ routed_scaling_factor=self.routed_scaling_factor,
215
+ num_token_non_padded=num_token_non_padded,
216
+ expert_location_dispatch_info=expert_location_dispatch_info,
217
+ )
58
218
 
59
219
 
60
220
  def fused_topk_torch_native(
@@ -97,6 +257,19 @@ def fused_topk_cpu(
97
257
  return topk_weights, topk_ids
98
258
 
99
259
 
260
+ def apply_topk_weights_cpu(need_apply, topk_weights, inputs):
261
+ if not need_apply:
262
+ return inputs, topk_weights
263
+
264
+ # TODO: fuse below processing in fused_experts_cpu kernel
265
+ inputs = inputs * topk_weights.to(inputs.dtype)
266
+ topk_weights = torch.ones_like(
267
+ topk_weights, dtype=torch.float32
268
+ ) # clear topk_weights as already applied
269
+
270
+ return inputs, topk_weights
271
+
272
+
100
273
  def fused_topk(
101
274
  hidden_states: torch.Tensor,
102
275
  gating_output: torch.Tensor,
@@ -213,6 +386,7 @@ def grouped_topk_cpu(
213
386
  )
214
387
 
215
388
 
389
+ @torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
216
390
  def biased_grouped_topk_impl(
217
391
  hidden_states: torch.Tensor,
218
392
  gating_output: torch.Tensor,
@@ -308,7 +482,6 @@ def biased_grouped_topk_gpu(
308
482
  renormalize: bool,
309
483
  num_expert_group: int = 0,
310
484
  topk_group: int = 0,
311
- compiled: bool = not _is_npu,
312
485
  num_fused_shared_experts: int = 0,
313
486
  routed_scaling_factor: Optional[float] = None,
314
487
  num_token_non_padded: Optional[torch.Tensor] = None,
@@ -325,7 +498,7 @@ def biased_grouped_topk_gpu(
325
498
  and is_power_of_two(correction_bias.shape[0])
326
499
  ):
327
500
  topk_weights, topk_ids = moe_fused_gate(
328
- gating_output,
501
+ gating_output.to(dtype=torch.float32),
329
502
  correction_bias,
330
503
  num_expert_group,
331
504
  topk_group,
@@ -350,7 +523,7 @@ def biased_grouped_topk_gpu(
350
523
  topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
351
524
  topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
352
525
  aiter_biased_grouped_topk(
353
- gating_output,
526
+ gating_output.to(dtype=torch.float32),
354
527
  correction_bias,
355
528
  topk_weights,
356
529
  topk_ids,
@@ -361,14 +534,7 @@ def biased_grouped_topk_gpu(
361
534
  )
362
535
  return topk_weights, topk_ids
363
536
  else:
364
- biased_grouped_topk_fn = (
365
- torch.compile(
366
- biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
367
- )
368
- if compiled
369
- else biased_grouped_topk_impl
370
- )
371
- return biased_grouped_topk_fn(
537
+ return biased_grouped_topk_impl(
372
538
  hidden_states,
373
539
  gating_output,
374
540
  correction_bias,
@@ -427,8 +593,9 @@ def select_experts(
427
593
  hidden_states: torch.Tensor,
428
594
  router_logits: torch.Tensor,
429
595
  top_k: int,
430
- use_grouped_topk: bool,
431
- renormalize: bool,
596
+ *,
597
+ use_grouped_topk: bool = False,
598
+ renormalize: bool = False,
432
599
  topk_group: Optional[int] = None,
433
600
  num_expert_group: Optional[int] = None,
434
601
  num_fused_shared_experts: int = 0,
@@ -438,7 +605,7 @@ def select_experts(
438
605
  routed_scaling_factor: Optional[float] = None,
439
606
  num_token_non_padded: Optional[torch.Tensor] = None,
440
607
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
441
- ):
608
+ ) -> TopKOutput:
442
609
  router_logits, correction_bias = (
443
610
  expert_location_dispatch.transform_select_experts_inputs(
444
611
  router_logits=router_logits,
@@ -513,4 +680,4 @@ def select_experts(
513
680
 
514
681
  get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
515
682
 
516
- return topk_weights, topk_ids
683
+ return TopKOutput(topk_weights, topk_ids, router_logits)
@@ -1,18 +1,14 @@
1
1
  # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
2
+ from __future__ import annotations
3
+
2
4
  import builtins
3
5
  import inspect
4
- import re
5
- from copy import deepcopy
6
- from typing import Callable, Dict, Optional, Type, Union
6
+ from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
7
7
 
8
8
  import torch
9
9
 
10
10
  try:
11
11
  from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
12
- from vllm.model_executor.layers.quantization.awq_marlin import (
13
- AWQMarlinConfig,
14
- AWQMoEMethod,
15
- )
16
12
  from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
17
13
  from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
18
14
  CompressedTensorsW8A8Fp8MoEMethod,
@@ -22,10 +18,6 @@ try:
22
18
  from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
23
19
  from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
24
20
  from vllm.model_executor.layers.quantization.gguf import GGUFConfig
25
- from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
26
- from vllm.model_executor.layers.quantization.gptq_marlin import (
27
- GPTQMarlinLinearMethod,
28
- )
29
21
  from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
30
22
  GPTQMarlin24Config,
31
23
  )
@@ -42,15 +34,14 @@ except ImportError:
42
34
  def override_quantization_method(self, *args, **kwargs):
43
35
  return None
44
36
 
45
- AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = (
46
- DeepSpeedFPConfig
47
- ) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = (
48
- MarlinConfig
49
- ) = QQQConfig = Int8TpuConfig = DummyConfig
37
+ AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
38
+ ExpertsInt8Config
39
+ ) = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = (
40
+ Int8TpuConfig
41
+ ) = DummyConfig
50
42
 
51
43
 
52
- from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
53
- from sglang.srt.layers.quantization.awq import AWQConfig
44
+ from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
54
45
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
55
46
  from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
56
47
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
@@ -59,7 +50,9 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
59
50
  from sglang.srt.layers.quantization.fp8 import Fp8Config
60
51
  from sglang.srt.layers.quantization.gptq import (
61
52
  GPTQConfig,
53
+ GPTQLinearMethod,
62
54
  GPTQMarlinConfig,
55
+ GPTQMarlinLinearMethod,
63
56
  GPTQMarlinMoEMethod,
64
57
  )
65
58
  from sglang.srt.layers.quantization.modelopt_quant import (
@@ -67,11 +60,16 @@ from sglang.srt.layers.quantization.modelopt_quant import (
67
60
  ModelOptFp8Config,
68
61
  )
69
62
  from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
63
+ from sglang.srt.layers.quantization.petit import PetitNvFp4Config
70
64
  from sglang.srt.layers.quantization.qoq import QoQConfig
65
+ from sglang.srt.layers.quantization.utils import get_linear_quant_method
71
66
  from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
72
67
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
73
68
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
74
69
 
70
+ if TYPE_CHECKING:
71
+ from sglang.srt.layers.moe.topk import TopKOutput
72
+
75
73
  # Base quantization methods that don't depend on vllm
76
74
  BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
77
75
  "fp8": Fp8Config,
@@ -84,6 +82,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
84
82
  "compressed-tensors": CompressedTensorsConfig,
85
83
  "qoq": QoQConfig,
86
84
  "w4afp8": W4AFp8Config,
85
+ "petit_nvfp4": PetitNvFp4Config,
87
86
  }
88
87
 
89
88
  # VLLM-dependent quantization methods
@@ -122,99 +121,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
122
121
  return QUANTIZATION_METHODS[quantization]
123
122
 
124
123
 
125
- # Match dynamic rules with module name (prefix) and override quantize
126
- # config if module (prefix) matches a rule
127
- def override_config(config: QuantizationConfig, prefix: str):
128
- weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
129
- if isinstance(weight_bits, int):
130
- config.weight_bits = weight_bits
131
- group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
132
- if isinstance(group_size, int):
133
- config.group_size = group_size
134
- desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
135
- if isinstance(desc_act, bool):
136
- config.desc_act = desc_act
137
-
138
- config.pack_factor = 32 // config.weight_bits # packed into int32
139
- if config.get_name() == "gptq_marlin":
140
- is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
141
- if isinstance(is_sym, bool):
142
- config.is_sym = is_sym
143
-
144
- if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
145
- raise ValueError(
146
- "Unsupported quantization config: "
147
- f"bits={config.weight_bits}, sym={config.is_sym}"
148
- )
149
-
150
- config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
151
- elif config.get_name() == "gptq":
152
- if config.weight_bits not in [2, 3, 4, 8]:
153
- raise ValueError(
154
- "Currently, only 2/3/4/8-bit weight quantization is "
155
- f"supported for GPTQ, but got {config.weight_bits} bits."
156
- )
157
-
158
-
159
- def get_dynamic_override(
160
- config: QuantizationConfig,
161
- layer_name: str,
162
- key: Optional[str] = None,
163
- default_value: Union[int, bool, None] = None,
164
- ) -> Union[Dict, int, bool, None]:
165
- for pattern, pattern_dict in config.dynamic.items():
166
- # Negative match: matched modules are excluded from quantized init
167
- if pattern.startswith("-:"):
168
- if re.match(pattern.removeprefix("-:"), layer_name):
169
- return False
170
- # Positive match: matched modules have quant properties overrides
171
- # base quant config
172
- elif re.match(pattern.removeprefix("+:"), layer_name):
173
- if key is None:
174
- return pattern_dict
175
- else:
176
- return pattern_dict.get(key, default_value)
177
- return default_value
178
-
179
-
180
- def get_linear_quant_method(
181
- config: QuantizationConfig,
182
- layer: torch.nn.Module,
183
- prefix: str,
184
- linear_method_cls: type,
185
- ):
186
- # Move import here to avoid circular import. This is only used in monkey patching
187
- # of vllm's QuantizationConfig.
188
- from sglang.srt.layers.vocab_parallel_embedding import (
189
- ParallelLMHead,
190
- UnquantizedEmbeddingMethod,
191
- )
192
-
193
- cloned_config = deepcopy(config)
194
- parallel_lm_head_quantized = (
195
- isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
196
- )
197
-
198
- if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
199
- # False = skip module, None = no override, else = Positive match
200
- if (
201
- get_dynamic_override( # noqa: E712
202
- cloned_config, layer_name=prefix # noqa: E712
203
- )
204
- == False
205
- ): # noqa: E712
206
- if parallel_lm_head_quantized:
207
- return UnquantizedEmbeddingMethod()
208
- return UnquantizedLinearMethod()
209
-
210
- if prefix:
211
- # Dynamic per module/layer rules may override base config
212
- override_config(cloned_config, prefix=prefix)
213
-
214
- return linear_method_cls(cloned_config)
215
- return None
216
-
217
-
218
124
  def gptq_get_quant_method(self, layer, prefix):
219
125
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
220
126
 
@@ -285,15 +191,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
285
191
  self,
286
192
  layer: torch.nn.Module,
287
193
  x: torch.Tensor,
288
- router_logits: torch.Tensor,
289
- top_k: int,
290
- renormalize: bool,
291
- use_grouped_topk: bool,
292
- topk_group: Optional[int] = None,
293
- num_expert_group: Optional[int] = None,
294
- num_fused_shared_experts: int = 0,
295
- custom_routing_function: Optional[Callable] = None,
296
- correction_bias: Optional[torch.Tensor] = None,
194
+ topk_output: TopKOutput,
195
+ *,
297
196
  activation: str = "silu",
298
197
  apply_router_weight_on_input: bool = False,
299
198
  inplace: bool = True,
@@ -307,20 +206,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
307
206
  "self": self,
308
207
  "layer": layer,
309
208
  "x": x,
310
- "router_logits": router_logits,
311
- "top_k": top_k,
312
- "renormalize": renormalize,
313
- "use_grouped_topk": use_grouped_topk,
314
- "topk_group": topk_group,
315
- "num_expert_group": num_expert_group,
316
- "custom_routing_function": custom_routing_function,
209
+ "topk_output": topk_output,
317
210
  }
318
- if correction_bias is not None:
319
- if not has_correction_bias:
320
- raise ValueError(
321
- "Please increase the version of your vllm. Try `pip install vllm==0.9.0.1`"
322
- )
323
- kwargs["e_score_correction_bias"] = correction_bias
324
211
  return original_apply(**kwargs)
325
212
 
326
213
  setattr(class_obj, "apply", new_apply)
@@ -331,7 +218,6 @@ def monkey_patch_quant_configs():
331
218
  setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
332
219
  setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
333
220
 
334
- monkey_patch_moe_apply(AWQMoEMethod)
335
221
  monkey_patch_moe_apply(GPTQMarlinMoEMethod)
336
222
  monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
337
223
  monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)