sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +17 -8
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +5 -17
- sglang/lang/backend/runtime_endpoint.py +5 -2
- sglang/lang/interpreter.py +1 -4
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +33 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +1 -3
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/fused_moe/layer.py +27 -7
- sglang/srt/layers/layernorm.py +12 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +38 -122
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +259 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +105 -71
- sglang/srt/managers/tokenizer_manager.py +17 -8
- sglang/srt/managers/tp_worker.py +188 -121
- sglang/srt/model_executor/cuda_graph_runner.py +69 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +123 -154
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +1 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/exaone.py +1 -5
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama.py +51 -5
- sglang/srt/models/llama_classification.py +1 -20
- sglang/srt/models/llava.py +30 -5
- sglang/srt/models/llavavid.py +2 -2
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +669 -0
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/olmoe.py +415 -0
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +46 -80
- sglang/srt/server.py +30 -15
- sglang/srt/server_args.py +163 -28
- sglang/srt/utils.py +19 -51
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +7 -5
- sglang/test/test_utils.py +85 -2
- sglang/utils.py +32 -37
- sglang/version.py +1 -1
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
- sglang-0.3.1.post1.dist-info/RECORD +130 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
- sglang-0.3.0.dist-info/RECORD +0 -118
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
"""
|
2
4
|
Copyright 2023-2024 SGLang Team
|
3
5
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -13,15 +15,13 @@ See the License for the specific language governing permissions and
|
|
13
15
|
limitations under the License.
|
14
16
|
"""
|
15
17
|
|
16
|
-
"""Run the model with cuda graph."""
|
18
|
+
"""Run the model with cuda graph and torch.compile."""
|
17
19
|
|
18
20
|
import bisect
|
19
21
|
from contextlib import contextmanager
|
20
|
-
from typing import
|
22
|
+
from typing import TYPE_CHECKING, Callable
|
21
23
|
|
22
24
|
import torch
|
23
|
-
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
24
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
25
25
|
from vllm.distributed.parallel_state import graph_capture
|
26
26
|
from vllm.model_executor.custom_op import CustomOp
|
27
27
|
|
@@ -30,20 +30,20 @@ from sglang.srt.layers.logits_processor import (
|
|
30
30
|
LogitsProcessor,
|
31
31
|
LogitsProcessorOutput,
|
32
32
|
)
|
33
|
-
from sglang.srt.layers.sampler import SampleOutput
|
34
33
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
35
|
-
from sglang.srt.model_executor.forward_batch_info import
|
36
|
-
ForwardMode,
|
37
|
-
InputMetadata,
|
38
|
-
update_flashinfer_indices,
|
39
|
-
)
|
40
|
-
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
34
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
41
35
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
42
36
|
|
37
|
+
if TYPE_CHECKING:
|
38
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
39
|
+
|
43
40
|
|
44
41
|
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
45
42
|
for sub in model._modules.values():
|
46
43
|
if isinstance(sub, CustomOp):
|
44
|
+
# NOTE: FusedMoE torch native implementaiton is not efficient
|
45
|
+
if "FusedMoE" in sub.__class__.__name__:
|
46
|
+
continue
|
47
47
|
if reverse:
|
48
48
|
sub._forward_method = sub.forward_cuda
|
49
49
|
setattr(sub, "is_torch_compile", False)
|
@@ -58,6 +58,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
|
58
58
|
def patch_model(
|
59
59
|
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
|
60
60
|
):
|
61
|
+
"""Patch the model to make it compatible with with torch.compile"""
|
61
62
|
backup_ca_comm = None
|
62
63
|
|
63
64
|
try:
|
@@ -89,28 +90,41 @@ def set_torch_compile_config():
|
|
89
90
|
|
90
91
|
|
91
92
|
class CudaGraphRunner:
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
use_torch_compile: bool,
|
97
|
-
disable_padding: bool,
|
98
|
-
):
|
93
|
+
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
94
|
+
|
95
|
+
def __init__(self, model_runner: "ModelRunner"):
|
96
|
+
# Parse args
|
99
97
|
self.model_runner = model_runner
|
100
98
|
self.graphs = {}
|
101
99
|
self.input_buffers = {}
|
102
100
|
self.output_buffers = {}
|
103
101
|
self.flashinfer_handlers = {}
|
104
102
|
self.graph_memory_pool = None
|
105
|
-
self.
|
103
|
+
self.use_torch_compile = model_runner.server_args.enable_torch_compile
|
104
|
+
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
105
|
+
|
106
|
+
# Batch sizes to capture
|
107
|
+
if self.model_runner.server_args.disable_cuda_graph_padding:
|
108
|
+
self.capture_bs = list(range(1, 32)) + [64, 128]
|
109
|
+
else:
|
110
|
+
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
111
|
+
self.compile_bs = (
|
112
|
+
[
|
113
|
+
bs
|
114
|
+
for bs in self.capture_bs
|
115
|
+
if bs <= self.model_runner.server_args.max_torch_compile_bs
|
116
|
+
]
|
117
|
+
if self.use_torch_compile
|
118
|
+
else []
|
119
|
+
)
|
106
120
|
|
107
121
|
# Common inputs
|
108
|
-
self.max_bs =
|
122
|
+
self.max_bs = max(self.capture_bs)
|
109
123
|
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
110
124
|
self.req_pool_indices = torch.zeros(
|
111
125
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
112
126
|
)
|
113
|
-
self.seq_lens = torch.
|
127
|
+
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
|
114
128
|
self.position_ids_offsets = torch.ones(
|
115
129
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
116
130
|
)
|
@@ -118,56 +132,38 @@ class CudaGraphRunner:
|
|
118
132
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
119
133
|
)
|
120
134
|
|
121
|
-
#
|
122
|
-
self.
|
123
|
-
|
135
|
+
# Attention backend
|
136
|
+
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
137
|
+
self.seq_len_fill_value = (
|
138
|
+
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
124
139
|
)
|
125
|
-
self.flashinfer_kv_indices = torch.zeros(
|
126
|
-
(self.max_bs * model_runner.model_config.context_len,),
|
127
|
-
dtype=torch.int32,
|
128
|
-
device="cuda",
|
129
|
-
)
|
130
|
-
self.flashinfer_kv_last_page_len = torch.ones(
|
131
|
-
(self.max_bs,), dtype=torch.int32, device="cuda"
|
132
|
-
)
|
133
|
-
if model_runner.sliding_window_size is None:
|
134
|
-
self.flashinfer_workspace_buffer = (
|
135
|
-
self.model_runner.flashinfer_workspace_buffer
|
136
|
-
)
|
137
|
-
else:
|
138
|
-
self.flashinfer_workspace_buffer = (
|
139
|
-
self.model_runner.flashinfer_workspace_buffer
|
140
|
-
)
|
141
|
-
|
142
|
-
self.flashinfer_kv_indptr = [
|
143
|
-
self.flashinfer_kv_indptr,
|
144
|
-
self.flashinfer_kv_indptr.clone(),
|
145
|
-
]
|
146
|
-
self.flashinfer_kv_indices = [
|
147
|
-
self.flashinfer_kv_indices,
|
148
|
-
self.flashinfer_kv_indices.clone(),
|
149
|
-
]
|
150
|
-
|
151
|
-
# Sampling inputs
|
152
|
-
vocab_size = model_runner.model_config.vocab_size
|
153
|
-
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
|
154
|
-
|
155
|
-
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
156
140
|
|
157
|
-
if use_torch_compile:
|
141
|
+
if self.use_torch_compile:
|
158
142
|
set_torch_compile_config()
|
159
143
|
|
144
|
+
# Capture
|
145
|
+
try:
|
146
|
+
self.capture()
|
147
|
+
except RuntimeError as e:
|
148
|
+
raise Exception(
|
149
|
+
f"Capture cuda graph failed: {e}\n"
|
150
|
+
"Possible solutions:\n"
|
151
|
+
"1. disable cuda graph by --disable-cuda-graph\n"
|
152
|
+
"2. set --mem-fraction-static to a smaller value\n"
|
153
|
+
"3. disable torch compile by not using --enable-torch-compile\n"
|
154
|
+
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
155
|
+
)
|
156
|
+
|
160
157
|
def can_run(self, batch_size: int):
|
161
158
|
if self.disable_padding:
|
162
159
|
return batch_size in self.graphs
|
163
160
|
else:
|
164
161
|
return batch_size <= self.max_bs
|
165
162
|
|
166
|
-
def capture(self
|
167
|
-
self.batch_size_list = batch_size_list
|
163
|
+
def capture(self):
|
168
164
|
with graph_capture() as graph_capture_context:
|
169
165
|
self.stream = graph_capture_context.stream
|
170
|
-
for bs in
|
166
|
+
for bs in self.capture_bs:
|
171
167
|
with patch_model(
|
172
168
|
self.model_runner.model,
|
173
169
|
bs in self.compile_bs,
|
@@ -175,14 +171,10 @@ class CudaGraphRunner:
|
|
175
171
|
) as forward:
|
176
172
|
(
|
177
173
|
graph,
|
178
|
-
input_buffers,
|
179
174
|
output_buffers,
|
180
|
-
flashinfer_handler,
|
181
175
|
) = self.capture_one_batch_size(bs, forward)
|
182
176
|
self.graphs[bs] = graph
|
183
|
-
self.input_buffers[bs] = input_buffers
|
184
177
|
self.output_buffers[bs] = output_buffers
|
185
|
-
self.flashinfer_handlers[bs] = flashinfer_handler
|
186
178
|
|
187
179
|
def capture_one_batch_size(self, bs: int, forward: Callable):
|
188
180
|
graph = torch.cuda.CUDAGraph()
|
@@ -195,67 +187,26 @@ class CudaGraphRunner:
|
|
195
187
|
position_ids_offsets = self.position_ids_offsets[:bs]
|
196
188
|
out_cache_loc = self.out_cache_loc[:bs]
|
197
189
|
|
198
|
-
#
|
199
|
-
|
200
|
-
|
201
|
-
// self.model_runner.tp_size,
|
202
|
-
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
|
203
|
-
):
|
204
|
-
use_tensor_cores = True
|
205
|
-
else:
|
206
|
-
use_tensor_cores = False
|
207
|
-
if self.model_runner.sliding_window_size is None:
|
208
|
-
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
209
|
-
self.flashinfer_workspace_buffer,
|
210
|
-
"NHD",
|
211
|
-
use_cuda_graph=True,
|
212
|
-
use_tensor_cores=use_tensor_cores,
|
213
|
-
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
|
214
|
-
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
215
|
-
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
216
|
-
)
|
217
|
-
else:
|
218
|
-
flashinfer_decode_wrapper = []
|
219
|
-
for i in range(2):
|
220
|
-
flashinfer_decode_wrapper.append(
|
221
|
-
BatchDecodeWithPagedKVCacheWrapper(
|
222
|
-
self.flashinfer_workspace_buffer,
|
223
|
-
"NHD",
|
224
|
-
use_cuda_graph=True,
|
225
|
-
use_tensor_cores=use_tensor_cores,
|
226
|
-
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
|
227
|
-
paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
|
228
|
-
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
|
229
|
-
:bs
|
230
|
-
],
|
231
|
-
)
|
232
|
-
)
|
233
|
-
update_flashinfer_indices(
|
234
|
-
ForwardMode.DECODE,
|
235
|
-
self.model_runner,
|
236
|
-
req_pool_indices,
|
237
|
-
seq_lens,
|
238
|
-
None,
|
239
|
-
flashinfer_decode_wrapper,
|
190
|
+
# Attention backend
|
191
|
+
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
192
|
+
bs, req_pool_indices, seq_lens
|
240
193
|
)
|
241
194
|
|
242
195
|
# Run and capture
|
243
196
|
def run_once():
|
244
197
|
input_metadata = InputMetadata(
|
245
198
|
forward_mode=ForwardMode.DECODE,
|
246
|
-
sampling_info=self.sampling_info[:bs],
|
247
199
|
batch_size=bs,
|
248
200
|
req_pool_indices=req_pool_indices,
|
249
201
|
seq_lens=seq_lens,
|
250
202
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
251
203
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
204
|
+
attn_backend=self.model_runner.attn_backend,
|
252
205
|
out_cache_loc=out_cache_loc,
|
253
206
|
return_logprob=False,
|
254
|
-
top_logprobs_nums=0,
|
207
|
+
top_logprobs_nums=[0] * bs,
|
255
208
|
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
|
256
|
-
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
257
209
|
)
|
258
|
-
|
259
210
|
return forward(input_ids, input_metadata.positions, input_metadata)
|
260
211
|
|
261
212
|
for _ in range(2):
|
@@ -277,17 +228,17 @@ class CudaGraphRunner:
|
|
277
228
|
self.model_runner.tp_group.barrier()
|
278
229
|
|
279
230
|
self.graph_memory_pool = graph.pool()
|
280
|
-
return graph,
|
231
|
+
return graph, out
|
281
232
|
|
282
233
|
def replay(self, batch: ScheduleBatch):
|
283
234
|
assert batch.out_cache_loc is not None
|
284
235
|
raw_bs = len(batch.reqs)
|
285
236
|
|
286
237
|
# Pad
|
287
|
-
index = bisect.bisect_left(self.
|
288
|
-
bs = self.
|
238
|
+
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
239
|
+
bs = self.capture_bs[index]
|
289
240
|
if bs != raw_bs:
|
290
|
-
self.seq_lens.
|
241
|
+
self.seq_lens.fill_(self.seq_len_fill_value)
|
291
242
|
self.position_ids_offsets.fill_(1)
|
292
243
|
self.out_cache_loc.zero_()
|
293
244
|
|
@@ -298,24 +249,14 @@ class CudaGraphRunner:
|
|
298
249
|
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
|
299
250
|
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
300
251
|
|
301
|
-
#
|
302
|
-
|
303
|
-
|
304
|
-
self.model_runner,
|
305
|
-
self.req_pool_indices[:bs],
|
306
|
-
self.seq_lens[:bs],
|
307
|
-
None,
|
308
|
-
self.flashinfer_handlers[bs],
|
252
|
+
# Attention backend
|
253
|
+
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
254
|
+
bs, self.req_pool_indices, self.seq_lens
|
309
255
|
)
|
310
256
|
|
311
|
-
# Sampling inputs
|
312
|
-
self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
|
313
|
-
|
314
257
|
# Replay
|
315
|
-
torch.cuda.synchronize()
|
316
258
|
self.graphs[bs].replay()
|
317
|
-
|
318
|
-
sample_output, logits_output = self.output_buffers[bs]
|
259
|
+
logits_output = self.output_buffers[bs]
|
319
260
|
|
320
261
|
# Unpad
|
321
262
|
if bs != raw_bs:
|
@@ -327,11 +268,6 @@ class CudaGraphRunner:
|
|
327
268
|
input_top_logprobs=None,
|
328
269
|
output_top_logprobs=None,
|
329
270
|
)
|
330
|
-
sample_output = SampleOutput(
|
331
|
-
sample_output.success[:raw_bs],
|
332
|
-
sample_output.probs[:raw_bs],
|
333
|
-
sample_output.batch_next_token_ids[:raw_bs],
|
334
|
-
)
|
335
271
|
|
336
272
|
# Extract logprobs
|
337
273
|
if batch.return_logprob:
|
@@ -348,4 +284,4 @@ class CudaGraphRunner:
|
|
348
284
|
logits_output.next_token_logprobs, logits_metadata
|
349
285
|
)[1]
|
350
286
|
|
351
|
-
return
|
287
|
+
return logits_output
|