sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import TYPE_CHECKING, List, Optional, Union
4
+ from contextlib import nullcontext
5
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
5
6
 
6
7
  import torch
7
8
  import triton
@@ -31,13 +32,19 @@ from sglang.srt.layers.quantization.fp8_kernel import (
31
32
  is_fp8_fnuz,
32
33
  sglang_per_token_group_quant_fp8,
33
34
  )
35
+ from sglang.srt.layers.quantization.modelopt_quant import (
36
+ CUTEDSL_MOE_NVFP4_DISPATCH,
37
+ ModelOptNvFp4FusedMoEMethod,
38
+ )
34
39
  from sglang.srt.managers.schedule_batch import global_server_args_dict
35
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
36
41
  from sglang.srt.offloader import get_offloader
42
+ from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
37
43
  from sglang.srt.utils import (
38
44
  ceil_div,
39
45
  dispose_tensor,
40
46
  get_bool_env_var,
47
+ get_int_env_var,
41
48
  is_cuda,
42
49
  is_hip,
43
50
  is_npu,
@@ -453,9 +460,20 @@ class DeepEPMoE(EPMoE):
453
460
  topk_idx=topk_idx,
454
461
  topk_weights=topk_weights,
455
462
  forward_batch=forward_batch,
463
+ input_global_scale=(
464
+ self.w13_input_scale_quant
465
+ if isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
466
+ and self.quant_method.enable_flashinfer_cutedsl_moe
467
+ and CUTEDSL_MOE_NVFP4_DISPATCH
468
+ else None
469
+ ),
456
470
  )
457
471
 
458
- def moe_impl(self, dispatch_output: DispatchOutput):
472
+ def moe_impl(
473
+ self,
474
+ dispatch_output: DispatchOutput,
475
+ down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
476
+ ):
459
477
  from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
460
478
 
461
479
  if _use_aiter:
@@ -470,7 +488,9 @@ class DeepEPMoE(EPMoE):
470
488
  return self.forward_deepgemm_contiguous(dispatch_output)
471
489
  elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
472
490
  if get_moe_runner_backend().is_flashinfer_cutedsl():
473
- return self.forward_flashinfer_cutedsl(dispatch_output)
491
+ return self.forward_flashinfer_cutedsl(
492
+ dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
493
+ )
474
494
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
475
495
  return self.forward_deepgemm_masked(dispatch_output)
476
496
  else:
@@ -484,12 +504,14 @@ class DeepEPMoE(EPMoE):
484
504
  topk_idx: torch.Tensor,
485
505
  topk_weights: torch.Tensor,
486
506
  forward_batch: ForwardBatch,
507
+ overlap_args: Optional[Dict[str, Any]] = None,
487
508
  ):
488
509
  return self.deepep_dispatcher.combine(
489
510
  hidden_states=hidden_states,
490
511
  topk_idx=topk_idx,
491
512
  topk_weights=topk_weights,
492
513
  forward_batch=forward_batch,
514
+ overlap_args=overlap_args,
493
515
  )
494
516
 
495
517
  def forward_aiter(
@@ -676,6 +698,7 @@ class DeepEPMoE(EPMoE):
676
698
  def forward_flashinfer_cutedsl(
677
699
  self,
678
700
  dispatch_output: DeepEPLLOutput,
701
+ down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
679
702
  ):
680
703
  hidden_states, _, _, masked_m, _ = dispatch_output
681
704
  assert self.quant_method is not None
@@ -686,6 +709,7 @@ class DeepEPMoE(EPMoE):
686
709
  x=hidden_states,
687
710
  masked_m=masked_m,
688
711
  moe_runner_config=self.moe_runner_config,
712
+ down_gemm_overlap_args=down_gemm_overlap_args,
689
713
  )
690
714
  return output
691
715
 
@@ -789,45 +813,69 @@ class DeepEPMoE(EPMoE):
789
813
  if isinstance(hidden_states, tuple):
790
814
  per_token_scale = hidden_states[1]
791
815
  hidden_states = hidden_states[0]
792
- else:
793
- # dynamic quant
794
- hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
795
- hidden_states
796
- )
797
816
 
798
817
  group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
799
818
  hidden_states.device
800
819
  )
