sglang 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +10 -6
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +0 -4
- sglang/lang/backend/runtime_endpoint.py +5 -2
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +29 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +1 -3
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +6 -25
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +256 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +104 -71
- sglang/srt/managers/tokenizer_manager.py +17 -8
- sglang/srt/managers/tp_worker.py +181 -115
- sglang/srt/model_executor/cuda_graph_runner.py +58 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +117 -131
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +1 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/exaone.py +1 -5
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama.py +51 -5
- sglang/srt/models/llama_classification.py +1 -20
- sglang/srt/models/llava.py +30 -5
- sglang/srt/models/llavavid.py +2 -2
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +665 -0
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +57 -44
- sglang/srt/server.py +24 -14
- sglang/srt/server_args.py +130 -28
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +7 -5
- sglang/test/test_utils.py +85 -1
- sglang/utils.py +32 -37
- sglang/version.py +1 -1
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/METADATA +30 -18
- sglang-0.3.1.dist-info/RECORD +129 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
- sglang-0.3.0.dist-info/RECORD +0 -118
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
"""
|
2
4
|
Copyright 2023-2024 SGLang Team
|
3
5
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -13,15 +15,13 @@ See the License for the specific language governing permissions and
|
|
13
15
|
limitations under the License.
|
14
16
|
"""
|
15
17
|
|
16
|
-
"""Run the model with cuda graph."""
|
18
|
+
"""Run the model with cuda graph and torch.compile."""
|
17
19
|
|
18
20
|
import bisect
|
19
21
|
from contextlib import contextmanager
|
20
|
-
from typing import
|
22
|
+
from typing import TYPE_CHECKING, Callable
|
21
23
|
|
22
24
|
import torch
|
23
|
-
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
24
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
25
25
|
from vllm.distributed.parallel_state import graph_capture
|
26
26
|
from vllm.model_executor.custom_op import CustomOp
|
27
27
|
|
@@ -30,16 +30,13 @@ from sglang.srt.layers.logits_processor import (
|
|
30
30
|
LogitsProcessor,
|
31
31
|
LogitsProcessorOutput,
|
32
32
|
)
|
33
|
-
from sglang.srt.layers.sampler import SampleOutput
|
34
33
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
35
|
-
from sglang.srt.model_executor.forward_batch_info import
|
36
|
-
ForwardMode,
|
37
|
-
InputMetadata,
|
38
|
-
update_flashinfer_indices,
|
39
|
-
)
|
40
|
-
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
34
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
41
35
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
42
36
|
|
37
|
+
if TYPE_CHECKING:
|
38
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
39
|
+
|
43
40
|
|
44
41
|
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
45
42
|
for sub in model._modules.values():
|
@@ -58,6 +55,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
|
58
55
|
def patch_model(
|
59
56
|
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
|
60
57
|
):
|
58
|
+
"""Patch the model to make it compatible with with torch.compile"""
|
61
59
|
backup_ca_comm = None
|
62
60
|
|
63
61
|
try:
|
@@ -89,28 +87,33 @@ def set_torch_compile_config():
|
|
89
87
|
|
90
88
|
|
91
89
|
class CudaGraphRunner:
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
use_torch_compile: bool,
|
97
|
-
disable_padding: bool,
|
98
|
-
):
|
90
|
+
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
91
|
+
|
92
|
+
def __init__(self, model_runner: "ModelRunner"):
|
93
|
+
# Parse args
|
99
94
|
self.model_runner = model_runner
|
100
95
|
self.graphs = {}
|
101
96
|
self.input_buffers = {}
|
102
97
|
self.output_buffers = {}
|
103
98
|
self.flashinfer_handlers = {}
|
104
99
|
self.graph_memory_pool = None
|
105
|
-
self.
|
100
|
+
self.use_torch_compile = model_runner.server_args.enable_torch_compile
|
101
|
+
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
102
|
+
|
103
|
+
# Batch sizes to capture
|
104
|
+
if self.model_runner.server_args.disable_cuda_graph_padding:
|
105
|
+
self.capture_bs = list(range(1, 32)) + [64, 128]
|
106
|
+
else:
|
107
|
+
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
108
|
+
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if self.use_torch_compile else []
|
106
109
|
|
107
110
|
# Common inputs
|
108
|
-
self.max_bs =
|
111
|
+
self.max_bs = max(self.capture_bs)
|
109
112
|
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
110
113
|
self.req_pool_indices = torch.zeros(
|
111
114
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
112
115
|
)
|
113
|
-
self.seq_lens = torch.
|
116
|
+
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
|
114
117
|
self.position_ids_offsets = torch.ones(
|
115
118
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
116
119
|
)
|
@@ -118,56 +121,38 @@ class CudaGraphRunner:
|
|
118
121
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
119
122
|
)
|
120
123
|
|
121
|
-
#
|
122
|
-
self.
|
123
|
-
|
124
|
-
|
125
|
-
self.flashinfer_kv_indices = torch.zeros(
|
126
|
-
(self.max_bs * model_runner.model_config.context_len,),
|
127
|
-
dtype=torch.int32,
|
128
|
-
device="cuda",
|
124
|
+
# Attention backend
|
125
|
+
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
126
|
+
self.seq_len_fill_value = (
|
127
|
+
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
129
128
|
)
|
130
|
-
self.flashinfer_kv_last_page_len = torch.ones(
|
131
|
-
(self.max_bs,), dtype=torch.int32, device="cuda"
|
132
|
-
)
|
133
|
-
if model_runner.sliding_window_size is None:
|
134
|
-
self.flashinfer_workspace_buffer = (
|
135
|
-
self.model_runner.flashinfer_workspace_buffer
|
136
|
-
)
|
137
|
-
else:
|
138
|
-
self.flashinfer_workspace_buffer = (
|
139
|
-
self.model_runner.flashinfer_workspace_buffer
|
140
|
-
)
|
141
|
-
|
142
|
-
self.flashinfer_kv_indptr = [
|
143
|
-
self.flashinfer_kv_indptr,
|
144
|
-
self.flashinfer_kv_indptr.clone(),
|
145
|
-
]
|
146
|
-
self.flashinfer_kv_indices = [
|
147
|
-
self.flashinfer_kv_indices,
|
148
|
-
self.flashinfer_kv_indices.clone(),
|
149
|
-
]
|
150
129
|
|
151
|
-
|
152
|
-
vocab_size = model_runner.model_config.vocab_size
|
153
|
-
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
|
154
|
-
|
155
|
-
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
156
|
-
|
157
|
-
if use_torch_compile:
|
130
|
+
if self.use_torch_compile:
|
158
131
|
set_torch_compile_config()
|
159
132
|
|
133
|
+
# Capture
|
134
|
+
try:
|
135
|
+
self.capture()
|
136
|
+
except RuntimeError as e:
|
137
|
+
raise Exception(
|
138
|
+
f"Capture cuda graph failed: {e}\n"
|
139
|
+
"Possible solutions:\n"
|
140
|
+
"1. disable cuda graph by --disable-cuda-graph\n"
|
141
|
+
"2. set --mem-fraction-static to a smaller value\n"
|
142
|
+
"3. disable torch compile by not using --enable-torch-compile\n"
|
143
|
+
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
144
|
+
)
|
145
|
+
|
160
146
|
def can_run(self, batch_size: int):
|
161
147
|
if self.disable_padding:
|
162
148
|
return batch_size in self.graphs
|
163
149
|
else:
|
164
150
|
return batch_size <= self.max_bs
|
165
151
|
|
166
|
-
def capture(self
|
167
|
-
self.batch_size_list = batch_size_list
|
152
|
+
def capture(self):
|
168
153
|
with graph_capture() as graph_capture_context:
|
169
154
|
self.stream = graph_capture_context.stream
|
170
|
-
for bs in
|
155
|
+
for bs in self.capture_bs:
|
171
156
|
with patch_model(
|
172
157
|
self.model_runner.model,
|
173
158
|
bs in self.compile_bs,
|
@@ -175,14 +160,10 @@ class CudaGraphRunner:
|
|
175
160
|
) as forward:
|
176
161
|
(
|
177
162
|
graph,
|
178
|
-
input_buffers,
|
179
163
|
output_buffers,
|
180
|
-
flashinfer_handler,
|
181
164
|
) = self.capture_one_batch_size(bs, forward)
|
182
165
|
self.graphs[bs] = graph
|
183
|
-
self.input_buffers[bs] = input_buffers
|
184
166
|
self.output_buffers[bs] = output_buffers
|
185
|
-
self.flashinfer_handlers[bs] = flashinfer_handler
|
186
167
|
|
187
168
|
def capture_one_batch_size(self, bs: int, forward: Callable):
|
188
169
|
graph = torch.cuda.CUDAGraph()
|
@@ -195,67 +176,26 @@ class CudaGraphRunner:
|
|
195
176
|
position_ids_offsets = self.position_ids_offsets[:bs]
|
196
177
|
out_cache_loc = self.out_cache_loc[:bs]
|
197
178
|
|
198
|
-
#
|
199
|
-
|
200
|
-
|
201
|
-
// self.model_runner.tp_size,
|
202
|
-
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
|
203
|
-
):
|
204
|
-
use_tensor_cores = True
|
205
|
-
else:
|
206
|
-
use_tensor_cores = False
|
207
|
-
if self.model_runner.sliding_window_size is None:
|
208
|
-
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
209
|
-
self.flashinfer_workspace_buffer,
|
210
|
-
"NHD",
|
211
|
-
use_cuda_graph=True,
|
212
|
-
use_tensor_cores=use_tensor_cores,
|
213
|
-
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
|
214
|
-
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
215
|
-
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
216
|
-
)
|
217
|
-
else:
|
218
|
-
flashinfer_decode_wrapper = []
|
219
|
-
for i in range(2):
|
220
|
-
flashinfer_decode_wrapper.append(
|
221
|
-
BatchDecodeWithPagedKVCacheWrapper(
|
222
|
-
self.flashinfer_workspace_buffer,
|
223
|
-
"NHD",
|
224
|
-
use_cuda_graph=True,
|
225
|
-
use_tensor_cores=use_tensor_cores,
|
226
|
-
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
|
227
|
-
paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
|
228
|
-
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
|
229
|
-
:bs
|
230
|
-
],
|
231
|
-
)
|
232
|
-
)
|
233
|
-
update_flashinfer_indices(
|
234
|
-
ForwardMode.DECODE,
|
235
|
-
self.model_runner,
|
236
|
-
req_pool_indices,
|
237
|
-
seq_lens,
|
238
|
-
None,
|
239
|
-
flashinfer_decode_wrapper,
|
179
|
+
# Attention backend
|
180
|
+
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
181
|
+
bs, req_pool_indices, seq_lens
|
240
182
|
)
|
241
183
|
|
242
184
|
# Run and capture
|
243
185
|
def run_once():
|
244
186
|
input_metadata = InputMetadata(
|
245
187
|
forward_mode=ForwardMode.DECODE,
|
246
|
-
sampling_info=self.sampling_info[:bs],
|
247
188
|
batch_size=bs,
|
248
189
|
req_pool_indices=req_pool_indices,
|
249
190
|
seq_lens=seq_lens,
|
250
191
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
251
192
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
193
|
+
attn_backend=self.model_runner.attn_backend,
|
252
194
|
out_cache_loc=out_cache_loc,
|
253
195
|
return_logprob=False,
|
254
|
-
top_logprobs_nums=0,
|
196
|
+
top_logprobs_nums=[0] * bs,
|
255
197
|
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
|
256
|
-
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
257
198
|
)
|
258
|
-
|
259
199
|
return forward(input_ids, input_metadata.positions, input_metadata)
|
260
200
|
|
261
201
|
for _ in range(2):
|
@@ -277,17 +217,17 @@ class CudaGraphRunner:
|
|
277
217
|
self.model_runner.tp_group.barrier()
|
278
218
|
|
279
219
|
self.graph_memory_pool = graph.pool()
|
280
|
-
return graph,
|
220
|
+
return graph, out
|
281
221
|
|
282
222
|
def replay(self, batch: ScheduleBatch):
|
283
223
|
assert batch.out_cache_loc is not None
|
284
224
|
raw_bs = len(batch.reqs)
|
285
225
|
|
286
226
|
# Pad
|
287
|
-
index = bisect.bisect_left(self.
|
288
|
-
bs = self.
|
227
|
+
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
228
|
+
bs = self.capture_bs[index]
|
289
229
|
if bs != raw_bs:
|
290
|
-
self.seq_lens.
|
230
|
+
self.seq_lens.fill_(self.seq_len_fill_value)
|
291
231
|
self.position_ids_offsets.fill_(1)
|
292
232
|
self.out_cache_loc.zero_()
|
293
233
|
|
@@ -298,24 +238,14 @@ class CudaGraphRunner:
|
|
298
238
|
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
|
299
239
|
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
300
240
|
|
301
|
-
#
|
302
|
-
|
303
|
-
|
304
|
-
self.model_runner,
|
305
|
-
self.req_pool_indices[:bs],
|
306
|
-
self.seq_lens[:bs],
|
307
|
-
None,
|
308
|
-
self.flashinfer_handlers[bs],
|
241
|
+
# Attention backend
|
242
|
+
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
243
|
+
bs, self.req_pool_indices, self.seq_lens
|
309
244
|
)
|
310
245
|
|
311
|
-
# Sampling inputs
|
312
|
-
self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
|
313
|
-
|
314
246
|
# Replay
|
315
|
-
torch.cuda.synchronize()
|
316
247
|
self.graphs[bs].replay()
|
317
|
-
|
318
|
-
sample_output, logits_output = self.output_buffers[bs]
|
248
|
+
logits_output = self.output_buffers[bs]
|
319
249
|
|
320
250
|
# Unpad
|
321
251
|
if bs != raw_bs:
|
@@ -327,11 +257,6 @@ class CudaGraphRunner:
|
|
327
257
|
input_top_logprobs=None,
|
328
258
|
output_top_logprobs=None,
|
329
259
|
)
|
330
|
-
sample_output = SampleOutput(
|
331
|
-
sample_output.success[:raw_bs],
|
332
|
-
sample_output.probs[:raw_bs],
|
333
|
-
sample_output.batch_next_token_ids[:raw_bs],
|
334
|
-
)
|
335
260
|
|
336
261
|
# Extract logprobs
|
337
262
|
if batch.return_logprob:
|
@@ -348,4 +273,4 @@ class CudaGraphRunner:
|
|
348
273
|
logits_output.next_token_logprobs, logits_metadata
|
349
274
|
)[1]
|
350
275
|
|
351
|
-
return
|
276
|
+
return logits_output
|