sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +208 -295
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  64. sglang/srt/layers/moe/topk.py +13 -4
  65. sglang/srt/layers/quantization/__init__.py +111 -7
  66. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  69. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  71. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  72. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  73. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  75. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  80. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  82. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  83. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  87. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  89. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  91. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  93. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  95. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  96. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  97. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  98. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  100. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  102. sglang/srt/layers/quantization/fp8.py +69 -28
  103. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  104. sglang/srt/layers/quantization/gptq.py +416 -0
  105. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  106. sglang/srt/layers/quantization/int8_utils.py +73 -0
  107. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  108. sglang/srt/layers/radix_attention.py +1 -0
  109. sglang/srt/layers/rotary_embedding.py +0 -1
  110. sglang/srt/layers/sampler.py +76 -31
  111. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  112. sglang/srt/lora/lora.py +17 -1
  113. sglang/srt/lora/lora_config.py +5 -0
  114. sglang/srt/lora/lora_manager.py +1 -3
  115. sglang/srt/managers/cache_controller.py +193 -62
  116. sglang/srt/managers/configure_logging.py +2 -1
  117. sglang/srt/managers/data_parallel_controller.py +6 -2
  118. sglang/srt/managers/detokenizer_manager.py +124 -102
  119. sglang/srt/managers/image_processor.py +2 -1
  120. sglang/srt/managers/io_struct.py +143 -6
  121. sglang/srt/managers/schedule_batch.py +238 -197
  122. sglang/srt/managers/schedule_policy.py +29 -29
  123. sglang/srt/managers/scheduler.py +681 -259
  124. sglang/srt/managers/session_controller.py +6 -2
  125. sglang/srt/managers/tokenizer_manager.py +224 -68
  126. sglang/srt/managers/tp_worker.py +15 -4
  127. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  128. sglang/srt/mem_cache/chunk_cache.py +18 -11
  129. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  130. sglang/srt/mem_cache/memory_pool.py +44 -18
  131. sglang/srt/mem_cache/radix_cache.py +58 -47
  132. sglang/srt/metrics/collector.py +94 -36
  133. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  134. sglang/srt/model_executor/forward_batch_info.py +49 -16
  135. sglang/srt/model_executor/model_runner.py +209 -28
  136. sglang/srt/model_loader/loader.py +3 -3
  137. sglang/srt/model_loader/weight_utils.py +36 -14
  138. sglang/srt/models/baichuan.py +31 -6
  139. sglang/srt/models/chatglm.py +39 -7
  140. sglang/srt/models/commandr.py +29 -5
  141. sglang/srt/models/dbrx.py +31 -5
  142. sglang/srt/models/deepseek.py +43 -6
  143. sglang/srt/models/deepseek_nextn.py +32 -19
  144. sglang/srt/models/deepseek_v2.py +265 -29
  145. sglang/srt/models/exaone.py +19 -9
  146. sglang/srt/models/gemma.py +22 -8
  147. sglang/srt/models/gemma2.py +25 -12
  148. sglang/srt/models/gemma2_reward.py +5 -1
  149. sglang/srt/models/gpt2.py +28 -13
  150. sglang/srt/models/gpt_bigcode.py +27 -5
  151. sglang/srt/models/granite.py +21 -9
  152. sglang/srt/models/grok.py +21 -4
  153. sglang/srt/models/internlm2.py +36 -6
  154. sglang/srt/models/internlm2_reward.py +5 -1
  155. sglang/srt/models/llama.py +26 -9
  156. sglang/srt/models/llama_classification.py +5 -1
  157. sglang/srt/models/llama_eagle.py +17 -4
  158. sglang/srt/models/llama_embedding.py +5 -1
  159. sglang/srt/models/llama_reward.py +7 -2
  160. sglang/srt/models/llava.py +19 -3
  161. sglang/srt/models/llavavid.py +10 -1
  162. sglang/srt/models/minicpm.py +26 -2
  163. sglang/srt/models/minicpm3.py +39 -3
  164. sglang/srt/models/minicpmv.py +45 -14
  165. sglang/srt/models/mixtral.py +20 -9
  166. sglang/srt/models/mixtral_quant.py +50 -8
  167. sglang/srt/models/mllama.py +57 -11
  168. sglang/srt/models/olmo.py +34 -6
  169. sglang/srt/models/olmo2.py +34 -13
  170. sglang/srt/models/olmoe.py +26 -4
  171. sglang/srt/models/phi3_small.py +29 -10
  172. sglang/srt/models/qwen.py +26 -3
  173. sglang/srt/models/qwen2.py +26 -4
  174. sglang/srt/models/qwen2_5_vl.py +46 -8
  175. sglang/srt/models/qwen2_eagle.py +17 -5
  176. sglang/srt/models/qwen2_moe.py +44 -6
  177. sglang/srt/models/qwen2_rm.py +78 -0
  178. sglang/srt/models/qwen2_vl.py +39 -8
  179. sglang/srt/models/stablelm.py +32 -5
  180. sglang/srt/models/torch_native_llama.py +5 -2
  181. sglang/srt/models/xverse.py +21 -9
  182. sglang/srt/models/xverse_moe.py +45 -7
  183. sglang/srt/models/yivl.py +2 -1
  184. sglang/srt/openai_api/adapter.py +109 -24
  185. sglang/srt/openai_api/protocol.py +17 -1
  186. sglang/srt/reasoning_parser.py +154 -0
  187. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  188. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  189. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  190. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  191. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  192. sglang/srt/sampling/sampling_batch_info.py +79 -157
  193. sglang/srt/sampling/sampling_params.py +16 -13
  194. sglang/srt/server_args.py +136 -52
  195. sglang/srt/speculative/build_eagle_tree.py +2 -8
  196. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  197. sglang/srt/speculative/eagle_utils.py +92 -58
  198. sglang/srt/speculative/eagle_worker.py +186 -94
  199. sglang/srt/speculative/spec_info.py +1 -13
  200. sglang/srt/utils.py +43 -17
  201. sglang/srt/warmup.py +47 -0
  202. sglang/test/few_shot_gsm8k.py +4 -1
  203. sglang/test/runners.py +389 -126
  204. sglang/test/send_one.py +88 -0
  205. sglang/test/test_block_fp8_ep.py +361 -0
  206. sglang/test/test_programs.py +1 -1
  207. sglang/test/test_utils.py +138 -84
  208. sglang/utils.py +50 -60
  209. sglang/version.py +1 -1
  210. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  211. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
  212. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  213. sglang/bench_latency.py +0 -1
  214. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  215. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  216. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  217. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  218. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  219. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -74,6 +74,8 @@ def _fwd_kernel(
74
74
  BLOCK_M: tl.constexpr,
75
75
  BLOCK_N: tl.constexpr,
76
76
  USE_CUSTOM_MASK: tl.constexpr,
77
+ SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
78
+ STORE_TRANSPOSE: tl.constexpr,
77
79
  ):
