sglang 0.4.2.post4__py3-none-any.whl → 0.4.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/global_config.py +2 -0
- sglang/lang/backend/openai.py +5 -0
- sglang/lang/chat_template.py +22 -7
- sglang/lang/ir.py +1 -0
- sglang/srt/configs/__init__.py +6 -3
- sglang/srt/configs/model_config.py +2 -0
- sglang/srt/configs/qwen2_5_vl_config.py +1003 -0
- sglang/srt/entrypoints/engine.py +18 -3
- sglang/srt/hf_transformers_utils.py +2 -3
- sglang/srt/layers/attention/flashinfer_backend.py +235 -110
- sglang/srt/layers/attention/triton_backend.py +358 -72
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/linear.py +12 -5
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -0
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +51 -5
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +29 -29
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +33 -33
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +27 -27
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +24 -24
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +42 -42
- sglang/srt/layers/quantization/fp8_kernel.py +123 -17
- sglang/srt/layers/quantization/fp8_utils.py +33 -4
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/image_processor.py +217 -122
- sglang/srt/managers/io_struct.py +4 -0
- sglang/srt/managers/schedule_batch.py +16 -3
- sglang/srt/managers/scheduler.py +29 -0
- sglang/srt/managers/tokenizer_manager.py +6 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -0
- sglang/srt/model_executor/cuda_graph_runner.py +12 -1
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +12 -2
- sglang/srt/models/deepseek_nextn.py +295 -0
- sglang/srt/models/deepseek_v2.py +21 -8
- sglang/srt/models/llava.py +2 -1
- sglang/srt/models/qwen2_5_vl.py +722 -0
- sglang/srt/models/qwen2_vl.py +2 -1
- sglang/srt/openai_api/adapter.py +17 -3
- sglang/srt/server_args.py +26 -4
- sglang/srt/speculative/eagle_worker.py +35 -10
- sglang/srt/speculative/spec_info.py +11 -1
- sglang/srt/utils.py +7 -0
- sglang/utils.py +99 -19
- sglang/version.py +1 -1
- {sglang-0.4.2.post4.dist-info → sglang-0.4.3.post1.dist-info}/METADATA +5 -4
- {sglang-0.4.2.post4.dist-info → sglang-0.4.3.post1.dist-info}/RECORD +73 -55
- sglang/srt/configs/qwen2vl.py +0 -130
- {sglang-0.4.2.post4.dist-info → sglang-0.4.3.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post4.dist-info → sglang-0.4.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post4.dist-info → sglang-0.4.3.post1.dist-info}/top_level.txt +0 -0
sglang/srt/entrypoints/engine.py
CHANGED
@@ -115,6 +115,9 @@ class Engine:
|
|
115
115
|
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
116
116
|
# The token ids for text; one can either specify text or input_ids.
|
117
117
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
118
|
+
# The image input. It can be a file name, a url, or base64 encoded string.
|
119
|
+
# See also python/sglang/srt/utils.py:load_image.
|
120
|
+
image_data: Optional[Union[List[str], str]] = None,
|
118
121
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
119
122
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
120
123
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
@@ -126,14 +129,20 @@ class Engine:
|
|
126
129
|
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
|
127
130
|
Please refer to `GenerateReqInput` for the documentation.
|
128
131
|
"""
|
132
|
+
modalities_list = []
|
133
|
+
if image_data is not None:
|
134
|
+
modalities_list.append("image")
|
135
|
+
|
129
136
|
obj = GenerateReqInput(
|
130
137
|
text=prompt,
|
131
138
|
input_ids=input_ids,
|
132
139
|
sampling_params=sampling_params,
|
140
|
+
image_data=image_data,
|
133
141
|
return_logprob=return_logprob,
|
134
142
|
logprob_start_len=logprob_start_len,
|
135
143
|
top_logprobs_num=top_logprobs_num,
|
136
144
|
lora_path=lora_path,
|
145
|
+
modalities=modalities_list,
|
137
146
|
custom_logit_processor=custom_logit_processor,
|
138
147
|
stream=stream,
|
139
148
|
)
|
@@ -162,6 +171,9 @@ class Engine:
|
|
162
171
|
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
163
172
|
# The token ids for text; one can either specify text or input_ids.
|
164
173
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
174
|
+
# The image input. It can be a file name, a url, or base64 encoded string.
|
175
|
+
# See also python/sglang/srt/utils.py:load_image.
|
176
|
+
image_data: Optional[Union[List[str], str]] = None,
|
165
177
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
166
178
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
167
179
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
@@ -177,6 +189,7 @@ class Engine:
|
|
177
189
|
text=prompt,
|
178
190
|
input_ids=input_ids,
|
179
191
|
sampling_params=sampling_params,
|
192
|
+
image_data=image_data,
|
180
193
|
return_logprob=return_logprob,
|
181
194
|
logprob_start_len=logprob_start_len,
|
182
195
|
top_logprobs_num=top_logprobs_num,
|
@@ -297,7 +310,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
297
310
|
# Set global environments
|
298
311
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
299
312
|
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
300
|
-
os.environ["NCCL_NVLS_ENABLE"] =
|
313
|
+
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
|
301
314
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
302
315
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
303
316
|
|
@@ -317,7 +330,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
317
330
|
if server_args.attention_backend == "flashinfer":
|
318
331
|
assert_pkg_version(
|
319
332
|
"flashinfer_python",
|
320
|
-
"0.2.
|
333
|
+
"0.2.1.post1",
|
321
334
|
"Please uninstall the old version and "
|
322
335
|
"reinstall the latest version by following the instructions "
|
323
336
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -425,7 +438,9 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
|
|
425
438
|
# Launch tokenizer process
|
426
439
|
tokenizer_manager = TokenizerManager(server_args, port_args)
|
427
440
|
if server_args.chat_template:
|
428
|
-
load_chat_template_for_openai_api(
|
441
|
+
load_chat_template_for_openai_api(
|
442
|
+
tokenizer_manager, server_args.chat_template, server_args.model_path
|
443
|
+
)
|
429
444
|
|
430
445
|
# Wait for the model to finish loading
|
431
446
|
scheduler_infos = []
|
@@ -30,16 +30,15 @@ from transformers import (
|
|
30
30
|
)
|
31
31
|
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
32
32
|
|
33
|
-
from sglang.srt.configs import ChatGLMConfig, DbrxConfig, ExaoneConfig,
|
33
|
+
from sglang.srt.configs import ChatGLMConfig, DbrxConfig, ExaoneConfig, Qwen2_5_VLConfig
|
34
34
|
|
35
35
|
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
36
36
|
ChatGLMConfig.model_type: ChatGLMConfig,
|
37
37
|
DbrxConfig.model_type: DbrxConfig,
|
38
38
|
ExaoneConfig.model_type: ExaoneConfig,
|
39
|
-
|
39
|
+
Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
|
40
40
|
}
|
41
41
|
|
42
|
-
|
43
42
|
for name, cls in _CONFIG_REGISTRY.items():
|
44
43
|
with contextlib.suppress(ValueError):
|
45
44
|
AutoConfig.register(name, cls)
|
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
|
|
7
7
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
8
8
|
"""
|
9
9
|
|
10
|
+
import math
|
10
11
|
import os
|
11
12
|
from dataclasses import dataclass
|
12
13
|
from enum import Enum, auto
|
@@ -20,6 +21,7 @@ import triton.language as tl
|
|
20
21
|
from sglang.global_config import global_config
|
21
22
|
from sglang.srt.layers.attention import AttentionBackend
|
22
23
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
24
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
23
25
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
24
26
|
from sglang.srt.utils import is_flashinfer_available
|
25
27
|
|
@@ -35,7 +37,7 @@ if is_flashinfer_available():
|
|
35
37
|
BatchPrefillWithRaggedKVCacheWrapper,
|
36
38
|
)
|
37
39
|
from flashinfer.cascade import merge_state
|
38
|
-
from flashinfer.
|
40
|
+
from flashinfer.mla import BatchMLAPagedAttentionWrapper
|
39
41
|
|
40
42
|
|
41
43
|
class WrapperDispatch(Enum):
|
@@ -45,7 +47,9 @@ class WrapperDispatch(Enum):
|
|
45
47
|
|
46
48
|
@dataclass
|
47
49
|
class DecodeMetadata:
|
48
|
-
decode_wrappers: List[
|
50
|
+
decode_wrappers: List[
|
51
|
+
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
52
|
+
]
|
49
53
|
|
50
54
|
|
51
55
|
@dataclass
|
@@ -103,6 +107,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
103
107
|
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
104
108
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
105
109
|
|
110
|
+
self.enable_flashinfer_mla = False
|
111
|
+
if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
112
|
+
if global_server_args_dict["enable_flashinfer_mla"]:
|
113
|
+
self.enable_flashinfer_mla = True
|
114
|
+
global_config.enable_flashinfer_mla = True
|
115
|
+
|
106
116
|
# Allocate buffers
|
107
117
|
global global_workspace_buffer
|
108
118
|
if global_workspace_buffer is None:
|
@@ -120,6 +130,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
120
130
|
)
|
121
131
|
for _ in range(self.num_wrappers)
|
122
132
|
]
|
133
|
+
if self.enable_flashinfer_mla:
|
134
|
+
self.qo_indptr = [
|
135
|
+
torch.zeros(
|
136
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
137
|
+
)
|
138
|
+
for _ in range(self.num_wrappers)
|
139
|
+
]
|
123
140
|
else:
|
124
141
|
assert self.num_wrappers == 1
|
125
142
|
self.kv_indptr = [kv_indptr_buf]
|
@@ -153,13 +170,18 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
153
170
|
self.prefill_wrappers_verify.append(
|
154
171
|
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
155
172
|
)
|
156
|
-
self.
|
157
|
-
|
158
|
-
self.workspace_buffer,
|
159
|
-
|
160
|
-
|
173
|
+
if self.enable_flashinfer_mla:
|
174
|
+
self.decode_wrappers.append(
|
175
|
+
BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
|
176
|
+
)
|
177
|
+
else:
|
178
|
+
self.decode_wrappers.append(
|
179
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
180
|
+
self.workspace_buffer,
|
181
|
+
"NHD",
|
182
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
183
|
+
)
|
161
184
|
)
|
162
|
-
)
|
163
185
|
|
164
186
|
# Create indices updater
|
165
187
|
if not skip_prefill:
|
@@ -274,19 +296,32 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
274
296
|
if forward_mode.is_decode_or_idle():
|
275
297
|
decode_wrappers = []
|
276
298
|
for i in range(self.num_wrappers):
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
299
|
+
if self.enable_flashinfer_mla:
|
300
|
+
decode_wrappers.append(
|
301
|
+
BatchMLAPagedAttentionWrapper(
|
302
|
+
self.workspace_buffer,
|
303
|
+
use_cuda_graph=True,
|
304
|
+
qo_indptr=self.qo_indptr[i][: num_tokens + 1],
|
305
|
+
kv_indptr=self.kv_indptr[i][: num_tokens + 1],
|
306
|
+
kv_indices=self.cuda_graph_kv_indices[i],
|
307
|
+
kv_len_arr=self.kv_last_page_len[:num_tokens],
|
308
|
+
backend="fa2",
|
309
|
+
)
|
310
|
+
)
|
311
|
+
else:
|
312
|
+
decode_wrappers.append(
|
313
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
314
|
+
self.workspace_buffer,
|
315
|
+
"NHD",
|
316
|
+
use_cuda_graph=True,
|
317
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
318
|
+
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
|
319
|
+
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
320
|
+
paged_kv_last_page_len_buffer=self.kv_last_page_len[
|
321
|
+
:num_tokens
|
322
|
+
],
|
323
|
+
)
|
288
324
|
)
|
289
|
-
)
|
290
325
|
seq_lens_sum = seq_lens.sum().item()
|
291
326
|
self.indices_updater_decode.update(
|
292
327
|
req_pool_indices,
|
@@ -375,64 +410,94 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
375
410
|
forward_batch: ForwardBatch,
|
376
411
|
save_kv_cache=True,
|
377
412
|
):
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
else forward_batch.encoder_out_cache_loc
|
385
|
-
)
|
413
|
+
if global_config.enable_flashinfer_mla:
|
414
|
+
cache_loc = (
|
415
|
+
forward_batch.out_cache_loc
|
416
|
+
if not layer.is_cross_attention
|
417
|
+
else forward_batch.encoder_out_cache_loc
|
418
|
+
)
|
386
419
|
|
387
|
-
|
420
|
+
logits_soft_cap = layer.logit_cap
|
388
421
|
|
389
|
-
|
390
|
-
if k is not None:
|
391
|
-
assert v is not None
|
392
|
-
if save_kv_cache:
|
393
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
394
|
-
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
395
|
-
)
|
396
|
-
|
397
|
-
o = prefill_wrapper_paged.forward(
|
398
|
-
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
399
|
-
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
400
|
-
causal=not layer.is_cross_attention,
|
401
|
-
sm_scale=layer.scaling,
|
402
|
-
window_left=layer.sliding_window_size,
|
403
|
-
logits_soft_cap=logits_soft_cap,
|
404
|
-
k_scale=layer.k_scale,
|
405
|
-
v_scale=layer.v_scale,
|
406
|
-
)
|
407
|
-
else:
|
408
|
-
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
422
|
+
o1, _ = self.prefill_wrapper_ragged.forward_return_lse(
|
409
423
|
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
410
424
|
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
411
|
-
v.view(-1, layer.tp_v_head_num, layer.
|
425
|
+
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
|
412
426
|
causal=True,
|
413
427
|
sm_scale=layer.scaling,
|
414
428
|
logits_soft_cap=logits_soft_cap,
|
415
429
|
)
|
416
430
|
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
431
|
+
o = o1
|
432
|
+
|
433
|
+
if save_kv_cache:
|
434
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
435
|
+
layer,
|
436
|
+
cache_loc,
|
437
|
+
k,
|
438
|
+
v,
|
439
|
+
)
|
440
|
+
|
441
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
442
|
+
else:
|
443
|
+
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
444
|
+
self._get_wrapper_idx(layer)
|
445
|
+
]
|
446
|
+
cache_loc = (
|
447
|
+
forward_batch.out_cache_loc
|
448
|
+
if not layer.is_cross_attention
|
449
|
+
else forward_batch.encoder_out_cache_loc
|
450
|
+
)
|
451
|
+
|
452
|
+
logits_soft_cap = layer.logit_cap
|
453
|
+
|
454
|
+
if not self.forward_metadata.use_ragged:
|
455
|
+
if k is not None:
|
456
|
+
assert v is not None
|
457
|
+
if save_kv_cache:
|
458
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
459
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
460
|
+
)
|
461
|
+
|
462
|
+
o = prefill_wrapper_paged.forward(
|
421
463
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
422
464
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
423
|
-
causal=
|
465
|
+
causal=not layer.is_cross_attention,
|
424
466
|
sm_scale=layer.scaling,
|
425
|
-
|
467
|
+
window_left=layer.sliding_window_size,
|
468
|
+
logits_soft_cap=logits_soft_cap,
|
469
|
+
k_scale=layer.k_scale,
|
470
|
+
v_scale=layer.v_scale,
|
471
|
+
)
|
472
|
+
else:
|
473
|
+
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
474
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
475
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
476
|
+
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
477
|
+
causal=True,
|
478
|
+
sm_scale=layer.scaling,
|
479
|
+
logits_soft_cap=logits_soft_cap,
|
426
480
|
)
|
427
481
|
|
428
|
-
|
482
|
+
if self.forward_metadata.extend_no_prefix:
|
483
|
+
o = o1
|
484
|
+
else:
|
485
|
+
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
486
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
487
|
+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
488
|
+
causal=False,
|
489
|
+
sm_scale=layer.scaling,
|
490
|
+
logits_soft_cap=layer.logit_cap,
|
491
|
+
)
|
429
492
|
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
493
|
+
o, _ = merge_state(o1, s1, o2, s2)
|
494
|
+
|
495
|
+
if save_kv_cache:
|
496
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
497
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
498
|
+
)
|
434
499
|
|
435
|
-
|
500
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
436
501
|
|
437
502
|
def forward_decode(
|
438
503
|
self,
|
@@ -452,23 +517,45 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
452
517
|
else forward_batch.encoder_out_cache_loc
|
453
518
|
)
|
454
519
|
|
455
|
-
if
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
520
|
+
if self.enable_flashinfer_mla:
|
521
|
+
if k is not None:
|
522
|
+
assert v is not None
|
523
|
+
if save_kv_cache:
|
524
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
525
|
+
layer,
|
526
|
+
cache_loc,
|
527
|
+
k,
|
528
|
+
v,
|
529
|
+
)
|
530
|
+
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
531
|
+
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
532
|
+
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
|
533
|
+
o = decode_wrapper.run(
|
534
|
+
reshaped_q[:, :, : layer.v_head_dim],
|
535
|
+
reshaped_q[:, :, layer.v_head_dim :],
|
536
|
+
reshaped_k[:, :, : layer.v_head_dim],
|
537
|
+
reshaped_k[:, :, layer.v_head_dim :],
|
538
|
+
)
|
461
539
|
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
540
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
541
|
+
else:
|
542
|
+
if k is not None:
|
543
|
+
assert v is not None
|
544
|
+
if save_kv_cache:
|
545
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
546
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
547
|
+
)
|
470
548
|
|
471
|
-
|
549
|
+
o = decode_wrapper.forward(
|
550
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
551
|
+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
552
|
+
sm_scale=layer.scaling,
|
553
|
+
logits_soft_cap=layer.logit_cap,
|
554
|
+
k_scale=layer.k_scale,
|
555
|
+
v_scale=layer.v_scale,
|
556
|
+
)
|
557
|
+
|
558
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
472
559
|
|
473
560
|
def _get_wrapper_idx(self, layer: RadixAttention):
|
474
561
|
if self.num_wrappers == 1:
|
@@ -516,7 +603,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
516
603
|
req_pool_indices: torch.Tensor,
|
517
604
|
seq_lens: torch.Tensor,
|
518
605
|
seq_lens_sum: int,
|
519
|
-
decode_wrappers: List[
|
606
|
+
decode_wrappers: List[
|
607
|
+
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
608
|
+
],
|
520
609
|
encoder_lens: Optional[torch.Tensor],
|
521
610
|
spec_info: Optional[SpecInfo],
|
522
611
|
):
|
@@ -528,7 +617,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
528
617
|
req_pool_indices: torch.Tensor,
|
529
618
|
seq_lens: torch.Tensor,
|
530
619
|
seq_lens_sum: int,
|
531
|
-
decode_wrappers: List[
|
620
|
+
decode_wrappers: List[
|
621
|
+
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
622
|
+
],
|
532
623
|
encoder_lens: Optional[torch.Tensor],
|
533
624
|
spec_info: Optional[SpecInfo],
|
534
625
|
):
|
@@ -609,7 +700,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
609
700
|
|
610
701
|
def call_begin_forward(
|
611
702
|
self,
|
612
|
-
wrapper:
|
703
|
+
wrapper: Union[
|
704
|
+
BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
705
|
+
],
|
613
706
|
req_pool_indices: torch.Tensor,
|
614
707
|
paged_kernel_lens: torch.Tensor,
|
615
708
|
paged_kernel_lens_sum: int,
|
@@ -637,18 +730,37 @@ class FlashInferIndicesUpdaterDecode:
|
|
637
730
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
638
731
|
bs = kv_indptr.shape[0] - 1
|
639
732
|
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
733
|
+
if global_config.enable_flashinfer_mla:
|
734
|
+
sm_scale = 1.0 / math.sqrt(192)
|
735
|
+
q_indptr = torch.arange(0, bs + 1).to(0).int()
|
736
|
+
kv_lens = paged_kernel_lens.to(torch.int32)
|
737
|
+
wrapper.plan(
|
738
|
+
q_indptr,
|
739
|
+
kv_indptr,
|
740
|
+
kv_indices,
|
741
|
+
kv_lens,
|
742
|
+
self.num_qo_heads,
|
743
|
+
512,
|
744
|
+
64,
|
745
|
+
1,
|
746
|
+
False,
|
747
|
+
sm_scale,
|
748
|
+
self.data_type,
|
749
|
+
self.data_type,
|
750
|
+
)
|
751
|
+
else:
|
752
|
+
wrapper.begin_forward(
|
753
|
+
kv_indptr,
|
754
|
+
kv_indices,
|
755
|
+
self.kv_last_page_len[:bs],
|
756
|
+
self.num_qo_heads,
|
757
|
+
self.num_kv_heads,
|
758
|
+
self.head_dim,
|
759
|
+
1,
|
760
|
+
data_type=self.data_type,
|
761
|
+
q_data_type=self.q_data_type,
|
762
|
+
non_blocking=True,
|
763
|
+
)
|
652
764
|
|
653
765
|
|
654
766
|
class FlashInferIndicesUpdaterPrefill:
|
@@ -857,30 +969,42 @@ class FlashInferIndicesUpdaterPrefill:
|
|
857
969
|
|
858
970
|
# extend part
|
859
971
|
if use_ragged:
|
860
|
-
|
861
|
-
|
972
|
+
if global_config.enable_flashinfer_mla:
|
973
|
+
wrapper_ragged.begin_forward(
|
974
|
+
qo_indptr=qo_indptr,
|
975
|
+
kv_indptr=qo_indptr,
|
976
|
+
num_qo_heads=self.num_qo_heads,
|
977
|
+
num_kv_heads=self.num_kv_heads,
|
978
|
+
head_dim_qk=192,
|
979
|
+
head_dim_vo=128,
|
980
|
+
q_data_type=self.q_data_type,
|
981
|
+
)
|
982
|
+
else:
|
983
|
+
wrapper_ragged.begin_forward(
|
984
|
+
qo_indptr,
|
985
|
+
qo_indptr,
|
986
|
+
self.num_qo_heads,
|
987
|
+
self.num_kv_heads,
|
988
|
+
self.head_dim,
|
989
|
+
q_data_type=self.q_data_type,
|
990
|
+
)
|
991
|
+
|
992
|
+
if not global_config.enable_flashinfer_mla:
|
993
|
+
# cached part
|
994
|
+
wrapper_paged.begin_forward(
|
862
995
|
qo_indptr,
|
996
|
+
kv_indptr,
|
997
|
+
kv_indices,
|
998
|
+
self.kv_last_page_len[:bs],
|
863
999
|
self.num_qo_heads,
|
864
1000
|
self.num_kv_heads,
|
865
1001
|
self.head_dim,
|
1002
|
+
1,
|
866
1003
|
q_data_type=self.q_data_type,
|
1004
|
+
custom_mask=custom_mask,
|
1005
|
+
non_blocking=True,
|
867
1006
|
)
|
868
1007
|
|
869
|
-
# cached part
|
870
|
-
wrapper_paged.begin_forward(
|
871
|
-
qo_indptr,
|
872
|
-
kv_indptr,
|
873
|
-
kv_indices,
|
874
|
-
self.kv_last_page_len[:bs],
|
875
|
-
self.num_qo_heads,
|
876
|
-
self.num_kv_heads,
|
877
|
-
self.head_dim,
|
878
|
-
1,
|
879
|
-
q_data_type=self.q_data_type,
|
880
|
-
custom_mask=custom_mask,
|
881
|
-
non_blocking=True,
|
882
|
-
)
|
883
|
-
|
884
1008
|
|
885
1009
|
class FlashInferMultiStepDraftBackend:
|
886
1010
|
"""
|
@@ -947,7 +1071,7 @@ class FlashInferMultiStepDraftBackend:
|
|
947
1071
|
triton.next_power_of_2(bs),
|
948
1072
|
)
|
949
1073
|
|
950
|
-
for i in range(self.speculative_num_steps):
|
1074
|
+
for i in range(self.speculative_num_steps - 1):
|
951
1075
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
952
1076
|
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
953
1077
|
: seq_lens_sum * self.topk + bs * (i + 1)
|
@@ -1163,6 +1287,7 @@ def fast_decode_plan(
|
|
1163
1287
|
window_left,
|
1164
1288
|
logits_soft_cap,
|
1165
1289
|
head_dim,
|
1290
|
+
head_dim,
|
1166
1291
|
empty_q_data,
|
1167
1292
|
empty_kv_cache,
|
1168
1293
|
stream.cuda_stream,
|