sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,14 @@ from typing import Callable, Optional
18
18
  import torch
19
19
  import torch.nn.functional as F
20
20
 
21
- from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
21
+ from sglang.srt.managers.expert_distribution import (
22
+ ExpertDistributionRecorder,
23
+ get_global_expert_distribution_recorder,
24
+ )
25
+ from sglang.srt.managers.expert_location_dispatch import (
26
+ ExpertLocationDispatchInfo,
27
+ topk_ids_logical_to_physical,
28
+ )
22
29
  from sglang.srt.managers.schedule_batch import global_server_args_dict
23
30
  from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
24
31
 
@@ -32,9 +39,6 @@ if _is_cuda or _is_hip:
32
39
  from sgl_kernel import topk_softmax
33
40
 
34
41
 
35
- expert_distribution_recorder = ExpertDistributionRecorder()
36
-
37
-
38
42
  def fused_topk_native(
39
43
  hidden_states: torch.Tensor,
40
44
  gating_output: torch.Tensor,
@@ -61,6 +65,7 @@ def fused_topk(
61
65
  gating_output: torch.Tensor,
62
66
  topk: int,
63
67
  renormalize: bool,
68
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
64
69
  ):
65
70
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
66
71
 
@@ -84,7 +89,7 @@ def fused_topk(
84
89
 
85
90
  if renormalize:
86
91
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
87
-
92
+ topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
88
93
  return topk_weights, topk_ids
89
94
 
90
95
 
@@ -99,6 +104,8 @@ def grouped_topk(
99
104
  topk_group: int = 0,
100
105
  n_share_experts_fusion: int = 0,
101
106
  routed_scaling_factor: Optional[float] = None,
107
+ num_token_non_padded: Optional[torch.Tensor] = None,
108
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
102
109
  ):
103
110
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
104
111
 
@@ -138,7 +145,10 @@ def grouped_topk(
138
145
  )
139
146
  topk_weights = topk_weights / topk_weights_sum
140
147
 
141
- return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
148
+ topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
149
+ topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
150
+ _mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
151
+ return topk_weights, topk_ids
142
152
 
143
153
 
144
154
  def biased_grouped_topk_impl(
@@ -151,6 +161,8 @@ def biased_grouped_topk_impl(
151
161
  topk_group: int = 0,
152
162
  n_share_experts_fusion: int = 0,
153
163
  routed_scaling_factor: Optional[float] = None,
164
+ num_token_non_padded: Optional[torch.Tensor] = None,
165
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
154
166
  ):
155
167
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
156
168
 
@@ -197,13 +209,26 @@ def biased_grouped_topk_impl(
197
209
  )
198
210
  topk_weights = topk_weights / topk_weights_sum
199
211
 
200
- return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
212
+ topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
213
+ topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
214
+ _mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
215
+ return topk_weights, topk_ids
201
216
 
202
217
 
203
218
  def is_power_of_two(n):
204
219
  return n > 0 and math.log2(n).is_integer()
205
220
 
206
221
 
222
+ def _mask_topk_ids_padded_region(
223
+ topk_ids: torch.Tensor,
224
+ num_token_non_padded: Optional[torch.Tensor] = None,
225
+ ):
226
+ if num_token_non_padded is None:
227
+ return
228
+ indices = torch.arange(0, topk_ids.shape[0], device=topk_ids.device)
229
+ topk_ids[indices >= num_token_non_padded, :] = -1
230
+
231
+
207
232
  def biased_grouped_topk(
208
233
  hidden_states: torch.Tensor,
209
234
  gating_output: torch.Tensor,
@@ -215,6 +240,8 @@ def biased_grouped_topk(
215
240
  compiled: bool = True,
216
241
  n_share_experts_fusion: int = 0,
217
242
  routed_scaling_factor: Optional[float] = None,
243
+ num_token_non_padded: Optional[torch.Tensor] = None,
244
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
218
245
  ):
219
246
  assert (
220
247
  routed_scaling_factor is not None
@@ -226,7 +253,7 @@ def biased_grouped_topk(
226
253
  <= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
227
254
  and is_power_of_two(correction_bias.shape[0])
228
255
  ):
229
- return moe_fused_gate(
256
+ topk_weights, topk_ids = moe_fused_gate(
230
257
  gating_output,
231
258
  correction_bias,
232
259
  num_expert_group,
@@ -235,6 +262,15 @@ def biased_grouped_topk(
235
262
  n_share_experts_fusion,
236
263
  routed_scaling_factor,
237
264
  )
265
+ # TODO merge into kernel for this branch
266
+ topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
267
+ # TODO will fuse this into kernel, thus use slow manual operation now
268
+ if num_token_non_padded is None:
269
+ return topk_weights, topk_ids
270
+ torch.compile(
271
+ _mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend()
272
+ )(topk_ids, num_token_non_padded)
273
+ return topk_weights, topk_ids
238
274
  else:
239
275
  biased_grouped_topk_fn = (
240
276
  torch.compile(
@@ -253,6 +289,8 @@ def biased_grouped_topk(
253
289
  topk_group,
254
290
  n_share_experts_fusion=n_share_experts_fusion,
255
291
  routed_scaling_factor=routed_scaling_factor,
292
+ num_token_non_padded=num_token_non_padded,
293
+ expert_location_dispatch_info=expert_location_dispatch_info,
256
294
  )
257
295
 
258
296
 
@@ -268,9 +306,11 @@ def select_experts(
268
306
  correction_bias: Optional[torch.Tensor] = None,
269
307
  torch_native: bool = False,
270
308
  routed_scaling_factor: Optional[float] = None,
309
+ num_token_non_padded: Optional[torch.Tensor] = None,
310
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
271
311
  ):
272
312
  n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
273
- # DeekSeek V2/V3/R1 serices models uses grouped_top_k
313
+ # DeepSeek V2/V3/R1 series models use grouped_top_k
274
314
  if use_grouped_topk:
275
315
  assert topk_group is not None
276
316
  assert num_expert_group is not None
@@ -284,6 +324,8 @@ def select_experts(
284
324
  topk_group=topk_group,
285
325
  n_share_experts_fusion=n_share_experts_fusion,
286
326
  routed_scaling_factor=routed_scaling_factor,
327
+ num_token_non_padded=num_token_non_padded,
328
+ expert_location_dispatch_info=expert_location_dispatch_info,
287
329
  )
288
330
  else:
289
331
  topk_weights, topk_ids = biased_grouped_topk(
@@ -296,8 +338,14 @@ def select_experts(
296
338
  topk_group=topk_group,
297
339
  n_share_experts_fusion=n_share_experts_fusion,
298
340
  routed_scaling_factor=routed_scaling_factor,
341
+ num_token_non_padded=num_token_non_padded,
342
+ expert_location_dispatch_info=expert_location_dispatch_info,
299
343
  )
300
344
  elif torch_native and custom_routing_function is None:
345
+ assert (
346
+ num_token_non_padded is None
347
+ ), "num_token_non_padded is not yet supported in fused_topk_native"
348
+ assert expert_location_dispatch_info is None
301
349
  topk_weights, topk_ids = fused_topk_native(
302
350
  hidden_states=hidden_states,
303
351
  gating_output=router_logits,
@@ -305,13 +353,22 @@ def select_experts(
305
353
  renormalize=renormalize,
306
354
  )
307
355
  elif custom_routing_function is None:
356
+ assert (
357
+ num_token_non_padded is None
358
+ ), "num_token_non_padded is not yet supported in fused_topk"
359
+ # Qwen3MOE uses fused_topk
308
360
  topk_weights, topk_ids = fused_topk(
309
361
  hidden_states=hidden_states,
310
362
  gating_output=router_logits,
311
363
  topk=top_k,
312
364
  renormalize=renormalize,
365
+ expert_location_dispatch_info=expert_location_dispatch_info,
313
366
  )
314
367
  else:
368
+ assert (
369
+ num_token_non_padded is None
370
+ ), "num_token_non_padded is not yet supported in custom_routing_function"
371
+ assert expert_location_dispatch_info is None
315
372
  topk_weights, topk_ids = custom_routing_function(
316
373
  hidden_states=hidden_states,
317
374
  gating_output=router_logits,
@@ -319,6 +376,6 @@ def select_experts(
319
376
  renormalize=renormalize,
320
377
  )
321
378
 
322
- expert_distribution_recorder.record_new_token(topk_ids)
379
+ get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
323
380
 
324
381
  return topk_weights, topk_ids
@@ -0,0 +1,70 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """Logits processing."""
15
+
16
+ import torch
17
+ import triton
18
+ import triton.language as tl
19
+
20
+
21
+ @triton.jit
22
+ def hash_kernel(
23
+ input_ptr,
24
+ output_ptr,
25
+ n_elements,
26
+ BLOCK_SIZE: tl.constexpr,
27
+ PRIME: tl.constexpr,
28
+ XCONST: tl.constexpr,
29
+ ):
30
+ pid = tl.program_id(axis=0)
31
+ block_start = pid * BLOCK_SIZE
32
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
33
+ mask = offsets < n_elements
34
+
35
+ data = tl.load(input_ptr + offsets, mask=mask, other=0)
36
+ mixed = data ^ (offsets + XCONST)
37
+ hash_val = mixed * PRIME
38
+ hash_val = hash_val ^ (hash_val >> 16)
39
+ hash_val = hash_val * (PRIME ^ XCONST)
40
+ hash_val = hash_val ^ (hash_val >> 13)
41
+
42
+ tl.store(output_ptr + offsets, hash_val, mask=mask)
43
+
44
+
45
+ PRIME_1 = -(11400714785074694791 ^ 0xFFFFFFFFFFFFFFFF) - 1
46
+ PRIME_2 = -(14029467366897019727 ^ 0xFFFFFFFFFFFFFFFF) - 1
47
+
48
+
49
+ def gpu_tensor_hash(tensor: torch.Tensor) -> int:
50
+ assert tensor.is_cuda
51
+ tensor = tensor.contiguous().view(torch.int32)
52
+ n = tensor.numel()
53
+ BLOCK_SIZE = 1024
54
+ grid = (triton.cdiv(n, BLOCK_SIZE),)
55
+
56
+ intermediate_hashes = torch.empty(n, dtype=torch.int32, device=tensor.device)
57
+
58
+ hash_kernel[grid](
59
+ tensor,
60
+ intermediate_hashes,
61
+ n,
62
+ BLOCK_SIZE=BLOCK_SIZE,
63
+ PRIME=PRIME_1,
64
+ XCONST=PRIME_2,
65
+ )
66
+
67
+ # TODO: threads can't be synced on triton kernel
68
+ final_hash = intermediate_hashes.sum().item()
69
+
70
+ return final_hash
@@ -25,7 +25,6 @@ try:
25
25
  from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
26
26
  from vllm.model_executor.layers.quantization.gptq_marlin import (
27
27
  GPTQMarlinLinearMethod,
28
- GPTQMarlinMoEMethod,
29
28
  )
30
29
  from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
31
30
  GPTQMarlin24Config,
@@ -58,12 +57,17 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
58
57
  CompressedTensorsConfig,
59
58
  )
60
59
  from sglang.srt.layers.quantization.fp8 import Fp8Config
61
- from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
60
+ from sglang.srt.layers.quantization.gptq import (
61
+ GPTQConfig,
62
+ GPTQMarlinConfig,
63
+ GPTQMarlinMoEMethod,
64
+ )
62
65
  from sglang.srt.layers.quantization.modelopt_quant import (
63
66
  ModelOptFp4Config,
64
67
  ModelOptFp8Config,
65
68
  )
66
69
  from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
70
+ from sglang.srt.layers.quantization.qoq import QoQConfig
67
71
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
68
72
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
69
73
 
@@ -77,6 +81,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
77
81
  "w8a8_fp8": W8A8Fp8Config,
78
82
  "moe_wna16": MoeWNA16Config,
79
83
  "compressed-tensors": CompressedTensorsConfig,
84
+ "qoq": QoQConfig,
80
85
  }
81
86
 
82
87
  # VLLM-dependent quantization methods
@@ -109,7 +114,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
109
114
  if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
110
115
  raise ValueError(
111
116
  f"{quantization} quantization requires some operators from vllm. "
112
- "Pleaes install vllm by `pip install vllm==0.8.4`"
117
+ "Please install vllm by `pip install vllm==0.8.4`"
113
118
  )
114
119
 
115
120
  return QUANTIZATION_METHODS[quantization]
@@ -152,7 +152,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
152
152
  f"{input_size_per_partition} is not divisible by "
153
153
  f"weight quantization block_k = {block_k}."
154
154
  )
155
- # Required by collum parallel or enabling merged weights
155
+ # Required by column parallel or enabling merged weights
156
156
  if (tp_size > 1 and output_size // output_size_per_partition == tp_size) or len(
157
157
  output_partition_sizes
158
158
  ) > 1:
@@ -285,7 +285,7 @@ class BlockInt8MoEMethod:
285
285
  self.quant_config.weight_block_size[1],
286
286
  )
287
287
  # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
288
- # Required by collum parallel or enabling merged weights
288
+ # Required by column parallel or enabling merged weights
289
289
  if intermediate_size % block_n != 0:
290
290
  raise ValueError(
291
291
  f"The output_size of gate's and up's weight = "
@@ -11,30 +11,29 @@ from tqdm.contrib.concurrent import thread_map
11
11
  from sglang.srt.server_args import ServerArgs
12
12
  from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda
13
13
 
14
+ logger = logging.getLogger(__name__)
14
15
  _ENABLE_JIT_DEEPGEMM = False
15
- if is_cuda():
16
+
17
+ try:
16
18
  import deep_gemm
17
19
  from deep_gemm import get_num_sms
20
+ from deep_gemm.jit.compiler import get_nvcc_compiler
18
21
  from deep_gemm.jit_kernels.gemm import get_best_configs
19
- from deep_gemm.jit_kernels.gemm import includes as deep_gemm_includes
20
- from deep_gemm.jit_kernels.gemm import template as deep_gemm_gemm_template
21
- from deep_gemm.jit_kernels.m_grouped_gemm import (
22
- template as deep_gemm_grouped_gemm_template,
23
- )
22
+ from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
24
23
  from deep_gemm.jit_kernels.tuner import jit_tuner
25
24
 
26
25
  sm_version = get_device_sm()
27
26
  if sm_version == 90:
28
27
  if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
29
28
  _ENABLE_JIT_DEEPGEMM = True
29
+ except ImportError:
30
+ logger.warning("Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.")
30
31
 
31
32
 
32
33
  def get_enable_jit_deepgemm():
33
34
  return _ENABLE_JIT_DEEPGEMM
34
35
 
35
36
 
36
- logger = logging.getLogger(__name__)
37
-
38
37
  _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
39
38
  _ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
40
39
  "SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
@@ -45,10 +44,25 @@ _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
45
44
  _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
46
45
 
47
46
  # Force redirect deep_gemm cache_dir
48
- os.environ["DG_CACHE_DIR"] = os.getenv(
49
- "SGL_DG_CACHE_DIR", os.path.expanduser("~") + "/.cache/deep_gemm"
47
+ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
48
+ "SGL_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
50
49
  )
51
50
 
51
+ # Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
52
+ # NVRTC may have performance loss with some cases.
53
+ # And NVCC JIT speed is also 9x faster in the ref commit
54
+ _USE_NVRTC_DEFAULT = "0"
55
+ if _ENABLE_JIT_DEEPGEMM:
56
+ try:
57
+ get_nvcc_compiler()
58
+ except:
59
+ logger.warning(
60
+ "NVCC Compiler not found, use NVRTC for DeepGEMM JIT "
61
+ "and may have performance loss with some cases."
62
+ )
63
+ _USE_NVRTC_DEFAULT = "1"
64
+ os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT)
65
+
52
66
 
53
67
  def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
54
68
  global _BUILTIN_M_LIST
@@ -103,10 +117,10 @@ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dic
103
117
  def _compile_warning_1():
104
118
  if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
105
119
  logger.warning(
106
- "Entering DeepGEMM JIT Pre-Complie session. "
120
+ "Entering DeepGEMM JIT Pre-Compile session. "
107
121
  "And it may takes a long time(Typically 10-20 mins) "
108
122
  "if you have not run `sglang.compile_deep_gemm`. "
109
- "Recommand to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
123
+ "It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
110
124
  " for pre-compilation to reduce the overhead if you have not run it before. "
111
125
  "For example: "
112
126
  "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
@@ -115,7 +129,7 @@ def _compile_warning_1():
115
129
 
116
130
  def _compile_warning_2():
117
131
  logger.warning(
118
- "Entering DeepGEMM JIT Single Kernel Complie session. "
132
+ "Entering DeepGEMM JIT Single Kernel Compile session. "
119
133
  "And it will makes inference throughput becomes flaky. "
120
134
  "Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
121
135
  " for pre-compilation to solve this issue. "
@@ -130,10 +144,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
130
144
  num_groups: int,
131
145
  config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
132
146
  ) -> None:
133
- # Auto-tuning with compilation
134
- global deep_gemm_includes, deep_gemm_grouped_gemm_template
135
- _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
136
- _ = jit_tuner.compile_and_tune(
147
+ num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
148
+ block_k = 128
149
+ num_tma_threads = 128
150
+ num_math_threads_per_group = 128
151
+ kwargs = {
152
+ "NUM_TMA_THREADS": num_tma_threads,
153
+ "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
154
+ "BLOCK_K": block_k,
155
+ "NUM_SMS": num_sms,
156
+ "SMEM_SIZE": smem_config[0],
157
+ }
158
+ _, _ = jit_tuner.compile_and_tune(
137
159
  name="m_grouped_gemm_fp8_fp8_bf16_nt",
138
160
  keys={
139
161
  "N": n,
@@ -146,24 +168,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
146
168
  "NUM_STAGES": num_stages,
147
169
  "NUM_TMA_MULTICAST": tma_multicast_config[0],
148
170
  "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
149
- "GEMM_TYPE": "GroupedMasked",
171
+ "GEMM_TYPE": GemmType.GroupedMasked,
150
172
  },
151
173
  space=(),
152
- includes=deep_gemm_includes,
153
- arg_defs=(
154
- ("lhs", torch.float8_e4m3fn),
155
- ("lhs_scales", torch.float),
156
- ("rhs", torch.float8_e4m3fn),
157
- ("rhs_scales", torch.float),
158
- ("out", torch.bfloat16),
159
- ("grouped_layout", torch.int32),
160
- ("m", int),
161
- ("stream", torch.cuda.Stream),
162
- ("num_sms", int),
163
- ("smem_size", int),
164
- ),
165
- template=deep_gemm_grouped_gemm_template,
166
- args=[],
174
+ kwargs=kwargs,
175
+ runtime_cls=FP8GemmRuntime,
167
176
  )
168
177
 
169
178
 
@@ -173,9 +182,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
173
182
  num_groups: int,
174
183
  config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
175
184
  ) -> None:
176
- global deep_gemm_includes, deep_gemm_grouped_gemm_template
177
- _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
178
- _ = jit_tuner.compile_and_tune(
185
+ num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
186
+ block_k = 128
187
+ num_tma_threads = 128
188
+ num_math_threads_per_group = 128
189
+ kwargs = {
190
+ "NUM_TMA_THREADS": num_tma_threads,
191
+ "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
192
+ "BLOCK_K": block_k,
193
+ "NUM_SMS": num_sms,
194
+ "SMEM_SIZE": smem_config[0],
195
+ }
196
+ _, _ = jit_tuner.compile_and_tune(
179
197
  name="m_grouped_gemm_fp8_fp8_bf16_nt",
180
198
  keys={
181
199
  "N": n,
@@ -188,25 +206,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
188
206
  "NUM_STAGES": num_stages,
189
207
  "NUM_TMA_MULTICAST": tma_multicast_config[0],
190
208
  "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
191
- "GEMM_TYPE": "GroupedContiguous",
209
+ "GEMM_TYPE": GemmType.GroupedContiguous,
192
210
  },
193
211
  space=(),
194
- includes=deep_gemm_includes,
195
- arg_defs=(
196
- ("lhs", torch.float8_e4m3fn),
197
- ("lhs_scales", torch.float),
198
- ("rhs", torch.float8_e4m3fn),
199
- ("rhs_scales", torch.float),
200
- ("out", torch.bfloat16),
201
- ("grouped_layout", torch.int32),
202
- ("m", int),
203
- ("num_groups", int),
204
- ("stream", torch.cuda.Stream),
205
- ("num_sms", int),
206
- ("smem_size", int),
207
- ),
208
- template=deep_gemm_grouped_gemm_template,
209
- args=[],
212
+ kwargs=kwargs,
213
+ runtime_cls=FP8GemmRuntime,
210
214
  )
211
215
 
212
216
 
@@ -216,9 +220,20 @@ def _compile_gemm_nt_f8f8bf16_one(
216
220
  _: int, # _ is a dummy parameter to align with other interfaces
217
221
  config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
218
222
  ) -> None:
219
- global deep_gemm_includes, deep_gemm_gemm_template
220
- _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
221
- _ = jit_tuner.compile_and_tune(
223
+ num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
224
+ block_k = 128
225
+ num_tma_threads = 128
226
+ num_math_threads_per_group = 128
227
+ kwargs = {
228
+ "GEMM_TYPE": GemmType.Normal,
229
+ "NUM_TMA_THREADS": num_tma_threads,
230
+ "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
231
+ "NUM_GROUPS": 1,
232
+ "BLOCK_K": block_k,
233
+ "NUM_SMS": num_sms,
234
+ "SMEM_SIZE": smem_config[0],
235
+ }
236
+ _, _ = jit_tuner.compile_and_tune(
222
237
  name="gemm_fp8_fp8_bf16_nt",
223
238
  keys={
224
239
  "N": n,
@@ -232,20 +247,8 @@ def _compile_gemm_nt_f8f8bf16_one(
232
247
  "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
233
248
  },
234
249
  space=(),
235
- includes=deep_gemm_includes,
236
- arg_defs=(
237
- ("lhs", torch.float8_e4m3fn),
238
- ("lhs_scales", torch.float),
239
- ("rhs", torch.float8_e4m3fn),
240
- ("rhs_scales", torch.float),
241
- ("out", torch.bfloat16),
242
- ("m", int),
243
- ("stream", torch.cuda.Stream),
244
- ("num_sms", int),
245
- ("smem_size", int),
246
- ),
247
- template=deep_gemm_gemm_template,
248
- args=[],
250
+ kwargs=kwargs,
251
+ runtime_cls=FP8GemmRuntime,
249
252
  )
250
253
 
251
254
 
@@ -298,7 +301,7 @@ def _maybe_compile_deep_gemm_one_type_all(
298
301
  logger.info(
299
302
  f"Try DeepGEMM JIT Compiling for "
300
303
  f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
301
- f"{' It only takes a litte time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
304
+ f"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
302
305
  )
303
306
 
304
307
  # NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
@@ -373,7 +376,7 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
373
376
 
374
377
  from deep_gemm.jit.runtime import RuntimeCache
375
378
 
376
- origin_func = RuntimeCache.__getitem__
379
+ origin_func = RuntimeCache.get
377
380
 
378
381
  def __patched_func(self, *args, **kwargs):
379
382
  ret = origin_func(self, *args, **kwargs)
@@ -385,6 +388,6 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
385
388
  )
386
389
  return ret
387
390
 
388
- RuntimeCache.__getitem__ = __patched_func
391
+ RuntimeCache.get = __patched_func
389
392
  yield
390
- RuntimeCache.__getitem__ = origin_func
393
+ RuntimeCache.get = origin_func