78
80
  cur_seq = tl.program_id(0)
79
81
  cur_head = tl.program_id(1)
@@ -159,7 +161,7 @@ def _fwd_kernel(
159
161
  if logit_cap > 0:
160
162
  qk = logit_cap * tanh(qk / logit_cap)
161
163
 
162
- if USE_CUSTOM_MASK:
164
+ if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
163
165
  custom_mask = tl.load(
164
166
  mask_ptr
165
167
  + cur_seq_mask_start_idx
@@ -272,9 +274,18 @@ def _fwd_kernel(
272
274
  + cur_head * stride_oh
273
275
  + offs_dv[None, :]
274
276
  )
275
- tl.store(
276
- O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :]
277
- )
277
+ if STORE_TRANSPOSE:
278
+ tl.store(
279
+ O_Extend + offs_o.T,
280
+ (acc / deno[:, None]).T,
281
+ mask=(mask_m[:, None] & mask_dv[None, :]).T,
282
+ )
283
+ else:
284
+ tl.store(
285
+ O_Extend + offs_o,
286
+ acc / deno[:, None],
287
+ mask=mask_m[:, None] & mask_dv[None, :],
288
+ )
278
289
 
279
290
 
280
291
  def extend_attention_fwd(
@@ -292,6 +303,7 @@ def extend_attention_fwd(
292
303
  max_len_extend,
293
304
  sm_scale=None,
294
305
  logit_cap=0.0,
306
+ skip_prefix_custom_mask=True,
295
307
  ):
296
308
  """
297
309
  q_extend, k_extend, v_extend, o_extend: contiguous tensors
@@ -345,6 +357,8 @@ def extend_attention_fwd(
345
357
  kv_group_num = q_extend.shape[1] // k_extend.shape[1]
346
358
 
347
359
  USE_CUSTOM_MASK = custom_mask is not None
360
+ # Skip custom mask for prefix part
361
+ SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
348
362
 
349
363
  grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
350
364
  num_stages = 1
@@ -388,6 +402,8 @@ def extend_attention_fwd(
388
402
  Lq=Lq,
389
403
  Lv=Lv,
390
404
  USE_CUSTOM_MASK=USE_CUSTOM_MASK,
405
+ SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
406
+ STORE_TRANSPOSE=is_hip_,
391
407
  num_warps=num_warps,
392
408
  num_stages=num_stages,
393
409
  **extra_kargs,
@@ -0,0 +1,439 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """
15
+ Memory-efficient attention for decoding.
16
+ It supports page size = 1.
17
+ """
18
+
19
+ # Adapted from
20
+ # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
21
+ # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
22
+
23
+ import triton
24
+ import triton.language as tl
25
+
26
+ from sglang.srt.layers.attention.triton_ops.decode_attention import (
27
+ _decode_softmax_reducev_fwd,
28
+ )
29
+
30
+
31
+ def is_hip():
32
+ return triton.runtime.driver.active.get_current_target().backend == "hip"
33
+
34
+
35
+ is_hip_ = is_hip()
36
+
37
+
38
+ @triton.jit
39
+ def tanh(x):
40
+ # Tanh is just a scaled sigmoid
41
+ return 2 * tl.sigmoid(2 * x) - 1
42
+
43
+
44
+ @triton.jit
45
+ def _fwd_grouped_kernel_stage1_rope(
46
+ Q, # Holds [Q_NOPE; Q_PE], b x h x (d+r)
47
+ K_Buffer, # Holds [KV; K_PE], b*s x (c+r)
48
+ V_buffer, # Holds [KV], b*s x (c)
49
+ cos_sin_cache, # max_seq_len x (rotary_dim * 2)
50
+ positions, # sequence positions
51
+ sm_scale,
52
+ kv_indptr,
53
+ kv_indices,
54
+ Att_Out, # b x h x NUM_KV_SPLITS x (kv_lora_rank + 1)
55
+ k_pe_t_out,
56
+ stride_qb,
57
+ stride_qh,
58
+ stride_buf_kbs,
59
+ stride_buf_vbs,
60
+ stride_mid_ob,
61
+ stride_mid_oh,
62
+ stride_mid_os,
63
+ stride_kpe_tokens_out_b,
64
+ stride_cos_sin_cache_s,
65
+ stride_positions_b,
66
+ rotary_dim: tl.constexpr,
67
+ kv_lora_rank: tl.constexpr,
68
+ qk_rope_head_dim: tl.constexpr,
69
+ kv_group_num: tl.constexpr,
70
+ q_head_num: tl.constexpr,
71
+ BLOCK_C: tl.constexpr,
72
+ BLOCK_R: tl.constexpr,
73
+ BLOCK_N: tl.constexpr,
74
+ BLOCK_H: tl.constexpr,
75
+ NUM_KV_SPLITS: tl.constexpr,
76
+ logit_cap: tl.constexpr,
77
+ USE_ROPE: tl.constexpr,
78
+ IS_NEOX_STYLE: tl.constexpr,
79
+ ):
80
+
81
+ cur_batch = tl.program_id(0)
82
+ cur_head_id = tl.program_id(1)
83
+ split_kv_id = tl.program_id(2)
84
+
85
+ if BLOCK_H < kv_group_num:
86
+ VALID_BLOCK_H: tl.constexpr = BLOCK_H
87
+ else:
88
+ VALID_BLOCK_H: tl.constexpr = kv_group_num
89
+ cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
90
+ mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
91
+ mask_h = mask_h & (cur_head < q_head_num)
92
+
93
+ offs_c = tl.arange(0, BLOCK_C)
94
+ offs_qk_r = tl.arange(kv_lora_rank, kv_lora_rank + BLOCK_R) # to get the k_pe
95
+
96
+ off_q_pe = (
97
+ cur_batch * stride_qb + cur_head[:, None] * stride_qh + offs_qk_r[None, :]
98
+ )
99
+ offs_q = cur_batch * stride_qb + cur_head[:, None] * stride_qh + offs_c[None, :]
100
+
101
+ mask_c = offs_c < kv_lora_rank
102
+ mask_qk_r = offs_qk_r < (kv_lora_rank + qk_rope_head_dim)
103
+
104
+ cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
105
+ cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
106
+
107
+ q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_c[None, :]), other=0.0)
108
+ q_pe = tl.load(
109
+ Q + off_q_pe, mask=(mask_h[:, None]) & (mask_qk_r[None, :]), other=0.0
110
+ )
111
+
112
+ kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
113
+ split_kv_start = kv_len_per_split * split_kv_id
114
+ split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
115
+
116
+ # apply rotary embedding for q_pe, and k_pe (last token per batch of K_PE)
117
+ LAST_SPLIT = split_kv_end == cur_batch_seq_len
118
+ k_pe_last_token = tl.zeros([BLOCK_R], dtype=q.dtype)
119
+
120
+ if USE_ROPE:
121
+ if IS_NEOX_STYLE:
122
+ # [BLOCK_ROTARY // 2, BLOCK_ROTARY // 2 + 1, BLOCK_ROTARY // 2 + 2, ..., 0, 1, 2, ..., BLOCK_ROTARY // 2 - 1, pass:]
123
+ offs_qk_rot_r = kv_lora_rank + (
124
+ (tl.arange(0, BLOCK_R) + (rotary_dim // 2)) % rotary_dim
125
+ )
126
+ # Which elements to flip
127
+ mask_rotate = tl.arange(0, BLOCK_R) < (rotary_dim // 2)
128
+ # [0 , 1, 2, ..., rotary_dim // 2 - 1, 0 , 1, 2, ..., rotary_dim // 2 - 1]
129
+ offs_rotary = tl.arange(0, BLOCK_R) % (rotary_dim // 2)
130
+ else:
131
+ # [1, 0, 3, 2, 5, 4, ..., BLOCK_R, BLOCK_R - 1]
132
+ offs_qk_rot_r = (
133
+ kv_lora_rank
134
+ + (((tl.arange(0, BLOCK_R) + 1) % 2) * 2)
135
+ - 1
136
+ + tl.arange(0, BLOCK_R)
137
+ )
138
+ mask_rotate = tl.arange(0, BLOCK_R) % 2 < 1
139
+ # [0, 0, 1, 1, ..., rotary_dim // 2 - 1, rotary_dim // 2 - 1]
140
+ offs_rotary = tl.arange(0, BLOCK_R) // 2
141
+
142
+ if qk_rope_head_dim > rotary_dim:
143
+ offs_qk_rot_r = tl.where(
144
+ tl.arange(0, BLOCK_R) < rotary_dim, offs_qk_rot_r, tl.arange(0, BLOCK_R)
145
+ )
146
+ offs_rotary = tl.where(
147
+ tl.arange(0, BLOCK_R) < rotary_dim, offs_rotary, tl.arange(0, BLOCK_R)
148
+ )
149
+
150
+ mask_rotary = tl.arange(0, BLOCK_R) < rotary_dim
151
+
152
+ pos = tl.load(positions + cur_batch * stride_positions_b)
153
+ cos = tl.load(
154
+ cos_sin_cache + pos * stride_cos_sin_cache_s + offs_rotary,
155
+ mask=mask_rotary,
156
+ other=1.0,
157
+ )
158
+ sin = tl.load(
159
+ cos_sin_cache
160
+ + pos * stride_cos_sin_cache_s
161
+ + offs_rotary
162
+ + rotary_dim // 2,
163
+ mask_rotary,
164
+ other=0.0,
165
+ )
166
+
167
+ off_q_pe_rot = (
168
+ cur_batch * stride_qb
169
+ + cur_head[:, None] * stride_qh
170
+ + offs_qk_rot_r[None, :]
171
+ )
172
+ mask_qk_rot_r = offs_qk_rot_r < (kv_lora_rank + qk_rope_head_dim)
173
+
174
+ # 0, 2, 4,.... 1, 3, 5...
175
+ q_pe_rot = tl.load(
176
+ Q + off_q_pe_rot,
177
+ mask=(mask_h[:, None]) & (mask_qk_rot_r[None, :]),
178
+ other=0.0,
179
+ )
180
+ q_pe_rot = tl.where(mask_rotate[None, :], -q_pe_rot, q_pe_rot)
181
+
182
+ q_pe = q_pe * cos + q_pe_rot * sin
183
+
184
+ # we only apply to the last token in the K_PE
185
+ if LAST_SPLIT:
186
+ # debug assert
187
+ if (cur_batch == 0 and cur_head == 0) and split_kv_id < NUM_KV_SPLITS - 1:
188
+ tl.device_assert(False, "Only last split should compute k_pe")
189
+
190
+ kv_loc = tl.load(
191
+ kv_indices + cur_batch_kv_start_idx + cur_batch_seq_len - 1
192
+ )
193
+ offs_buf_k_pe_last_token = kv_loc * stride_buf_kbs + offs_qk_r
194
+ offs_buf_k_pe_rot_last_token = kv_loc * stride_buf_kbs + offs_qk_rot_r
195
+ k_pe_last_token = tl.load(K_Buffer + offs_buf_k_pe_last_token)
196
+
197
+ k_pe_rot_last_token = tl.load(K_Buffer + offs_buf_k_pe_rot_last_token)
198
+ k_pe_rot_last_token = tl.where(
199
+ mask_rotate, -k_pe_rot_last_token, k_pe_rot_last_token
200
+ )
201
+
202
+ k_pe_last_token = k_pe_last_token * cos + k_pe_rot_last_token * sin
203
+
204
+ e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
205
+ e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
206
+ acc = tl.zeros([BLOCK_H, BLOCK_C], dtype=tl.float32)
207
+
208
+ if split_kv_end > split_kv_start:
209
+ for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
210
+ offs_n = start_n + tl.arange(0, BLOCK_N)
211
+ kv_loc = tl.load(
212
+ kv_indices + cur_batch_kv_start_idx + offs_n,
213
+ mask=offs_n < split_kv_end,
214
+ other=0,
215
+ )
216
+
217
+ offs_buf_kv = kv_loc[None, :] * stride_buf_kbs + offs_c[:, None]
218
+ offs_buf_k_pe = kv_loc[None, :] * stride_buf_kbs + offs_qk_r[:, None]
219
+
220
+ k_pe = tl.load(
221
+ K_Buffer + offs_buf_k_pe,
222
+ mask=(offs_n[None, :] < split_kv_end) & (mask_qk_r[:, None]),
223
+ other=0.0,
224
+ ) # positional embedding part of keys
225
+
226
+ if (USE_ROPE and LAST_SPLIT) and start_n >= cur_batch_seq_len - BLOCK_N:
227
+ k_pe = tl.where(
228
+ offs_n[None, :] != (split_kv_end - 1),
229
+ k_pe,
230
+ k_pe_last_token[:, None],
231
+ )
232
+
233
+ # (16, 64) x (64, 32)
234
+ # dot product of rope parts
235
+ qk = tl.dot(q_pe, k_pe.to(q_pe.dtype))
236
+
237
+ kv = tl.load(
238
+ K_Buffer + offs_buf_kv,
239
+ mask=(offs_n[None, :] < split_kv_end) & (mask_c[:, None]),
240
+ other=0.0,
241
+ ) # the shared latent tensor for keys and values
242
+
243
+ # (16, 512) x (512, 32)
244
+ # dot product of nope parts
245
+ qk += tl.dot(q, kv)
246
+
247
+ qk *= sm_scale
248
+
249
+ if logit_cap > 0:
250
+ qk = logit_cap * tanh(qk / logit_cap)
251
+
252
+ qk = tl.where(
253
+ mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
254
+ )
255
+
256
+ offs_buf_v = kv_loc[:, None] * stride_buf_vbs + offs_c[None, :]
257
+ v = tl.load(
258
+ V_buffer + offs_buf_v,
259
+ mask=(offs_n[:, None] < split_kv_end) & (mask_c[None, :]),
260
+ other=0.0,
261
+ )
262
+
263
+ n_e_max = tl.maximum(tl.max(qk, 1), e_max)
264
+ re_scale = tl.exp(e_max - n_e_max)
265
+ p = tl.exp(qk - n_e_max[:, None])
266
+ acc *= re_scale[:, None]
267
+ # (16, 32) x (32, 512)
268
+ acc += tl.dot(p.to(v.dtype), v)
269
+
270
+ e_sum = e_sum * re_scale + tl.sum(p, 1)
271
+ e_max = n_e_max
272
+
273
+ offs_mid_o = (
274
+ cur_batch * stride_mid_ob
275
+ + cur_head[:, None] * stride_mid_oh
276
+ + split_kv_id * stride_mid_os
277
+ + offs_c[None, :]
278
+ )
279
+
280
+ if USE_ROPE:
281
+ if LAST_SPLIT:
282
+ k_pe_last_token_ptrs = (
283
+ k_pe_t_out
284
+ + cur_batch * stride_kpe_tokens_out_b
285
+ + tl.arange(0, BLOCK_R)
286
+ )
287
+ tl.store(k_pe_last_token_ptrs, k_pe_last_token, mask=mask_qk_r)
288
+
289
+ tl.store(
290
+ Att_Out + offs_mid_o,
291
+ acc / e_sum[:, None],
292
+ mask=(mask_h[:, None]) & (mask_c[None, :]),
293
+ )
294
+
295
+ offs_mid_o_1 = (
296
+ cur_batch * stride_mid_ob
297
+ + cur_head * stride_mid_oh
298
+ + split_kv_id * stride_mid_os
299
+ + kv_lora_rank
300
+ )
301
+
302
+ tl.store(
303
+ Att_Out + offs_mid_o_1,
304
+ e_max + tl.log(e_sum),
305
+ mask=mask_h,
306
+ )
307
+
308
+
309
+ # TODO rope offset
310
+ def _decode_grouped_att_m_fwd_rope(
311
+ q,
312
+ k_buffer,
313
+ v_buffer,
314
+ att_out,
315
+ k_pe_tokens_out,
316
+ kv_lora_rank, # c
317
+ cos_sin_cache,
318
+ positions,
319
+ rotary_dim,
320
+ kv_indptr,
321
+ kv_indices,
322
+ num_kv_splits,
323
+ sm_scale,
324
+ logit_cap,
325
+ use_rope,
326
+ is_neox_style=True,
327
+ ):
328
+ if use_rope:
329
+ assert (
330
+ k_pe_tokens_out is not None
331
+ ), "We must output the k_pe tokens with rope applied if rope fusion enabled."
332
+
333
+ BLOCK = 32
334
+
335
+ # # [TODO] work around shmem limit on MI3xx
336
+ # if is_hip_ and kv_lora_rank >= 576:
337
+ # BLOCK = 16
338
+
339
+ qk_rope_head_dim = k_buffer.shape[-1] - kv_lora_rank
340
+ batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
341
+ kv_group_num = q.shape[1] // k_buffer.shape[1]
342
+
343
+ BLOCK_C = triton.next_power_of_2(kv_lora_rank)
344
+ BLOCK_R = triton.next_power_of_2(qk_rope_head_dim)
345
+
346
+ BLOCK_H = 16
347
+ NUM_KV_SPLITS = num_kv_splits
348
+ grid = (
349
+ batch,
350
+ triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
351
+ NUM_KV_SPLITS,
352
+ )
353
+
354
+ extra_kargs = {}
355
+ num_stages = 2
356
+ if is_hip_:
357
+ # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
358
+ # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
359
+ extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
360
+ num_stages = 1
361
+
362
+ _fwd_grouped_kernel_stage1_rope[grid](
363
+ q,
364
+ k_buffer,
365
+ v_buffer,
366
+ cos_sin_cache,
367
+ positions,
368
+ sm_scale,
369
+ kv_indptr,
370
+ kv_indices,
371
+ att_out,
372
+ k_pe_tokens_out,
373
+ q.stride(0),
374
+ q.stride(1),
375
+ k_buffer.stride(0),
376
+ v_buffer.stride(0),
377
+ att_out.stride(0),
378
+ att_out.stride(1),
379
+ att_out.stride(2),
380
+ k_pe_tokens_out.stride(0) if use_rope else 0,
381
+ cos_sin_cache.stride(0) if use_rope else 0,
382
+ positions.stride(0) if use_rope else 0,
383
+ rotary_dim,
384
+ kv_lora_rank,
385
+ qk_rope_head_dim,
386
+ kv_group_num=kv_group_num,
387
+ q_head_num=head_num,
388
+ BLOCK_C=BLOCK_C,
389
+ BLOCK_R=BLOCK_R,
390
+ BLOCK_N=BLOCK,
391
+ BLOCK_H=BLOCK_H,
392
+ NUM_KV_SPLITS=NUM_KV_SPLITS,
393
+ logit_cap=logit_cap,
394
+ USE_ROPE=use_rope,
395
+ IS_NEOX_STYLE=is_neox_style,
396
+ num_warps=4,
397
+ num_stages=num_stages,
398
+ **extra_kargs
399
+ )
400
+
401
+
402
+ def decode_attention_fwd_grouped_rope(
403
+ q,
404
+ k_buffer,
405
+ v_buffer,
406
+ o,
407
+ kv_indptr,
408
+ kv_indices,
409
+ k_pe_tokens,
410
+ kv_lora_rank,
411
+ rotary_dim,
412
+ cos_sin_cache,
413
+ positions,
414
+ attn_logits,
415
+ num_kv_splits,
416
+ sm_scale,
417
+ logit_cap=0.0,
418
+ use_rope=False,
419
+ is_neox_style=False,
420
+ ):
421
+ _decode_grouped_att_m_fwd_rope(
422
+ q,
423
+ k_buffer,
424
+ v_buffer,
425
+ attn_logits,
426
+ k_pe_tokens,
427
+ kv_lora_rank,
428
+ cos_sin_cache,
429
+ positions,
430
+ rotary_dim,
431
+ kv_indptr,
432
+ kv_indices,
433
+ num_kv_splits,
434
+ sm_scale,
435
+ logit_cap,
436
+ use_rope,
437
+ is_neox_style,
438
+ )
439
+ _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
@@ -0,0 +1,39 @@
1
+ import triton
2
+ import triton.language as tl
3
+
4
+
5
+ @triton.jit
6
+ def create_flashinfer_kv_indices_triton(
7
+ req_to_token_ptr, # [max_batch, max_context_len]
8
+ req_pool_indices_ptr,
9
+ page_kernel_lens_ptr,
10
+ kv_indptr,
11
+ kv_start_idx,
12
+ kv_indices_ptr,
13
+ req_to_token_ptr_stride: tl.constexpr,
14
+ ):
15
+ BLOCK_SIZE: tl.constexpr = 512
16
+ pid = tl.program_id(axis=0)
17
+
18
+ req_pool_index = tl.load(req_pool_indices_ptr + pid)
19
+ kv_indices_offset = tl.load(kv_indptr + pid)
20
+
21
+ kv_start = 0
22
+ kv_end = 0
23
+ if kv_start_idx:
24
+ kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
25
+ kv_end = kv_start
26
+ kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
27
+
28
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
29
+ for i in range(num_loop):
30
+ offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
31
+ mask = offset < kv_end - kv_start
32
+ data = tl.load(
33
+ req_to_token_ptr
34
+ + req_pool_index * req_to_token_ptr_stride
35
+ + kv_start
36
+ + offset,
37
+ mask=mask,
38
+ )
39
+ tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)