sglang 0.4.2.post2__py3-none-any.whl → 0.4.2.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/check_env.py +1 -0
- sglang/srt/constrained/outlines_backend.py +4 -1
- sglang/srt/function_call_parser.py +96 -69
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
- sglang/srt/layers/attention/flashinfer_backend.py +34 -41
- sglang/srt/layers/attention/triton_backend.py +64 -16
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -5
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8_kernel.py +43 -10
- sglang/srt/lora/backend/__init__.py +25 -5
- sglang/srt/lora/backend/base_backend.py +31 -9
- sglang/srt/lora/backend/flashinfer_backend.py +41 -4
- sglang/srt/lora/backend/triton_backend.py +34 -4
- sglang/srt/lora/layers.py +293 -0
- sglang/srt/lora/lora.py +101 -326
- sglang/srt/lora/lora_manager.py +101 -269
- sglang/srt/lora/mem_pool.py +174 -0
- sglang/srt/lora/triton_ops/__init__.py +7 -1
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
- sglang/srt/lora/utils.py +141 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -0
- sglang/srt/models/llama.py +8 -3
- sglang/srt/speculative/build_eagle_tree.py +482 -102
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +134 -61
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/METADATA +4 -4
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/RECORD +49 -32
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post2.dist-info → sglang-0.4.2.post4.dist-info}/top_level.txt +0 -0
sglang/check_env.py
CHANGED
@@ -35,7 +35,10 @@ is_hip_ = is_hip()
|
|
35
35
|
if is_hip_:
|
36
36
|
from outlines_core.fsm.json_schema import build_regex_from_schema
|
37
37
|
else:
|
38
|
-
|
38
|
+
try:
|
39
|
+
from outlines.fsm.json_schema import build_regex_from_schema
|
40
|
+
except ImportError:
|
41
|
+
from outlines_core.fsm.json_schema import build_regex_from_schema
|
39
42
|
|
40
43
|
|
41
44
|
logger = logging.getLogger(__name__)
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import json
|
2
|
+
import logging
|
2
3
|
import re
|
3
4
|
from abc import ABC, abstractmethod
|
4
5
|
from json import JSONDecodeError, JSONDecoder
|
@@ -8,6 +9,8 @@ import partial_json_parser
|
|
8
9
|
from partial_json_parser.core.options import Allow
|
9
10
|
from pydantic import BaseModel, Field
|
10
11
|
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
11
14
|
TOOLS_TAG_LIST = [
|
12
15
|
"<|plugin|>",
|
13
16
|
"<function=",
|
@@ -88,17 +91,43 @@ class BaseFormatDetector:
|
|
88
91
|
self.bot_token = ""
|
89
92
|
self.eot_token = ""
|
90
93
|
|
91
|
-
def parse_base_json(self, action:
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
)
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
94
|
+
def parse_base_json(self, action: Any, tools: List[Function]) -> List[ToolCallItem]:
|
95
|
+
tool_indices = {
|
96
|
+
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
|
97
|
+
}
|
98
|
+
if not isinstance(action, list):
|
99
|
+
name = action.get("name")
|
100
|
+
if not name or name not in tool_indices:
|
101
|
+
logger.warning(f"Model attempted to call undefined function: {name}")
|
102
|
+
return []
|
103
|
+
|
104
|
+
return [
|
105
|
+
ToolCallItem(
|
106
|
+
tool_index=tool_indices[name],
|
107
|
+
name=name,
|
108
|
+
parameters=json.dumps(
|
109
|
+
action.get("parameters") or action.get("arguments", {}),
|
110
|
+
ensure_ascii=False,
|
111
|
+
),
|
112
|
+
)
|
113
|
+
]
|
114
|
+
|
115
|
+
results = []
|
116
|
+
for act in action:
|
117
|
+
name = act.get("name")
|
118
|
+
if name and name in tool_indices:
|
119
|
+
results.append(
|
120
|
+
ToolCallItem(
|
121
|
+
tool_index=tool_indices[name],
|
122
|
+
name=name,
|
123
|
+
parameters=json.dumps(
|
124
|
+
act.get("parameters") or act.get("arguments", {}),
|
125
|
+
ensure_ascii=False,
|
126
|
+
),
|
127
|
+
)
|
128
|
+
)
|
129
|
+
|
130
|
+
return results
|
102
131
|
|
103
132
|
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
104
133
|
"""
|
@@ -112,9 +141,7 @@ class BaseFormatDetector:
|
|
112
141
|
self, new_text: str, tools: List[Function]
|
113
142
|
) -> StreamingParseResult:
|
114
143
|
"""
|
115
|
-
Streaming incremental parsing
|
116
|
-
We partially parse JSON within <tool_call>...</tool_call>, and handle
|
117
|
-
incremental argument output.
|
144
|
+
Streaming incremental parsing with tool validation.
|
118
145
|
"""
|
119
146
|
# Append new text to buffer
|
120
147
|
self._buffer += new_text
|
@@ -125,17 +152,19 @@ class BaseFormatDetector:
|
|
125
152
|
new_text = new_text.replace(self.eot_token, "")
|
126
153
|
return StreamingParseResult(normal_text=new_text)
|
127
154
|
|
128
|
-
#
|
129
|
-
|
130
|
-
|
131
|
-
|
155
|
+
# Build tool indices if not already built
|
156
|
+
if not hasattr(self, "_tool_indices"):
|
157
|
+
self._tool_indices = {
|
158
|
+
tool.function.name: i
|
159
|
+
for i, tool in enumerate(tools)
|
160
|
+
if tool.function and tool.function.name
|
161
|
+
}
|
162
|
+
|
132
163
|
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
133
164
|
try:
|
134
165
|
tool_call_arr = []
|
135
166
|
is_complete = []
|
136
167
|
try:
|
137
|
-
# depending on the prompt format the Llama model may or may not
|
138
|
-
# prefix the output with the <|python_tag|> token
|
139
168
|
start_idx = (
|
140
169
|
len(self.bot_token)
|
141
170
|
if current_text.startswith(self.bot_token)
|
@@ -149,8 +178,18 @@ class BaseFormatDetector:
|
|
149
178
|
_is_complete_json(current_text[start_idx : start_idx + end_idx])
|
150
179
|
)
|
151
180
|
start_idx += end_idx + len("; ")
|
152
|
-
|
153
|
-
#
|
181
|
+
|
182
|
+
# Validate tool name if present
|
183
|
+
if "name" in obj and obj["name"] not in self._tool_indices:
|
184
|
+
# Invalid tool name - reset state
|
185
|
+
self._buffer = ""
|
186
|
+
self.current_tool_id = -1
|
187
|
+
self.current_tool_name_sent = False
|
188
|
+
if self.streamed_args_for_tool:
|
189
|
+
self.streamed_args_for_tool.pop()
|
190
|
+
return StreamingParseResult()
|
191
|
+
|
192
|
+
# Handle parameters/arguments consistency
|
154
193
|
if "parameters" in obj:
|
155
194
|
assert (
|
156
195
|
"arguments" not in obj
|
@@ -159,29 +198,17 @@ class BaseFormatDetector:
|
|
159
198
|
tool_call_arr.append(obj)
|
160
199
|
|
161
200
|
except partial_json_parser.core.exceptions.MalformedJSON:
|
162
|
-
# not enough tokens to parse into JSON yet
|
163
201
|
return StreamingParseResult()
|
164
202
|
|
165
|
-
# select as the current tool call the one we're on the state at
|
166
|
-
current_tool_call: Dict = (
|
167
|
-
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
168
|
-
)
|
169
|
-
|
170
|
-
# case -- if no tokens have been streamed for the tool, e.g.
|
171
|
-
# only the array brackets, stream nothing
|
172
203
|
if len(tool_call_arr) == 0:
|
173
204
|
return StreamingParseResult()
|
174
205
|
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
|
179
|
-
):
|
206
|
+
current_tool_call: Dict = (
|
207
|
+
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
208
|
+
)
|
180
209
|
|
181
|
-
|
182
|
-
|
183
|
-
# auto-generated due to JSON completions, but wasn't
|
184
|
-
# streamed to the client yet.
|
210
|
+
# Handle new tool in array
|
211
|
+
if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
|
185
212
|
if self.current_tool_id >= 0:
|
186
213
|
cur_arguments = current_tool_call.get("arguments")
|
187
214
|
if cur_arguments:
|
@@ -190,7 +217,6 @@ class BaseFormatDetector:
|
|
190
217
|
argument_diff = cur_args_json[sent:]
|
191
218
|
|
192
219
|
res = StreamingParseResult(
|
193
|
-
normal_text=None,
|
194
220
|
calls=[
|
195
221
|
ToolCallItem(
|
196
222
|
tool_index=self.current_tool_id,
|
@@ -206,23 +232,20 @@ class BaseFormatDetector:
|
|
206
232
|
res = StreamingParseResult()
|
207
233
|
else:
|
208
234
|
res = StreamingParseResult()
|
209
|
-
|
235
|
+
|
210
236
|
self.current_tool_id = len(tool_call_arr) - 1
|
211
237
|
self.current_tool_name_sent = False
|
212
238
|
self.streamed_args_for_tool.append("")
|
213
|
-
print("starting on new tool %d", self.current_tool_id)
|
214
239
|
return res
|
215
240
|
|
216
|
-
#
|
217
|
-
# - otherwise send nothing
|
241
|
+
# Handle tool name
|
218
242
|
elif not self.current_tool_name_sent:
|
219
243
|
function_name = current_tool_call.get("name")
|
220
|
-
if function_name:
|
244
|
+
if function_name and function_name in self._tool_indices:
|
221
245
|
res = StreamingParseResult(
|
222
|
-
normal_text=None,
|
223
246
|
calls=[
|
224
247
|
ToolCallItem(
|
225
|
-
tool_index=self.
|
248
|
+
tool_index=self._tool_indices[function_name],
|
226
249
|
name=function_name,
|
227
250
|
parameters="",
|
228
251
|
)
|
@@ -232,8 +255,7 @@ class BaseFormatDetector:
|
|
232
255
|
else:
|
233
256
|
res = StreamingParseResult()
|
234
257
|
|
235
|
-
#
|
236
|
-
# arguments
|
258
|
+
# Handle streaming arguments
|
237
259
|
else:
|
238
260
|
cur_arguments = current_tool_call.get("arguments")
|
239
261
|
res = StreamingParseResult()
|
@@ -250,13 +272,12 @@ class BaseFormatDetector:
|
|
250
272
|
argument_diff = cur_args_json[sent:]
|
251
273
|
self._buffer = ""
|
252
274
|
self.prev_tool_call_arr[self.current_tool_id].clear()
|
253
|
-
self.current_tool_name_sent
|
275
|
+
self.current_tool_name_sent = False
|
254
276
|
self.streamed_args_for_tool[self.current_tool_id] = ""
|
255
277
|
|
256
278
|
elif prev_arguments:
|
257
279
|
prev_args_json = json.dumps(prev_arguments)
|
258
280
|
if cur_args_json != prev_args_json:
|
259
|
-
|
260
281
|
prefix = _find_common_prefix(prev_args_json, cur_args_json)
|
261
282
|
argument_diff = prefix[sent:]
|
262
283
|
|
@@ -279,8 +300,7 @@ class BaseFormatDetector:
|
|
279
300
|
return res
|
280
301
|
|
281
302
|
except Exception as e:
|
282
|
-
|
283
|
-
# Skipping chunk as a result of tool streaming extraction error
|
303
|
+
logger.error(f"Error in parse_streaming_increment: {e}")
|
284
304
|
return StreamingParseResult()
|
285
305
|
|
286
306
|
|
@@ -372,31 +392,38 @@ class Llama32Detector(BaseFormatDetector):
|
|
372
392
|
Detector for Llama 3.2 models.
|
373
393
|
Assumes function call format:
|
374
394
|
<|python_tag|>{"name":"xxx", "arguments":{...}}
|
375
|
-
Does not require a closing tag "</python_tag|>",
|
376
|
-
relies on json.loads(...) success to determine if JSON is complete.
|
377
395
|
"""
|
378
396
|
|
379
397
|
def __init__(self):
|
380
|
-
"""
|
381
|
-
Initializes the detector with necessary state variables.
|
382
|
-
"""
|
383
398
|
super().__init__()
|
384
399
|
self.bot_token = "<|python_tag|>"
|
385
400
|
|
386
401
|
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
387
|
-
"""
|
388
|
-
One-time parsing: Detects and parses tool calls in the provided text.
|
389
|
-
|
390
|
-
:param text: The complete text to parse.
|
391
|
-
:param tools: List of available tools.
|
392
|
-
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
393
|
-
"""
|
394
|
-
|
402
|
+
"""Parse function calls from text, handling multiple JSON objects."""
|
395
403
|
if "<|python_tag|>" not in text:
|
396
404
|
return []
|
397
|
-
|
398
|
-
|
399
|
-
|
405
|
+
|
406
|
+
_, action_text = text.split("<|python_tag|>")
|
407
|
+
|
408
|
+
# Split by semicolon and process each part
|
409
|
+
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
|
410
|
+
|
411
|
+
all_actions = []
|
412
|
+
for part in json_parts:
|
413
|
+
try:
|
414
|
+
# Parse each individual JSON object
|
415
|
+
action = json.loads(part)
|
416
|
+
all_actions.append(action)
|
417
|
+
except json.JSONDecodeError as e:
|
418
|
+
logger.warning(f"Failed to parse JSON part: {part}")
|
419
|
+
logger.warning(f"JSON parse error: {str(e)}")
|
420
|
+
continue
|
421
|
+
|
422
|
+
# Only process if we found valid JSON objects
|
423
|
+
if all_actions:
|
424
|
+
return self.parse_base_json(all_actions, tools)
|
425
|
+
|
426
|
+
return []
|
400
427
|
|
401
428
|
|
402
429
|
class MultiFormatParser:
|
@@ -17,12 +17,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
17
17
|
def __init__(self, model_runner: ModelRunner):
|
18
18
|
# Lazy import to avoid the initialization of cuda context
|
19
19
|
from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
|
20
|
+
extend_attention_fwd,
|
20
21
|
flash_decode_attention_fwd,
|
21
22
|
flash_decode_sparse_attention_fwd,
|
22
23
|
)
|
23
|
-
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
24
|
-
extend_attention_fwd,
|
25
|
-
)
|
26
24
|
|
27
25
|
super().__init__()
|
28
26
|
|
@@ -70,6 +70,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
70
70
|
):
|
71
71
|
super().__init__()
|
72
72
|
|
73
|
+
self.is_multimodal = model_runner.model_config.is_multimodal
|
74
|
+
|
73
75
|
# Parse constants
|
74
76
|
self.decode_use_tensor_cores = should_use_tensor_core(
|
75
77
|
kv_cache_dtype=model_runner.kv_cache_dtype,
|
@@ -130,12 +132,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
130
132
|
for _ in range(self.num_wrappers)
|
131
133
|
]
|
132
134
|
|
133
|
-
|
134
|
-
|
135
|
-
self.prefill_wrapper_ragged = (
|
136
|
-
BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
|
137
|
-
if self.num_wrappers == 1
|
138
|
-
else None
|
135
|
+
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
136
|
+
self.workspace_buffer, "NHD"
|
139
137
|
)
|
140
138
|
|
141
139
|
# Two wrappers: one for sliding window attention and one for full attention.
|
@@ -217,13 +215,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
217
215
|
else:
|
218
216
|
prefix_lens = forward_batch.extend_prefix_lens
|
219
217
|
|
220
|
-
|
221
|
-
if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
|
222
|
-
use_ragged = True
|
223
|
-
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
224
|
-
else:
|
218
|
+
if self.is_multimodal:
|
225
219
|
use_ragged = False
|
226
220
|
extend_no_prefix = False
|
221
|
+
else:
|
222
|
+
use_ragged = True
|
223
|
+
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
227
224
|
|
228
225
|
self.indices_updater_prefill.update(
|
229
226
|
forward_batch.req_pool_indices,
|
@@ -409,9 +406,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
409
406
|
)
|
410
407
|
else:
|
411
408
|
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
412
|
-
q.
|
413
|
-
k.
|
414
|
-
v.
|
409
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
410
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
411
|
+
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
415
412
|
causal=True,
|
416
413
|
sm_scale=layer.scaling,
|
417
414
|
logits_soft_cap=logits_soft_cap,
|
@@ -640,7 +637,6 @@ class FlashInferIndicesUpdaterDecode:
|
|
640
637
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
641
638
|
bs = kv_indptr.shape[0] - 1
|
642
639
|
|
643
|
-
wrapper.end_forward()
|
644
640
|
wrapper.begin_forward(
|
645
641
|
kv_indptr,
|
646
642
|
kv_indices,
|
@@ -651,6 +647,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
651
647
|
1,
|
652
648
|
data_type=self.data_type,
|
653
649
|
q_data_type=self.q_data_type,
|
650
|
+
non_blocking=True,
|
654
651
|
)
|
655
652
|
|
656
653
|
|
@@ -860,7 +857,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|
860
857
|
|
861
858
|
# extend part
|
862
859
|
if use_ragged:
|
863
|
-
wrapper_ragged.end_forward()
|
864
860
|
wrapper_ragged.begin_forward(
|
865
861
|
qo_indptr,
|
866
862
|
qo_indptr,
|
@@ -871,7 +867,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|
871
867
|
)
|
872
868
|
|
873
869
|
# cached part
|
874
|
-
wrapper_paged.end_forward()
|
875
870
|
wrapper_paged.begin_forward(
|
876
871
|
qo_indptr,
|
877
872
|
kv_indptr,
|
@@ -883,6 +878,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
883
878
|
1,
|
884
879
|
q_data_type=self.q_data_type,
|
885
880
|
custom_mask=custom_mask,
|
881
|
+
non_blocking=True,
|
886
882
|
)
|
887
883
|
|
888
884
|
|
@@ -924,38 +920,50 @@ class FlashInferMultiStepDraftBackend:
|
|
924
920
|
self.max_context_len = self.attn_backends[0].max_context_len
|
925
921
|
# Cached variables for generate_draft_decode_kv_indices
|
926
922
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
927
|
-
self.kv_indptr_stride = self.kv_indptr.shape[1]
|
928
923
|
|
929
|
-
def common_template(
|
924
|
+
def common_template(
|
925
|
+
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
926
|
+
):
|
930
927
|
num_seqs = forward_batch.batch_size
|
931
928
|
bs = self.topk * num_seqs
|
932
929
|
seq_lens_sum = forward_batch.seq_lens_sum
|
930
|
+
|
933
931
|
self.generate_draft_decode_kv_indices[
|
934
932
|
(self.speculative_num_steps, num_seqs, self.topk)
|
935
933
|
](
|
936
934
|
forward_batch.req_pool_indices,
|
937
935
|
forward_batch.req_to_token_pool.req_to_token,
|
938
936
|
forward_batch.seq_lens,
|
939
|
-
|
937
|
+
kv_indices_buffer,
|
940
938
|
self.kv_indptr,
|
941
939
|
forward_batch.positions,
|
942
940
|
num_seqs,
|
943
941
|
self.topk,
|
944
942
|
self.pool_len,
|
945
|
-
|
943
|
+
kv_indices_buffer.shape[1],
|
946
944
|
self.kv_indptr.shape[1],
|
947
945
|
triton.next_power_of_2(num_seqs),
|
948
946
|
triton.next_power_of_2(self.speculative_num_steps),
|
949
947
|
triton.next_power_of_2(bs),
|
950
948
|
)
|
949
|
+
|
951
950
|
for i in range(self.speculative_num_steps):
|
952
951
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
953
|
-
forward_batch.spec_info.kv_indices =
|
952
|
+
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
954
953
|
: seq_lens_sum * self.topk + bs * (i + 1)
|
955
954
|
]
|
956
955
|
call_fn(i, forward_batch)
|
957
956
|
|
958
957
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
958
|
+
kv_indices = torch.zeros(
|
959
|
+
(
|
960
|
+
self.speculative_num_steps,
|
961
|
+
forward_batch.batch_size * self.topk * self.max_context_len,
|
962
|
+
),
|
963
|
+
dtype=torch.int32,
|
964
|
+
device="cuda",
|
965
|
+
)
|
966
|
+
|
959
967
|
def call_fn(i, forward_batch):
|
960
968
|
forward_batch.spec_info.kv_indptr = (
|
961
969
|
forward_batch.spec_info.kv_indptr.clone()
|
@@ -965,7 +973,7 @@ class FlashInferMultiStepDraftBackend:
|
|
965
973
|
)
|
966
974
|
self.attn_backends[i].init_forward_metadata(forward_batch)
|
967
975
|
|
968
|
-
self.common_template(forward_batch, call_fn)
|
976
|
+
self.common_template(forward_batch, kv_indices, call_fn)
|
969
977
|
|
970
978
|
def init_cuda_graph_state(self, max_bs: int):
|
971
979
|
self.cuda_graph_kv_indices = torch.zeros(
|
@@ -973,7 +981,6 @@ class FlashInferMultiStepDraftBackend:
|
|
973
981
|
dtype=torch.int32,
|
974
982
|
device="cuda",
|
975
983
|
)
|
976
|
-
self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1]
|
977
984
|
for i in range(self.speculative_num_steps):
|
978
985
|
self.attn_backends[i].init_cuda_graph_state(
|
979
986
|
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
@@ -995,7 +1002,7 @@ class FlashInferMultiStepDraftBackend:
|
|
995
1002
|
][0]
|
996
1003
|
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
|
997
1004
|
|
998
|
-
self.common_template(forward_batch, call_fn)
|
1005
|
+
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
999
1006
|
|
1000
1007
|
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
|
1001
1008
|
def call_fn(i, forward_batch):
|
@@ -1009,7 +1016,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1009
1016
|
spec_info=forward_batch.spec_info,
|
1010
1017
|
)
|
1011
1018
|
|
1012
|
-
self.common_template(forward_batch, call_fn)
|
1019
|
+
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
1013
1020
|
|
1014
1021
|
|
1015
1022
|
@triton.jit
|
@@ -1070,21 +1077,6 @@ def should_use_tensor_core(
|
|
1070
1077
|
if env_override is not None:
|
1071
1078
|
return env_override.lower() == "true"
|
1072
1079
|
|
1073
|
-
# Try to use _grouped_size_compiled_for_decode_kernels if available
|
1074
|
-
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
|
1075
|
-
try:
|
1076
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
1077
|
-
|
1078
|
-
if not _grouped_size_compiled_for_decode_kernels(
|
1079
|
-
num_attention_heads,
|
1080
|
-
num_kv_heads,
|
1081
|
-
):
|
1082
|
-
return True
|
1083
|
-
else:
|
1084
|
-
return False
|
1085
|
-
except (ImportError, AttributeError):
|
1086
|
-
pass
|
1087
|
-
|
1088
1080
|
# Calculate GQA group size
|
1089
1081
|
gqa_group_size = num_attention_heads // num_kv_heads
|
1090
1082
|
|
@@ -1114,6 +1106,7 @@ def fast_decode_plan(
|
|
1114
1106
|
sm_scale: Optional[float] = None,
|
1115
1107
|
rope_scale: Optional[float] = None,
|
1116
1108
|
rope_theta: Optional[float] = None,
|
1109
|
+
**kwargs,
|
1117
1110
|
) -> None:
|
1118
1111
|
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
|
1119
1112
|
batch_size = len(last_page_len)
|
@@ -37,6 +37,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
37
37
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
38
38
|
)
|
39
39
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
40
|
+
self.qo_indptr = torch.zeros(
|
41
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
42
|
+
)
|
40
43
|
|
41
44
|
self.num_head = (
|
42
45
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
@@ -54,6 +57,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
54
57
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
55
58
|
"""Init auxiliary variables for triton attention backend."""
|
56
59
|
|
60
|
+
bs = forward_batch.batch_size
|
61
|
+
kv_indptr = self.kv_indptr
|
62
|
+
|
57
63
|
if forward_batch.forward_mode.is_decode():
|
58
64
|
attn_logits = torch.empty(
|
59
65
|
(
|
@@ -68,31 +74,62 @@ class TritonAttnBackend(AttentionBackend):
|
|
68
74
|
|
69
75
|
max_extend_len = None
|
70
76
|
|
71
|
-
kv_indptr = self.kv_indptr
|
72
|
-
bs = len(forward_batch.req_pool_indices)
|
73
77
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
74
78
|
kv_indptr = kv_indptr[: bs + 1]
|
75
79
|
kv_indices = torch.empty(
|
76
|
-
forward_batch.seq_lens_sum, dtype=torch.int32, device=
|
80
|
+
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
77
81
|
)
|
78
82
|
create_flashinfer_kv_indices_triton[(bs,)](
|
79
|
-
|
83
|
+
self.req_to_token,
|
80
84
|
forward_batch.req_pool_indices,
|
81
85
|
forward_batch.seq_lens,
|
82
86
|
kv_indptr,
|
83
87
|
None,
|
84
88
|
kv_indices,
|
85
|
-
|
89
|
+
self.req_to_token.stride(0),
|
86
90
|
)
|
87
91
|
|
92
|
+
qo_indptr = None
|
93
|
+
custom_mask = None
|
94
|
+
mask_offsets = None
|
88
95
|
else:
|
96
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(
|
97
|
+
forward_batch.extend_prefix_lens, dim=0
|
98
|
+
)
|
99
|
+
kv_indptr = kv_indptr[: bs + 1]
|
100
|
+
kv_indices = torch.empty(
|
101
|
+
forward_batch.extend_prefix_lens.sum().item(),
|
102
|
+
dtype=torch.int32,
|
103
|
+
device=self.device,
|
104
|
+
)
|
105
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
106
|
+
self.req_to_token,
|
107
|
+
forward_batch.req_pool_indices,
|
108
|
+
forward_batch.extend_prefix_lens,
|
109
|
+
kv_indptr,
|
110
|
+
None,
|
111
|
+
kv_indices,
|
112
|
+
self.req_to_token.stride(0),
|
113
|
+
)
|
114
|
+
|
115
|
+
qo_indptr = self.qo_indptr
|
116
|
+
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
|
117
|
+
qo_indptr = qo_indptr[: bs + 1]
|
118
|
+
custom_mask = None
|
119
|
+
mask_offsets = None
|
120
|
+
|
89
121
|
attn_logits = None
|
90
122
|
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
91
123
|
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
124
|
+
self.forward_metadata = (
|
125
|
+
attn_logits,
|
126
|
+
max_extend_len,
|
127
|
+
kv_indptr,
|
128
|
+
kv_indices,
|
129
|
+
qo_indptr,
|
130
|
+
custom_mask,
|
131
|
+
mask_offsets,
|
132
|
+
)
|
96
133
|
|
97
134
|
def init_cuda_graph_state(self, max_bs: int):
|
98
135
|
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
@@ -144,6 +181,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
144
181
|
None,
|
145
182
|
kv_indptr,
|
146
183
|
kv_indices,
|
184
|
+
None,
|
185
|
+
None,
|
186
|
+
None,
|
147
187
|
)
|
148
188
|
|
149
189
|
def init_forward_metadata_replay_cuda_graph(
|
@@ -197,7 +237,15 @@ class TritonAttnBackend(AttentionBackend):
|
|
197
237
|
layer, forward_batch.out_cache_loc, k, v
|
198
238
|
)
|
199
239
|
|
200
|
-
|
240
|
+
(
|
241
|
+
_,
|
242
|
+
max_extend_len,
|
243
|
+
kv_indptr,
|
244
|
+
kv_indices,
|
245
|
+
qo_indptr,
|
246
|
+
custom_mask,
|
247
|
+
mask_offsets,
|
248
|
+
) = self.forward_metadata
|
201
249
|
self.extend_attention_fwd(
|
202
250
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
203
251
|
k.contiguous(),
|
@@ -205,11 +253,11 @@ class TritonAttnBackend(AttentionBackend):
|
|
205
253
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
206
254
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
207
255
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
256
|
+
qo_indptr,
|
257
|
+
kv_indptr,
|
258
|
+
kv_indices,
|
259
|
+
custom_mask,
|
260
|
+
mask_offsets,
|
213
261
|
max_extend_len,
|
214
262
|
layer.scaling,
|
215
263
|
layer.logit_cap,
|
@@ -235,7 +283,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
235
283
|
else:
|
236
284
|
o = torch.empty_like(q)
|
237
285
|
|
238
|
-
attn_logits, _, kv_indptr, kv_indices = self.forward_metadata
|
286
|
+
attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata
|
239
287
|
|
240
288
|
if save_kv_cache:
|
241
289
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|