sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import TYPE_CHECKING, Optional, Union
4
+ from typing import TYPE_CHECKING, List, Optional
5
5
 
6
6
  import torch
7
7
  import triton
@@ -12,6 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
12
12
  from sglang.srt.layers.dp_attention import get_attention_tp_size
13
13
  from sglang.srt.layers.radix_attention import AttentionType
14
14
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
15
+ from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
15
16
  from sglang.srt.utils import (
16
17
  get_bool_env_var,
17
18
  get_device_core_count,
@@ -63,13 +64,19 @@ class TritonAttnBackend(AttentionBackend):
63
64
  decode_attention_fwd,
64
65
  )
65
66
  from sglang.srt.layers.attention.triton_ops.extend_attention import (
67
+ build_unified_kv_indices,
66
68
  extend_attention_fwd,
69
+ extend_attention_fwd_unified,
67
70
  )
68
71
 
69
72
  super().__init__()
70
73
 
71
74
  self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
72
75
  self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
76
+ self.extend_attention_fwd_unified = torch.compiler.disable(
77
+ extend_attention_fwd_unified
78
+ )
79
+ self.build_unified_kv_indices = torch.compiler.disable(build_unified_kv_indices)
73
80
 
74
81
  # Parse args
75
82
  self.skip_prefill = skip_prefill
@@ -85,7 +92,7 @@ class TritonAttnBackend(AttentionBackend):
85
92
  self.num_kv_head = model_runner.model_config.get_num_kv_heads(
86
93
  get_attention_tp_size()
87
94
  )
88
- if model_runner.is_hybrid_gdn:
95
+ if model_runner.hybrid_gdn_config is not None:
89
96
  # For hybrid linear models, layer_id = 0 may not be full attention
90
97
  self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
91
98
  else:
@@ -162,6 +169,8 @@ class TritonAttnBackend(AttentionBackend):
162
169
  # Initialize forward metadata
163
170
  self.forward_metadata: ForwardMetadata = None
164
171
 
172
+ self.cuda_graph_custom_mask = None
173
+
165
174
  def get_num_kv_splits(
166
175
  self,
167
176
  num_kv_splits: torch.Tensor,
@@ -362,7 +371,7 @@ class TritonAttnBackend(AttentionBackend):
362
371
  )
363
372
  kv_indptr = kv_indptr[: bs + 1]
364
373
  kv_indices = torch.empty(
365
- forward_batch.extend_prefix_lens.sum().item(),
374
+ sum(forward_batch.extend_prefix_lens_cpu),
366
375
  dtype=torch.int64,
367
376
  device=self.device,
368
377
  )
@@ -421,6 +430,7 @@ class TritonAttnBackend(AttentionBackend):
421
430
  max_bs: int,
422
431
  max_num_tokens: int,
423
432
  kv_indices_buf: Optional[torch.Tensor] = None,
433
+ cuda_graph_num_kv_splits_buf: Optional[torch.Tensor] = None,
424
434
  ):
