sglang 0.2.9.post1__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +234 -74
- sglang/check_env.py +25 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -40
- sglang/lang/choices.py +164 -0
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +11 -2
- sglang/srt/hf_transformers_utils.py +2 -2
- sglang/srt/layers/extend_attention.py +59 -7
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/radix_attention.py +24 -14
- sglang/srt/layers/token_attention.py +28 -2
- sglang/srt/managers/io_struct.py +9 -4
- sglang/srt/managers/schedule_batch.py +98 -323
- sglang/srt/managers/tokenizer_manager.py +34 -16
- sglang/srt/managers/tp_worker.py +20 -22
- sglang/srt/mem_cache/memory_pool.py +74 -38
- sglang/srt/model_config.py +11 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -3
- sglang/srt/model_executor/forward_batch_info.py +256 -0
- sglang/srt/model_executor/model_runner.py +51 -26
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +199 -17
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +1 -1
- sglang/srt/models/llama2.py +1 -1
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +151 -29
- sglang/srt/openai_api/protocol.py +7 -1
- sglang/srt/server.py +111 -84
- sglang/srt/server_args.py +12 -2
- sglang/srt/utils.py +25 -20
- sglang/test/run_eval.py +21 -10
- sglang/test/runners.py +237 -0
- sglang/test/simple_eval_common.py +12 -12
- sglang/test/simple_eval_gpqa.py +92 -0
- sglang/test/simple_eval_humaneval.py +5 -5
- sglang/test/simple_eval_math.py +72 -0
- sglang/test/test_utils.py +95 -14
- sglang/utils.py +15 -37
- sglang/version.py +1 -1
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/METADATA +59 -48
- sglang-0.2.11.dist-info/RECORD +102 -0
- sglang-0.2.9.post1.dist-info/RECORD +0 -97
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,256 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""ModelRunner runs the forward passes of the models."""
|
17
|
+
from dataclasses import dataclass
|
18
|
+
from enum import IntEnum, auto
|
19
|
+
from typing import List
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
import torch
|
23
|
+
|
24
|
+
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
25
|
+
|
26
|
+
|
27
|
+
class ForwardMode(IntEnum):
|
28
|
+
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
29
|
+
PREFILL = auto()
|
30
|
+
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
31
|
+
EXTEND = auto()
|
32
|
+
# Decode one token.
|
33
|
+
DECODE = auto()
|
34
|
+
|
35
|
+
|
36
|
+
@dataclass
|
37
|
+
class InputMetadata:
|
38
|
+
"""Store all inforamtion of a forward pass."""
|
39
|
+
|
40
|
+
forward_mode: ForwardMode
|
41
|
+
batch_size: int
|
42
|
+
total_num_tokens: int
|
43
|
+
req_pool_indices: torch.Tensor
|
44
|
+
seq_lens: torch.Tensor
|
45
|
+
positions: torch.Tensor
|
46
|
+
req_to_token_pool: ReqToTokenPool
|
47
|
+
token_to_kv_pool: BaseTokenToKVPool
|
48
|
+
|
49
|
+
# For extend
|
50
|
+
extend_seq_lens: torch.Tensor
|
51
|
+
extend_start_loc: torch.Tensor
|
52
|
+
extend_no_prefix: bool
|
53
|
+
|
54
|
+
# Output location of the KV cache
|
55
|
+
out_cache_loc: torch.Tensor = None
|
56
|
+
|
57
|
+
# Output options
|
58
|
+
return_logprob: bool = False
|
59
|
+
top_logprobs_nums: List[int] = None
|
60
|
+
|
61
|
+
# Trition attention backend
|
62
|
+
triton_max_seq_len: int = 0
|
63
|
+
triton_max_extend_len: int = 0
|
64
|
+
triton_start_loc: torch.Tensor = None
|
65
|
+
triton_prefix_lens: torch.Tensor = None
|
66
|
+
|
67
|
+
# FlashInfer attention backend
|
68
|
+
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
69
|
+
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
70
|
+
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
71
|
+
flashinfer_use_ragged: bool = False
|
72
|
+
|
73
|
+
@classmethod
|
74
|
+
def create(
|
75
|
+
cls,
|
76
|
+
model_runner,
|
77
|
+
forward_mode,
|
78
|
+
req_pool_indices,
|
79
|
+
seq_lens,
|
80
|
+
prefix_lens,
|
81
|
+
position_ids_offsets,
|
82
|
+
out_cache_loc,
|
83
|
+
top_logprobs_nums=None,
|
84
|
+
return_logprob=False,
|
85
|
+
skip_flashinfer_init=False,
|
86
|
+
):
|
87
|
+
flashinfer_use_ragged = False
|
88
|
+
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
89
|
+
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
|
90
|
+
flashinfer_use_ragged = True
|
91
|
+
init_flashinfer_args(
|
92
|
+
forward_mode,
|
93
|
+
model_runner,
|
94
|
+
req_pool_indices,
|
95
|
+
seq_lens,
|
96
|
+
prefix_lens,
|
97
|
+
model_runner.flashinfer_decode_wrapper,
|
98
|
+
flashinfer_use_ragged,
|
99
|
+
)
|
100
|
+
|
101
|
+
batch_size = len(req_pool_indices)
|
102
|
+
|
103
|
+
if forward_mode == ForwardMode.DECODE:
|
104
|
+
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
105
|
+
extend_seq_lens = extend_start_loc = extend_no_prefix = None
|
106
|
+
if not model_runner.server_args.disable_flashinfer:
|
107
|
+
# This variable is not needed in this case,
|
108
|
+
# we do not compute it to make it compatbile with cuda graph.
|
109
|
+
total_num_tokens = None
|
110
|
+
else:
|
111
|
+
total_num_tokens = int(torch.sum(seq_lens))
|
112
|
+
else:
|
113
|
+
seq_lens_cpu = seq_lens.cpu().numpy()
|
114
|
+
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
115
|
+
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
116
|
+
positions = torch.tensor(
|
117
|
+
np.concatenate(
|
118
|
+
[
|
119
|
+
np.arange(
|
120
|
+
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
121
|
+
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
122
|
+
)
|
123
|
+
for i in range(batch_size)
|
124
|
+
],
|
125
|
+
axis=0,
|
126
|
+
),
|
127
|
+
device="cuda",
|
128
|
+
)
|
129
|
+
extend_seq_lens = seq_lens - prefix_lens
|
130
|
+
extend_start_loc = torch.zeros_like(seq_lens)
|
131
|
+
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
|
132
|
+
extend_no_prefix = torch.all(prefix_lens == 0)
|
133
|
+
total_num_tokens = int(torch.sum(seq_lens))
|
134
|
+
|
135
|
+
ret = cls(
|
136
|
+
forward_mode=forward_mode,
|
137
|
+
batch_size=batch_size,
|
138
|
+
total_num_tokens=total_num_tokens,
|
139
|
+
req_pool_indices=req_pool_indices,
|
140
|
+
seq_lens=seq_lens,
|
141
|
+
positions=positions,
|
142
|
+
req_to_token_pool=model_runner.req_to_token_pool,
|
143
|
+
token_to_kv_pool=model_runner.token_to_kv_pool,
|
144
|
+
out_cache_loc=out_cache_loc,
|
145
|
+
extend_seq_lens=extend_seq_lens,
|
146
|
+
extend_start_loc=extend_start_loc,
|
147
|
+
extend_no_prefix=extend_no_prefix,
|
148
|
+
return_logprob=return_logprob,
|
149
|
+
top_logprobs_nums=top_logprobs_nums,
|
150
|
+
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
151
|
+
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
152
|
+
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
153
|
+
flashinfer_use_ragged=flashinfer_use_ragged,
|
154
|
+
)
|
155
|
+
|
156
|
+
if model_runner.server_args.disable_flashinfer:
|
157
|
+
(
|
158
|
+
ret.triton_max_seq_len,
|
159
|
+
ret.triton_max_extend_len,
|
160
|
+
ret.triton_start_loc,
|
161
|
+
ret.triton_prefix_lens,
|
162
|
+
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
163
|
+
|
164
|
+
return ret
|
165
|
+
|
166
|
+
|
167
|
+
def init_flashinfer_args(
|
168
|
+
forward_mode,
|
169
|
+
model_runner,
|
170
|
+
req_pool_indices,
|
171
|
+
seq_lens,
|
172
|
+
prefix_lens,
|
173
|
+
flashinfer_decode_wrapper,
|
174
|
+
flashinfer_use_ragged=False,
|
175
|
+
):
|
176
|
+
"""Init auxiliary variables for FlashInfer attention backend."""
|
177
|
+
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
178
|
+
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
179
|
+
head_dim = model_runner.model_config.head_dim
|
180
|
+
batch_size = len(req_pool_indices)
|
181
|
+
total_num_tokens = int(torch.sum(seq_lens))
|
182
|
+
|
183
|
+
if flashinfer_use_ragged:
|
184
|
+
paged_kernel_lens = prefix_lens
|
185
|
+
else:
|
186
|
+
paged_kernel_lens = seq_lens
|
187
|
+
|
188
|
+
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
189
|
+
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
190
|
+
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
191
|
+
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
192
|
+
kv_indices = torch.cat(
|
193
|
+
[
|
194
|
+
model_runner.req_to_token_pool.req_to_token[
|
195
|
+
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
196
|
+
]
|
197
|
+
for i in range(batch_size)
|
198
|
+
],
|
199
|
+
dim=0,
|
200
|
+
).contiguous()
|
201
|
+
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
202
|
+
|
203
|
+
if forward_mode == ForwardMode.DECODE:
|
204
|
+
flashinfer_decode_wrapper.end_forward()
|
205
|
+
flashinfer_decode_wrapper.begin_forward(
|
206
|
+
kv_indptr,
|
207
|
+
kv_indices,
|
208
|
+
kv_last_page_len,
|
209
|
+
num_qo_heads,
|
210
|
+
num_kv_heads,
|
211
|
+
head_dim,
|
212
|
+
1,
|
213
|
+
)
|
214
|
+
else:
|
215
|
+
# extend part
|
216
|
+
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
217
|
+
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
218
|
+
|
219
|
+
if flashinfer_use_ragged:
|
220
|
+
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
221
|
+
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
222
|
+
qo_indptr,
|
223
|
+
qo_indptr,
|
224
|
+
num_qo_heads,
|
225
|
+
num_kv_heads,
|
226
|
+
head_dim,
|
227
|
+
)
|
228
|
+
|
229
|
+
# cached part
|
230
|
+
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
231
|
+
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
232
|
+
qo_indptr,
|
233
|
+
kv_indptr,
|
234
|
+
kv_indices,
|
235
|
+
kv_last_page_len,
|
236
|
+
num_qo_heads,
|
237
|
+
num_kv_heads,
|
238
|
+
head_dim,
|
239
|
+
1,
|
240
|
+
)
|
241
|
+
|
242
|
+
|
243
|
+
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
244
|
+
"""Init auxiliary variables for triton attention backend."""
|
245
|
+
batch_size = len(seq_lens)
|
246
|
+
max_seq_len = int(torch.max(seq_lens))
|
247
|
+
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
248
|
+
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
249
|
+
|
250
|
+
if forward_mode == ForwardMode.DECODE:
|
251
|
+
max_extend_len = None
|
252
|
+
else:
|
253
|
+
extend_seq_lens = seq_lens - prefix_lens
|
254
|
+
max_extend_len = int(torch.max(extend_seq_lens))
|
255
|
+
|
256
|
+
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
@@ -41,13 +41,14 @@ from vllm.distributed import (
|
|
41
41
|
from vllm.model_executor.models import ModelRegistry
|
42
42
|
|
43
43
|
from sglang.global_config import global_config
|
44
|
-
from sglang.srt.managers.schedule_batch import
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
44
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
45
|
+
from sglang.srt.mem_cache.memory_pool import (
|
46
|
+
MHATokenToKVPool,
|
47
|
+
MLATokenToKVPool,
|
48
|
+
ReqToTokenPool,
|
49
49
|
)
|
50
|
-
from sglang.srt.
|
50
|
+
from sglang.srt.model_config import AttentionArch
|
51
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
51
52
|
from sglang.srt.server_args import ServerArgs
|
52
53
|
from sglang.srt.utils import (
|
53
54
|
get_available_gpu_memory,
|
@@ -86,6 +87,7 @@ class ModelRunner:
|
|
86
87
|
"disable_flashinfer": server_args.disable_flashinfer,
|
87
88
|
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
|
88
89
|
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
90
|
+
"enable_mla": server_args.enable_mla,
|
89
91
|
}
|
90
92
|
)
|
91
93
|
|
@@ -193,15 +195,23 @@ class ModelRunner:
|
|
193
195
|
available_gpu_memory = get_available_gpu_memory(
|
194
196
|
self.gpu_id, distributed=self.tp_size > 1
|
195
197
|
)
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
198
|
+
if (
|
199
|
+
self.model_config.attention_arch == AttentionArch.MLA
|
200
|
+
and self.server_args.enable_mla
|
201
|
+
):
|
202
|
+
cell_size = (
|
203
|
+
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
204
|
+
* self.model_config.num_hidden_layers
|
205
|
+
* torch._utils._element_size(self.dtype)
|
206
|
+
)
|
207
|
+
else:
|
208
|
+
cell_size = (
|
209
|
+
self.model_config.get_num_kv_heads(self.tp_size)
|
210
|
+
* self.model_config.head_dim
|
211
|
+
* self.model_config.num_hidden_layers
|
212
|
+
* 2
|
213
|
+
* torch._utils._element_size(self.dtype)
|
214
|
+
)
|
205
215
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
206
216
|
1 - self.mem_fraction_static
|
207
217
|
)
|
@@ -241,13 +251,28 @@ class ModelRunner:
|
|
241
251
|
max_num_reqs,
|
242
252
|
self.model_config.context_len + 8,
|
243
253
|
)
|
244
|
-
|
245
|
-
self.
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
254
|
+
if (
|
255
|
+
self.model_config.attention_arch == AttentionArch.MLA
|
256
|
+
and self.server_args.enable_mla
|
257
|
+
):
|
258
|
+
self.token_to_kv_pool = MLATokenToKVPool(
|
259
|
+
self.max_total_num_tokens,
|
260
|
+
dtype=self.dtype,
|
261
|
+
kv_lora_rank=self.model_config.kv_lora_rank,
|
262
|
+
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
263
|
+
layer_num=self.model_config.num_hidden_layers,
|
264
|
+
)
|
265
|
+
logger.info("using MLA Triton implementaion, flashinfer is disabled")
|
266
|
+
# FIXME: temporarily only Triton MLA is supported
|
267
|
+
self.server_args.disable_flashinfer = True
|
268
|
+
else:
|
269
|
+
self.token_to_kv_pool = MHATokenToKVPool(
|
270
|
+
self.max_total_num_tokens,
|
271
|
+
dtype=self.dtype,
|
272
|
+
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
273
|
+
head_dim=self.model_config.head_dim,
|
274
|
+
layer_num=self.model_config.num_hidden_layers,
|
275
|
+
)
|
251
276
|
logger.info(
|
252
277
|
f"[gpu={self.gpu_id}] Memory pool end. "
|
253
278
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
@@ -321,7 +346,7 @@ class ModelRunner:
|
|
321
346
|
)
|
322
347
|
|
323
348
|
@torch.inference_mode()
|
324
|
-
def forward_decode(self, batch:
|
349
|
+
def forward_decode(self, batch: ScheduleBatch):
|
325
350
|
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
326
351
|
return self.cuda_graph_runner.replay(batch)
|
327
352
|
|
@@ -341,7 +366,7 @@ class ModelRunner:
|
|
341
366
|
)
|
342
367
|
|
343
368
|
@torch.inference_mode()
|
344
|
-
def forward_extend(self, batch:
|
369
|
+
def forward_extend(self, batch: ScheduleBatch):
|
345
370
|
input_metadata = InputMetadata.create(
|
346
371
|
self,
|
347
372
|
forward_mode=ForwardMode.EXTEND,
|
@@ -358,7 +383,7 @@ class ModelRunner:
|
|
358
383
|
)
|
359
384
|
|
360
385
|
@torch.inference_mode()
|
361
|
-
def forward_extend_multi_modal(self, batch:
|
386
|
+
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
362
387
|
input_metadata = InputMetadata.create(
|
363
388
|
self,
|
364
389
|
forward_mode=ForwardMode.EXTEND,
|
@@ -379,7 +404,7 @@ class ModelRunner:
|
|
379
404
|
batch.image_offsets,
|
380
405
|
)
|
381
406
|
|
382
|
-
def forward(self, batch:
|
407
|
+
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
|
383
408
|
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
384
409
|
return self.forward_extend_multi_modal(batch)
|
385
410
|
elif forward_mode == ForwardMode.DECODE:
|
sglang/srt/models/chatglm.py
CHANGED
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig
|
|
45
45
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
-
from sglang.srt.model_executor.
|
48
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
49
|
|
50
50
|
LoraConfig = None
|
51
51
|
|
sglang/srt/models/commandr.py
CHANGED
@@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
|
64
64
|
|
65
65
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
66
66
|
from sglang.srt.layers.radix_attention import RadixAttention
|
67
|
-
from sglang.srt.model_executor.
|
67
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
68
68
|
|
69
69
|
|
70
70
|
@torch.compile
|
sglang/srt/models/dbrx.py
CHANGED
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
|
45
45
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
-
from sglang.srt.model_executor.
|
48
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
49
|
|
50
50
|
|
51
51
|
class DbrxRouter(nn.Module):
|
sglang/srt/models/deepseek.py
CHANGED
@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
46
46
|
|
47
47
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
48
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
-
from sglang.srt.
|
49
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
50
50
|
|
51
51
|
|
52
52
|
class DeepseekMLP(nn.Module):
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -45,7 +45,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
45
45
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
-
from sglang.srt.
|
48
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
49
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
50
|
|
50
51
|
|
51
52
|
class DeepseekV2MLP(nn.Module):
|
@@ -312,6 +313,165 @@ class DeepseekV2Attention(nn.Module):
|
|
312
313
|
return output
|
313
314
|
|
314
315
|
|
316
|
+
class DeepseekV2AttentionMLA(nn.Module):
|
317
|
+
|
318
|
+
def __init__(
|
319
|
+
self,
|
320
|
+
config: PretrainedConfig,
|
321
|
+
hidden_size: int,
|
322
|
+
num_heads: int,
|
323
|
+
qk_nope_head_dim: int,
|
324
|
+
qk_rope_head_dim: int,
|
325
|
+
v_head_dim: int,
|
326
|
+
q_lora_rank: int,
|
327
|
+
kv_lora_rank: int,
|
328
|
+
rope_theta: float = 10000,
|
329
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
330
|
+
max_position_embeddings: int = 8192,
|
331
|
+
cache_config: Optional[CacheConfig] = None,
|
332
|
+
quant_config: Optional[QuantizationConfig] = None,
|
333
|
+
layer_id=None,
|
334
|
+
) -> None:
|
335
|
+
super().__init__()
|
336
|
+
self.layer_id = layer_id
|
337
|
+
self.hidden_size = hidden_size
|
338
|
+
self.qk_nope_head_dim = qk_nope_head_dim
|
339
|
+
self.qk_rope_head_dim = qk_rope_head_dim
|
340
|
+
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
341
|
+
self.v_head_dim = v_head_dim
|
342
|
+
self.q_lora_rank = q_lora_rank
|
343
|
+
self.kv_lora_rank = kv_lora_rank
|
344
|
+
self.num_heads = num_heads
|
345
|
+
tp_size = get_tensor_model_parallel_world_size()
|
346
|
+
assert num_heads % tp_size == 0
|
347
|
+
self.num_local_heads = num_heads // tp_size
|
348
|
+
self.scaling = self.qk_head_dim**-0.5
|
349
|
+
self.rope_theta = rope_theta
|
350
|
+
self.max_position_embeddings = max_position_embeddings
|
351
|
+
|
352
|
+
if self.q_lora_rank is not None:
|
353
|
+
self.q_a_proj = ReplicatedLinear(
|
354
|
+
self.hidden_size,
|
355
|
+
self.q_lora_rank,
|
356
|
+
bias=False,
|
357
|
+
quant_config=quant_config,
|
358
|
+
)
|
359
|
+
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
360
|
+
self.q_b_proj = ColumnParallelLinear(
|
361
|
+
q_lora_rank,
|
362
|
+
self.num_heads * self.qk_head_dim,
|
363
|
+
bias=False,
|
364
|
+
quant_config=quant_config,
|
365
|
+
)
|
366
|
+
else:
|
367
|
+
self.q_proj = ColumnParallelLinear(
|
368
|
+
self.hidden_size,
|
369
|
+
self.num_heads * self.qk_head_dim,
|
370
|
+
bias=False,
|
371
|
+
quant_config=quant_config,
|
372
|
+
)
|
373
|
+
|
374
|
+
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
375
|
+
self.hidden_size,
|
376
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
377
|
+
bias=False,
|
378
|
+
quant_config=quant_config,
|
379
|
+
)
|
380
|
+
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
381
|
+
self.kv_b_proj = ColumnParallelLinear(
|
382
|
+
self.kv_lora_rank,
|
383
|
+
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
384
|
+
bias=False,
|
385
|
+
quant_config=quant_config,
|
386
|
+
)
|
387
|
+
# O projection.
|
388
|
+
self.o_proj = RowParallelLinear(
|
389
|
+
self.num_heads * self.v_head_dim,
|
390
|
+
self.hidden_size,
|
391
|
+
bias=False,
|
392
|
+
quant_config=quant_config,
|
393
|
+
)
|
394
|
+
rope_scaling["type"] = "deepseek_yarn"
|
395
|
+
self.rotary_emb = get_rope(
|
396
|
+
qk_rope_head_dim,
|
397
|
+
rotary_dim=qk_rope_head_dim,
|
398
|
+
max_position=max_position_embeddings,
|
399
|
+
base=rope_theta,
|
400
|
+
rope_scaling=rope_scaling,
|
401
|
+
is_neox_style=False,
|
402
|
+
)
|
403
|
+
|
404
|
+
if rope_scaling:
|
405
|
+
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
406
|
+
scaling_factor = rope_scaling["factor"]
|
407
|
+
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
408
|
+
self.scaling = self.scaling * mscale * mscale
|
409
|
+
|
410
|
+
self.attn = RadixAttention(
|
411
|
+
self.num_local_heads,
|
412
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
413
|
+
self.scaling,
|
414
|
+
num_kv_heads=1,
|
415
|
+
layer_id=layer_id,
|
416
|
+
v_head_dim=self.kv_lora_rank,
|
417
|
+
)
|
418
|
+
|
419
|
+
kv_b_proj = self.kv_b_proj
|
420
|
+
w_kc, w_vc = kv_b_proj.weight.unflatten(
|
421
|
+
0, (-1, qk_nope_head_dim + v_head_dim)
|
422
|
+
).split([qk_nope_head_dim, v_head_dim], dim=1)
|
423
|
+
self.w_kc = w_kc
|
424
|
+
self.w_vc = w_vc
|
425
|
+
|
426
|
+
def forward(
|
427
|
+
self,
|
428
|
+
positions: torch.Tensor,
|
429
|
+
hidden_states: torch.Tensor,
|
430
|
+
input_metadata: InputMetadata,
|
431
|
+
) -> torch.Tensor:
|
432
|
+
q_len = hidden_states.shape[0]
|
433
|
+
q_input = hidden_states.new_empty(
|
434
|
+
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
|
435
|
+
)
|
436
|
+
if self.q_lora_rank is not None:
|
437
|
+
q = self.q_a_proj(hidden_states)[0]
|
438
|
+
q = self.q_a_layernorm(q)
|
439
|
+
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
440
|
+
else:
|
441
|
+
q = self.q_proj(hidden_states)[0].view(
|
442
|
+
-1, self.num_local_heads, self.qk_head_dim
|
443
|
+
)
|
444
|
+
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
445
|
+
q_nope_out = q_input[..., : self.kv_lora_rank]
|
446
|
+
torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
|
447
|
+
|
448
|
+
k_input = self.kv_a_proj_with_mqa(hidden_states)[0].unsqueeze(1)
|
449
|
+
k_pe = k_input[..., self.kv_lora_rank :]
|
450
|
+
v_input = k_input[..., : self.kv_lora_rank]
|
451
|
+
v_input = self.kv_a_layernorm(v_input.contiguous())
|
452
|
+
k_input[..., : self.kv_lora_rank] = v_input
|
453
|
+
|
454
|
+
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
455
|
+
q_input[..., self.kv_lora_rank :] = q_pe
|
456
|
+
k_input[..., self.kv_lora_rank :] = k_pe
|
457
|
+
|
458
|
+
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
|
459
|
+
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
460
|
+
attn_bmm_output = attn_output.new_empty(
|
461
|
+
q_len, self.num_local_heads, self.v_head_dim
|
462
|
+
)
|
463
|
+
torch.bmm(
|
464
|
+
attn_output.transpose(0, 1),
|
465
|
+
self.w_vc.transpose(1, 2).contiguous(),
|
466
|
+
out=attn_bmm_output.transpose(0, 1),
|
467
|
+
)
|
468
|
+
|
469
|
+
attn_output = attn_bmm_output.flatten(1, 2)
|
470
|
+
output, _ = self.o_proj(attn_output)
|
471
|
+
|
472
|
+
return output
|
473
|
+
|
474
|
+
|
315
475
|
class DeepseekV2DecoderLayer(nn.Module):
|
316
476
|
|
317
477
|
def __init__(
|
@@ -326,22 +486,44 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
326
486
|
rope_theta = getattr(config, "rope_theta", 10000)
|
327
487
|
rope_scaling = getattr(config, "rope_scaling", None)
|
328
488
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
489
|
+
if global_server_args_dict["enable_mla"]:
|
490
|
+
self.self_attn = DeepseekV2AttentionMLA(
|
491
|
+
config=config,
|
492
|
+
hidden_size=self.hidden_size,
|
493
|
+
num_heads=config.num_attention_heads,
|
494
|
+
qk_nope_head_dim=config.qk_nope_head_dim,
|
495
|
+
qk_rope_head_dim=config.qk_rope_head_dim,
|
496
|
+
v_head_dim=config.v_head_dim,
|
497
|
+
q_lora_rank=(
|
498
|
+
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
499
|
+
),
|
500
|
+
kv_lora_rank=config.kv_lora_rank,
|
501
|
+
rope_theta=rope_theta,
|
502
|
+
rope_scaling=rope_scaling,
|
503
|
+
max_position_embeddings=max_position_embeddings,
|
504
|
+
cache_config=cache_config,
|
505
|
+
quant_config=quant_config,
|
506
|
+
layer_id=layer_id,
|
507
|
+
)
|
508
|
+
else:
|
509
|
+
self.self_attn = DeepseekV2Attention(
|
510
|
+
config=config,
|
511
|
+
hidden_size=self.hidden_size,
|
512
|
+
num_heads=config.num_attention_heads,
|
513
|
+
qk_nope_head_dim=config.qk_nope_head_dim,
|
514
|
+
qk_rope_head_dim=config.qk_rope_head_dim,
|
515
|
+
v_head_dim=config.v_head_dim,
|
516
|
+
q_lora_rank=(
|
517
|
+
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
518
|
+
),
|
519
|
+
kv_lora_rank=config.kv_lora_rank,
|
520
|
+
rope_theta=rope_theta,
|
521
|
+
rope_scaling=rope_scaling,
|
522
|
+
max_position_embeddings=max_position_embeddings,
|
523
|
+
cache_config=cache_config,
|
524
|
+
quant_config=quant_config,
|
525
|
+
layer_id=layer_id,
|
526
|
+
)
|
345
527
|
if (
|
346
528
|
config.n_routed_experts is not None
|
347
529
|
and layer_id >= config.first_k_dense_replace
|
sglang/srt/models/gemma.py
CHANGED
@@ -37,7 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
37
37
|
|
38
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
-
from sglang.srt.model_executor.
|
40
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
41
41
|
|
42
42
|
|
43
43
|
class GemmaMLP(nn.Module):
|
sglang/srt/models/gemma2.py
CHANGED
@@ -42,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
42
42
|
|
43
43
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
44
44
|
from sglang.srt.layers.radix_attention import RadixAttention
|
45
|
-
from sglang.srt.model_executor.
|
45
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
46
46
|
|
47
47
|
|
48
48
|
class GemmaRMSNorm(CustomOp):
|