820
+ if self.w13_weight.dtype != torch.int8:
821
+ # gmm1: gate_up_proj
822
+ hidden_states = torch_npu.npu_grouped_matmul(
823
+ x=[hidden_states],
824
+ weight=[self.w13_weight.permute(0, 2, 1)],
825
+ # per_token_scale=[per_token_scale],
826
+ split_item=2,
827
+ group_list_type=group_list_type,
828
+ group_type=0,
829
+ group_list=group_list,
830
+ output_dtype=output_dtype,
831
+ )[0]
832
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
833
+ # gmm2: down_proj
834
+ hidden_states = torch_npu.npu_grouped_matmul(
835
+ x=[hidden_states],
836
+ weight=[self.w2_weight.permute(0, 2, 1)],
837
+ split_item=2,
838
+ group_list_type=group_list_type,
839
+ group_type=0,
840
+ group_list=group_list,
841
+ output_dtype=output_dtype,
842
+ )[0]
843
+ else:
844
+ if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
845
+ hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
846
+ hidden_states
847
+ )
848
+ # gmm1: gate_up_proj
849
+ hidden_states = torch_npu.npu_grouped_matmul(
850
+ x=[hidden_states],
851
+ weight=[self.w13_weight],
852
+ scale=[self.w13_weight_scale.to(output_dtype)],
853
+ per_token_scale=[per_token_scale],
854
+ split_item=2,
855
+ group_list_type=group_list_type,
856
+ group_type=0,
857
+ group_list=group_list,
858
+ output_dtype=output_dtype,
859
+ )[0]
860
+
861
+ # act_fn: swiglu
862
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
863
+ hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
864
+ hidden_states
865
+ )
801
866
 
802
- # gmm1: gate_up_proj
803
- hidden_states = torch_npu.npu_grouped_matmul(
804
- x=[hidden_states],
805
- weight=[self.w13_weight],
806
- scale=[self.w13_weight_scale.to(output_dtype)],
807
- per_token_scale=[per_token_scale],
808
- split_item=2,
809
- group_list_type=group_list_type,
810
- group_type=0,
811
- group_list=group_list,
812
- output_dtype=output_dtype,
813
- )[0]
814
-
815
- # act_fn: swiglu
816
- hidden_states = torch_npu.npu_swiglu(hidden_states)
817
- hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
818
-
819
- # gmm2: down_proj
820
- hidden_states = torch_npu.npu_grouped_matmul(
821
- x=[hidden_states],
822
- weight=[self.w2_weight],
823
- scale=[self.w2_weight_scale.to(output_dtype)],
824
- per_token_scale=[swiglu_out_scale],
825
- split_item=2,
826
- group_list_type=group_list_type,
827
- group_type=0,
828
- group_list=group_list,
829
- output_dtype=output_dtype,
830
- )[0]
867
+ # gmm2: down_proj
868
+ hidden_states = torch_npu.npu_grouped_matmul(
869
+ x=[hidden_states],
870
+ weight=[self.w2_weight],
871
+ scale=[self.w2_weight_scale.to(output_dtype)],
872
+ per_token_scale=[swiglu_out_scale],
873
+ split_item=2,
874
+ group_list_type=group_list_type,
875
+ group_type=0,
876
+ group_list=group_list,
877
+ output_dtype=output_dtype,
878
+ )[0]
831
879
 
832
880
  return hidden_states
833
881
 
@@ -836,47 +884,72 @@ class DeepEPMoE(EPMoE):
836
884
  assert isinstance(dispatch_output, DeepEPLLOutput)
837
885
  hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
838
886
 
839
- per_token_scale = hidden_states[1]
840
- hidden_states = hidden_states[0]
887
+ if isinstance(hidden_states, tuple):
888
+ per_token_scale = hidden_states[1]
889
+ hidden_states = hidden_states[0]
841
890
 
842
891
  group_list = group_list.to(torch.int64)
843
892
 
