sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,627 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, Optional, Union
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
12
+ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
13
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
14
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
15
+ from sglang.srt.utils import get_bool_env_var, get_device_core_count
16
+
17
+ if TYPE_CHECKING:
18
+ from sglang.srt.layers.radix_attention import RadixAttention
19
+ from sglang.srt.model_executor.model_runner import ModelRunner
20
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @triton.jit
26
+ def get_num_kv_splits_triton(
27
+ num_kv_splits_ptr,
28
+ seq_lens_ptr,
29
+ num_seq,
30
+ num_group,
31
+ num_head,
32
+ num_kv_head,
33
+ max_kv_splits,
34
+ device_core_count,
35
+ MAX_NUM_SEQ: tl.constexpr,
36
+ ):
37
+ # TODO: this method is tunable, we need more online serving data to tune it
38
+ offs_seq = tl.arange(0, MAX_NUM_SEQ)
39
+ mask_seq = offs_seq < num_seq
40
+
41
+ seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
42
+ max_seq_len = tl.max(seq_lens)
43
+ seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
44
+ min_seq_len = tl.min(seq_lens)
45
+ if max_seq_len * 8 < min_seq_len * 10:
46
+ min_seq_len = max_seq_len
47
+ max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
48
+ kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)
49
+
50
+ # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
51
+ ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
52
+ ext_device_core_count = tl.cast(
53
+ device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
54
+ )
55
+ block_h, num_kv_group = 16, num_head // num_kv_head
56
+ if num_kv_group == 1:
57
+ token_grid = num_seq * num_group * num_head
58
+ else:
59
+ # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
60
+ block_h = tl.minimum(block_h, num_kv_group)
61
+ token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
62
+ max_kv_splits_2 = tl.minimum(
63
+ tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
64
+ )
65
+ kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)
66
+
67
+ num_kv_splits = tl.maximum(
68
+ tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
69
+ )
70
+
71
+ offs_token = offs_seq * num_group
72
+ mask_token = offs_token < num_seq * num_group
73
+ for i in range(0, num_group):
74
+ tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
75
+
76
+
77
+ @dataclass
78
+ class ForwardMetadata:
79
+ attn_logits: torch.Tensor
80
+ attn_lse: torch.Tensor
81
+ max_extend_len: int
82
+ num_kv_splits: torch.Tensor
83
+ kv_indptr: torch.Tensor
84
+ kv_indices: torch.Tensor
85
+ qo_indptr: torch.Tensor
86
+ custom_mask: torch.Tensor
87
+ mask_indptr: torch.Tensor
88
+
89
+
90
+ class WaveAttnBackend(AttentionBackend):
91
+ def __init__(
92
+ self,
93
+ model_runner: ModelRunner,
94
+ skip_prefill: bool = False,
95
+ kv_indptr_buf: Optional[torch.Tensor] = None,
96
+ ):
97
+ # Lazy import to avoid the initialization of cuda context
98
+ from sglang.srt.layers.attention.wave_ops.decode_attention import (
99
+ decode_attention_fwd,
100
+ )
101
+ from sglang.srt.layers.attention.wave_ops.extend_attention import (
102
+ extend_attention_wave,
103
+ )
104
+
105
+ super().__init__()
106
+
107
+ # Set unique cache dir for each process to avoid cache write races
108
+ import wave_lang.kernel.wave.cache as cache
109
+
110
+ base_cache_dir = cache.CACHE_BASE_DIR
111
+ new_dir = base_cache_dir / f"worker_{model_runner.tp_rank}"
112
+ logger.info(f"Setting Wave cache dir: {new_dir}")
113
+ cache.CACHE_BASE_DIR = new_dir
114
+
115
+ self.decode_attention_fwd = decode_attention_fwd
116
+ self.extend_attention_fwd = extend_attention_wave
117
+
118
+ self.skip_prefill = skip_prefill
119
+
120
+ max_bs = model_runner.req_to_token_pool.size
121
+
122
+ if kv_indptr_buf is None:
123
+ self.kv_indptr = torch.zeros(
124
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
125
+ )
126
+ else:
127
+ self.kv_indptr = kv_indptr_buf
128
+
129
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
130
+
131
+ if not self.skip_prefill:
132
+ self.qo_indptr = torch.zeros(
133
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
134
+ )
135
+
136
+ self.mask_indptr = torch.zeros(
137
+ (max_bs + 1,), dtype=torch.int64, device=model_runner.device
138
+ )
139
+
140
+ self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
141
+
142
+ self.num_head = (
143
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
144
+ )
145
+ self.num_kv_head = model_runner.model_config.get_num_kv_heads(
146
+ get_attention_tp_size()
147
+ )
148
+
149
+ self.static_kv_splits = get_bool_env_var(
150
+ "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
151
+ )
152
+ self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
153
+ self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
154
+
155
+ self.forward_metadata: ForwardMetadata = None
156
+
157
+ self.max_context_len = model_runner.model_config.context_len
158
+
159
+ self.device = model_runner.device
160
+ self.device_core_count = get_device_core_count(model_runner.gpu_id)
161
+
162
+ def get_num_kv_splits(
163
+ self,
164
+ num_kv_splits: torch.Tensor,
165
+ seq_lens: torch.Tensor,
166
+ ):
167
+ num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0]
168
+ num_group = num_token // num_seq
169
+
170
+ assert (
171
+ num_group * num_seq == num_token
172
+ ), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!"
173
+
174
+ if self.static_kv_splits or self.device_core_count <= 0:
175
+ num_kv_splits.fill_(self.max_kv_splits)
176
+ return
177
+
178
+ if num_seq < 256:
179
+ SCHEDULE_SEQ = 256
180
+ else:
181
+ SCHEDULE_SEQ = triton.next_power_of_2(num_seq)
182
+
183
+ get_num_kv_splits_triton[(1,)](
184
+ num_kv_splits,
185
+ seq_lens,
186
+ num_seq,
187
+ num_group,
188
+ self.num_head,
189
+ self.num_kv_head,
190
+ self.max_kv_splits,
191
+ self.device_core_count,
192
+ MAX_NUM_SEQ=SCHEDULE_SEQ,
193
+ )
194
+
195
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
196
+ """Init auxiliary variables for wave attention backend."""
197
+
198
+ bs = forward_batch.batch_size
199
+ kv_indptr = self.kv_indptr
200
+ spec_info = forward_batch.spec_info
201
+
202
+ if forward_batch.forward_mode.is_decode_or_idle():
203
+ if spec_info is None:
204
+ kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
205
+ kv_indptr = kv_indptr[: bs + 1]
206
+ kv_indices = torch.empty(
207
+ forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
208
+ )
209
+ create_flashinfer_kv_indices_triton[(bs,)](
210
+ self.req_to_token,
211
+ forward_batch.req_pool_indices,
212
+ forward_batch.seq_lens,
213
+ kv_indptr,
214
+ None,
215
+ kv_indices,
216
+ self.req_to_token.stride(0),
217
+ )
218
+ else:
219
+ kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
220
+ bs = kv_indptr.shape[0] - 1
221
+
222
+ from sglang.srt.layers.attention.wave_ops.decode_attention import (
223
+ decode_attention_intermediate_arrays_shapes,
224
+ )
225
+
226
+ attn_logits_shape, attn_logits_max_shape = (
227
+ decode_attention_intermediate_arrays_shapes(
228
+ bs, self.v_head_dim, self.num_head, self.max_kv_splits
229
+ )
230
+ )
231
+ attn_logits = torch.empty(
232
+ attn_logits_shape,
233
+ dtype=torch.float32,
234
+ device=self.device,
235
+ )
236
+ attn_lse = torch.empty(
237
+ attn_logits_max_shape,
238
+ dtype=torch.float32,
239
+ device=self.device,
240
+ )
241
+ num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device)
242
+
243
+ self.get_num_kv_splits(num_kv_splits, forward_batch.seq_lens)
244
+
245
+ qo_indptr = None
246
+ custom_mask = None
247
+ mask_indptr = None
248
+ max_extend_len = None
249
+ elif forward_batch.forward_mode.is_target_verify():
250
+ bs = len(forward_batch.req_pool_indices)
251
+ qo_indptr = torch.arange(
252
+ 0,
253
+ (1 + bs) * self.num_draft_tokens,
254
+ step=self.num_draft_tokens,
255
+ dtype=torch.int32,
256
+ device=self.device,
257
+ )
258
+ # Different with flashinfer kv_indptr and kv_indices construction
259
+ kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
260
+ kv_indptr = kv_indptr[: bs + 1]
261
+ kv_indices = torch.empty(
262
+ kv_indptr[-1], dtype=torch.int32, device=self.device
263
+ )
264
+ create_flashinfer_kv_indices_triton[(bs,)](
265
+ self.req_to_token,
266
+ forward_batch.req_pool_indices,
267
+ forward_batch.seq_lens,
268
+ kv_indptr,
269
+ None,
270
+ kv_indices,
271
+ self.req_to_token.stride(0),
272
+ )
273
+
274
+ custom_mask = spec_info.custom_mask
275
+ seq_mask_len = self.num_draft_tokens * (
276
+ forward_batch.seq_lens + self.num_draft_tokens
277
+ )
278
+ mask_indptr = self.mask_indptr
279
+ mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
280
+ mask_indptr = mask_indptr[: bs + 1]
281
+ max_extend_len = self.num_draft_tokens
282
+ num_kv_splits = None
283
+ attn_logits = None
284
+ attn_lse = None
285
+ elif forward_batch.forward_mode.is_draft_extend():
286
+ kv_indices, kv_indptr, qo_indptr, custom_mask = (
287
+ spec_info.generate_attn_arg_prefill(
288
+ forward_batch.req_pool_indices,
289
+ forward_batch.seq_lens,
290
+ None,
291
+ self.req_to_token,
292
+ )
293
+ )
294
+ mask_indptr = None
295
+ # TODO(FIXME): This will trigger an invalid Eagle tree when using
296
+ # `max(spec_info.accept_length_cpu)`.
297
+ # It might have been forgotten to update somewhere.
298
+ max_extend_len = torch.max(spec_info.accept_length).item()
299
+ num_kv_splits = None
300
+ attn_logits = None
301
+ attn_lse = None
302
+ else:
303
+ kv_indptr[1 : bs + 1] = torch.cumsum(
304
+ forward_batch.extend_prefix_lens, dim=0
305
+ )
306
+ kv_indptr = kv_indptr[: bs + 1]
307
+ kv_indices = torch.empty(
308
+ forward_batch.extend_prefix_lens.sum().item(),
309
+ dtype=torch.int32,
310
+ device=self.device,
311
+ )
312
+ create_flashinfer_kv_indices_triton[(bs,)](
313
+ self.req_to_token,
314
+ forward_batch.req_pool_indices,
315
+ forward_batch.extend_prefix_lens,
316
+ kv_indptr,
317
+ None,
318
+ kv_indices,
319
+ self.req_to_token.stride(0),
320
+ )
321
+
322
+ qo_indptr = self.qo_indptr
323
+ qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
324
+ qo_indptr = qo_indptr[: bs + 1]
325
+ custom_mask = None
326
+ mask_indptr = None
327
+ attn_logits = None
328
+ attn_lse = None
329
+ max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
330
+ num_kv_splits = None
331
+
332
+ self.forward_metadata = ForwardMetadata(
333
+ attn_logits,
334
+ attn_lse,
335
+ max_extend_len,
336
+ num_kv_splits,
337
+ kv_indptr,
338
+ kv_indices,
339
+ qo_indptr,
340
+ custom_mask,
341
+ mask_indptr,
342
+ )
343
+
344
+ def init_cuda_graph_state(
345
+ self,
346
+ max_bs: int,
347
+ max_num_tokens: int,
348
+ kv_indices_buf: Optional[torch.Tensor] = None,
349
+ ):
350
+ from sglang.srt.layers.attention.wave_ops.decode_attention import (
351
+ decode_attention_intermediate_arrays_shapes,
352
+ )
353
+
354
+ attn_logits_shape, attn_logits_max_shape = (
355
+ decode_attention_intermediate_arrays_shapes(
356
+ max_bs, self.v_head_dim, self.num_head, self.max_kv_splits
357
+ )
358
+ )
359
+ self.cuda_graph_attn_logits = torch.zeros(
360
+ attn_logits_shape,
361
+ dtype=torch.float32,
362
+ device=self.device,
363
+ )
364
+ self.cuda_graph_attn_lse = torch.zeros(
365
+ attn_logits_max_shape,
366
+ dtype=torch.float32,
367
+ device=self.device,
368
+ )
369
+ self.cuda_graph_num_kv_splits = torch.full(
370
+ (max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
371
+ )
372
+ if kv_indices_buf is None:
373
+ self.cuda_graph_kv_indices = torch.zeros(
374
+ (max_bs * self.max_context_len),
375
+ dtype=torch.int32,
376
+ device=self.device,
377
+ )
378
+ else:
379
+ self.cuda_graph_kv_indices = kv_indices_buf
380
+
381
+ if not self.skip_prefill:
382
+ self.cuda_graph_custom_mask = torch.zeros(
383
+ (max_bs * self.max_context_len),
384
+ dtype=torch.uint8,
385
+ device=self.device,
386
+ )
387
+
388
+ def init_forward_metadata_capture_cuda_graph(
389
+ self,
390
+ bs: int,
391
+ num_tokens: int,
392
+ req_pool_indices: torch.Tensor,
393
+ seq_lens: torch.Tensor,
394
+ encoder_lens: Optional[torch.Tensor],
395
+ forward_mode: ForwardMode,
396
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
397
+ ):
398
+ assert encoder_lens is None, "Not supported"
399
+
400
+ if forward_mode.is_decode_or_idle():
401
+ if spec_info is None:
402
+ kv_indptr = self.kv_indptr
403
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
404
+ kv_indptr = kv_indptr[: bs + 1]
405
+ kv_indices = self.cuda_graph_kv_indices
406
+ create_flashinfer_kv_indices_triton[(bs,)](
407
+ self.req_to_token,
408
+ req_pool_indices,
409
+ seq_lens,
410
+ kv_indptr,
411
+ None,
412
+ kv_indices,
413
+ self.req_to_token.stride(0),
414
+ )
415
+ else:
416
+ kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
417
+
418
+ attn_logits = self.cuda_graph_attn_logits
419
+ attn_lse = self.cuda_graph_attn_lse
420
+ max_extend_len = None
421
+ num_kv_splits = self.cuda_graph_num_kv_splits
422
+ qo_indptr = None
423
+ custom_mask = None
424
+ mask_indptr = None
425
+ elif forward_mode.is_target_verify():
426
+ qo_indptr = self.qo_indptr[: bs + 1]
427
+ qo_indptr[: bs + 1] = torch.arange(
428
+ 0,
429
+ (1 + bs) * self.num_draft_tokens,
430
+ step=self.num_draft_tokens,
431
+ dtype=torch.int32,
432
+ device=self.device,
433
+ )
434
+ kv_indptr = self.kv_indptr[: bs + 1]
435
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
436
+ kv_indices = self.cuda_graph_kv_indices
437
+ create_flashinfer_kv_indices_triton[(bs,)](
438
+ self.req_to_token,
439
+ req_pool_indices,
440
+ seq_lens,
441
+ kv_indptr,
442
+ None,
443
+ kv_indices,
444
+ self.req_to_token.stride(0),
445
+ )
446
+
447
+ custom_mask = self.cuda_graph_custom_mask
448
+ seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
449
+ mask_indptr = self.mask_indptr[: bs + 1]
450
+ mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
451
+ max_extend_len = self.num_draft_tokens
452
+ num_kv_splits = None
453
+ attn_logits = None
454
+ attn_lse = None
455
+ else:
456
+ raise ValueError(
457
+ f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
458
+ )
459
+
460
+ self.forward_metadata = ForwardMetadata(
461
+ attn_logits,
462
+ attn_lse,
463
+ max_extend_len,
464
+ num_kv_splits,
465
+ kv_indptr,
466
+ kv_indices,
467
+ qo_indptr,
468
+ custom_mask,
469
+ mask_indptr,
470
+ )
471
+
472
+ def init_forward_metadata_replay_cuda_graph(
473
+ self,
474
+ bs: int,
475
+ req_pool_indices: torch.Tensor,
476
+ seq_lens: torch.Tensor,
477
+ seq_lens_sum: int,
478
+ encoder_lens: Optional[torch.Tensor],
479
+ forward_mode: ForwardMode,
480
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
481
+ seq_lens_cpu: Optional[torch.Tensor],
482
+ ):
483
+ # NOTE: encoder_lens expected to be zeros or None
484
+ if forward_mode.is_decode_or_idle():
485
+ # Update kv_indptr, kv_indices
486
+ kv_indptr = self.kv_indptr
487
+ kv_indices = self.cuda_graph_kv_indices
488
+ num_kv_splits = self.cuda_graph_num_kv_splits
489
+ if spec_info is None:
490
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
491
+ kv_indptr = kv_indptr[: bs + 1]
492
+ create_flashinfer_kv_indices_triton[(bs,)](
493
+ self.req_to_token,
494
+ req_pool_indices[:bs],
495
+ seq_lens[:bs],
496
+ kv_indptr,
497
+ None,
498
+ kv_indices,
499
+ self.req_to_token.stride(0),
500
+ )
501
+ num_token = bs
502
+ else:
503
+ kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
504
+ kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
505
+ num_token = spec_info.kv_indptr.shape[0] - 1
506
+ self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
507
+ elif forward_mode.is_target_verify():
508
+ # Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
509
+ bs = len(req_pool_indices)
510
+ qo_indptr = self.qo_indptr[: bs + 1]
511
+ qo_indptr[: bs + 1] = torch.arange(
512
+ 0,
513
+ (1 + bs) * self.num_draft_tokens,
514
+ step=self.num_draft_tokens,
515
+ dtype=torch.int32,
516
+ device=self.device,
517
+ )
518
+ kv_indptr = self.kv_indptr[: bs + 1]
519
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
520
+ kv_indices = self.cuda_graph_kv_indices
521
+ create_flashinfer_kv_indices_triton[(bs,)](
522
+ self.req_to_token,
523
+ req_pool_indices,
524
+ seq_lens,
525
+ kv_indptr,
526
+ None,
527
+ kv_indices,
528
+ self.req_to_token.stride(0),
529
+ )
530
+ custom_mask = self.cuda_graph_custom_mask
531
+ custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
532
+ seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
533
+ mask_indptr = self.mask_indptr[: bs + 1]
534
+ mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
535
+ else:
536
+ raise ValueError(
537
+ f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
538
+ )
539
+
540
+ def get_cuda_graph_seq_len_fill_value(self):
541
+ return 1
542
+
543
+ def forward_extend(
544
+ self,
545
+ q: torch.Tensor,
546
+ k: torch.Tensor,
547
+ v: torch.Tensor,
548
+ layer: RadixAttention,
549
+ forward_batch: ForwardBatch,
550
+ save_kv_cache=True,
551
+ ):
552
+ # TODO: reuse the buffer across layers
553
+ if layer.qk_head_dim != layer.v_head_dim:
554
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
555
+ else:
556
+ o = torch.empty_like(q)
557
+
558
+ if save_kv_cache:
559
+ forward_batch.token_to_kv_pool.set_kv_buffer(
560
+ layer, forward_batch.out_cache_loc, k, v
561
+ )
562
+
563
+ max_extend_len = self.forward_metadata.max_extend_len
564
+ computed_max_ext_seq_len = torch.max(forward_batch.extend_seq_lens)
565
+ if computed_max_ext_seq_len != max_extend_len:
566
+ assert len(forward_batch.extend_seq_lens) == 1
567
+ forward_batch.extend_seq_lens[0] = max_extend_len
568
+ forward_batch.seq_lens = max_extend_len
569
+
570
+ self.extend_attention_fwd(
571
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
572
+ k.contiguous(),
573
+ v.contiguous(),
574
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
575
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
576
+ self.forward_metadata.qo_indptr,
577
+ self.forward_metadata.kv_indptr,
578
+ self.forward_metadata.kv_indices,
579
+ self.forward_metadata.custom_mask,
580
+ self.forward_metadata.mask_indptr,
581
+ self.forward_metadata.max_extend_len,
582
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
583
+ is_causal=True,
584
+ layer_scaling=layer.scaling,
585
+ logit_cap=layer.logit_cap,
586
+ )
587
+ return o
588
+
589
+ def forward_decode(
590
+ self,
591
+ q: torch.Tensor,
592
+ k: torch.Tensor,
593
+ v: torch.Tensor,
594
+ layer: RadixAttention,
595
+ forward_batch: ForwardBatch,
596
+ save_kv_cache=True,
597
+ ):
598
+ # During torch.compile, there is a bug in rotary_emb that causes the
599
+ # output value to have a 3D tensor shape. This reshapes the output correctly.
600
+ q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
601
+
602
+ # TODO: reuse the buffer across layers
603
+ if layer.qk_head_dim != layer.v_head_dim:
604
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
605
+ else:
606
+ o = torch.empty_like(q)
607
+
608
+ if save_kv_cache:
609
+ forward_batch.token_to_kv_pool.set_kv_buffer(
610
+ layer, forward_batch.out_cache_loc, k, v
611
+ )
612
+
613
+ self.decode_attention_fwd(
614
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
615
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
616
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
617
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
618
+ self.forward_metadata.kv_indptr,
619
+ self.forward_metadata.kv_indices,
620
+ self.forward_metadata.attn_logits,
621
+ self.forward_metadata.attn_lse,
622
+ self.forward_metadata.num_kv_splits,
623
+ self.max_kv_splits,
624
+ layer.scaling,
625
+ layer.logit_cap,
626
+ )
627
+ return o