sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from dataclasses import dataclass
3
4
  from typing import TYPE_CHECKING, Optional, Union
4
5
 
5
6
  import torch
6
7
  import triton
8
+ import triton.language as tl
7
9
 
8
10
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
9
11
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
10
12
  from sglang.srt.layers.dp_attention import get_attention_tp_size
11
13
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
14
+ from sglang.srt.utils import get_bool_env_var, get_device_core_count
12
15
 
13
16
  if TYPE_CHECKING:
14
17
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -16,6 +19,71 @@ if TYPE_CHECKING:
16
19
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
17
20
 
18
21
 
22
+ @triton.jit
23
+ def get_num_kv_splits_triton(
24
+ num_kv_splits_ptr,
25
+ seq_lens_ptr,
26
+ num_seq,
27
+ num_group,
28
+ num_head,
29
+ num_kv_head,
30
+ max_kv_splits,
31
+ device_core_count,
32
+ MAX_NUM_SEQ: tl.constexpr,
33
+ ):
34
+ # TODO: this method is tunable, we need more online serving data to tune it
35
+ offs_seq = tl.arange(0, MAX_NUM_SEQ)
36
+ mask_seq = offs_seq < num_seq
37
+
38
+ seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
39
+ max_seq_len = tl.max(seq_lens)
40
+ seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
41
+ min_seq_len = tl.min(seq_lens)
42
+ if max_seq_len * 8 < min_seq_len * 10:
43
+ min_seq_len = max_seq_len
44
+ max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
45
+ kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)
46
+
47
+ # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
48
+ ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
49
+ ext_device_core_count = tl.cast(
50
+ device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
51
+ )
52
+ block_h, num_kv_group = 16, num_head // num_kv_head
53
+ if num_kv_group == 1:
54
+ token_grid = num_seq * num_group * num_head
55
+ else:
56
+ # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
57
+ block_h = tl.minimum(block_h, num_kv_group)
58
+ token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
59
+ max_kv_splits_2 = tl.minimum(
60
+ tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
61
+ )
62
+ kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)
63
+
64
+ num_kv_splits = tl.maximum(
65
+ tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
66
+ )
67
+
68
+ offs_token = offs_seq * num_group
69
+ mask_token = offs_token < num_seq * num_group
70
+ for i in range(0, num_group):
71
+ tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
72
+
73
+
74
+ @dataclass
75
+ class ForwardMetadata:
76
+ attn_logits: torch.Tensor
77
+ attn_lse: torch.Tensor
78
+ max_extend_len: int
79
+ num_kv_splits: torch.Tensor
80
+ kv_indptr: torch.Tensor
81
+ kv_indices: torch.Tensor
82
+ qo_indptr: torch.Tensor
83
+ custom_mask: torch.Tensor
84
+ mask_indptr: torch.Tensor
85
+
86
+
19
87
  class TritonAttnBackend(AttentionBackend):
20
88
  def __init__(
21
89
  self,
@@ -63,15 +131,55 @@ class TritonAttnBackend(AttentionBackend):
63
131
  self.num_head = (
64
132
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
65
133
  )
134
+ self.num_kv_head = model_runner.model_config.get_num_kv_heads(
135
+ get_attention_tp_size()
136
+ )
66
137
 
67
- self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
138
+ self.static_kv_splits = get_bool_env_var(
139
+ "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
140
+ )
141
+ self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
68
142
  self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
69
143
 
70
- self.forward_metadata = None
144
+ self.forward_metadata: ForwardMetadata = None
71
145
 
72
146
  self.max_context_len = model_runner.model_config.context_len
73
147
 
74
148
  self.device = model_runner.device
149
+ self.device_core_count = get_device_core_count(model_runner.gpu_id)
150
+
151
+ def get_num_kv_splits(
152
+ self,
153
+ num_kv_splits: torch.Tensor,
154
+ seq_lens: torch.Tensor,
155
+ ):
156
+ num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0]
157
+ num_group = num_token // num_seq
158
+
159
+ assert (
160
+ num_group * num_seq == num_token
161
+ ), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!"
162
+
163
+ if self.static_kv_splits or self.device_core_count <= 0:
164
+ num_kv_splits.fill_(self.max_kv_splits)
165
+ return
166
+
167
+ if num_seq < 256:
168
+ SCHEDULE_SEQ = 256
169
+ else:
170
+ SCHEDULE_SEQ = triton.next_power_of_2(num_seq)
171
+
172
+ get_num_kv_splits_triton[(1,)](
173
+ num_kv_splits,
174
+ seq_lens,
175
+ num_seq,
176
+ num_group,
177
+ self.num_head,
178
+ self.num_kv_head,
179
+ self.max_kv_splits,
180
+ self.device_core_count,
181
+ MAX_NUM_SEQ=SCHEDULE_SEQ,
182
+ )
75
183
 
76
184
  def init_forward_metadata(self, forward_batch: ForwardBatch):
77
185
  """Init auxiliary variables for triton attention backend."""
@@ -84,7 +192,7 @@ class TritonAttnBackend(AttentionBackend):
84
192
  if spec_info is None:
85
193
  kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
86
194
  kv_indptr = kv_indptr[: bs + 1]
87
- kv_indices = torch.zeros(
195
+ kv_indices = torch.empty(
88
196
  forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
89
197
  )
90
198
  create_flashinfer_kv_indices_triton[(bs,)](
@@ -100,16 +208,19 @@ class TritonAttnBackend(AttentionBackend):
100
208
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
101
209
  bs = kv_indptr.shape[0] - 1
102
210
 
103
- attn_logits = torch.zeros(
104
- (
105
- bs,
106
- self.num_head,
107
- self.num_kv_splits,
108
- self.v_head_dim + 1,
109
- ),
211
+ attn_logits = torch.empty(
212
+ (bs, self.num_head, self.max_kv_splits, self.v_head_dim),
213
+ dtype=torch.float32,
214
+ device=self.device,
215
+ )
216
+ attn_lse = torch.empty(
217
+ (bs, self.num_head, self.max_kv_splits),
110
218
  dtype=torch.float32,
111
219
  device=self.device,
112
220
  )
221
+ num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device)
222
+
223
+ self.get_num_kv_splits(num_kv_splits, forward_batch.seq_lens)
113
224
 
114
225
  qo_indptr = None
115
226
  custom_mask = None
@@ -127,7 +238,7 @@ class TritonAttnBackend(AttentionBackend):
127
238
  # Different with flashinfer kv_indptr and kv_indices construction
128
239
  kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
129
240
  kv_indptr = kv_indptr[: bs + 1]
130
- kv_indices = torch.zeros(
241
+ kv_indices = torch.empty(
131
242
  kv_indptr[-1], dtype=torch.int32, device=self.device
132
243
  )
133
244
  create_flashinfer_kv_indices_triton[(bs,)](
@@ -148,7 +259,9 @@ class TritonAttnBackend(AttentionBackend):
148
259
  mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
149
260
  mask_indptr = mask_indptr[: bs + 1]
150
261
  max_extend_len = self.num_draft_tokens
262
+ num_kv_splits = None
151
263
  attn_logits = None
264
+ attn_lse = None
152
265
  elif forward_batch.forward_mode.is_draft_extend():
153
266
  kv_indices, kv_indptr, qo_indptr, custom_mask = (
154
267
  spec_info.generate_attn_arg_prefill(
@@ -159,14 +272,19 @@ class TritonAttnBackend(AttentionBackend):
159
272
  )
160
273
  )
161
274
  mask_indptr = None
275
+ # TODO(FIXME): This will trigger an invalid Eagle tree when using
276
+ # `max(spec_info.accept_length_cpu)`.
277
+ # It might have been forgotten to update somewhere.
162
278
  max_extend_len = torch.max(spec_info.accept_length).item()
279
+ num_kv_splits = None
163
280
  attn_logits = None
281
+ attn_lse = None
164
282
  else:
165
283
  kv_indptr[1 : bs + 1] = torch.cumsum(
166
284
  forward_batch.extend_prefix_lens, dim=0
167
285
  )
168
286
  kv_indptr = kv_indptr[: bs + 1]
169
- kv_indices = torch.zeros(
287
+ kv_indices = torch.empty(
170
288
  forward_batch.extend_prefix_lens.sum().item(),
171
289
  dtype=torch.int32,
172
290
  device=self.device,
@@ -187,11 +305,15 @@ class TritonAttnBackend(AttentionBackend):
187
305
  custom_mask = None
188
306
  mask_indptr = None
189
307
  attn_logits = None
308
+ attn_lse = None
190
309
  max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
310
+ num_kv_splits = None
191
311
 
192
- self.forward_metadata = (
312
+ self.forward_metadata = ForwardMetadata(
193
313
  attn_logits,
314
+ attn_lse,
194
315
  max_extend_len,
316
+ num_kv_splits,
195
317
  kv_indptr,
196
318
  kv_indices,
197
319
  qo_indptr,
@@ -203,10 +325,18 @@ class TritonAttnBackend(AttentionBackend):
203
325
  self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
204
326
  ):
205
327
  self.cuda_graph_attn_logits = torch.zeros(
206
- (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
328
+ (max_bs, self.num_head, self.max_kv_splits, self.v_head_dim),
207
329
  dtype=torch.float32,
208
330
  device=self.device,
209
331
  )
332
+ self.cuda_graph_attn_lse = torch.zeros(
333
+ (max_bs, self.num_head, self.max_kv_splits),
334
+ dtype=torch.float32,
335
+ device=self.device,
336
+ )
337
+ self.cuda_graph_num_kv_splits = torch.full(
338
+ (max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
339
+ )
210
340
  if kv_indices_buf is None:
211
341
  self.cuda_graph_kv_indices = torch.zeros(
212
342
  (max_bs * self.max_context_len),
@@ -254,7 +384,9 @@ class TritonAttnBackend(AttentionBackend):
254
384
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
255
385
 
256
386
  attn_logits = self.cuda_graph_attn_logits
387
+ attn_lse = self.cuda_graph_attn_lse
257
388
  max_extend_len = None
389
+ num_kv_splits = self.cuda_graph_num_kv_splits
258
390
  qo_indptr = None
259
391
  custom_mask = None
260
392
  mask_indptr = None
@@ -285,15 +417,19 @@ class TritonAttnBackend(AttentionBackend):
285
417
  mask_indptr = self.mask_indptr[: bs + 1]
286
418
  mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
287
419
  max_extend_len = self.num_draft_tokens
420
+ num_kv_splits = None
288
421
  attn_logits = None
422
+ attn_lse = None
289
423
  else:
290
424
  raise ValueError(
291
425
  f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
292
426
  )
293
427
 
294
- self.forward_metadata = (
428
+ self.forward_metadata = ForwardMetadata(
295
429
  attn_logits,
430
+ attn_lse,
296
431
  max_extend_len,
432
+ num_kv_splits,
297
433
  kv_indptr,
298
434
  kv_indices,
299
435
  qo_indptr,
@@ -317,6 +453,7 @@ class TritonAttnBackend(AttentionBackend):
317
453
  # Update kv_indptr, kv_indices
318
454
  kv_indptr = self.kv_indptr
319
455
  kv_indices = self.cuda_graph_kv_indices
456
+ num_kv_splits = self.cuda_graph_num_kv_splits
320
457
  if spec_info is None:
321
458
  kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
322
459
  kv_indptr = kv_indptr[: bs + 1]
@@ -329,9 +466,12 @@ class TritonAttnBackend(AttentionBackend):
329
466
  kv_indices,
330
467
  self.req_to_token.stride(0),
331
468
  )
469
+ num_token = bs
332
470
  else:
333
471
  kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
334
472
  kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
473
+ num_token = spec_info.kv_indptr.shape[0] - 1
474
+ self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
335
475
  elif forward_mode.is_target_verify():
336
476
  # Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
337
477
  bs = len(req_pool_indices)
@@ -388,16 +528,6 @@ class TritonAttnBackend(AttentionBackend):
388
528
  layer, forward_batch.out_cache_loc, k, v
389
529
  )
390
530
 
391
- (
392
- _,
393
- max_extend_len,
394
- kv_indptr,
395
- kv_indices,
396
- qo_indptr,
397
- custom_mask,
398
- mask_indptr,
399
- ) = self.forward_metadata
400
-
401
531
  self.extend_attention_fwd(
402
532
  q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
403
533
  k.contiguous(),
@@ -405,12 +535,12 @@ class TritonAttnBackend(AttentionBackend):
405
535
  o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
406
536
  forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
407
537
  forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
408
- qo_indptr,
409
- kv_indptr,
410
- kv_indices,
411
- custom_mask,
412
- mask_indptr,
413
- max_extend_len,
538
+ self.forward_metadata.qo_indptr,
539
+ self.forward_metadata.kv_indptr,
540
+ self.forward_metadata.kv_indices,
541
+ self.forward_metadata.custom_mask,
542
+ self.forward_metadata.mask_indptr,
543
+ self.forward_metadata.max_extend_len,
414
544
  layer.scaling,
415
545
  layer.logit_cap,
416
546
  )
@@ -435,8 +565,6 @@ class TritonAttnBackend(AttentionBackend):
435
565
  else:
436
566
  o = torch.empty_like(q)
437
567
 
438
- attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata
439
-
440
568
  if save_kv_cache:
441
569
  forward_batch.token_to_kv_pool.set_kv_buffer(
442
570
  layer, forward_batch.out_cache_loc, k, v
@@ -447,10 +575,12 @@ class TritonAttnBackend(AttentionBackend):
447
575
  forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
448
576
  forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
449
577
  o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
450
- kv_indptr,
451
- kv_indices,
452
- attn_logits,
453
- self.num_kv_splits,
578
+ self.forward_metadata.kv_indptr,
579
+ self.forward_metadata.kv_indices,
580
+ self.forward_metadata.attn_logits,
581
+ self.forward_metadata.attn_lse,
582
+ self.forward_metadata.num_kv_splits,
583
+ self.max_kv_splits,
454
584
  layer.scaling,
455
585
  layer.logit_cap,
456
586
  )
@@ -493,6 +623,9 @@ class TritonMultiStepDraftBackend:
493
623
  )
494
624
  )
495
625
  self.max_context_len = self.attn_backends[0].max_context_len
626
+ self.num_head = (
627
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
628
+ )
496
629
  self.device = model_runner.device
497
630
  # Cached variables for generate_draft_decode_kv_indices
498
631
  self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
@@ -531,7 +664,7 @@ class TritonMultiStepDraftBackend:
531
664
  call_fn(i, forward_batch)
532
665
 
533
666
  def init_forward_metadata(self, forward_batch: ForwardBatch):
534
- kv_indices = torch.zeros(
667
+ kv_indices = torch.empty(
535
668
  (
536
669
  self.speculative_num_steps,
537
670
  forward_batch.batch_size * self.topk * self.max_context_len,