844
- # gmm1: gate_up_proj
845
- hidden_states = torch_npu.npu_grouped_matmul(
846
- x=[hidden_states],
847
- weight=[self.w13_weight],
848
- split_item=2,
849
- group_list_type=group_list_type,
850
- group_type=0,
851
- group_list=group_list,
852
- output_dtype=torch.int32,
853
- )[0]
854
-
855
- # act_fn: swiglu
856
- hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
857
- x=hidden_states,
858
- weight_scale=self.w13_weight_scale.to(torch.float32),
859
- activation_scale=per_token_scale,
860
- bias=None,
861
- quant_scale=None,
862
- quant_offset=None,
863
- group_index=group_list,
864
- activate_left=True,
865
- quant_mode=1,
866
- )
893
+ if self.w13_weight.dtype != torch.int8:
894
+ # gmm1: gate_up_proj
895
+ hidden_states = torch_npu.npu_grouped_matmul(
896
+ x=[hidden_states],
897
+ weight=[self.w13_weight.permute(0, 2, 1)],
898
+ # per_token_scale=[per_token_scale],
899
+ split_item=2,
900
+ group_list_type=group_list_type,
901
+ group_type=0,
902
+ group_list=group_list,
903
+ output_dtype=output_dtype,
904
+ )[0]
905
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
906
+ # gmm2: down_proj
907
+ hidden_states = torch_npu.npu_grouped_matmul(
908
+ x=[hidden_states],
909
+ weight=[self.w2_weight.permute(0, 2, 1)],
910
+ split_item=2,
911
+ group_list_type=group_list_type,
912
+ group_type=0,
913
+ group_list=group_list,
914
+ output_dtype=output_dtype,
915
+ )[0]
916
+ else:
917
+ # gmm1: gate_up_proj
918
+ hidden_states = torch_npu.npu_grouped_matmul(
919
+ x=[hidden_states],
920
+ weight=[self.w13_weight],
921
+ split_item=2,
922
+ group_list_type=group_list_type,
923
+ group_type=0,
924
+ group_list=group_list,
925
+ output_dtype=torch.int32,
926
+ )[0]
927
+
928
+ # act_fn: swiglu
929
+ hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
930
+ x=hidden_states,
931
+ weight_scale=self.w13_weight_scale.to(torch.float32),
932
+ activation_scale=per_token_scale,
933
+ bias=None,
934
+ quant_scale=None,
935
+ quant_offset=None,
936
+ group_index=group_list,
937
+ activate_left=True,
938
+ quant_mode=1,
939
+ )
867
940
 
868
- # gmm2: down_proj
869
- hidden_states = torch_npu.npu_grouped_matmul(
870
- x=[hidden_states],
871
- weight=[self.w2_weight],
872
- scale=[self.w2_weight_scale.to(output_dtype)],
873
- per_token_scale=[swiglu_out_scale],
874
- split_item=2,
875
- group_list_type=group_list_type,
876
- group_type=0,
877
- group_list=group_list,
878
- output_dtype=output_dtype,
879
- )[0]
941
+ # gmm2: down_proj
942
+ hidden_states = torch_npu.npu_grouped_matmul(
943
+ x=[hidden_states],
944
+ weight=[self.w2_weight],
945
+ scale=[self.w2_weight_scale.to(output_dtype)],
946
+ per_token_scale=[swiglu_out_scale],
947
+ split_item=2,
948
+ group_list_type=group_list_type,
949
+ group_type=0,
950
+ group_list=group_list,
951
+ output_dtype=output_dtype,
952
+ )[0]
880
953
 
881
954
  return hidden_states
882
955
 
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import Any, Dict, Optional, Union
2
2
 
3
3
  import torch
4
4
  from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
@@ -20,7 +20,7 @@ def get_cute_dtype(input: torch.Tensor) -> str:
20
20
 
21
21
 
