sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -11,24 +11,25 @@ from sgl_kernel import (
11
11
  )
12
12
 
13
13
  from sglang.srt.layers.moe.ep_moe.kernels import (
14
+ deepep_ll_get_cutlass_w4a8_moe_mm_data,
15
+ deepep_permute_triton_kernel,
16
+ deepep_post_reorder_triton_kernel,
17
+ deepep_run_moe_deep_preprocess,
14
18
  post_reorder_triton_kernel_for_cutlass_moe,
15
19
  pre_reorder_triton_kernel_for_cutlass_moe,
16
- run_cutlass_moe_ep_preproess,
20
+ run_moe_ep_preproess,
21
+ silu_and_mul_masked_post_per_tensor_quant_fwd,
17
22
  )
18
23
 
19
24
 
20
25
  def cutlass_w4a8_moe(
21
- start_expert_id: int,
22
- end_expert_id: int,
23
- total_num_experts: int,
24
26
  a: torch.Tensor,
25
27
  w1_q: torch.Tensor,
26
28
  w2_q: torch.Tensor,
27
29
  w1_scale: torch.Tensor,
28
30
  w2_scale: torch.Tensor,
29
31
  topk_weights: torch.Tensor,
30
- topk_ids_: torch.Tensor,
31
- local_topk_ids: torch.Tensor,
32
+ topk_ids: torch.Tensor,
32
33
  a_strides1: torch.Tensor,
33
34
  b_strides1: torch.Tensor,
34
35
  c_strides1: torch.Tensor,
@@ -64,6 +65,7 @@ def cutlass_w4a8_moe(
64
65
  - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
65
66
  Shape: [num_experts, N // 512, K * 4]
66
67
  - topk_weights (torch.Tensor): The weights of each token->expert mapping.
68
+ - topk_ids (torch.Tensor): The ids of each token->expert mapping.
67
69
  - a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
68
70
  - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
69
71
  - c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
@@ -83,7 +85,7 @@ def cutlass_w4a8_moe(
83
85
  Returns:
84
86
  - torch.Tensor: The fp8 output tensor after applying the MoE layer.
85
87
  """
86
- assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
88
+ assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
87
89
  assert w1_q.dtype == torch.int8
88
90
  assert w2_q.dtype == torch.int8
89
91
  assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
@@ -96,20 +98,21 @@ def cutlass_w4a8_moe(
96
98
  assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
97
99
  assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
98
100
  assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
99
- num_experts = w1_q.size(0)
101
+ num_local_experts = w1_q.size(0)
100
102
  m = a.size(0)
101
103
  k = w1_q.size(2) * 2 # w1_q is transposed and packed
102
104
  n = w2_q.size(2) * 2 # w2_q is transposed and packed
103
- topk = topk_ids_.size(1)
105
+ topk = topk_ids.size(1)
104
106
 
105
107
  if apply_router_weight_on_input:
106
108
  assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
107
109
 
108
110
  device = a.device
111
+ topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids)
109
112
 
110
- _, src2dst, _ = run_cutlass_moe_ep_preproess(
111
- local_topk_ids,
112
- num_experts,
113
+ _, src2dst, _ = run_moe_ep_preproess(
114
+ topk_ids,
115
+ num_local_experts,
113
116
  )
114
117
 
115
118
  gateup_input = torch.empty(
@@ -122,9 +125,9 @@ def cutlass_w4a8_moe(
122
125
  a,
123
126
  gateup_input,
124
127
  src2dst,
125
- local_topk_ids,
128
+ topk_ids,
126
129
  a1_scale,
127
- total_num_experts,
130
+ num_local_experts,
128
131
  topk,
129
132
  k,
130
133
  BLOCK_SIZE=512,
@@ -133,16 +136,16 @@ def cutlass_w4a8_moe(
133
136
  # NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
134
137
  # they are kept to allow for a quick switch of the permutation logic
135
138
  # from the current triton kernel implementation to the cutlass-based one if needed.
136
- a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
137
- c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
139
+ a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
140
+ c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
138
141
  get_cutlass_w4a8_moe_mm_data(
139
- local_topk_ids,
142
+ topk_ids,
140
143
  expert_offsets,
141
144
  problem_sizes1,
142
145
  problem_sizes2,
143
146
  a_map,
144
147
  c_map,
145
- num_experts,
148
+ num_local_experts,
146
149
  n,
147
150
  k,
148
151
  )
@@ -195,12 +198,339 @@ def cutlass_w4a8_moe(
195
198
  c2,
196
199
  output,
197
200
  src2dst,
198
- local_topk_ids,
201
+ topk_ids,
199
202
  topk_weights,
200
- num_experts,
201
203
  topk,
204
+ num_local_experts,
205
+ k,
206
+ BLOCK_SIZE=512,
207
+ )
208
+ return output
209
+
210
+
211
+ def cutlass_w4a8_moe_deepep_normal(
212
+ a: torch.Tensor,
213
+ w1_q: torch.Tensor,
214
+ w2_q: torch.Tensor,
215
+ w1_scale: torch.Tensor,
216
+ w2_scale: torch.Tensor,
217
+ topk_weights: torch.Tensor,
218
+ topk_ids_: torch.Tensor,
219
+ a_strides1: torch.Tensor,
220
+ b_strides1: torch.Tensor,
221
+ c_strides1: torch.Tensor,
222
+ a_strides2: torch.Tensor,
223
+ b_strides2: torch.Tensor,
224
+ c_strides2: torch.Tensor,
225
+ s_strides13: torch.Tensor,
226
+ s_strides2: torch.Tensor,
227
+ expert_offsets: torch.Tensor,
228
+ problem_sizes1: torch.Tensor,
229
+ problem_sizes2: torch.Tensor,
230
+ a1_scale: Optional[torch.Tensor] = None,
231
+ a2_scale: Optional[torch.Tensor] = None,
232
+ ) -> torch.Tensor:
233
+ """
234
+ This function computes a w4a8-quantized Mixture of Experts (MoE) layer
235
+ using two sets of quantized weights, w1_q and w2_q, and top-k gating
236
+ mechanism. The matrix multiplications are implemented with CUTLASS
237
+ grouped gemm.
238
+
239
+ Parameters:
240
+ - a (torch.Tensor): The input tensor to the MoE layer.
241
+ Shape: [M, K]
242
+ - w1_q (torch.Tensor): The first set of int4-quantized expert weights.
243
+ Shape: [num_experts, N * 2, K // 2]
244
+ (the weights are passed transposed and int4-packed)
245
+ - w2_q (torch.Tensor): The second set of int4-quantized expert weights.
246
+ Shape: [num_experts, K, N // 2]
247
+ (the weights are passed transposed and int4-packed)
248
+ - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
249
+ Shape: [num_experts, K // 512, N * 8]
250
+ - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
251
+ Shape: [num_experts, N // 512, K * 4]
252
+ - topk_weights (torch.Tensor): The weights of each token->expert mapping.
253
+ - a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
254
+ - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
255
+ - c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
256
+ - a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
257
+ - b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
258
+ - c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
259
+ - s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
260
+ - s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
261
+ - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
262
+ Shape: scalar or [1, K]
263
+ - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
264
+ quantize the intermediate result between the gemms.
265
+ Shape: scalar or [1, N]
266
+ - apply_router_weight_on_input (bool): When true, the topk weights are
267
+ applied directly on the inputs. This is only applicable when topk is 1.
268
+
269
+ Returns:
270
+ - torch.Tensor: The fp8 output tensor after applying the MoE layer.
271
+ """
272
+ assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
273
+ assert w1_q.dtype == torch.int8
274
+ assert w2_q.dtype == torch.int8
275
+ assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
276
+ assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
277
+ assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
278
+ assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
279
+ assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
280
+
281
+ assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
282
+ assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
283
+ assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
284
+ assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
285
+ num_experts = w1_q.size(0)
286
+ m = a.size(0)
287
+ k = w1_q.size(2) * 2 # w1_q is transposed and packed
288
+ n = w2_q.size(2) * 2 # w2_q is transposed and packed
289
+ topk = topk_ids_.size(1)
290
+
291
+ num_experts = w1_q.size(0)
292
+ m = a.size(0)
293
+ k = w1_q.size(2) * 2
294
+ n = w2_q.size(2) * 2
295
+ topk = topk_ids_.size(1)
296
+ device = a.device
297
+
298
+ reorder_topk_ids, src2dst, _ = deepep_run_moe_deep_preprocess(
299
+ topk_ids_, num_experts
300
+ )
301
+ num_total_tokens = reorder_topk_ids.numel()
302
+ gateup_input_pre_reorder = torch.empty(
303
+ (int(num_total_tokens), a.shape[1]),
304
+ device=device,
305
+ dtype=a.dtype,
306
+ )
307
+ deepep_permute_triton_kernel[(a.shape[0],)](
308
+ a,
309
+ gateup_input_pre_reorder,
310
+ src2dst,
311
+ topk_ids_.to(torch.int64),
312
+ None,
313
+ topk,
314
+ a.shape[1],
315
+ BLOCK_SIZE=512,
316
+ )
317
+ gateup_input = torch.empty(
318
+ gateup_input_pre_reorder.shape, dtype=torch.float8_e4m3fn, device=device
319
+ )
320
+ sgl_per_tensor_quant_fp8(
321
+ gateup_input_pre_reorder, gateup_input, a1_scale.float(), True
322
+ )
323
+ del gateup_input_pre_reorder
324
+ local_topk_ids = topk_ids_
325
+ local_topk_ids = (
326
+ torch.where(local_topk_ids == -1, num_experts, topk_ids_).to(torch.int32)
327
+ ).contiguous()
328
+
329
+ a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
330
+ c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
331
+ get_cutlass_w4a8_moe_mm_data(
332
+ local_topk_ids,
333
+ expert_offsets,
334
+ problem_sizes1,
335
+ problem_sizes2,
336
+ a_map,
337
+ c_map,
338
+ num_experts,
339
+ n,
202
340
  k,
203
- 0,
341
+ )
342
+ c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
343
+ c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
344
+
345
+ cutlass_w4a8_moe_mm(
346
+ c1,
347
+ gateup_input,
348
+ w1_q,
349
+ a1_scale.float(),
350
+ w1_scale,
351
+ expert_offsets[:-1],
352
+ problem_sizes1,
353
+ a_strides1,
354
+ b_strides1,
355
+ c_strides1,
356
+ s_strides13,
357
+ 128,
358
+ topk,
359
+ )
360
+ intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
361
+ silu_and_mul(c1, intermediate)
362
+
363
+ intermediate_q = torch.empty(
364
+ intermediate.shape, dtype=torch.float8_e4m3fn, device=device
365
+ )
366
+ sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)
367
+
368
+ cutlass_w4a8_moe_mm(
369
+ c2,
370
+ intermediate_q,
371
+ w2_q,
372
+ a2_scale.float(),
373
+ w2_scale,
374
+ expert_offsets[:-1],
375
+ problem_sizes2,
376
+ a_strides2,
377
+ b_strides2,
378
+ c_strides2,
379
+ s_strides2,
380
+ 128,
381
+ topk,
382
+ )
383
+ num_tokens = src2dst.shape[0] // topk
384
+ output = torch.empty(
385
+ (num_tokens, c2.shape[1]),
386
+ device=c2.device,
387
+ dtype=torch.bfloat16,
388
+ )
389
+ deepep_post_reorder_triton_kernel[(num_tokens,)](
390
+ c2,
391
+ output,
392
+ src2dst,
393
+ topk_ids_,
394
+ topk_weights,
395
+ topk,
396
+ c2.shape[1],
204
397
  BLOCK_SIZE=512,
205
398
  )
399
+
206
400
  return output
401
+
402
+
403
+ def cutlass_w4a8_moe_deepep_ll(
404
+ a: torch.Tensor,
405
+ w1_q: torch.Tensor,
406
+ w2_q: torch.Tensor,
407
+ w1_scale: torch.Tensor,
408
+ w2_scale: torch.Tensor,
409
+ topk_ids_: torch.Tensor,
410
+ masked_m: torch.Tensor,
411
+ a_strides1: torch.Tensor,
412
+ b_strides1: torch.Tensor,
413
+ c_strides1: torch.Tensor,
414
+ a_strides2: torch.Tensor,
415
+ b_strides2: torch.Tensor,
416
+ c_strides2: torch.Tensor,
417
+ s_strides13: torch.Tensor,
418
+ s_strides2: torch.Tensor,
419
+ expert_offsets: torch.Tensor,
420
+ problem_sizes1: torch.Tensor,
421
+ problem_sizes2: torch.Tensor,
422
+ a1_scale: Optional[torch.Tensor] = None,
423
+ a2_scale: Optional[torch.Tensor] = None,
424
+ ) -> torch.Tensor:
425
+ """
426
+ This function computes a w4a8-quantized Mixture of Experts (MoE) layer
427
+ using two sets of quantized weights, w1_q and w2_q, and top-k gating
428
+ mechanism. The matrix multiplications are implemented with CUTLASS
429
+ grouped gemm.
430
+
431
+ Parameters:
432
+ - a (torch.Tensor): The input tensor to the MoE layer.
433
+ Shape: [num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, K]
434
+ - w1_q (torch.Tensor): The first set of int4-quantized expert weights.
435
+ Shape: [num_experts, N * 2, K // 2]
436
+ (the weights are passed transposed and int4-packed)
437
+ - w2_q (torch.Tensor): The second set of int4-quantized expert weights.
438
+ Shape: [num_experts, K, N // 2]
439
+ (the weights are passed transposed and int4-packed)
440
+ - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
441
+ Shape: [num_experts, K // 512, N * 8]
442
+ - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
443
+ Shape: [num_experts, N // 512, K * 4]
444
+ - topk_weights (torch.Tensor): The weights of each token->expert mapping.
445
+ - a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
446
+ - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
447
+ - c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
448
+ - a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
449
+ - b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
450
+ - c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
451
+ - s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
452
+ - s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
453
+ - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
454
+ Shape: scalar or [1, K]
455
+ - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
456
+ quantize the intermediate result between the gemms.
457
+ Shape: scalar or [1, N]
458
+ - apply_router_weight_on_input (bool): When true, the topk weights are
459
+ applied directly on the inputs. This is only applicable when topk is 1.
460
+
461
+ Returns:
462
+ - torch.Tensor: The fp8 output tensor after applying the MoE layer.
463
+ """
464
+ assert w1_q.dtype == torch.int8
465
+ assert w2_q.dtype == torch.int8
466
+ assert a.shape[2] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
467
+ assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
468
+ assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
469
+ assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
470
+ assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
471
+
472
+ assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
473
+ assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
474
+ assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
475
+ assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
476
+ num_experts = w1_q.size(0)
477
+ m = a.size(1)
478
+ k = w1_q.size(2) * 2 # w1_q is transposed and packed
479
+ n = w2_q.size(2) * 2 # w2_q is transposed and packed
480
+ topk = topk_ids_.size(1)
481
+
482
+ device = a.device
483
+
484
+ problem_sizes1, problem_sizes2 = deepep_ll_get_cutlass_w4a8_moe_mm_data(
485
+ masked_m,
486
+ problem_sizes1,
487
+ problem_sizes2,
488
+ num_experts,
489
+ n,
490
+ k,
491
+ )
492
+
493
+ gateup_input = torch.empty(a.shape, dtype=torch.float8_e4m3fn, device=device)
494
+ sgl_per_tensor_quant_fp8(a, gateup_input, a1_scale.float(), True)
495
+ c1 = torch.empty((num_experts, m, n * 2), device=device, dtype=torch.bfloat16)
496
+ c2 = torch.empty((num_experts, m, k), device=device, dtype=torch.bfloat16)
497
+
498
+ cutlass_w4a8_moe_mm(
499
+ c1,
500
+ gateup_input,
501
+ w1_q,
502
+ a1_scale.float(),
503
+ w1_scale,
504
+ expert_offsets[:-1],
505
+ problem_sizes1,
506
+ a_strides1,
507
+ b_strides1,
508
+ c_strides1,
509
+ s_strides13,
510
+ 128,
511
+ topk,
512
+ )
513
+
514
+ intermediate_q = torch.empty(
515
+ (num_experts, m, n), device=a.device, dtype=torch.float8_e4m3fn
516
+ )
517
+ silu_and_mul_masked_post_per_tensor_quant_fwd(
518
+ c1, intermediate_q, masked_m, a2_scale
519
+ )
520
+ cutlass_w4a8_moe_mm(
521
+ c2,
522
+ intermediate_q,
523
+ w2_q,
524
+ a2_scale.float(),
525
+ w2_scale,
526
+ expert_offsets[:-1],
527
+ problem_sizes2,
528
+ a_strides2,
529
+ b_strides2,
530
+ c_strides2,
531
+ s_strides2,
532
+ 128,
533
+ topk,
534
+ )
535
+
536
+ return c2