sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +10 -6
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +0 -4
- sglang/lang/backend/runtime_endpoint.py +13 -6
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +29 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +2 -4
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +40 -35
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +256 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +110 -74
- sglang/srt/managers/tokenizer_manager.py +24 -15
- sglang/srt/managers/tp_worker.py +181 -115
- sglang/srt/model_executor/cuda_graph_runner.py +60 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +118 -141
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +6 -8
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/exaone.py +8 -43
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/{llama2.py → llama.py} +48 -26
- sglang/srt/models/llama_classification.py +14 -40
- sglang/srt/models/llama_embedding.py +7 -6
- sglang/srt/models/llava.py +38 -16
- sglang/srt/models/llavavid.py +7 -8
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +665 -0
- sglang/srt/models/mistral.py +2 -3
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +67 -58
- sglang/srt/server.py +24 -14
- sglang/srt/server_args.py +130 -28
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +70 -0
- sglang/test/test_utils.py +89 -1
- sglang/utils.py +38 -4
- sglang/version.py +1 -1
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
- sglang-0.3.1.dist-info/RECORD +129 -0
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
- sglang-0.2.15.dist-info/RECORD +0 -118
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
"""
|
2
4
|
Copyright 2023-2024 SGLang Team
|
3
5
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -13,15 +15,13 @@ See the License for the specific language governing permissions and
|
|
13
15
|
limitations under the License.
|
14
16
|
"""
|
15
17
|
|
16
|
-
"""Run the model with cuda graph."""
|
18
|
+
"""Run the model with cuda graph and torch.compile."""
|
17
19
|
|
18
20
|
import bisect
|
19
21
|
from contextlib import contextmanager
|
20
|
-
from typing import
|
22
|
+
from typing import TYPE_CHECKING, Callable
|
21
23
|
|
22
24
|
import torch
|
23
|
-
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
24
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
25
25
|
from vllm.distributed.parallel_state import graph_capture
|
26
26
|
from vllm.model_executor.custom_op import CustomOp
|
27
27
|
|
@@ -30,24 +30,23 @@ from sglang.srt.layers.logits_processor import (
|
|
30
30
|
LogitsProcessor,
|
31
31
|
LogitsProcessorOutput,
|
32
32
|
)
|
33
|
-
from sglang.srt.layers.sampler import SampleOutput
|
34
33
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
35
|
-
from sglang.srt.model_executor.forward_batch_info import
|
36
|
-
ForwardMode,
|
37
|
-
InputMetadata,
|
38
|
-
update_flashinfer_indices,
|
39
|
-
)
|
40
|
-
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
34
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
41
35
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
42
36
|
|
37
|
+
if TYPE_CHECKING:
|
38
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
39
|
+
|
43
40
|
|
44
41
|
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
45
42
|
for sub in model._modules.values():
|
46
43
|
if isinstance(sub, CustomOp):
|
47
44
|
if reverse:
|
48
45
|
sub._forward_method = sub.forward_cuda
|
46
|
+
setattr(sub, "is_torch_compile", False)
|
49
47
|
else:
|
50
48
|
sub._forward_method = sub.forward_native
|
49
|
+
setattr(sub, "is_torch_compile", True)
|
51
50
|
if isinstance(sub, torch.nn.Module):
|
52
51
|
_to_torch(sub, reverse)
|
53
52
|
|
@@ -56,6 +55,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
|
56
55
|
def patch_model(
|
57
56
|
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
|
58
57
|
):
|
58
|
+
"""Patch the model to make it compatible with with torch.compile"""
|
59
59
|
backup_ca_comm = None
|
60
60
|
|
61
61
|
try:
|
@@ -87,28 +87,33 @@ def set_torch_compile_config():
|
|
87
87
|
|
88
88
|
|
89
89
|
class CudaGraphRunner:
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
use_torch_compile: bool,
|
95
|
-
disable_padding: bool,
|
96
|
-
):
|
90
|
+
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
91
|
+
|
92
|
+
def __init__(self, model_runner: "ModelRunner"):
|
93
|
+
# Parse args
|
97
94
|
self.model_runner = model_runner
|
98
95
|
self.graphs = {}
|
99
96
|
self.input_buffers = {}
|
100
97
|
self.output_buffers = {}
|
101
98
|
self.flashinfer_handlers = {}
|
102
99
|
self.graph_memory_pool = None
|
103
|
-
self.
|
100
|
+
self.use_torch_compile = model_runner.server_args.enable_torch_compile
|
101
|
+
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
102
|
+
|
103
|
+
# Batch sizes to capture
|
104
|
+
if self.model_runner.server_args.disable_cuda_graph_padding:
|
105
|
+
self.capture_bs = list(range(1, 32)) + [64, 128]
|
106
|
+
else:
|
107
|
+
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
108
|
+
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if self.use_torch_compile else []
|
104
109
|
|
105
110
|
# Common inputs
|
106
|
-
self.max_bs =
|
111
|
+
self.max_bs = max(self.capture_bs)
|
107
112
|
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
108
113
|
self.req_pool_indices = torch.zeros(
|
109
114
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
110
115
|
)
|
111
|
-
self.seq_lens = torch.
|
116
|
+
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
|
112
117
|
self.position_ids_offsets = torch.ones(
|
113
118
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
114
119
|
)
|
@@ -116,56 +121,38 @@ class CudaGraphRunner:
|
|
116
121
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
117
122
|
)
|
118
123
|
|
119
|
-
#
|
120
|
-
self.
|
121
|
-
|
122
|
-
|
123
|
-
self.flashinfer_kv_indices = torch.zeros(
|
124
|
-
(self.max_bs * model_runner.model_config.context_len,),
|
125
|
-
dtype=torch.int32,
|
126
|
-
device="cuda",
|
124
|
+
# Attention backend
|
125
|
+
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
126
|
+
self.seq_len_fill_value = (
|
127
|
+
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
127
128
|
)
|
128
|
-
self.flashinfer_kv_last_page_len = torch.ones(
|
129
|
-
(self.max_bs,), dtype=torch.int32, device="cuda"
|
130
|
-
)
|
131
|
-
if model_runner.sliding_window_size is None:
|
132
|
-
self.flashinfer_workspace_buffer = (
|
133
|
-
self.model_runner.flashinfer_workspace_buffer
|
134
|
-
)
|
135
|
-
else:
|
136
|
-
self.flashinfer_workspace_buffer = (
|
137
|
-
self.model_runner.flashinfer_workspace_buffer
|
138
|
-
)
|
139
|
-
|
140
|
-
self.flashinfer_kv_indptr = [
|
141
|
-
self.flashinfer_kv_indptr,
|
142
|
-
self.flashinfer_kv_indptr.clone(),
|
143
|
-
]
|
144
|
-
self.flashinfer_kv_indices = [
|
145
|
-
self.flashinfer_kv_indices,
|
146
|
-
self.flashinfer_kv_indices.clone(),
|
147
|
-
]
|
148
129
|
|
149
|
-
|
150
|
-
vocab_size = model_runner.model_config.vocab_size
|
151
|
-
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
|
152
|
-
|
153
|
-
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
154
|
-
|
155
|
-
if use_torch_compile:
|
130
|
+
if self.use_torch_compile:
|
156
131
|
set_torch_compile_config()
|
157
132
|
|
133
|
+
# Capture
|
134
|
+
try:
|
135
|
+
self.capture()
|
136
|
+
except RuntimeError as e:
|
137
|
+
raise Exception(
|
138
|
+
f"Capture cuda graph failed: {e}\n"
|
139
|
+
"Possible solutions:\n"
|
140
|
+
"1. disable cuda graph by --disable-cuda-graph\n"
|
141
|
+
"2. set --mem-fraction-static to a smaller value\n"
|
142
|
+
"3. disable torch compile by not using --enable-torch-compile\n"
|
143
|
+
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
144
|
+
)
|
145
|
+
|
158
146
|
def can_run(self, batch_size: int):
|
159
147
|
if self.disable_padding:
|
160
148
|
return batch_size in self.graphs
|
161
149
|
else:
|
162
150
|
return batch_size <= self.max_bs
|
163
151
|
|
164
|
-
def capture(self
|
165
|
-
self.batch_size_list = batch_size_list
|
152
|
+
def capture(self):
|
166
153
|
with graph_capture() as graph_capture_context:
|
167
154
|
self.stream = graph_capture_context.stream
|
168
|
-
for bs in
|
155
|
+
for bs in self.capture_bs:
|
169
156
|
with patch_model(
|
170
157
|
self.model_runner.model,
|
171
158
|
bs in self.compile_bs,
|
@@ -173,14 +160,10 @@ class CudaGraphRunner:
|
|
173
160
|
) as forward:
|
174
161
|
(
|
175
162
|
graph,
|
176
|
-
input_buffers,
|
177
163
|
output_buffers,
|
178
|
-
flashinfer_handler,
|
179
164
|
) = self.capture_one_batch_size(bs, forward)
|
180
165
|
self.graphs[bs] = graph
|
181
|
-
self.input_buffers[bs] = input_buffers
|
182
166
|
self.output_buffers[bs] = output_buffers
|
183
|
-
self.flashinfer_handlers[bs] = flashinfer_handler
|
184
167
|
|
185
168
|
def capture_one_batch_size(self, bs: int, forward: Callable):
|
186
169
|
graph = torch.cuda.CUDAGraph()
|
@@ -193,67 +176,26 @@ class CudaGraphRunner:
|
|
193
176
|
position_ids_offsets = self.position_ids_offsets[:bs]
|
194
177
|
out_cache_loc = self.out_cache_loc[:bs]
|
195
178
|
|
196
|
-
#
|
197
|
-
|
198
|
-
|
199
|
-
// self.model_runner.tp_size,
|
200
|
-
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
|
201
|
-
):
|
202
|
-
use_tensor_cores = True
|
203
|
-
else:
|
204
|
-
use_tensor_cores = False
|
205
|
-
if self.model_runner.sliding_window_size is None:
|
206
|
-
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
207
|
-
self.flashinfer_workspace_buffer,
|
208
|
-
"NHD",
|
209
|
-
use_cuda_graph=True,
|
210
|
-
use_tensor_cores=use_tensor_cores,
|
211
|
-
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
|
212
|
-
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
213
|
-
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
214
|
-
)
|
215
|
-
else:
|
216
|
-
flashinfer_decode_wrapper = []
|
217
|
-
for i in range(2):
|
218
|
-
flashinfer_decode_wrapper.append(
|
219
|
-
BatchDecodeWithPagedKVCacheWrapper(
|
220
|
-
self.flashinfer_workspace_buffer,
|
221
|
-
"NHD",
|
222
|
-
use_cuda_graph=True,
|
223
|
-
use_tensor_cores=use_tensor_cores,
|
224
|
-
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
|
225
|
-
paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
|
226
|
-
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
|
227
|
-
:bs
|
228
|
-
],
|
229
|
-
)
|
230
|
-
)
|
231
|
-
update_flashinfer_indices(
|
232
|
-
ForwardMode.DECODE,
|
233
|
-
self.model_runner,
|
234
|
-
req_pool_indices,
|
235
|
-
seq_lens,
|
236
|
-
None,
|
237
|
-
flashinfer_decode_wrapper,
|
179
|
+
# Attention backend
|
180
|
+
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
181
|
+
bs, req_pool_indices, seq_lens
|
238
182
|
)
|
239
183
|
|
240
184
|
# Run and capture
|
241
185
|
def run_once():
|
242
186
|
input_metadata = InputMetadata(
|
243
187
|
forward_mode=ForwardMode.DECODE,
|
244
|
-
sampling_info=self.sampling_info[:bs],
|
245
188
|
batch_size=bs,
|
246
189
|
req_pool_indices=req_pool_indices,
|
247
190
|
seq_lens=seq_lens,
|
248
191
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
249
192
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
193
|
+
attn_backend=self.model_runner.attn_backend,
|
250
194
|
out_cache_loc=out_cache_loc,
|
251
195
|
return_logprob=False,
|
252
|
-
top_logprobs_nums=0,
|
196
|
+
top_logprobs_nums=[0] * bs,
|
253
197
|
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
|
254
|
-
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
255
198
|
)
|
256
|
-
|
257
199
|
return forward(input_ids, input_metadata.positions, input_metadata)
|
258
200
|
|
259
201
|
for _ in range(2):
|
@@ -275,17 +217,17 @@ class CudaGraphRunner:
|
|
275
217
|
self.model_runner.tp_group.barrier()
|
276
218
|
|
277
219
|
self.graph_memory_pool = graph.pool()
|
278
|
-
return graph,
|
220
|
+
return graph, out
|
279
221
|
|
280
222
|
def replay(self, batch: ScheduleBatch):
|
281
223
|
assert batch.out_cache_loc is not None
|
282
224
|
raw_bs = len(batch.reqs)
|
283
225
|
|
284
226
|
# Pad
|
285
|
-
index = bisect.bisect_left(self.
|
286
|
-
bs = self.
|
227
|
+
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
228
|
+
bs = self.capture_bs[index]
|
287
229
|
if bs != raw_bs:
|
288
|
-
self.seq_lens.
|
230
|
+
self.seq_lens.fill_(self.seq_len_fill_value)
|
289
231
|
self.position_ids_offsets.fill_(1)
|
290
232
|
self.out_cache_loc.zero_()
|
291
233
|
|
@@ -296,24 +238,14 @@ class CudaGraphRunner:
|
|
296
238
|
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
|
297
239
|
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
298
240
|
|
299
|
-
#
|
300
|
-
|
301
|
-
|
302
|
-
self.model_runner,
|
303
|
-
self.req_pool_indices[:bs],
|
304
|
-
self.seq_lens[:bs],
|
305
|
-
None,
|
306
|
-
self.flashinfer_handlers[bs],
|
241
|
+
# Attention backend
|
242
|
+
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
243
|
+
bs, self.req_pool_indices, self.seq_lens
|
307
244
|
)
|
308
245
|
|
309
|
-
# Sampling inputs
|
310
|
-
self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
|
311
|
-
|
312
246
|
# Replay
|
313
|
-
torch.cuda.synchronize()
|
314
247
|
self.graphs[bs].replay()
|
315
|
-
|
316
|
-
sample_output, logits_output = self.output_buffers[bs]
|
248
|
+
logits_output = self.output_buffers[bs]
|
317
249
|
|
318
250
|
# Unpad
|
319
251
|
if bs != raw_bs:
|
@@ -325,11 +257,6 @@ class CudaGraphRunner:
|
|
325
257
|
input_top_logprobs=None,
|
326
258
|
output_top_logprobs=None,
|
327
259
|
)
|
328
|
-
sample_output = SampleOutput(
|
329
|
-
sample_output.success[:raw_bs],
|
330
|
-
sample_output.probs[:raw_bs],
|
331
|
-
sample_output.batch_next_token_ids[:raw_bs],
|
332
|
-
)
|
333
260
|
|
334
261
|
# Extract logprobs
|
335
262
|
if batch.return_logprob:
|
@@ -346,4 +273,4 @@ class CudaGraphRunner:
|
|
346
273
|
logits_output.next_token_logprobs, logits_metadata
|
347
274
|
)[1]
|
348
275
|
|
349
|
-
return
|
276
|
+
return logits_output
|