22
22
  def flashinfer_cutedsl_moe_masked(
23
- hidden_states: torch.Tensor,
23
+ hidden_states: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
24
24
  input_global_scale: torch.Tensor,
25
25
  w1: torch.Tensor,
26
26
  w1_blockscale: torch.Tensor,
@@ -30,13 +30,18 @@ def flashinfer_cutedsl_moe_masked(
30
30
  w2_blockscale: torch.Tensor,
31
31
  w2_alpha,
32
32
  masked_m: torch.Tensor,
33
+ down_sm_count: Optional[int] = None,
34
+ down_signals: Optional[torch.Tensor] = None,
35
+ down_start_event: Optional[torch.cuda.Event] = None,
33
36
  ):
34
37
  """
35
38
  Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
36
39
  kernels.
37
40
 
38
41
  Args:
39
- hidden_states (torch.Tensor): [num_experts, m, k], bf16
42
+ hidden_states: Either of the following case
43
+ * torch.Tensor: [num_experts, m, k], bf16
44
+ * tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], uint8, [num_experts, m, k // 16], float8_e4m3fn
40
45
  input_global_scale (torch.Tensor): (l,)
41
46
  w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
42
47
  w1_blockscale (torch.Tensor): blockscale factors, e4m3,
@@ -48,13 +53,10 @@ def flashinfer_cutedsl_moe_masked(
48
53
  masked_m (torch.Tensor): Masked dimension indices
49
54
 
50
55
  Notes:
51
- - Assumes max(masked_m) <= m.
56
+ - Assumes max(masked_m) == m.
52
57
  """
53
58
 
54
59
  # === Assertions on dtypes ===
55
- assert (
56
- input_global_scale.dtype == torch.float32
57
- ), f"input_global_scale must be float32, got {input_global_scale.dtype}"
58
60
  assert w1.dtype == torch.uint8, f"w1 must be uint8 (fp4 packed), got {w1.dtype}"
59
61
  assert (
60
62
  w1_blockscale.dtype == torch.float8_e4m3fn
@@ -75,7 +77,31 @@ def flashinfer_cutedsl_moe_masked(
75
77
 
76
78
  # === Assertions on shapes ===
77
79
  n = w2.shape[-1] * 2 # intermediate dimension
78
- num_experts, m, k = hidden_states.shape
80
+
81
+ if isinstance(hidden_states, tuple):
82
+ assert (
83
+ input_global_scale is None
84
+ ), "input_global_scale is needed when input needs quant"
85
+
86
+ a_q = hidden_states[0].view(torch.uint8)
87
+ a_q_sf = hidden_states[1].view(torch.float8_e4m3fn)
88
+ m, k_by_2, num_experts = a_q.shape
89
+ k = k_by_2 * 2
90
+ else:
91
+ num_experts, m, k = hidden_states.shape
92
+
93
+ assert (
94
+ input_global_scale.dtype == torch.float32
95
+ ), f"input_global_scale must be float32, got {input_global_scale.dtype}"
96
+ assert input_global_scale.shape == (
97
+ num_experts,
98
+ ), f"input_global_scale must be (l,), got {input_global_scale.shape}"
99
+
100
+ a_q, a_q_sf = scaled_fp4_grouped_quant(
101
+ hidden_states,
102
+ input_global_scale,
103
+ masked_m,
104
+ )
79
105
 
80
106
  assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
81
107
  assert (
@@ -85,10 +111,6 @@ def flashinfer_cutedsl_moe_masked(
85
111
  k,
86
112
  n // 2,
87
113
  ), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n//2)}"
88
-
89
- assert input_global_scale.shape == (
90
- num_experts,
91
- ), f"input_global_scale must be (l,), got {input_global_scale.shape}"
92
114
  assert w1_alpha.shape == (
93
115
  num_experts,
94
116
  ), f"w1_alpha must be (l,), got {w1_alpha.shape}"
@@ -99,27 +121,21 @@ def flashinfer_cutedsl_moe_masked(
99
121
  num_experts,
100
122
  ), f"w2_alpha must be (l,), got {w2_alpha.shape}"
101
123
 
102
- aq, aq_sf = scaled_fp4_grouped_quant(
103
- hidden_states,
104
- input_global_scale,
105
- masked_m,
106
- )
124
+ # TODO(kaixih@nvidia): dtype should be based on inputs.
107
125
  gateup_output = torch.empty(
108
- (num_experts, m, n * 2), dtype=hidden_states.dtype, device=aq.device
126
+ (num_experts, m, n * 2), dtype=torch.bfloat16, device=a_q.device
109
127
  )
110
128
  gateup_output = gateup_output.permute(1, 2, 0) # requirement of kernel
111
129
  sf_vec_size = 16
112
- assert aq_sf.dtype == torch.float8_e4m3fn
113
- assert aq.dtype == torch.uint8
130
+ assert a_q_sf.dtype == torch.float8_e4m3fn
131
+ assert a_q.dtype == torch.uint8
114
132
  ab_dtype = "float4_e2m1fn"
115
133
  sf_dtype = "float8_e4m3fn"
116
-
117
- c_dtype = get_cute_dtype(hidden_states)
134
+ c_dtype = "bfloat16"
118
135
 
119
136
  # Gemm1
120
-
121
137
  grouped_gemm_nt_masked(
122
- (aq, aq_sf),
138
+ (a_q, a_q_sf),
123
139
  (w1.permute(1, 2, 0), w1_blockscale),
124
140
  gateup_output,
125
141
  masked_m,
@@ -138,8 +154,11 @@ def flashinfer_cutedsl_moe_masked(
138
154
  masked_m,
139
155
  )
140
156
 
157
+ if down_start_event is not None:
158
+ down_start_event.record()
159
+
141
160
  # Gemm2
142
- out = torch.empty_like(hidden_states)
161
+ out = torch.empty((num_experts, m, k), dtype=torch.bfloat16, device=a_q.device)
143
162
  out = out.permute(1, 2, 0) # requirement of kernel
144
163
  grouped_gemm_nt_masked(
145
164
  (diq, diq_sf),
@@ -152,5 +171,13 @@ def flashinfer_cutedsl_moe_masked(
152
171
  sf_vec_size=sf_vec_size,
153
172
  alpha=w2_alpha.view(1, 1, num_experts),
154
173
  alpha_dtype=get_cute_dtype(w2_alpha),
174
+ **(
175
+ dict(
176
+ sm_count=down_sm_count,
177
+ dst_signals=down_signals,
178
+ )
179
+ if down_sm_count is not None or down_signals is not None
180
+ else {}
181
+ ),
155
182
  ) # in logical [m, k, l]
156
183
  return out.permute(2, 0, 1)
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 64,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 16,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 64,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 64,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 64,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 32,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }