sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +2 -1
- sglang/lang/chat_template.py +17 -0
- sglang/launch_server_llavavid.py +1 -1
- sglang/srt/configs/__init__.py +3 -0
- sglang/srt/configs/model_config.py +27 -2
- sglang/srt/configs/qwen2vl.py +133 -0
- sglang/srt/constrained/fsm_cache.py +10 -3
- sglang/srt/conversation.py +27 -0
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/__init__.py +16 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
- sglang/srt/layers/attention/flashinfer_backend.py +174 -54
- sglang/srt/layers/attention/triton_backend.py +22 -6
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
- sglang/srt/layers/linear.py +89 -63
- sglang/srt/layers/logits_processor.py +5 -5
- sglang/srt/layers/rotary_embedding.py +112 -0
- sglang/srt/layers/sampler.py +51 -39
- sglang/srt/lora/lora.py +3 -1
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +4 -0
- sglang/srt/managers/image_processor.py +186 -13
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/schedule_batch.py +238 -68
- sglang/srt/managers/scheduler.py +69 -50
- sglang/srt/managers/tokenizer_manager.py +24 -4
- sglang/srt/managers/tp_worker.py +26 -111
- sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
- sglang/srt/mem_cache/memory_pool.py +56 -10
- sglang/srt/mem_cache/radix_cache.py +4 -3
- sglang/srt/model_executor/cuda_graph_runner.py +87 -28
- sglang/srt/model_executor/forward_batch_info.py +83 -3
- sglang/srt/model_executor/model_runner.py +32 -11
- sglang/srt/models/chatglm.py +3 -3
- sglang/srt/models/deepseek_v2.py +2 -2
- sglang/srt/models/mllama.py +1004 -0
- sglang/srt/models/qwen2_vl.py +724 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
- sglang/srt/sampling/sampling_batch_info.py +13 -3
- sglang/srt/sampling/sampling_params.py +5 -7
- sglang/srt/server.py +12 -0
- sglang/srt/server_args.py +10 -0
- sglang/srt/utils.py +22 -0
- sglang/test/run_eval.py +2 -0
- sglang/test/runners.py +20 -1
- sglang/test/srt/sampling/penaltylib/utils.py +1 -0
- sglang/test/test_utils.py +100 -3
- sglang/version.py +1 -1
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py
CHANGED
@@ -227,8 +227,9 @@ def extend(reqs, model_runner):
|
|
227
227
|
req_to_token_pool=model_runner.req_to_token_pool,
|
228
228
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
229
229
|
tree_cache=None,
|
230
|
+
model_config=model_runner.model_config,
|
230
231
|
)
|
231
|
-
batch.prepare_for_extend(
|
232
|
+
batch.prepare_for_extend()
|
232
233
|
model_worker_batch = batch.get_model_worker_batch()
|
233
234
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
234
235
|
logits_output = model_runner.forward(forward_batch)
|
sglang/lang/chat_template.py
CHANGED
@@ -133,6 +133,22 @@ register_chat_template(
|
|
133
133
|
)
|
134
134
|
)
|
135
135
|
|
136
|
+
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
|
137
|
+
register_chat_template(
|
138
|
+
ChatTemplate(
|
139
|
+
name="qwen2-vl",
|
140
|
+
default_system_prompt="You are a helpful assistant.",
|
141
|
+
role_prefix_and_suffix={
|
142
|
+
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
143
|
+
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
144
|
+
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
145
|
+
},
|
146
|
+
style=ChatTemplateStyle.PLAIN,
|
147
|
+
stop_str=("<|im_end|>"),
|
148
|
+
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
149
|
+
)
|
150
|
+
)
|
151
|
+
|
136
152
|
|
137
153
|
register_chat_template(
|
138
154
|
ChatTemplate(
|
@@ -213,6 +229,7 @@ register_chat_template(
|
|
213
229
|
),
|
214
230
|
},
|
215
231
|
stop_str=("<|eot_id|>",),
|
232
|
+
image_token="<|image|>",
|
216
233
|
)
|
217
234
|
)
|
218
235
|
|
sglang/launch_server_llavavid.py
CHANGED
@@ -14,7 +14,7 @@ if __name__ == "__main__":
|
|
14
14
|
model_override_args["num_frames"] = 16
|
15
15
|
model_override_args["model_type"] = "llavavid"
|
16
16
|
if model_override_args["num_frames"] == 32:
|
17
|
-
model_override_args["rope_scaling"] = {"factor": 2.0, "
|
17
|
+
model_override_args["rope_scaling"] = {"factor": 2.0, "rope_type": "linear"}
|
18
18
|
model_override_args["max_sequence_length"] = 4096 * 2
|
19
19
|
model_override_args["tokenizer_model_max_length"] = 4096 * 2
|
20
20
|
model_override_args["model_max_length"] = 4096 * 2
|
sglang/srt/configs/__init__.py
CHANGED
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
import logging
|
17
|
+
import os
|
16
18
|
from enum import IntEnum, auto
|
17
19
|
from typing import Optional
|
18
20
|
|
@@ -20,6 +22,8 @@ from transformers import PretrainedConfig
|
|
20
22
|
|
21
23
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
22
24
|
|
25
|
+
logger = logging.getLogger(__name__)
|
26
|
+
|
23
27
|
|
24
28
|
class AttentionArch(IntEnum):
|
25
29
|
MLA = auto()
|
@@ -46,10 +50,29 @@ class ModelConfig:
|
|
46
50
|
model_override_args=model_override_args,
|
47
51
|
)
|
48
52
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
53
|
+
derived_context_len = get_context_length(self.hf_text_config)
|
54
|
+
allow_long_context = os.environ.get(
|
55
|
+
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
|
56
|
+
)
|
57
|
+
|
49
58
|
if context_length is not None:
|
50
|
-
|
59
|
+
if context_length > derived_context_len:
|
60
|
+
if allow_long_context:
|
61
|
+
logger.warning(
|
62
|
+
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
63
|
+
f"This may lead to incorrect model outputs or CUDA errors."
|
64
|
+
)
|
65
|
+
self.context_len = context_length
|
66
|
+
else:
|
67
|
+
raise ValueError(
|
68
|
+
f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
69
|
+
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. "
|
70
|
+
f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
|
71
|
+
)
|
72
|
+
else:
|
73
|
+
self.context_len = context_length
|
51
74
|
else:
|
52
|
-
self.context_len =
|
75
|
+
self.context_len = derived_context_len
|
53
76
|
|
54
77
|
# Unify the config keys for hf_text_config
|
55
78
|
self.head_dim = getattr(
|
@@ -89,6 +112,8 @@ class ModelConfig:
|
|
89
112
|
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
90
113
|
self.vocab_size = self.hf_text_config.vocab_size
|
91
114
|
|
115
|
+
self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
|
116
|
+
|
92
117
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
93
118
|
def get_total_num_kv_heads(self) -> int:
|
94
119
|
"""Returns the total number of KV heads."""
|
@@ -0,0 +1,133 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
|
3
|
+
# All rights reserved.
|
4
|
+
#
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6
|
+
# you may not use this file except in compliance with the License.
|
7
|
+
# You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14
|
+
# See the License for the specific language governing permissions and
|
15
|
+
# limitations under the License.
|
16
|
+
"""Qwen2VL model configuration"""
|
17
|
+
|
18
|
+
import os
|
19
|
+
from typing import Union
|
20
|
+
|
21
|
+
from transformers import PretrainedConfig
|
22
|
+
|
23
|
+
|
24
|
+
class Qwen2VLVisionConfig(PretrainedConfig):
|
25
|
+
model_type = "qwen2_vl"
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
depth=32,
|
30
|
+
embed_dim=1280,
|
31
|
+
hidden_size=3584,
|
32
|
+
hidden_act="quick_gelu",
|
33
|
+
mlp_ratio=4,
|
34
|
+
num_heads=16,
|
35
|
+
in_channels=3,
|
36
|
+
patch_size=14,
|
37
|
+
spatial_merge_size=2,
|
38
|
+
temporal_patch_size=2,
|
39
|
+
**kwargs,
|
40
|
+
):
|
41
|
+
super().__init__(**kwargs)
|
42
|
+
|
43
|
+
self.depth = depth
|
44
|
+
self.embed_dim = embed_dim
|
45
|
+
self.hidden_size = hidden_size
|
46
|
+
self.hidden_act = hidden_act
|
47
|
+
self.mlp_ratio = mlp_ratio
|
48
|
+
self.num_heads = num_heads
|
49
|
+
self.in_channels = in_channels
|
50
|
+
self.patch_size = patch_size
|
51
|
+
self.spatial_merge_size = spatial_merge_size
|
52
|
+
self.temporal_patch_size = temporal_patch_size
|
53
|
+
|
54
|
+
@classmethod
|
55
|
+
def from_pretrained(
|
56
|
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
57
|
+
) -> "PretrainedConfig":
|
58
|
+
cls._set_token_in_kwargs(kwargs)
|
59
|
+
|
60
|
+
config_dict, kwargs = cls.get_config_dict(
|
61
|
+
pretrained_model_name_or_path, **kwargs
|
62
|
+
)
|
63
|
+
|
64
|
+
if config_dict.get("model_type") == "qwen2_vl":
|
65
|
+
config_dict = config_dict["vision_config"]
|
66
|
+
|
67
|
+
return cls.from_dict(config_dict, **kwargs)
|
68
|
+
|
69
|
+
|
70
|
+
class Qwen2VLConfig(PretrainedConfig):
|
71
|
+
model_type = "qwen2_vl"
|
72
|
+
|
73
|
+
def __init__(
|
74
|
+
self,
|
75
|
+
vocab_size=152064,
|
76
|
+
hidden_size=8192,
|
77
|
+
intermediate_size=29568,
|
78
|
+
num_hidden_layers=80,
|
79
|
+
num_attention_heads=64,
|
80
|
+
num_key_value_heads=8,
|
81
|
+
hidden_act="silu",
|
82
|
+
max_position_embeddings=32768,
|
83
|
+
initializer_range=0.02,
|
84
|
+
rms_norm_eps=1e-05,
|
85
|
+
use_cache=True,
|
86
|
+
tie_word_embeddings=False,
|
87
|
+
rope_theta=1000000.0,
|
88
|
+
use_sliding_window=False,
|
89
|
+
sliding_window=4096,
|
90
|
+
max_window_layers=80,
|
91
|
+
attention_dropout=0.0,
|
92
|
+
vision_config=None,
|
93
|
+
rope_scaling=None,
|
94
|
+
**kwargs,
|
95
|
+
):
|
96
|
+
if isinstance(vision_config, dict):
|
97
|
+
self.vision_config = Qwen2VLVisionConfig(**vision_config)
|
98
|
+
elif vision_config is None:
|
99
|
+
self.vision_config = Qwen2VLVisionConfig()
|
100
|
+
|
101
|
+
self.vocab_size = vocab_size
|
102
|
+
self.max_position_embeddings = max_position_embeddings
|
103
|
+
self.hidden_size = hidden_size
|
104
|
+
self.intermediate_size = intermediate_size
|
105
|
+
self.num_hidden_layers = num_hidden_layers
|
106
|
+
self.num_attention_heads = num_attention_heads
|
107
|
+
self.use_sliding_window = use_sliding_window
|
108
|
+
self.sliding_window = sliding_window
|
109
|
+
self.max_window_layers = max_window_layers
|
110
|
+
|
111
|
+
# for backward compatibility
|
112
|
+
if num_key_value_heads is None:
|
113
|
+
num_key_value_heads = num_attention_heads
|
114
|
+
|
115
|
+
self.num_key_value_heads = num_key_value_heads
|
116
|
+
self.hidden_act = hidden_act
|
117
|
+
self.initializer_range = initializer_range
|
118
|
+
self.rms_norm_eps = rms_norm_eps
|
119
|
+
self.use_cache = use_cache
|
120
|
+
self.rope_theta = rope_theta
|
121
|
+
self.attention_dropout = attention_dropout
|
122
|
+
self.rope_scaling = rope_scaling
|
123
|
+
|
124
|
+
# NOTE: the following section from original transformers config
|
125
|
+
# for Qwen2-VL is commented out to address rope config loading issue
|
126
|
+
#
|
127
|
+
# if self.rope_scaling is not None and "type" in self.rope_scaling:
|
128
|
+
# if self.rope_scaling["type"] == "mrope":
|
129
|
+
# self.rope_scaling["type"] = "default"
|
130
|
+
# self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
131
|
+
# rope_config_validation(self)
|
132
|
+
|
133
|
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
@@ -73,9 +73,16 @@ class FSMCache(BaseToolCache):
|
|
73
73
|
def init_value(self, key):
|
74
74
|
key_type, key_string = key
|
75
75
|
if key_type == "json":
|
76
|
-
|
77
|
-
|
78
|
-
|
76
|
+
try:
|
77
|
+
regex = build_regex_from_schema(
|
78
|
+
key_string,
|
79
|
+
whitespace_pattern=self.constrained_json_whitespace_pattern,
|
80
|
+
)
|
81
|
+
except NotImplementedError as e:
|
82
|
+
logger.warning(
|
83
|
+
f"skip invalid json schema: json_schema={key_string}, {e=}"
|
84
|
+
)
|
85
|
+
return None, key_string
|
79
86
|
elif key_type == "regex":
|
80
87
|
regex = key_string
|
81
88
|
else:
|
sglang/srt/conversation.py
CHANGED
@@ -509,6 +509,19 @@ register_conv_template(
|
|
509
509
|
)
|
510
510
|
)
|
511
511
|
|
512
|
+
register_conv_template(
|
513
|
+
Conversation(
|
514
|
+
name="llama_3_vision",
|
515
|
+
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
|
516
|
+
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
|
517
|
+
roles=("user", "assistant"),
|
518
|
+
sep_style=SeparatorStyle.LLAMA3,
|
519
|
+
sep="",
|
520
|
+
stop_str=["<|end_of_text|>", "<|eot_id|>"],
|
521
|
+
image_token="<|image|>",
|
522
|
+
)
|
523
|
+
)
|
524
|
+
|
512
525
|
register_conv_template(
|
513
526
|
Conversation(
|
514
527
|
name="llava_llama_3",
|
@@ -530,3 +543,17 @@ register_conv_template(
|
|
530
543
|
stop_str=["<|im_end|>", "<|action_end|>"],
|
531
544
|
)
|
532
545
|
)
|
546
|
+
|
547
|
+
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
|
548
|
+
register_conv_template(
|
549
|
+
Conversation(
|
550
|
+
name="qwen2-vl",
|
551
|
+
system_message="You are a helpful assistant.",
|
552
|
+
system_template="<|im_start|>system\n{system_message}",
|
553
|
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
554
|
+
sep="<|im_end|>\n",
|
555
|
+
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
556
|
+
stop_str=["<|im_end|>"],
|
557
|
+
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
558
|
+
)
|
559
|
+
)
|
@@ -33,12 +33,13 @@ from transformers import (
|
|
33
33
|
try:
|
34
34
|
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
35
35
|
|
36
|
-
from sglang.srt.configs import ExaoneConfig
|
36
|
+
from sglang.srt.configs import ExaoneConfig, Qwen2VLConfig
|
37
37
|
|
38
38
|
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
39
39
|
ChatGLMConfig.model_type: ChatGLMConfig,
|
40
40
|
DbrxConfig.model_type: DbrxConfig,
|
41
41
|
ExaoneConfig.model_type: ExaoneConfig,
|
42
|
+
Qwen2VLConfig.model_type: Qwen2VLConfig,
|
42
43
|
}
|
43
44
|
except ImportError:
|
44
45
|
# We want this file to run without vllm dependency
|
@@ -162,6 +163,8 @@ def get_tokenizer(
|
|
162
163
|
"Using a slow tokenizer. This might cause a significant "
|
163
164
|
"slowdown. Consider using a fast tokenizer instead."
|
164
165
|
)
|
166
|
+
|
167
|
+
attach_additional_stop_token_ids(tokenizer)
|
165
168
|
return tokenizer
|
166
169
|
|
167
170
|
|
@@ -180,4 +183,16 @@ def get_processor(
|
|
180
183
|
tokenizer_revision=tokenizer_revision,
|
181
184
|
**kwargs,
|
182
185
|
)
|
186
|
+
|
187
|
+
attach_additional_stop_token_ids(processor.tokenizer)
|
183
188
|
return processor
|
189
|
+
|
190
|
+
|
191
|
+
def attach_additional_stop_token_ids(tokenizer):
|
192
|
+
# Special handling for stop token <|eom_id|> generated by llama 3 tool use.
|
193
|
+
if "<|eom_id|>" in tokenizer.get_added_vocab():
|
194
|
+
tokenizer.additional_stop_token_ids = set(
|
195
|
+
[tokenizer.get_added_vocab()["<|eom_id|>"]]
|
196
|
+
)
|
197
|
+
else:
|
198
|
+
tokenizer.additional_stop_token_ids = None
|
@@ -1,8 +1,10 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
+
from typing import Optional
|
2
3
|
|
3
4
|
import torch
|
4
5
|
from torch import nn
|
5
6
|
|
7
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
6
8
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
7
9
|
|
8
10
|
|
@@ -19,13 +21,22 @@ class AttentionBackend(ABC):
|
|
19
21
|
raise NotImplementedError()
|
20
22
|
|
21
23
|
def init_forward_metadata_capture_cuda_graph(
|
22
|
-
self,
|
24
|
+
self,
|
25
|
+
bs: int,
|
26
|
+
req_pool_indices: torch.Tensor,
|
27
|
+
seq_lens: torch.Tensor,
|
28
|
+
encoder_lens: Optional[torch.Tensor] = None,
|
23
29
|
):
|
24
30
|
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
25
31
|
raise NotImplementedError()
|
26
32
|
|
27
33
|
def init_forward_metadata_replay_cuda_graph(
|
28
|
-
self,
|
34
|
+
self,
|
35
|
+
bs: int,
|
36
|
+
req_pool_indices: torch.Tensor,
|
37
|
+
seq_lens: torch.Tensor,
|
38
|
+
seq_lens_sum: int,
|
39
|
+
encoder_lens: Optional[torch.Tensor] = None,
|
29
40
|
):
|
30
41
|
"""Init the metadata for a forward pass for replying a cuda graph."""
|
31
42
|
raise NotImplementedError()
|
@@ -39,7 +50,7 @@ class AttentionBackend(ABC):
|
|
39
50
|
q: torch.Tensor,
|
40
51
|
k: torch.Tensor,
|
41
52
|
v: torch.Tensor,
|
42
|
-
layer:
|
53
|
+
layer: RadixAttention,
|
43
54
|
forward_batch: ForwardBatch,
|
44
55
|
):
|
45
56
|
"""Run forward on an attention layer."""
|
@@ -53,7 +64,7 @@ class AttentionBackend(ABC):
|
|
53
64
|
q: torch.Tensor,
|
54
65
|
k: torch.Tensor,
|
55
66
|
v: torch.Tensor,
|
56
|
-
layer:
|
67
|
+
layer: RadixAttention,
|
57
68
|
forward_batch: ForwardBatch,
|
58
69
|
):
|
59
70
|
"""Run a forward for decode."""
|
@@ -64,7 +75,7 @@ class AttentionBackend(ABC):
|
|
64
75
|
q: torch.Tensor,
|
65
76
|
k: torch.Tensor,
|
66
77
|
v: torch.Tensor,
|
67
|
-
layer:
|
78
|
+
layer: RadixAttention,
|
68
79
|
forward_batch: ForwardBatch,
|
69
80
|
):
|
70
81
|
"""Run a forward for extend."""
|
@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
10
10
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
11
11
|
|
12
12
|
if TYPE_CHECKING:
|
13
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
13
14
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
14
15
|
|
15
16
|
|
@@ -134,8 +135,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
134
135
|
)
|
135
136
|
|
136
137
|
def init_forward_metadata_capture_cuda_graph(
|
137
|
-
self,
|
138
|
+
self,
|
139
|
+
bs: int,
|
140
|
+
req_pool_indices: torch.Tensor,
|
141
|
+
seq_lens: torch.Tensor,
|
142
|
+
encoder_lens=None,
|
138
143
|
):
|
144
|
+
# NOTE: encoder_lens expected to be zeros or None
|
139
145
|
self.forward_metadata = (
|
140
146
|
self.cuda_graph_start_loc,
|
141
147
|
self.cuda_graph_attn_logits,
|
@@ -144,15 +150,23 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
144
150
|
)
|
145
151
|
|
146
152
|
def init_forward_metadata_replay_cuda_graph(
|
147
|
-
self,
|
153
|
+
self,
|
154
|
+
bs: int,
|
155
|
+
req_pool_indices: torch.Tensor,
|
156
|
+
seq_lens: torch.Tensor,
|
157
|
+
seq_lens_sum: int,
|
158
|
+
encoder_lens=None,
|
148
159
|
):
|
160
|
+
# NOTE: encoder_lens expected to be zeros or None
|
149
161
|
self.cuda_graph_start_loc.zero_()
|
150
162
|
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
151
163
|
|
152
164
|
def get_cuda_graph_seq_len_fill_value(self):
|
153
165
|
return 1
|
154
166
|
|
155
|
-
def forward_extend(
|
167
|
+
def forward_extend(
|
168
|
+
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
169
|
+
):
|
156
170
|
# TODO: reuse the buffer across layers
|
157
171
|
if layer.qk_head_dim != layer.v_head_dim:
|
158
172
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
@@ -168,7 +182,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
168
182
|
)
|
169
183
|
|
170
184
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
171
|
-
layer
|
185
|
+
layer, forward_batch.out_cache_loc, k, v, k_label
|
172
186
|
)
|
173
187
|
|
174
188
|
(
|
@@ -197,7 +211,9 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
197
211
|
)
|
198
212
|
return o
|
199
213
|
|
200
|
-
def forward_decode(
|
214
|
+
def forward_decode(
|
215
|
+
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
216
|
+
):
|
201
217
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
202
218
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
203
219
|
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
@@ -227,7 +243,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
227
243
|
)
|
228
244
|
|
229
245
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
230
|
-
layer
|
246
|
+
layer, forward_batch.out_cache_loc, k, v, k_label
|
231
247
|
)
|
232
248
|
|
233
249
|
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
|