sglang 0.4.2.post2__py3-none-any.whl → 0.4.2.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. sglang/check_env.py +1 -0
  2. sglang/srt/constrained/outlines_backend.py +4 -1
  3. sglang/srt/function_call_parser.py +96 -69
  4. sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
  5. sglang/srt/layers/attention/flashinfer_backend.py +34 -41
  6. sglang/srt/layers/attention/triton_backend.py +64 -16
  7. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
  8. sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
  9. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -5
  10. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  11. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  12. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  13. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  14. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  15. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  16. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  17. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  19. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  20. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  21. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  22. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/quantization/fp8_kernel.py +43 -10
  24. sglang/srt/lora/backend/__init__.py +25 -5
  25. sglang/srt/lora/backend/base_backend.py +31 -9
  26. sglang/srt/lora/backend/flashinfer_backend.py +41 -4
  27. sglang/srt/lora/backend/triton_backend.py +34 -4
  28. sglang/srt/lora/layers.py +293 -0
  29. sglang/srt/lora/lora.py +101 -326
  30. sglang/srt/lora/lora_manager.py +101 -269
  31. sglang/srt/lora/mem_pool.py +174 -0
  32. sglang/srt/lora/triton_ops/__init__.py +7 -1
  33. sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
  34. sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
  35. sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
  36. sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
  37. sglang/srt/lora/utils.py +141 -0
  38. sglang/srt/model_executor/cuda_graph_runner.py +4 -0
  39. sglang/srt/models/llama.py +8 -3
  40. sglang/srt/speculative/build_eagle_tree.py +482 -102
  41. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  42. sglang/srt/speculative/eagle_utils.py +134 -61
  43. sglang/srt/speculative/eagle_worker.py +1 -0
  44. sglang/version.py +1 -1
  45. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/METADATA +4 -4
  46. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/RECORD +49 -32
  47. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/LICENSE +0 -0
  48. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/WHEEL +0 -0
  49. {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/top_level.txt +0 -0
sglang/check_env.py CHANGED
@@ -19,6 +19,7 @@ def is_cuda_v2():
19
19
  # List of packages to check versions
20
20
  PACKAGE_LIST = [
21
21
  "sglang",
22
+ "sgl_kernel",
22
23
  "flashinfer",
23
24
  "triton",
24
25
  "transformers",
@@ -35,7 +35,10 @@ is_hip_ = is_hip()
35
35
  if is_hip_:
36
36
  from outlines_core.fsm.json_schema import build_regex_from_schema
37
37
  else:
38
- from outlines.fsm.json_schema import build_regex_from_schema
38
+ try:
39
+ from outlines.fsm.json_schema import build_regex_from_schema
40
+ except ImportError:
41
+ from outlines_core.fsm.json_schema import build_regex_from_schema
39
42
 
40
43
 
41
44
  logger = logging.getLogger(__name__)
@@ -1,4 +1,5 @@
1
1
  import json
2
+ import logging
2
3
  import re
3
4
  from abc import ABC, abstractmethod
4
5
  from json import JSONDecodeError, JSONDecoder
@@ -8,6 +9,8 @@ import partial_json_parser
8
9
  from partial_json_parser.core.options import Allow
9
10
  from pydantic import BaseModel, Field
10
11
 
12
+ logger = logging.getLogger(__name__)
13
+
11
14
  TOOLS_TAG_LIST = [
12
15
  "<|plugin|>",
13
16
  "<function=",
@@ -88,17 +91,43 @@ class BaseFormatDetector:
88
91
  self.bot_token = ""
89
92
  self.eot_token = ""
90
93
 
91
- def parse_base_json(self, action: Dict, tools: List[Function]):
92
- name, parameters = action["name"], json.dumps(
93
- action.get("parameters", action.get("arguments", {})),
94
- ensure_ascii=False,
95
- )
96
- tool_index = [tool.function.name for tool in tools].index(name)
97
- tool_call_item = ToolCallItem(
98
- tool_index=tool_index, name=name, parameters=parameters
99
- )
100
- calls = [tool_call_item]
101
- return calls
94
+ def parse_base_json(self, action: Any, tools: List[Function]) -> List[ToolCallItem]:
95
+ tool_indices = {
96
+ tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
97
+ }
98
+ if not isinstance(action, list):
99
+ name = action.get("name")
100
+ if not name or name not in tool_indices:
101
+ logger.warning(f"Model attempted to call undefined function: {name}")
102
+ return []
103
+
104
+ return [
105
+ ToolCallItem(
106
+ tool_index=tool_indices[name],
107
+ name=name,
108
+ parameters=json.dumps(
109
+ action.get("parameters") or action.get("arguments", {}),
110
+ ensure_ascii=False,
111
+ ),
112
+ )
113
+ ]
114
+
115
+ results = []
116
+ for act in action:
117
+ name = act.get("name")
118
+ if name and name in tool_indices:
119
+ results.append(
120
+ ToolCallItem(
121
+ tool_index=tool_indices[name],
122
+ name=name,
123
+ parameters=json.dumps(
124
+ act.get("parameters") or act.get("arguments", {}),
125
+ ensure_ascii=False,
126
+ ),
127
+ )
128
+ )
129
+
130
+ return results
102
131
 
103
132
  def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
104
133
  """
@@ -112,9 +141,7 @@ class BaseFormatDetector:
112
141
  self, new_text: str, tools: List[Function]
113
142
  ) -> StreamingParseResult:
114
143
  """
115
- Streaming incremental parsing, referencing the logic of Llama32Detector.
116
- We partially parse JSON within <tool_call>...</tool_call>, and handle
117
- incremental argument output.
144
+ Streaming incremental parsing with tool validation.
118
145
  """
119
146
  # Append new text to buffer
120
147
  self._buffer += new_text
@@ -125,17 +152,19 @@ class BaseFormatDetector:
125
152
  new_text = new_text.replace(self.eot_token, "")
126
153
  return StreamingParseResult(normal_text=new_text)
127
154
 
128
- # bit mask flags for partial JSON parsing. If the name hasn't been
129
- # sent yet, don't allow sending
130
- # an incomplete string since OpenAI only ever (as far as I have
131
- # seen) allows sending the entire tool/ function name at once.
155
+ # Build tool indices if not already built
156
+ if not hasattr(self, "_tool_indices"):
157
+ self._tool_indices = {
158
+ tool.function.name: i
159
+ for i, tool in enumerate(tools)
160
+ if tool.function and tool.function.name
161
+ }
162
+
132
163
  flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
133
164
  try:
134
165
  tool_call_arr = []
135
166
  is_complete = []
136
167
  try:
137
- # depending on the prompt format the Llama model may or may not
138
- # prefix the output with the <|python_tag|> token
139
168
  start_idx = (
140
169
  len(self.bot_token)
141
170
  if current_text.startswith(self.bot_token)
@@ -149,8 +178,18 @@ class BaseFormatDetector:
149
178
  _is_complete_json(current_text[start_idx : start_idx + end_idx])
150
179
  )
151
180
  start_idx += end_idx + len("; ")
152
- # depending on the prompt Llama can use
153
- # either arguments or parameters
181
+
182
+ # Validate tool name if present
183
+ if "name" in obj and obj["name"] not in self._tool_indices:
184
+ # Invalid tool name - reset state
185
+ self._buffer = ""
186
+ self.current_tool_id = -1
187
+ self.current_tool_name_sent = False
188
+ if self.streamed_args_for_tool:
189
+ self.streamed_args_for_tool.pop()
190
+ return StreamingParseResult()
191
+
192
+ # Handle parameters/arguments consistency
154
193
  if "parameters" in obj:
155
194
  assert (
156
195
  "arguments" not in obj
@@ -159,29 +198,17 @@ class BaseFormatDetector:
159
198
  tool_call_arr.append(obj)
160
199
 
161
200
  except partial_json_parser.core.exceptions.MalformedJSON:
162
- # not enough tokens to parse into JSON yet
163
201
  return StreamingParseResult()
164
202
 
165
- # select as the current tool call the one we're on the state at
166
- current_tool_call: Dict = (
167
- tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
168
- )
169
-
170
- # case -- if no tokens have been streamed for the tool, e.g.
171
- # only the array brackets, stream nothing
172
203
  if len(tool_call_arr) == 0:
173
204
  return StreamingParseResult()
174
205
 
175
- # case: we are starting a new tool in the array
176
- # -> array has > 0 length AND length has moved past cursor
177
- elif (
178
- len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
179
- ):
206
+ current_tool_call: Dict = (
207
+ tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
208
+ )
180
209
 
181
- # if we're moving on to a new call, first make sure we
182
- # haven't missed anything in the previous one that was
183
- # auto-generated due to JSON completions, but wasn't
184
- # streamed to the client yet.
210
+ # Handle new tool in array
211
+ if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
185
212
  if self.current_tool_id >= 0:
186
213
  cur_arguments = current_tool_call.get("arguments")
187
214
  if cur_arguments:
@@ -190,7 +217,6 @@ class BaseFormatDetector:
190
217
  argument_diff = cur_args_json[sent:]
191
218
 
192
219
  res = StreamingParseResult(
193
- normal_text=None,
194
220
  calls=[
195
221
  ToolCallItem(
196
222
  tool_index=self.current_tool_id,
@@ -206,23 +232,20 @@ class BaseFormatDetector:
206
232
  res = StreamingParseResult()
207
233
  else:
208
234
  res = StreamingParseResult()
209
- # re-set stuff pertaining to progress in the current tool
235
+
210
236
  self.current_tool_id = len(tool_call_arr) - 1
211
237
  self.current_tool_name_sent = False
212
238
  self.streamed_args_for_tool.append("")
213
- print("starting on new tool %d", self.current_tool_id)
214
239
  return res
215
240
 
216
- # if the current tool name hasn't been sent, send if available
217
- # - otherwise send nothing
241
+ # Handle tool name
218
242
  elif not self.current_tool_name_sent:
219
243
  function_name = current_tool_call.get("name")
220
- if function_name:
244
+ if function_name and function_name in self._tool_indices:
221
245
  res = StreamingParseResult(
222
- normal_text=None,
223
246
  calls=[
224
247
  ToolCallItem(
225
- tool_index=self.current_tool_id,
248
+ tool_index=self._tool_indices[function_name],
226
249
  name=function_name,
227
250
  parameters="",
228
251
  )
@@ -232,8 +255,7 @@ class BaseFormatDetector:
232
255
  else:
233
256
  res = StreamingParseResult()
234
257
 
235
- # now we know we're on the same tool call and we're streaming
236
- # arguments
258
+ # Handle streaming arguments
237
259
  else:
238
260
  cur_arguments = current_tool_call.get("arguments")
239
261
  res = StreamingParseResult()
@@ -250,13 +272,12 @@ class BaseFormatDetector:
250
272
  argument_diff = cur_args_json[sent:]
251
273
  self._buffer = ""
252
274
  self.prev_tool_call_arr[self.current_tool_id].clear()
253
- self.current_tool_name_sent: bool = False
275
+ self.current_tool_name_sent = False
254
276
  self.streamed_args_for_tool[self.current_tool_id] = ""
255
277
 
256
278
  elif prev_arguments:
257
279
  prev_args_json = json.dumps(prev_arguments)
258
280
  if cur_args_json != prev_args_json:
259
-
260
281
  prefix = _find_common_prefix(prev_args_json, cur_args_json)
261
282
  argument_diff = prefix[sent:]
262
283
 
@@ -279,8 +300,7 @@ class BaseFormatDetector:
279
300
  return res
280
301
 
281
302
  except Exception as e:
282
- print(e)
283
- # Skipping chunk as a result of tool streaming extraction error
303
+ logger.error(f"Error in parse_streaming_increment: {e}")
284
304
  return StreamingParseResult()
285
305
 
286
306
 
@@ -372,31 +392,38 @@ class Llama32Detector(BaseFormatDetector):
372
392
  Detector for Llama 3.2 models.
373
393
  Assumes function call format:
374
394
  <|python_tag|>{"name":"xxx", "arguments":{...}}
375
- Does not require a closing tag "</python_tag|>",
376
- relies on json.loads(...) success to determine if JSON is complete.
377
395
  """
378
396
 
379
397
  def __init__(self):
380
- """
381
- Initializes the detector with necessary state variables.
382
- """
383
398
  super().__init__()
384
399
  self.bot_token = "<|python_tag|>"
385
400
 
386
401
  def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
387
- """
388
- One-time parsing: Detects and parses tool calls in the provided text.
389
-
390
- :param text: The complete text to parse.
391
- :param tools: List of available tools.
392
- :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
393
- """
394
-
402
+ """Parse function calls from text, handling multiple JSON objects."""
395
403
  if "<|python_tag|>" not in text:
396
404
  return []
397
- _, action = text.split("<|python_tag|>")
398
- action = json.loads(action)
399
- return self.parse_base_json(action, tools)
405
+
406
+ _, action_text = text.split("<|python_tag|>")
407
+
408
+ # Split by semicolon and process each part
409
+ json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
410
+
411
+ all_actions = []
412
+ for part in json_parts:
413
+ try:
414
+ # Parse each individual JSON object
415
+ action = json.loads(part)
416
+ all_actions.append(action)
417
+ except json.JSONDecodeError as e:
418
+ logger.warning(f"Failed to parse JSON part: {part}")
419
+ logger.warning(f"JSON parse error: {str(e)}")
420
+ continue
421
+
422
+ # Only process if we found valid JSON objects
423
+ if all_actions:
424
+ return self.parse_base_json(all_actions, tools)
425
+
426
+ return []
400
427
 
401
428
 
402
429
  class MultiFormatParser:
@@ -17,12 +17,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
17
17
  def __init__(self, model_runner: ModelRunner):
18
18
  # Lazy import to avoid the initialization of cuda context
19
19
  from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
20
+ extend_attention_fwd,
20
21
  flash_decode_attention_fwd,
21
22
  flash_decode_sparse_attention_fwd,
22
23
  )
23
- from sglang.srt.layers.attention.triton_ops.extend_attention import (
24
- extend_attention_fwd,
25
- )
26
24
 
27
25
  super().__init__()
28
26
 
@@ -70,6 +70,8 @@ class FlashInferAttnBackend(AttentionBackend):
70
70
  ):
71
71
  super().__init__()
72
72
 
73
+ self.is_multimodal = model_runner.model_config.is_multimodal
74
+
73
75
  # Parse constants
74
76
  self.decode_use_tensor_cores = should_use_tensor_core(
75
77
  kv_cache_dtype=model_runner.kv_cache_dtype,
@@ -130,12 +132,8 @@ class FlashInferAttnBackend(AttentionBackend):
130
132
  for _ in range(self.num_wrappers)
131
133
  ]
132
134
 
133
- # Create wrappers
134
- # NOTE: we do not use ragged attention when there are multiple wrappers
135
- self.prefill_wrapper_ragged = (
136
- BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
137
- if self.num_wrappers == 1
138
- else None
135
+ self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
136
+ self.workspace_buffer, "NHD"
139
137
  )
140
138
 
141
139
  # Two wrappers: one for sliding window attention and one for full attention.
@@ -217,13 +215,12 @@ class FlashInferAttnBackend(AttentionBackend):
217
215
  else:
218
216
  prefix_lens = forward_batch.extend_prefix_lens
219
217
 
220
- # Some heuristics to check whether to use ragged forward
221
- if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
222
- use_ragged = True
223
- extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
224
- else:
218
+ if self.is_multimodal:
225
219
  use_ragged = False
226
220
  extend_no_prefix = False
221
+ else:
222
+ use_ragged = True
223
+ extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
227
224
 
228
225
  self.indices_updater_prefill.update(
229
226
  forward_batch.req_pool_indices,
@@ -409,9 +406,9 @@ class FlashInferAttnBackend(AttentionBackend):
409
406
  )
410
407
  else:
411
408
  o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
412
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
413
- k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
414
- v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
409
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
410
+ k.view(-1, layer.tp_k_head_num, layer.head_dim),
411
+ v.view(-1, layer.tp_v_head_num, layer.head_dim),
415
412
  causal=True,
416
413
  sm_scale=layer.scaling,
417
414
  logits_soft_cap=logits_soft_cap,
@@ -640,7 +637,6 @@ class FlashInferIndicesUpdaterDecode:
640
637
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
641
638
  bs = kv_indptr.shape[0] - 1
642
639
 
643
- wrapper.end_forward()
644
640
  wrapper.begin_forward(
645
641
  kv_indptr,
646
642
  kv_indices,
@@ -651,6 +647,7 @@ class FlashInferIndicesUpdaterDecode:
651
647
  1,
652
648
  data_type=self.data_type,
653
649
  q_data_type=self.q_data_type,
650
+ non_blocking=True,
654
651
  )
655
652
 
656
653
 
@@ -860,7 +857,6 @@ class FlashInferIndicesUpdaterPrefill:
860
857
 
861
858
  # extend part
862
859
  if use_ragged:
863
- wrapper_ragged.end_forward()
864
860
  wrapper_ragged.begin_forward(
865
861
  qo_indptr,
866
862
  qo_indptr,
@@ -871,7 +867,6 @@ class FlashInferIndicesUpdaterPrefill:
871
867
  )
872
868
 
873
869
  # cached part
874
- wrapper_paged.end_forward()
875
870
  wrapper_paged.begin_forward(
876
871
  qo_indptr,
877
872
  kv_indptr,
@@ -883,6 +878,7 @@ class FlashInferIndicesUpdaterPrefill:
883
878
  1,
884
879
  q_data_type=self.q_data_type,
885
880
  custom_mask=custom_mask,
881
+ non_blocking=True,
886
882
  )
887
883
 
888
884
 
@@ -924,38 +920,50 @@ class FlashInferMultiStepDraftBackend:
924
920
  self.max_context_len = self.attn_backends[0].max_context_len
925
921
  # Cached variables for generate_draft_decode_kv_indices
926
922
  self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
927
- self.kv_indptr_stride = self.kv_indptr.shape[1]
928
923
 
929
- def common_template(self, forward_batch: ForwardBatch, call_fn: int):
924
+ def common_template(
925
+ self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
926
+ ):
930
927
  num_seqs = forward_batch.batch_size
931
928
  bs = self.topk * num_seqs
932
929
  seq_lens_sum = forward_batch.seq_lens_sum
930
+
933
931
  self.generate_draft_decode_kv_indices[
934
932
  (self.speculative_num_steps, num_seqs, self.topk)
935
933
  ](
936
934
  forward_batch.req_pool_indices,
937
935
  forward_batch.req_to_token_pool.req_to_token,
938
936
  forward_batch.seq_lens,
939
- self.cuda_graph_kv_indices,
937
+ kv_indices_buffer,
940
938
  self.kv_indptr,
941
939
  forward_batch.positions,
942
940
  num_seqs,
943
941
  self.topk,
944
942
  self.pool_len,
945
- self.kv_indptr_stride,
943
+ kv_indices_buffer.shape[1],
946
944
  self.kv_indptr.shape[1],
947
945
  triton.next_power_of_2(num_seqs),
948
946
  triton.next_power_of_2(self.speculative_num_steps),
949
947
  triton.next_power_of_2(bs),
950
948
  )
949
+
951
950
  for i in range(self.speculative_num_steps):
952
951
  forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
953
- forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][
952
+ forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
954
953
  : seq_lens_sum * self.topk + bs * (i + 1)
955
954
  ]
956
955
  call_fn(i, forward_batch)
957
956
 
958
957
  def init_forward_metadata(self, forward_batch: ForwardBatch):
958
+ kv_indices = torch.zeros(
959
+ (
960
+ self.speculative_num_steps,
961
+ forward_batch.batch_size * self.topk * self.max_context_len,
962
+ ),
963
+ dtype=torch.int32,
964
+ device="cuda",
965
+ )
966
+
959
967
  def call_fn(i, forward_batch):
960
968
  forward_batch.spec_info.kv_indptr = (
961
969
  forward_batch.spec_info.kv_indptr.clone()
@@ -965,7 +973,7 @@ class FlashInferMultiStepDraftBackend:
965
973
  )
966
974
  self.attn_backends[i].init_forward_metadata(forward_batch)
967
975
 
968
- self.common_template(forward_batch, call_fn)
976
+ self.common_template(forward_batch, kv_indices, call_fn)
969
977
 
970
978
  def init_cuda_graph_state(self, max_bs: int):
971
979
  self.cuda_graph_kv_indices = torch.zeros(
@@ -973,7 +981,6 @@ class FlashInferMultiStepDraftBackend:
973
981
  dtype=torch.int32,
974
982
  device="cuda",
975
983
  )
976
- self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1]
977
984
  for i in range(self.speculative_num_steps):
978
985
  self.attn_backends[i].init_cuda_graph_state(
979
986
  max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
@@ -995,7 +1002,7 @@ class FlashInferMultiStepDraftBackend:
995
1002
  ][0]
996
1003
  decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
997
1004
 
998
- self.common_template(forward_batch, call_fn)
1005
+ self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
999
1006
 
1000
1007
  def init_forward_metadata_replay_cuda_graph(self, forward_batch):
1001
1008
  def call_fn(i, forward_batch):
@@ -1009,7 +1016,7 @@ class FlashInferMultiStepDraftBackend:
1009
1016
  spec_info=forward_batch.spec_info,
1010
1017
  )
1011
1018
 
1012
- self.common_template(forward_batch, call_fn)
1019
+ self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
1013
1020
 
1014
1021
 
1015
1022
  @triton.jit
@@ -1070,21 +1077,6 @@ def should_use_tensor_core(
1070
1077
  if env_override is not None:
1071
1078
  return env_override.lower() == "true"
1072
1079
 
1073
- # Try to use _grouped_size_compiled_for_decode_kernels if available
1074
- # This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
1075
- try:
1076
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
1077
-
1078
- if not _grouped_size_compiled_for_decode_kernels(
1079
- num_attention_heads,
1080
- num_kv_heads,
1081
- ):
1082
- return True
1083
- else:
1084
- return False
1085
- except (ImportError, AttributeError):
1086
- pass
1087
-
1088
1080
  # Calculate GQA group size
1089
1081
  gqa_group_size = num_attention_heads // num_kv_heads
1090
1082
 
@@ -1114,6 +1106,7 @@ def fast_decode_plan(
1114
1106
  sm_scale: Optional[float] = None,
1115
1107
  rope_scale: Optional[float] = None,
1116
1108
  rope_theta: Optional[float] = None,
1109
+ **kwargs,
1117
1110
  ) -> None:
1118
1111
  """A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
1119
1112
  batch_size = len(last_page_len)
@@ -37,6 +37,9 @@ class TritonAttnBackend(AttentionBackend):
37
37
  (max_bs + 1,), dtype=torch.int32, device=model_runner.device
38
38
  )
39
39
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
40
+ self.qo_indptr = torch.zeros(
41
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
42
+ )
40
43
 
41
44
  self.num_head = (
42
45
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
@@ -54,6 +57,9 @@ class TritonAttnBackend(AttentionBackend):
54
57
  def init_forward_metadata(self, forward_batch: ForwardBatch):
55
58
  """Init auxiliary variables for triton attention backend."""
56
59
 
60
+ bs = forward_batch.batch_size
61
+ kv_indptr = self.kv_indptr
62
+
57
63
  if forward_batch.forward_mode.is_decode():
58
64
  attn_logits = torch.empty(
59
65
  (
@@ -68,31 +74,62 @@ class TritonAttnBackend(AttentionBackend):
68
74
 
69
75
  max_extend_len = None
70
76
 
71
- kv_indptr = self.kv_indptr
72
- bs = len(forward_batch.req_pool_indices)
73
77
  kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
74
78
  kv_indptr = kv_indptr[: bs + 1]
75
79
  kv_indices = torch.empty(
76
- forward_batch.seq_lens_sum, dtype=torch.int32, device="cuda"
80
+ forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
77
81
  )
78
82
  create_flashinfer_kv_indices_triton[(bs,)](
79
- forward_batch.req_to_token_pool.req_to_token,
83
+ self.req_to_token,
80
84
  forward_batch.req_pool_indices,
81
85
  forward_batch.seq_lens,
82
86
  kv_indptr,
83
87
  None,
84
88
  kv_indices,
85
- forward_batch.req_to_token_pool.req_to_token.stride(0),
89
+ self.req_to_token.stride(0),
86
90
  )
87
91
 
92
+ qo_indptr = None
93
+ custom_mask = None
94
+ mask_offsets = None
88
95
  else:
96
+ kv_indptr[1 : bs + 1] = torch.cumsum(
97
+ forward_batch.extend_prefix_lens, dim=0
98
+ )
99
+ kv_indptr = kv_indptr[: bs + 1]
100
+ kv_indices = torch.empty(
101
+ forward_batch.extend_prefix_lens.sum().item(),
102
+ dtype=torch.int32,
103
+ device=self.device,
104
+ )
105
+ create_flashinfer_kv_indices_triton[(bs,)](
106
+ self.req_to_token,
107
+ forward_batch.req_pool_indices,
108
+ forward_batch.extend_prefix_lens,
109
+ kv_indptr,
110
+ None,
111
+ kv_indices,
112
+ self.req_to_token.stride(0),
113
+ )
114
+
115
+ qo_indptr = self.qo_indptr
116
+ qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
117
+ qo_indptr = qo_indptr[: bs + 1]
118
+ custom_mask = None
119
+ mask_offsets = None
120
+
89
121
  attn_logits = None
90
122
  max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
91
123
 
92
- kv_indptr = None
93
- kv_indices = None
94
-
95
- self.forward_metadata = attn_logits, max_extend_len, kv_indptr, kv_indices
124
+ self.forward_metadata = (
125
+ attn_logits,
126
+ max_extend_len,
127
+ kv_indptr,
128
+ kv_indices,
129
+ qo_indptr,
130
+ custom_mask,
131
+ mask_offsets,
132
+ )
96
133
 
97
134
  def init_cuda_graph_state(self, max_bs: int):
98
135
  self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
@@ -144,6 +181,9 @@ class TritonAttnBackend(AttentionBackend):
144
181
  None,
145
182
  kv_indptr,
146
183
  kv_indices,
184
+ None,
185
+ None,
186
+ None,
147
187
  )
148
188
 
149
189
  def init_forward_metadata_replay_cuda_graph(
@@ -197,7 +237,15 @@ class TritonAttnBackend(AttentionBackend):
197
237
  layer, forward_batch.out_cache_loc, k, v
198
238
  )
199
239
 
200
- _, max_extend_len, _, _ = self.forward_metadata
240
+ (
241
+ _,
242
+ max_extend_len,
243
+ kv_indptr,
244
+ kv_indices,
245
+ qo_indptr,
246
+ custom_mask,
247
+ mask_offsets,
248
+ ) = self.forward_metadata
201
249
  self.extend_attention_fwd(
202
250
  q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
203
251
  k.contiguous(),
@@ -205,11 +253,11 @@ class TritonAttnBackend(AttentionBackend):
205
253
  o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
206
254
  forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
207
255
  forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
208
- forward_batch.req_to_token_pool.req_to_token,
209
- forward_batch.req_pool_indices,
210
- forward_batch.seq_lens,
211
- forward_batch.extend_seq_lens,
212
- forward_batch.extend_start_loc,
256
+ qo_indptr,
257
+ kv_indptr,
258
+ kv_indices,
259
+ custom_mask,
260
+ mask_offsets,
213
261
  max_extend_len,
214
262
  layer.scaling,
215
263
  layer.logit_cap,
@@ -235,7 +283,7 @@ class TritonAttnBackend(AttentionBackend):
235
283
  else:
236
284
  o = torch.empty_like(q)
237
285
 
238
- attn_logits, _, kv_indptr, kv_indices = self.forward_metadata
286
+ attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata
239
287
 
240
288
  if save_kv_cache:
241
289
  forward_batch.token_to_kv_pool.set_kv_buffer(