sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,8 @@ import torch
11
11
  from compressed_tensors import CompressionFormat
12
12
  from compressed_tensors.quantization import QuantizationStrategy
13
13
 
14
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
15
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
14
16
  from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
15
17
  from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
16
18
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
@@ -30,8 +32,10 @@ from sglang.srt.utils import (
30
32
 
31
33
  if TYPE_CHECKING:
32
34
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
33
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
34
- from sglang.srt.layers.moe.topk import TopKOutput
35
+ from sglang.srt.layers.moe.token_dispatcher import (
36
+ CombineInput,
37
+ StandardDispatchOutput,
38
+ )
35
39
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
36
40
  CompressedTensorsConfig,
37
41
  )
@@ -293,14 +297,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
293
297
  )
294
298
  torch.cuda.empty_cache()
295
299
 
300
+ def create_moe_runner(
301
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
302
+ ):
303
+ self.moe_runner_config = moe_runner_config
304
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
305
+
296
306
  def apply(
297
307
  self,
298
308
  layer: torch.nn.Module,
299
- x: torch.Tensor,
300
- topk_output: TopKOutput,
301
- moe_runner_config: MoeRunnerConfig,
302
- ) -> torch.Tensor:
303
- from sglang.srt.layers.moe.fused_moe_triton import fused_experts
309
+ dispatch_output: StandardDispatchOutput,
310
+ ) -> CombineInput:
311
+
312
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
313
+
314
+ x = dispatch_output.hidden_states
315
+ topk_output = dispatch_output.topk_output
316
+
317
+ moe_runner_config = self.moe_runner_config
304
318
 
305
319
  if (
306
320
  _use_aiter
@@ -308,7 +322,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
308
322
  and moe_runner_config.apply_router_weight_on_input
309
323
  ):
310
324
  topk_weights, topk_ids, _ = topk_output
311
- return rocm_fused_experts_tkw1(
325
+ output = rocm_fused_experts_tkw1(
312
326
  hidden_states=x,
313
327
  w1=layer.w13_weight,
314
328
  w2=layer.w2_weight,
@@ -324,21 +338,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
324
338
  a1_scale=layer.w13_input_scale,
325
339
  a2_scale=layer.w2_input_scale,
326
340
  )
341
+ return StandardCombineInput(hidden_states=output)
327
342
  else:
328
- return fused_experts(
329
- x,
330
- layer.w13_weight,
331
- layer.w2_weight,
332
- topk_output=topk_output,
333
- moe_runner_config=moe_runner_config,
343
+ quant_info = TritonMoeQuantInfo(
344
+ w13_weight=layer.w13_weight,
345
+ w2_weight=layer.w2_weight,
334
346
  use_fp8_w8a8=True,
335
347
  per_channel_quant=self.weight_quant.strategy
336
348
  == QuantizationStrategy.CHANNEL,
337
- w1_scale=layer.w13_weight_scale,
349
+ w13_scale=layer.w13_weight_scale,
338
350
  w2_scale=layer.w2_weight_scale,
339
- a1_scale=layer.w13_input_scale,
351
+ a13_scale=layer.w13_input_scale,
340
352
  a2_scale=layer.w2_input_scale,
341
353
  )
354
+ return self.runner.run(dispatch_output, quant_info)
342
355
 
343
356
 
344
357
  class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
@@ -380,8 +393,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
380
393
  params_dtype == torch.float16
381
394
  ), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
382
395
 
383
- intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
384
-
385
396
  # Will transpose the loaded weight along the
386
397
  # intermediate and hidden dim sizes. Will
387
398
  # shard for TP along the transposed dims
@@ -415,13 +426,13 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
415
426
  # In the case where we have actorder/g_idx,
416
427
  # we do not partition the w2 scales
417
428
  load_full_w2 = self.actorder and self.group_size != -1
418
- w2_scales_size = (
419
- intermediate_size_full if load_full_w2 else intermediate_size_per_partition
420
- )
421
429
 
422
- self.is_k_full = (not self.actorder) or (
423
- intermediate_size_per_partition == intermediate_size_full
424
- )
430
+ if load_full_w2:
431
+ w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
432
+ else:
433
+ w2_scales_size = intermediate_size_per_partition
434
+
435
+ self.is_k_full = (not self.actorder) or layer.moe_tp_size == 1
425
436
 
426
437
  if self.strategy == "channel":
427
438
  num_groups_w2 = num_groups_w13 = 1
@@ -640,21 +651,29 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
640
651
  )
641
652
  replace_tensor("w2_weight_scale", marlin_w2_scales)
642
653
 
654
+ def create_moe_runner(
655
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
656
+ ):
657
+ self.moe_runner_config = moe_runner_config
658
+
643
659
  def apply(
644
660
  self,
645
661
  layer: torch.nn.Module,
646
- x: torch.Tensor,
647
- topk_output: TopKOutput,
648
- moe_runner_config: MoeRunnerConfig,
649
- ) -> torch.Tensor:
662
+ dispatch_output: StandardDispatchOutput,
663
+ ) -> CombineInput:
664
+
665
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
650
666
 
651
667
  assert (
652
- moe_runner_config.activation == "silu"
668
+ self.moe_runner_config.activation == "silu"
653
669
  ), "Only SiLU activation is supported."
654
670
 
671
+ x = dispatch_output.hidden_states
672
+ topk_output = dispatch_output.topk_output
673
+
655
674
  topk_weights, topk_ids, router_logits = topk_output
656
675
 
657
- return torch.ops.vllm.fused_marlin_moe(
676
+ output = torch.ops.vllm.fused_marlin_moe(
658
677
  x,
659
678
  layer.w13_weight_packed,
660
679
  layer.w2_weight_packed,
@@ -670,3 +689,4 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
670
689
  num_bits=self.num_bits,
671
690
  is_k_full=self.is_k_full,
672
691
  )
692
+ return StandardCombineInput(hidden_states=output)
@@ -21,9 +21,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
21
21
  normalize_e4m3fn_to_e4m3fnuz,
22
22
  )
23
23
  from sglang.srt.layers.quantization.utils import requantize_with_max_scale
24
+ from sglang.srt.utils import get_bool_env_var, is_hip
24
25
 
25
26
  __all__ = ["CompressedTensorsW8A8Fp8"]
26
27
 
28
+ _is_hip = is_hip()
29
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
30
+ if _use_aiter:
31
+ from aiter.ops.shuffle import shuffle_weight
32
+
27
33
 
28
34
  class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
29
35
 
@@ -76,7 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
76
82
  else:
77
83
  weight_scale = layer.weight_scale.data
78
84
 
79
- layer.weight = Parameter(weight.t(), requires_grad=False)
85
+ if _use_aiter:
86
+ layer.weight = Parameter(
87
+ shuffle_weight(weight, (16, 16)), requires_grad=False
88
+ )
89
+ else:
90
+ layer.weight = Parameter(weight.t(), requires_grad=False)
91
+
80
92
  # required by torch.compile to be torch.nn.Parameter
81
93
  layer.weight_scale = Parameter(weight_scale, requires_grad=False)
82
94
 
@@ -1,26 +1,22 @@
1
1
  import logging
2
2
  import os
3
3
  from contextlib import contextmanager
4
- from dataclasses import dataclass
5
4
  from enum import IntEnum, auto
6
- from typing import Callable, Dict, List, Optional, Tuple
5
+ from typing import Dict, List, Tuple
7
6
 
8
- from tqdm.contrib.concurrent import thread_map
7
+ import torch
8
+ from tqdm import tqdm
9
9
 
10
10
  from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
11
- DEEPGEMM_BLACKWELL,
12
11
  ENABLE_JIT_DEEPGEMM,
13
12
  )
14
13
  from sglang.srt.server_args import ServerArgs
15
- from sglang.srt.utils import get_bool_env_var, get_int_env_var
14
+ from sglang.srt.utils import ceil_div, get_bool_env_var, get_int_env_var
16
15
 
17
16
  logger = logging.getLogger(__name__)
18
17
 