425
435
  self.cuda_graph_attn_logits = torch.zeros(
426
436
  (max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
@@ -432,9 +442,17 @@ class TritonAttnBackend(AttentionBackend):
432
442
  dtype=torch.float32,
433
443
  device=self.device,
434
444
  )
435
- self.cuda_graph_num_kv_splits = torch.full(
436
- (max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
437
- )
445
+
446
+ if cuda_graph_num_kv_splits_buf is None:
447
+ self.cuda_graph_num_kv_splits = torch.full(
448
+ (max_num_tokens,),
449
+ self.max_kv_splits,
450
+ dtype=torch.int32,
451
+ device=self.device,
452
+ )
453
+ else:
454
+ self.cuda_graph_num_kv_splits = cuda_graph_num_kv_splits_buf
455
+
438
456
  if kv_indices_buf is None:
439
457
  self.cuda_graph_kv_indices = torch.zeros(
440
458
  (max_num_tokens * self.max_context_len),
@@ -681,9 +699,7 @@ class TritonAttnBackend(AttentionBackend):
681
699
  )
682
700
 
683
701
  else:
684
- kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
685
- kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
686
- num_token = spec_info.kv_indptr.shape[0] - 1
702
+ assert False, "Multi-step cuda graph init is not done here."
687
703
  self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
688
704
 
689
705
  elif forward_mode.is_target_verify():
@@ -755,6 +771,19 @@ class TritonAttnBackend(AttentionBackend):
755
771
  def get_cuda_graph_seq_len_fill_value(self):
756
772
  return 1
757
773
 
774
+ def get_verify_buffers_to_fill_after_draft(self):
775
+ """
776
+ Return buffers for verify attention kernels that needs to be filled after draft.
777
+
778
+ Typically, these are tree mask and position buffers.
779
+ """
780
+ return [self.cuda_graph_custom_mask, None]
781
+
782
+ def update_verify_buffers_to_fill_after_draft(
783
+ self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
784
+ ):
785
+ pass
786
+
758
787
  def forward_extend(
759
788
  self,
760
789
  q: torch.Tensor,
@@ -771,6 +800,7 @@ class TritonAttnBackend(AttentionBackend):
771
800
  else:
772
801
  o = torch.empty_like(q)
773
802
 
803
+ # Save KV cache first (must do this before unified kernel)
774
804
  if save_kv_cache:
775
805
  forward_batch.token_to_kv_pool.set_kv_buffer(
776
806
  layer, forward_batch.out_cache_loc, k, v
@@ -779,9 +809,16 @@ class TritonAttnBackend(AttentionBackend):
779
809
  logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
780
810
 
781
811
  causal = True
782
- if layer.attn_type == AttentionType.ENCODER_ONLY:
812
+ if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
783
813
  causal = False
784
814
 
815
+ # Deterministic mode: use unified 1-stage kernel
816
+ if self.enable_deterministic:
817
+ return self._forward_extend_unified(
818
+ q, o, layer, forward_batch, causal, logits_soft_cap, sinks
819
+ )
820
+
821
+ # Normal mode: use original 2-stage kernel
785
822
  if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
786
823
  sliding_window_size = (
787
824
  layer.sliding_window_size
@@ -818,6 +855,127 @@ class TritonAttnBackend(AttentionBackend):
818
855
  )
819
856
  return o
820
857
 
858
+ def _forward_extend_unified(
859
+ self,
860
+ q: torch.Tensor,
861
+ o: torch.Tensor,
862
+ layer: RadixAttention,
863
+ forward_batch: ForwardBatch,
864
+ causal: bool,
865
+ logits_soft_cap: float,
866
+ sinks: Optional[torch.Tensor],
867
+ ):
868
+ """
869
+ Unified 1-stage extend attention for deterministic inference.
870
+ Both prefix and extend KV are accessed through unified kv_indices.
871
+ """
872
+ bs = forward_batch.batch_size
873
+
874
+ # Determine sliding window settings
875
+ if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
876
+ sliding_window_size = layer.sliding_window_size
877
+ # Note: for unified kernel, we use full kv_indptr (not window)
878
+ prefix_kv_indptr = self.forward_metadata.window_kv_indptr
879
+ prefix_kv_indices = self.forward_metadata.window_kv_indices
880
+ # Compute window start positions (absolute position of first key in window)
881
+ # window_start_pos = seq_len - window_len
882
+ window_kv_lens = prefix_kv_indptr[1 : bs + 1] - prefix_kv_indptr[:bs]
883
+ # Handle TARGET_VERIFY mode where extend_prefix_lens might not be set
884
+ if forward_batch.extend_prefix_lens is not None:
885
+ window_start_pos = (
886
+ forward_batch.extend_prefix_lens[:bs] - window_kv_lens
887
+ )
888
+ else:
889
+ # Infer from spec_info: prefix_len = seq_len - draft_token_num
890
+ if forward_batch.spec_info is not None and hasattr(
891
+ forward_batch.spec_info, "draft_token_num"
892
+ ):
893
+ extend_prefix_lens = (
894
+ forward_batch.seq_lens[:bs]
895
+ - forward_batch.spec_info.draft_token_num
896
+ )
897
+ window_start_pos = extend_prefix_lens - window_kv_lens
898
+ else:
899
+ window_start_pos = None
900
+ else:
901
+ sliding_window_size = -1
902
+ prefix_kv_indptr = self.forward_metadata.kv_indptr
903
+ prefix_kv_indices = self.forward_metadata.kv_indices
904
+ window_start_pos = None
905
+
906
+ # Build unified kv_indices using fused Triton kernel
907
+ extend_kv_indices = forward_batch.out_cache_loc
908
+
909
+ # Handle cases where extend_seq_lens or extend_start_loc might not be set
910
+ # In speculative decoding, we can infer these from spec_info or compute them
911
+ if forward_batch.extend_seq_lens is None:
912
+ # TARGET_VERIFY mode: infer extend_seq_lens from spec_info
913
+ if forward_batch.spec_info is not None and hasattr(
914
+ forward_batch.spec_info, "draft_token_num"
915
+ ):
916
+ draft_token_num = forward_batch.spec_info.draft_token_num
917
+ extend_seq_lens = torch.full(
918
+ (bs,), draft_token_num, dtype=torch.int32, device=self.device
919
+ )
920
+ else:
921
+ raise RuntimeError(
922
+ "extend_seq_lens is None but cannot infer from spec_info. "
923
+ "This should not happen in TARGET_VERIFY mode."
924
+ )
925
+ else:
926
+ extend_seq_lens = forward_batch.extend_seq_lens
927
+
928
+ # Check extend_start_loc separately - it might be None even when extend_seq_lens is set
929
+ if forward_batch.extend_start_loc is None:
930
+ # Compute extend_start_loc from extend_seq_lens
931
+ # extend_start_loc[i] = sum(extend_seq_lens[0:i])
932
+ extend_start_loc = torch.cat(
933
+ [
934
+ torch.zeros(1, dtype=torch.int32, device=self.device),
935
+ torch.cumsum(extend_seq_lens[:-1], dim=0),
936
+ ]
937
+ )
938
+ else:
939
+ extend_start_loc = forward_batch.extend_start_loc
940
+
941
+ unified_kv_indptr, unified_kv_indices, prefix_lens = (
942
+ self.build_unified_kv_indices(
943
+ prefix_kv_indptr,
944
+ prefix_kv_indices,
945
+ extend_start_loc,
946
+ extend_seq_lens,
947
+ extend_kv_indices,
948
+ bs,
949
+ )
950
+ )
951
+
952
+ # Convert prefix_lens to int32 for the kernel
953
+ prefix_lens = prefix_lens.to(torch.int32)
954
+
955
+ # Call unified kernel
956
+ self.extend_attention_fwd_unified(
957
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
958
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
959
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
960
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
961
+ self.forward_metadata.qo_indptr,
962
+ unified_kv_indptr,
963
+ unified_kv_indices,
964
+ prefix_lens,
965
+ self.forward_metadata.max_extend_len,
966
+ custom_mask=self.forward_metadata.custom_mask,
967
+ mask_indptr=self.forward_metadata.mask_indptr,
968
+ sm_scale=layer.scaling,
969
+ logit_cap=logits_soft_cap,
970
+ is_causal=causal,
971
+ sliding_window_size=sliding_window_size,
972
+ sinks=sinks,
973
+ window_start_pos=window_start_pos,
974
+ xai_temperature_len=layer.xai_temperature_len,
975
+ )
976
+
977
+ return o
978
+
821
979
  def forward_decode(
822
980
  self,
823
981
  q: torch.Tensor,
@@ -883,11 +1041,8 @@ class TritonMultiStepDraftBackend:
883
1041
  topk: int,
884
1042
  speculative_num_steps: int,
885
1043
  ):
886
- from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
887
-
888
1044
  self.topk = topk
889
1045
  self.speculative_num_steps = speculative_num_steps
890
- self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
891
1046
  max_bs = model_runner.req_to_token_pool.size * self.topk
892
1047
  self.kv_indptr = torch.zeros(
893
1048
  (
@@ -897,8 +1052,8 @@ class TritonMultiStepDraftBackend:
897
1052
  dtype=torch.int32,
898
1053
  device=model_runner.device,
899
1054
  )
900
- self.attn_backends = []
901
- for i in range(self.speculative_num_steps):
1055
+ self.attn_backends: List[TritonAttnBackend] = []
1056
+ for i in range(self.speculative_num_steps - 1):
902
1057
  self.attn_backends.append(
903
1058
  TritonAttnBackend(
904
1059
  model_runner,
@@ -916,13 +1071,19 @@ class TritonMultiStepDraftBackend:
916
1071
  self.page_size = model_runner.server_args.page_size
917
1072
 
918
1073
  def common_template(
919
- self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
1074
+ self,
1075
+ forward_batch: ForwardBatch,
1076
+ kv_indices_buffer: Optional[torch.Tensor],
1077
+ call_fn: int,
920
1078
  ):
1079
+ if kv_indices_buffer is None:
1080
+ kv_indices_buffer = self.cuda_graph_kv_indices
1081
+
921
1082
  num_seqs = forward_batch.batch_size
922
1083
  bs = self.topk * num_seqs
923
1084
  seq_lens_sum = forward_batch.seq_lens_sum
924
1085
 
925
- self.generate_draft_decode_kv_indices[
1086
+ generate_draft_decode_kv_indices[
926
1087
  (self.speculative_num_steps, num_seqs, self.topk)
927
1088
  ](
928
1089
  forward_batch.req_pool_indices,
@@ -940,7 +1101,10 @@ class TritonMultiStepDraftBackend:
940
1101
  self.page_size,
941
1102
  )
942
1103
 
943
- for i in range(self.speculative_num_steps):
1104
+ if call_fn is None:
1105
+ return
1106
+
1107
+ for i in range(self.speculative_num_steps - 1):
944
1108
  forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
945
1109
  forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
946
1110
  : seq_lens_sum * self.topk + bs * (i + 1)
@@ -974,9 +1138,19 @@ class TritonMultiStepDraftBackend:
974
1138
  dtype=torch.int64,
975
1139
  device=self.device,
976
1140
  )
977
- for i in range(self.speculative_num_steps):
1141
+ self.cuda_graph_num_kv_splits = torch.full(
1142
+ (max_num_tokens,),
1143
+ self.attn_backends[0].max_kv_splits,
1144
+ dtype=torch.int32,
1145
+ device=self.device,
1146
+ )
1147
+
1148
+ for i in range(self.speculative_num_steps - 1):
978
1149
  self.attn_backends[i].init_cuda_graph_state(
979
- max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
1150
+ max_bs,
1151
+ max_num_tokens,
1152
+ kv_indices_buf=self.cuda_graph_kv_indices[i],
1153
+ cuda_graph_num_kv_splits_buf=self.cuda_graph_num_kv_splits,
980
1154
  )
981
1155
 
982
1156
  def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
@@ -991,24 +1165,24 @@ class TritonMultiStepDraftBackend:
991
1165
  spec_info=forward_batch.spec_info,
992
1166
  )
993
1167
 
994
- self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
1168
+ self.common_template(forward_batch, None, call_fn)
995
1169
 
996
1170
  def init_forward_metadata_replay_cuda_graph(
997
1171
  self, forward_batch: ForwardBatch, bs: int
998
1172
  ):
999
- def call_fn(i, forward_batch):
1000
- self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
1001
- bs,
1002
- forward_batch.req_pool_indices,
1003
- forward_batch.seq_lens,
1004
- seq_lens_sum=-1,
1005
- encoder_lens=None,
1006
- forward_mode=ForwardMode.DECODE,
1007
- spec_info=forward_batch.spec_info,
1008
- seq_lens_cpu=None,
1009
- )
1010
-
1011
- self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
1173
+ self.common_template(forward_batch, None, None)
1174
+
1175
+ # NOTE: Multi-step's attention backends use the slice of
1176
+ # - kv_indptr buffer (cuda graph and non-cuda graph)
1177
+ # - kv_indices buffer (cuda graph only)
1178
+ # So we don't need to assign the KV indices inside the attention backend.
1179
+
1180
+ # Compute num_kv_splits only once
1181
+ num_token = forward_batch.batch_size * self.topk
1182
+ self.attn_backends[-1].get_num_kv_splits(
1183
+ self.attn_backends[-1].cuda_graph_num_kv_splits[:num_token],
1184
+ forward_batch.seq_lens[:bs],
1185
+ )
1012
1186
 
1013
1187
 
1014
1188
  @triton.jit
@@ -2,7 +2,7 @@ import torch
2
2
  import triton
3
3
  import triton.language as tl
4
4
 
5
- from sglang.srt.managers.schedule_batch import global_server_args_dict
5
+ from sglang.srt.server_args import get_global_server_args
6
6
  from sglang.srt.utils import is_cuda, is_hip
7
7
 
8
8
  _is_cuda = is_cuda()
@@ -11,7 +11,7 @@ if _is_cuda:
11
11
 
12
12
  _is_hip = is_hip()
13
13
 
14
- if global_server_args_dict.get("attention_reduce_in_fp32", False):
14
+ if get_global_server_args().triton_attention_reduce_in_fp32:
15
15
  REDUCE_TRITON_TYPE = tl.float32
16
16
  REDUCE_TORCH_TYPE = torch.float32
17
17
  else: