sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 256,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 256,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 4
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 32,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 64,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -147,6 +147,7 @@ class FusedMoE(torch.nn.Module):
147
147
 
148
148
  self.layer_id = layer_id
149
149
  self.top_k = top_k
150
+ self.hidden_size = hidden_size
150
151
  self.num_experts = num_experts
151
152
  self.num_fused_shared_experts = num_fused_shared_experts
152
153
  self.expert_map_cpu = None
@@ -209,13 +210,13 @@ class FusedMoE(torch.nn.Module):
209
210
  self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
210
211
  "enable_flashinfer_mxfp4_moe", False
211
212
  )
213
+ # TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
212
214
  if (
213
215
  self.quant_config is not None
214
216
  and self.quant_config.get_name() == "mxfp4"
215
217
  and self.use_enable_flashinfer_mxfp4_moe
216
218
  ):
217
219
  hidden_size = round_up(hidden_size, 256)
218
- self.hidden_size = hidden_size
219
220
  self.quant_method.create_weights(
220
221
  layer=self,
221
222
  num_experts=self.num_local_experts,
@@ -795,13 +796,6 @@ class FusedMoE(torch.nn.Module):
795
796
 
796
797
  def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
797
798
  origin_hidden_states_dim = hidden_states.shape[-1]
798
- if self.hidden_size != origin_hidden_states_dim:
799
- hidden_states = torch.nn.functional.pad(
800
- hidden_states,
801
- (0, self.hidden_size - origin_hidden_states_dim),
802
- mode="constant",
803
- value=0.0,
804
- )
805
799
  assert self.quant_method is not None
806
800
 
807
801
  if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
@@ -846,10 +840,14 @@ class FusedMoE(torch.nn.Module):
846
840
  )
847
841
  sm.tag(final_hidden_states)
848
842
 
843
+ final_hidden_states = final_hidden_states[
844
+ ..., :origin_hidden_states_dim
845
+ ].contiguous()
846
+
849
847
  if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
850
848
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
851
849
 
852
- return final_hidden_states[..., :origin_hidden_states_dim].contiguous()
850
+ return final_hidden_states
853
851
 
854
852
  @classmethod
855
853
  def make_expert_params_mapping(
@@ -23,14 +23,23 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
23
23
  from sglang.srt.layers.moe.utils import DeepEPMode
24
24
  from sglang.srt.layers.quantization import deep_gemm_wrapper
25
25
  from sglang.srt.managers.schedule_batch import global_server_args_dict
26
- from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config
26
+ from sglang.srt.utils import (
27
+ get_bool_env_var,
28
+ get_int_env_var,
29
+ is_hip,
30
+ is_npu,
31
+ load_json_config,
32
+ )
33
+
34
+ _is_npu = is_npu()
27
35
 
28
36
  try:
29
37
  from deep_ep import Buffer, Config
30
38
 
31
- from sglang.srt.layers.quantization.fp8_kernel import (
32
- sglang_per_token_group_quant_fp8,
33
- )
39
+ if not _is_npu:
40
+ from sglang.srt.layers.quantization.fp8_kernel import (
41
+ sglang_per_token_group_quant_fp8,
42
+ )
34
43
 
35
44
  use_deepep = True
36
45
  except ImportError:
@@ -80,8 +89,24 @@ class DeepEPLLOutput(NamedTuple):
80
89
  return DispatchOutputFormat.deepep_ll
81
90
 
82
91
 
92
+ class AscendDeepEPLLOutput(NamedTuple):
93
+ """AscendDeepEP low latency dispatch output."""
94
+
95
+ hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
96
+ topk_idx: torch.Tensor
97
+ topk_weights: torch.Tensor
98
+ masked_m: torch.Tensor
99
+ seg_indptr: torch.Tensor
100
+ expected_m: int
101
+
102
+ @property
103
+ def format(self) -> DispatchOutputFormat:
104
+ return DispatchOutputFormat.deepep_ll
105
+
106
+
83
107
  assert isinstance(DeepEPNormalOutput, DispatchOutput)
84
108
  assert isinstance(DeepEPLLOutput, DispatchOutput)
109
+ assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
85
110
 
86
111
 
87
112
  class DeepEPDispatchMode(IntEnum):
@@ -150,19 +175,20 @@ class DeepEPBuffer:
150
175
  else:
151
176
  raise NotImplementedError
152
177
 
153
- total_num_sms = torch.cuda.get_device_properties(
154
- device="cuda"
155
- ).multi_processor_count
156
- if (
157
- (deepep_mode != DeepEPMode.LOW_LATENCY)
158
- and not global_server_args_dict["enable_two_batch_overlap"]
159
- and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
160
- ):
161
- logger.warning(
162
- f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
163
- f"This may result in highly suboptimal performance. "
164
- f"Consider using --deepep-config to change the behavior."
165
- )
178
+ if not _is_npu:
179
+ total_num_sms = torch.cuda.get_device_properties(
180
+ device="cuda"
181
+ ).multi_processor_count
182
+ if (
183
+ (deepep_mode != DeepEPMode.LOW_LATENCY)
184
+ and not global_server_args_dict["enable_two_batch_overlap"]
185
+ and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
186
+ ):
187
+ logger.warning(
188
+ f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
189
+ f"This may result in highly suboptimal performance. "
190
+ f"Consider using --deepep-config to change the behavior."
191
+ )
166
192
 
167
193
  cls._buffer = Buffer(
168
194
  group,
@@ -507,13 +533,24 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
507
533
  masked_m
508
534
  )
509
535
 
510
- return DeepEPLLOutput(
511
- hidden_states,
512
- topk_idx,
513
- topk_weights,
514
- masked_m,
515
- expected_m,
516
- )
536
+ if _is_npu:
537
+ deepep_output = AscendDeepEPLLOutput(
538
+ hidden_states,
539
+ topk_idx,
540
+ topk_weights,
541
+ masked_m,
542
+ self.handle[1],
543
+ expected_m,
544
+ )
545
+ else:
546
+ deepep_output = DeepEPLLOutput(
547
+ hidden_states,
548
+ topk_idx,
549
+ topk_weights,
550
+ masked_m,
551
+ expected_m,
552
+ )
553
+ return deepep_output
517
554
 
518
555
  def _dispatch_core(
519
556
  self,
@@ -245,10 +245,11 @@ class TopK(CustomOp):
245
245
 
246
246
  # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
247
247
  if global_num_experts == 256:
248
+ router_logits = router_logits.to(torch.float32)
248
249
  return torch_npu.npu_moe_gating_top_k(
249
250
  router_logits,
250
251
  k=self.top_k,
251
- bias=self.correction_bias,
252
+ bias=self.correction_bias.to(torch.float32),
252
253
  k_group=self.topk_group,
253
254
  group_count=self.num_expert_group,
254
255
  group_select_mode=1,
@@ -440,7 +441,9 @@ def grouped_topk_cpu(
440
441
  routed_scaling_factor: Optional[float] = None,
441
442
  num_token_non_padded: Optional[torch.Tensor] = None,
442
443
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
444
+ apply_routed_scaling_factor_on_output: Optional[bool] = False,
443
445
  ):
446
+ assert not apply_routed_scaling_factor_on_output
444
447
  assert expert_location_dispatch_info is None
445
448
  return torch.ops.sgl_kernel.grouped_topk_cpu(
446
449
  hidden_states,
@@ -17,57 +17,173 @@ import torch
17
17
  import triton
18
18
  import triton.language as tl
19
19
 
20
+ FMIX32_C1 = 0x85EBCA6B
21
+ FMIX32_C2 = 0xC2B2AE35
22
+ POS_C1 = 0x27D4EB2D
23
+ POS_C2 = 0x165667B1
24
+
25
+
26
+ @triton.jit
27
+ def _rotl32(x, r: tl.constexpr):
28
+ return (x << r) | (x >> (32 - r))
29
+
30
+
31
+ @triton.jit
32
+ def _fmix32(x, C1: tl.constexpr, C2: tl.constexpr):
33
+ c1 = tl.full((), C1, tl.uint32)
34
+ c2 = tl.full((), C2, tl.uint32)
35
+ x ^= x >> 16
36
+ x = x * c1
37
+ x ^= x >> 13
38
+ x = x * c2
39
+ x ^= x >> 16
40
+ return x
41
+
20
42
 
21
43
  @triton.jit
22
- def hash_kernel(
23
- input_ptr,
24
- output_ptr,
25
- n_elements,
26
- BLOCK_SIZE: tl.constexpr,
27
- PRIME: tl.constexpr,
28
- XCONST: tl.constexpr,
44
+ def hash_tiles32_kernel_blocked(
45
+ in_ptr,
46
+ out_ptr,
47
+ n_u32,
48
+ seed1,
49
+ seed2,
50
+ FM_C1: tl.constexpr,
51
+ FM_C2: tl.constexpr,
52
+ POS_A: tl.constexpr,
53
+ POS_B: tl.constexpr,
54
+ TILE: tl.constexpr,
55
+ BLOCK: tl.constexpr,
56
+ USE_CG: tl.constexpr,
29
57
  ):
30
58
  pid = tl.program_id(axis=0)
31
- block_start = pid * BLOCK_SIZE
32
- offsets = block_start + tl.arange(0, BLOCK_SIZE)
33
- mask = offsets < n_elements
59
+ base = pid * TILE
60
+
61
+ s1 = tl.full((), seed1, tl.uint32)
62
+ s2 = tl.full((), seed2, tl.uint32)
63
+ posA = tl.full((), POS_A, tl.uint32)
64
+ posB = tl.full((), POS_B, tl.uint32)
65
+
66
+ h1 = tl.zeros((), dtype=tl.uint32)
67
+ h2 = tl.zeros((), dtype=tl.uint32)
68
+
69
+ for off in tl.static_range(0, TILE, BLOCK):
70
+ idx = base + off + tl.arange(0, BLOCK)
71
+ m = idx < n_u32
34
72
 
35
- data = tl.load(input_ptr + offsets, mask=mask, other=0).to(tl.int64)
36
- mixed = data ^ (offsets.to(tl.int64) + XCONST)
37
- hash_val = mixed * PRIME
38
- hash_val = hash_val ^ (hash_val >> 16)
39
- hash_val = hash_val * (PRIME ^ XCONST)
40
- hash_val = hash_val ^ (hash_val >> 13)
73
+ if USE_CG:
74
+ v = tl.load(in_ptr + idx, mask=m, other=0, cache_modifier=".cg")
75
+ else:
76
+ v = tl.load(in_ptr + idx, mask=m, other=0)
77
+ v = v.to(tl.uint32)
78
+
79
+ iu = idx.to(tl.uint32)
80
+ p1 = (iu * posA + s1) ^ _rotl32(iu, 15)
81
+ p2 = (iu * posB + s2) ^ _rotl32(iu, 13)
82
+
83
+ k1 = _fmix32(v ^ p1, C1=FM_C1, C2=FM_C2)
84
+ k2 = _fmix32(v ^ p2, C1=FM_C1, C2=FM_C2)
85
+
86
+ zero32 = tl.zeros_like(k1)
87
+ k1 = tl.where(m, k1, zero32)
88
+ k2 = tl.where(m, k2, zero32)
89
+
90
+ h1 += tl.sum(k1, axis=0).to(tl.uint32)
91
+ h2 += tl.sum(k2, axis=0).to(tl.uint32)
92
+
93
+ nbytes = tl.full((), n_u32 * 4, tl.uint32)
94
+ h1 ^= nbytes
95
+ h2 ^= nbytes
96
+ h1 = _fmix32(h1, C1=FM_C1, C2=FM_C2)
97
+ h2 = (
98
+ _fmix32(h2, C1=FMIX32_C1, C2=FMIX32_C2)
99
+ if False
100
+ else _fmix32(h2, C1=FM_C1, C2=FM_C2)
101
+ )
102
+
103
+ out = (h1.to(tl.uint64) << 32) | h2.to(tl.uint64)
104
+ tl.store(out_ptr + pid, out)
105
+
106
+
107
+ @triton.jit
108
+ def add_tree_reduce_u64_kernel(in_ptr, out_ptr, n_elems, CHUNK: tl.constexpr):
109
+ pid = tl.program_id(axis=0)
110
+ start = pid * CHUNK
111
+ h = tl.zeros((), dtype=tl.uint64)
112
+ for i in tl.static_range(0, CHUNK):
113
+ idx = start + i
114
+ m = idx < n_elems
115
+ v = tl.load(in_ptr + idx, mask=m, other=0).to(tl.uint64)
116
+ h += v
117
+ tl.store(out_ptr + pid, h)
41
118
 
42
- tl.store(output_ptr + offsets, hash_val, mask=mask)
43
119
 
120
+ def _as_uint32_words(t: torch.Tensor) -> torch.Tensor:
121
+ assert t.is_cuda, "Use .cuda() first"
122
+ tb = t.contiguous().view(torch.uint8)
123
+ nbytes = tb.numel()
124
+ pad = (4 - (nbytes & 3)) & 3
125
+ if pad:
126
+ tb_p = torch.empty(nbytes + pad, dtype=torch.uint8, device=tb.device)
127
+ tb_p[:nbytes].copy_(tb)
128
+ tb_p[nbytes:].zero_()
129
+ tb = tb_p
130
+ return tb.view(torch.uint32)
44
131
 
45
- PRIME_1 = -(11400714785074694791 ^ 0xFFFFFFFFFFFFFFFF) - 1
46
- PRIME_2 = -(14029467366897019727 ^ 0xFFFFFFFFFFFFFFFF) - 1
47
132
 
133
+ def _final_splitmix64(x: int) -> int:
134
+ mask = (1 << 64) - 1
135
+ x &= mask
136
+ x ^= x >> 30
137
+ x = (x * 0xBF58476D1CE4E5B9) & mask
138
+ x ^= x >> 27
139
+ x = (x * 0x94D049BB133111EB) & mask
140
+ x ^= x >> 31
141
+ return x
48
142
 
49
- def gpu_tensor_hash(tensor: torch.Tensor) -> int:
50
- assert tensor.is_cuda
51
- tensor = tensor.contiguous().view(torch.int32)
52
- n = tensor.numel()
53
- BLOCK_SIZE = 1024
54
- grid = (triton.cdiv(n, BLOCK_SIZE),)
55
143
 
56
- intermediate_hashes = torch.empty(n, dtype=torch.int64, device=tensor.device)
144
+ @torch.inference_mode()
145
+ def gpu_tensor_hash(
146
+ tensor: torch.Tensor,
147
+ *,
148
+ seed: int = 0x243F6A88,
149
+ tile_words: int = 8192,
150
+ block_words: int = 256,
151
+ reduce_chunk: int = 1024,
152
+ num_warps: int = 4,
153
+ num_stages: int = 4,
154
+ use_cg: bool = True,
155
+ ) -> int:
156
+ assert tensor.is_cuda, "Use .cuda() first"
157
+ u32 = _as_uint32_words(tensor)
158
+ n = u32.numel()
159
+ if n == 0:
160
+ return 0
57
161
 
58
- # Set cuda device to prevent ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
59
- # Solution from Tri: https://github.com/Dao-AILab/flash-attention/issues/523#issuecomment-1707611579
60
- with torch.cuda.device(tensor.device):
61
- hash_kernel[grid](
62
- tensor,
63
- intermediate_hashes,
64
- n,
65
- BLOCK_SIZE=BLOCK_SIZE,
66
- PRIME=PRIME_1,
67
- XCONST=PRIME_2,
68
- )
162
+ grid1 = (triton.cdiv(n, tile_words),)
163
+ partials = torch.empty(grid1[0], dtype=torch.uint64, device=u32.device)
164
+ hash_tiles32_kernel_blocked[grid1](
165
+ u32,
166
+ partials,
167
+ n,
168
+ seed1=seed & 0xFFFFFFFF,
169
+ seed2=((seed * 0x9E3779B1) ^ 0xDEADBEEF) & 0xFFFFFFFF,
170
+ FM_C1=FMIX32_C1,
171
+ FM_C2=FMIX32_C2,
172
+ POS_A=POS_C1,
173
+ POS_B=POS_C2,
174
+ TILE=tile_words,
175
+ BLOCK=block_words,
176
+ USE_CG=use_cg,
177
+ num_warps=num_warps,
178
+ num_stages=num_stages,
179
+ )
69
180
 
70
- # TODO: threads can't be synced on triton kernel
71
- final_hash = intermediate_hashes.sum().item()
181
+ cur = partials
182
+ while cur.numel() > 1:
183
+ n_elems = cur.numel()
184
+ grid2 = (triton.cdiv(n_elems, reduce_chunk),)
185
+ nxt = torch.empty(grid2[0], dtype=torch.uint64, device=cur.device)
186
+ add_tree_reduce_u64_kernel[grid2](cur, nxt, n_elems, CHUNK=reduce_chunk)
187
+ cur = nxt
72
188
 
73
- return final_hash
189
+ return _final_splitmix64(int(cur.item()))
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
 
4
4
  import builtins
5
5
  import inspect
6
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
6
+ from typing import TYPE_CHECKING, Dict, Optional, Type
7
7
 
8
8
  import torch
9
9
 
@@ -26,8 +26,9 @@ try:
26
26
  from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
27
27
 
28
28
  VLLM_AVAILABLE = True
29
- except ImportError:
29
+ except ImportError as e:
30
30
  VLLM_AVAILABLE = False
31
+ VLLM_IMPORT_ERROR = e
31
32
 
32
33
  # Define empty classes as placeholders when vllm is not available
33
34
  class DummyConfig:
@@ -54,13 +55,7 @@ if is_mxfp_supported:
54
55
  from sglang.srt.layers.quantization.fp4 import MxFp4Config
55
56
 
56
57
  from sglang.srt.layers.quantization.fp8 import Fp8Config
57
- from sglang.srt.layers.quantization.gptq import (
58
- GPTQConfig,
59
- GPTQLinearMethod,
60
- GPTQMarlinConfig,
61
- GPTQMarlinLinearMethod,
62
- GPTQMarlinMoEMethod,
63
- )
58
+ from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
64
59
  from sglang.srt.layers.quantization.modelopt_quant import (
65
60
  ModelOptFp4Config,
66
61
  ModelOptFp8Config,
@@ -69,7 +64,6 @@ from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
69
64
  from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
70
65
  from sglang.srt.layers.quantization.petit import PetitNvFp4Config
71
66
  from sglang.srt.layers.quantization.qoq import QoQConfig
72
- from sglang.srt.layers.quantization.utils import get_linear_quant_method
73
67
  from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
74
68
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
75
69
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
@@ -85,6 +79,10 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
85
79
  "modelopt_fp4": ModelOptFp4Config,
86
80
  "w8a8_int8": W8A8Int8Config,
87
81
  "w8a8_fp8": W8A8Fp8Config,
82
+ "awq": AWQConfig,
83
+ "awq_marlin": AWQMarlinConfig,
84
+ "gptq": GPTQConfig,
85
+ "gptq_marlin": GPTQMarlinConfig,
88
86
  "moe_wna16": MoeWNA16Config,
89
87
  "compressed-tensors": CompressedTensorsConfig,
90
88
  "qoq": QoQConfig,
@@ -110,19 +108,15 @@ elif is_mxfp_supported and is_hip():
110
108
  # VLLM-dependent quantization methods
111
109
  VLLM_QUANTIZATION_METHODS = {
112
110
  "aqlm": AQLMConfig,
113
- "awq": AWQConfig,
114
111
  "deepspeedfp": DeepSpeedFPConfig,
115
112
  "tpu_int8": Int8TpuConfig,
116
113
  "fbgemm_fp8": FBGEMMFp8Config,
117
114
  "marlin": MarlinConfig,
118
115
  "gguf": GGUFConfig,
119
116
  "gptq_marlin_24": GPTQMarlin24Config,
120
- "awq_marlin": AWQMarlinConfig,
121
117
  "bitsandbytes": BitsAndBytesConfig,
122
118
  "qqq": QQQConfig,
123
119
  "experts_int8": ExpertsInt8Config,
124
- "gptq_marlin": GPTQMarlinConfig,
125
- "gptq": GPTQConfig,
126
120
  }
127
121
 
128
122
  QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
@@ -137,29 +131,13 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
137
131
  if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
138
132
  raise ValueError(
139
133
  f"{quantization} quantization requires some operators from vllm. "
140
- "Please install vllm by `pip install vllm==0.9.0.1`"
134
+ f"Please install vllm by `pip install vllm==0.9.0.1`\n"
135
+ f"Import error: {VLLM_IMPORT_ERROR}"
141
136
  )
142
137
 
143
138
  return QUANTIZATION_METHODS[quantization]
144
139
 
145
140
 
146
- def gptq_get_quant_method(self, layer, prefix):
147
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
148
-
149
- if isinstance(layer, FusedMoE):
150
- return GPTQMarlinMoEMethod(self)
151
-
152
- if isinstance(self, GPTQConfig):
153
- return get_linear_quant_method(
154
- self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
155
- )
156
- elif isinstance(self, GPTQMarlinConfig):
157
- return get_linear_quant_method(
158
- self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
159
- )
160
- return None
161
-
162
-
163
141
  original_isinstance = builtins.isinstance
164
142
 
165
143
 
@@ -237,10 +215,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
237
215
 
238
216
  def monkey_patch_quant_configs():
239
217
  """Apply all monkey patches in one place."""
240
- setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
241
- setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
242
218
 
243
- monkey_patch_moe_apply(GPTQMarlinMoEMethod)
244
219
  monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
245
220
  monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
246
221