19
- if ENABLE_JIT_DEEPGEMM and not DEEPGEMM_BLACKWELL:
20
- from deep_gemm import get_num_sms
21
- from deep_gemm.jit import build
22
- from deep_gemm.jit_kernels.gemm import get_best_configs
23
- from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
18
+ if ENABLE_JIT_DEEPGEMM:
19
+ import deep_gemm
24
20
 
25
21
 
26
22
  _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
@@ -40,19 +36,7 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
40
36
  # Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
41
37
  # NVRTC may have performance loss with some cases.
42
38
  # And NVCC JIT speed is also 9x faster in the ref commit
43
- _USE_NVRTC_DEFAULT = "0"
44
- if ENABLE_JIT_DEEPGEMM:
45
- try:
46
- from deep_gemm.jit.compiler import get_nvcc_compiler
47
-
48
- get_nvcc_compiler()
49
- except:
50
- logger.warning(
51
- "NVCC Compiler not found, use NVRTC for DeepGEMM JIT "
52
- "and may have performance loss with some cases."
53
- )
54
- _USE_NVRTC_DEFAULT = "1"
55
- os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT)
39
+ os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0")
56
40
 
57
41
 
58
42
  def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
@@ -75,7 +59,7 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
75
59
  # Default each rank will try compile all Ms to
76
60
  # load all symbols at the launch stages.
77
61
  # Avoid loading symbols at the serving stages.
78
- _DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE or not _IN_PRECOMPILE_STAGE
62
+ _DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE
79
63
 
80
64
 
81
65
  class DeepGemmKernelType(IntEnum):
@@ -84,185 +68,15 @@ class DeepGemmKernelType(IntEnum):
84
68
  GEMM_NT_F8F8BF16 = auto()
85
69
 
86
70
 
87
- @dataclass
88
- class DeepGemmKernelHelper:
89
- name: str
90
- compile_func: Callable[
91
- [
92
- int,
93
- int,
94
- int,
95
- Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
96
- ],
97
- None,
98
- ]
99
- configure_func: Callable[
100
- [int, int, int, int, int],
101
- Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
102
- ]
103
-
104
-
105
71
  _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
106
72
 
107
73
 
108
- # TODO improve naming
109
- def _compile_warning_1():
110
- if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
111
- logger.warning(
112
- "Entering DeepGEMM JIT Pre-Compile session. "
113
- "It may takes a long time (typically 10-20 mins) "
114
- "if you have not run `sglang.compile_deep_gemm`. "
115
- "It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
116
- " for pre-compilation to reduce the overhead if you have not run it before. "
117
- "For example: "
118
- "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
119
- )
120
-
121
-
122
- # TODO improve naming
123
- def _compile_warning_2():
124
- logger.warning(
125
- "Entering DeepGEMM JIT Single Kernel Compile session. "
126
- "And it will makes inference throughput becomes flaky. "
127
- "Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
128
- " for pre-compilation to solve this issue. "
129
- "For example: "
130
- "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
131
- )
132
-
133
-
134
- def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
135
- n: int,
136
- k: int,
137
- num_groups: int,
138
- config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
139
- ) -> None:
140
- num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
141
- block_k = 128
142
- num_tma_threads = 128
143
- num_math_threads_per_group = 128
144
-
145
- kwargs = {
146
- "GEMM_TYPE": GemmType.GroupedMasked,
147
- "NUM_TMA_THREADS": num_tma_threads,
148
- "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
149
- "N": n,
150
- "K": k,
151
- "NUM_GROUPS": num_groups,
152
- "BLOCK_M": block_m,
153
- "BLOCK_N": block_n,
154
- "BLOCK_K": block_k,
155
- "SWIZZLE_D_MODE": smem_config[1],
156
- "BLOCK_N_PADDING": smem_config[2],
157
- "NUM_STAGES": num_stages,
158
- "NUM_TMA_MULTICAST": tma_multicast_config[0],
159
- "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
160
- "NUM_SMS": num_sms,
161
- "SMEM_SIZE": smem_config[0],
162
- }
163
-
164
- code = FP8GemmRuntime.generate(kwargs)
165
- _ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
166
-
167
-
168
- def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
169
- n: int,
170
- k: int,
171
- num_groups: int,
172
- config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
173
- ) -> None:
174
- num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
175
- block_k = 128
176
- num_tma_threads = 128
177
- num_math_threads_per_group = 128
178
- kwargs = {
179
- "GEMM_TYPE": GemmType.GroupedContiguous,
180
- "NUM_TMA_THREADS": num_tma_threads,
181
- "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
182
- "N": n,
183
- "K": k,
184
- "NUM_GROUPS": 1,
185
- "BLOCK_M": block_m,
186
- "BLOCK_N": block_n,
187
- "BLOCK_K": block_k,
188
- "SWIZZLE_D_MODE": smem_config[1],
189
- "BLOCK_N_PADDING": smem_config[2],
190
- "NUM_STAGES": num_stages,
191
- "NUM_TMA_MULTICAST": tma_multicast_config[0],
192
- "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
193
- "NUM_SMS": num_sms,
194
- "SMEM_SIZE": smem_config[0],
195
- }
196
-
197
- code = FP8GemmRuntime.generate(kwargs)
198
- _ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
199
-
200
-
201
- def _compile_gemm_nt_f8f8bf16_one(
202
- n: int,
203
- k: int,
204
- _: int, # _ is a dummy parameter to align with other interfaces
205
- config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
206
- ) -> None:
207
- num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
208
- block_k = 128
209
- num_tma_threads = 128
210
- num_math_threads_per_group = 128
211
- kwargs = {
212
- "GEMM_TYPE": GemmType.Normal,
213
- "NUM_TMA_THREADS": num_tma_threads,
214
- "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
215
- "N": n,
216
- "K": k,
217
- "NUM_GROUPS": 1,
218
- "BLOCK_M": block_m,
219
- "BLOCK_N": block_n,
220
- "BLOCK_K": block_k,
221
- "SWIZZLE_D_MODE": smem_config[1],
222
- "BLOCK_N_PADDING": smem_config[2],
223
- "NUM_STAGES": num_stages,
224
- "NUM_TMA_MULTICAST": tma_multicast_config[0],
225
- "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
226
- "NUM_SMS": num_sms,
227
- "SMEM_SIZE": smem_config[0],
228
- }
229
-
230
- code = FP8GemmRuntime.generate(kwargs)
231
- _ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
232
-
233
-
234
- # TODO further refactor warmup-related
235
- _KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
236
- DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
237
- name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
238
- compile_func=_compile_grouped_gemm_nt_f8f8bf16_masked_one,
239
- configure_func=lambda m, n, k, num_groups, num_sms: get_best_configs(
240
- m, n, k, num_groups, num_sms, is_grouped_masked=True
241
- ),
242
- ),
243
- DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: DeepGemmKernelHelper(
244
- name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
245
- compile_func=_compile_grouped_gemm_nt_f8f8bf16_contig_one,
246
- configure_func=lambda m, n, k, _, num_sms: get_best_configs(
247
- m, n, k, 1, num_sms, is_grouped_contiguous=True
248
- ),
249
- ),
250
- DeepGemmKernelType.GEMM_NT_F8F8BF16: DeepGemmKernelHelper(
251
- name="gemm_fp8_fp8_bf16_nt",
252
- compile_func=_compile_gemm_nt_f8f8bf16_one,
253
- configure_func=lambda m, n, k, _, num_sms: get_best_configs(
254
- m, n, k, 1, num_sms
255
- ),
256
- ),
257
- }
258
-
259
-
74
+ # TODO improve code
260
75
  def _maybe_compile_deep_gemm_one_type_all(
261
76
  kernel_type: DeepGemmKernelType,
262
77
  n: int,
263
78
  k: int,
264
79
  num_groups: int,
265
- m_list: Optional[List[int]] = None,
266
80
  ) -> None:
267
81
  global _INITIALIZATION_DICT
268
82
  global _BUILTIN_M_LIST
@@ -275,61 +89,153 @@ def _maybe_compile_deep_gemm_one_type_all(
275
89
  ):
276
90
  _INITIALIZATION_DICT[query_key] = True
277
91
 
278
- kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
279
- _compile_warning_1()
92
+ # TODO maybe improve logs
93
+ if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
94
+ logger.warning(
95
+ "Entering DeepGEMM JIT Pre-Compile session. "
96
+ "It may take a long time (typically 10-20 mins) "
97
+ "if you have not run `sglang.compile_deep_gemm`. "
98
+ "It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
99
+ " for pre-compilation to reduce the overhead if you have not run it before. "
100
+ "For example: "
101
+ "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
102
+ )
103
+
280
104
  logger.info(
281
105
  f"Try DeepGEMM JIT Compiling for "
282
- f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
106
+ f"<{kernel_type.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
283
107
  f"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
284
108
  )
285
109
 
286
- # NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
287
- num_sms = get_num_sms()
288
- collected_configs = set()
289
- for m in m_list if m_list is not None else _BUILTIN_M_LIST:
290
- # Put config into set to get unique configs and reduce cases to be compiled
291
- collected_configs.add(
292
- kernel_helper.configure_func(m, n, k, num_groups, num_sms)
293
- )
294
- compile_func = lambda config: kernel_helper.compile_func(
295
- n, k, num_groups, config
110
+ _compile_deep_gemm_one_type_all(
111
+ kernel_type=kernel_type,
112
+ n=n,
113
+ k=k,
114
+ num_groups=num_groups,
115
+ m_list=_BUILTIN_M_LIST,
296
116
  )
297
- thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
298
117
 
299
118
 
300
- @contextmanager
301
- def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
302
- if _IN_PRECOMPILE_STAGE:
303
- yield
304
- return
119
+ # NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
120
+ def _compile_deep_gemm_one_type_all(
121
+ kernel_type: DeepGemmKernelType,
122
+ n: int,
123
+ k: int,
124
+ num_groups: int,
125
+ m_list: List[int],
126
+ ) -> None:
127
+ if kernel_type == DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG:
128
+ m_alignment = deep_gemm.get_mk_alignment_for_contiguous_layout()
129
+ m_list = sorted(list(set(m for m in m_list if m % m_alignment == 0)))
130
+
131
+ executor = _BaseWarmupExecutor.create(
132
+ kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
133
+ )
305
134
 
306
- from deep_gemm.jit.runtime import RuntimeCache
135
+ old_compile_mode = deep_gemm.get_compile_mode()
136
+ deep_gemm.set_compile_mode(1)
137
+ # TODO can use multi thread
138
+ for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
139
+ executor.execute(m=m)
140
+ deep_gemm.set_compile_mode(old_compile_mode)
141
+
142
+ # clean up input buffers
143
+ torch.cuda.current_stream().synchronize()
144
+ del executor
145
+ torch.cuda.empty_cache()
146
+
147
+
148
+ class _BaseWarmupExecutor:
149
+ @staticmethod
150
+ def create(kernel_type: DeepGemmKernelType, **kwargs):
151
+ return {
152
+ DeepGemmKernelType.GEMM_NT_F8F8BF16: _NormalWarmupExecutor,
153
+ DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: _GroupedContWarmupExecutor,
154
+ DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: _GroupedMaskedWarmupExecutor,
155
+ }[kernel_type](**kwargs)
156
+
157
+ def execute(self, m):
158
+ raise NotImplementedError
159
+
160
+
161
+ def _empty_token_fp8(size):
162
+ *dims, k = size
163
+ return (
164
+ torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
165
+ torch.empty(
166
+ (*dims, ceil_div(k, _BLOCK_SIZE)), device="cuda", dtype=torch.float32
167
+ ),
168
+ )
307
169
 
308
- origin_func = RuntimeCache.get
309
170
 
310
- def __patched_func(self, *args, **kwargs):
311
- ret = origin_func(self, *args, **kwargs)
312
- if ret is None:
313
- kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
314
- if not DEEPGEMM_BLACKWELL:
315
- _compile_warning_2()
316
- logger.warning(
317
- f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait."
318
- )
319
- return ret
171
+ def _empty_block_fp8(size):
172
+ *dims, n, k = size
173
+ return (
174
+ torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
175
+ torch.empty(
176
+ (*dims, ceil_div(n, _BLOCK_SIZE), ceil_div(k, _BLOCK_SIZE)),
177
+ device="cuda",
178
+ dtype=torch.float32,
179
+ ),
180
+ )
320
181
 
321
- RuntimeCache.get = __patched_func
322
- yield
323
- RuntimeCache.get = origin_func
182
+
183
+ _BLOCK_SIZE = 128
184
+
185
+
186
+ class _NormalWarmupExecutor(_BaseWarmupExecutor):
187
+ def __init__(self, max_m: int, n: int, k: int, num_groups: int):
188
+ self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
189
+ self.rhs_q, self.rhs_s = _empty_block_fp8((n, k))
190
+ self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
191
+
192
+ def execute(self, m):
193
+ deep_gemm.fp8_gemm_nt(
194
+ (self.lhs_q[:m], self.lhs_s[:m]),
195
+ (self.rhs_q, self.rhs_s),
196
+ self.out[:m],
197
+ )
198
+
199
+
200
+ class _GroupedContWarmupExecutor(_BaseWarmupExecutor):
201
+ def __init__(self, max_m: int, n: int, k: int, num_groups: int):
202
+ self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
203
+ self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))
204
+ self.m_indices = torch.zeros((max_m,), device="cuda", dtype=torch.int32)
205
+ self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
206
+
207
+ def execute(self, m):
208
+ deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
209
+ (self.lhs_q[:m], self.lhs_s[:m]),
210
+ (self.rhs_q, self.rhs_s),
211
+ self.out[:m],
212
+ m_indices=self.m_indices[:m],
213
+ )
214
+
215
+
216
+ class _GroupedMaskedWarmupExecutor(_BaseWarmupExecutor):
217
+ def __init__(self, max_m: int, n: int, k: int, num_groups: int):
218
+ self.lhs_q, self.lhs_s = _empty_token_fp8((num_groups, max_m, k))
219
+ self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))
220
+ self.masked_m = torch.zeros((num_groups,), device="cuda", dtype=torch.int32)
221
+ self.out = torch.empty(
222
+ (num_groups, max_m, n), device="cuda", dtype=torch.bfloat16
223
+ )
224
+
225
+ def execute(self, m):
226
+ deep_gemm.fp8_m_grouped_gemm_nt_masked(
227
+ (self.lhs_q, self.lhs_s),
228
+ (self.rhs_q, self.rhs_s),
229
+ self.out,
230
+ masked_m=self.masked_m,
231
+ # DeepGEMM uses `expect_m` instead of input shape for `get_best_config`
232
+ expected_m=m,
233
+ )
324
234
 
325
235
 
326
236
  @contextmanager
327
237
  def deep_gemm_execution_hook(
328
238
  m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType
329
239
  ):
330
- # not supported yet
331
- if not DEEPGEMM_BLACKWELL:
332
- _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
333
-
334
- with _log_jit_build(m, n, k, kernel_type):
335
- yield
240
+ _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
241
+ yield
@@ -11,9 +11,6 @@ def _compute_enable_deep_gemm():
11
11
  sm_version = get_device_sm()
12
12
  if sm_version < 90:
13
13
  return False
14
- # TODO fix deepgemm cu129 fp8 issue
15
- if torch.version.cuda == "12.9":
16
- return False
17
14
 
18
15
  try:
19
16
  import deep_gemm
@@ -24,14 +21,12 @@ def _compute_enable_deep_gemm():
24
21
  return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
25
22
 
26
23
 
27
- ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
24
+ def _is_blackwell_arch() -> bool:
25
+ major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
26
+ return major == 10
28
27
 
29
- try:
30
- from deep_gemm import fp8_gemm_nt
31
28
 
32
- # They have not given a name to this breaking change
33
- DEEPGEMM_BLACKWELL = True
34
- except ImportError:
35
- DEEPGEMM_BLACKWELL = False
29
+ ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
36
30
 
31
+ DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and _is_blackwell_arch()
37
32
  DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL