sglang 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +6 -0
- sglang/bench_latency.py +7 -3
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +9 -0
- sglang/launch_server.py +8 -1
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -1
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/fused_moe/layer.py +2 -2
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +60 -23
- sglang/srt/layers/radix_attention.py +3 -4
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +52 -167
- sglang/srt/managers/tokenizer_manager.py +192 -83
- sglang/srt/managers/tp_worker.py +130 -43
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +49 -11
- sglang/srt/model_executor/forward_batch_info.py +59 -27
- sglang/srt/model_executor/model_runner.py +210 -61
- sglang/srt/models/chatglm.py +4 -12
- sglang/srt/models/commandr.py +5 -1
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +5 -1
- sglang/srt/models/deepseek_v2.py +5 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +15 -7
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +16 -2
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +5 -1
- sglang/srt/models/mixtral.py +5 -1
- sglang/srt/models/mixtral_quant.py +5 -1
- sglang/srt/models/qwen.py +5 -2
- sglang/srt/models/qwen2.py +13 -3
- sglang/srt/models/qwen2_moe.py +5 -14
- sglang/srt/models/stablelm.py +5 -1
- sglang/srt/openai_api/adapter.py +117 -37
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -0
- sglang/srt/server.py +84 -56
- sglang/srt/server_args.py +43 -15
- sglang/srt/utils.py +26 -16
- sglang/test/runners.py +23 -31
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_utils.py +36 -53
- sglang/version.py +1 -1
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/METADATA +92 -25
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang-0.2.13.dist-info/RECORD +0 -112
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -25,16 +25,18 @@ from vllm.distributed.parallel_state import graph_capture
|
|
25
25
|
from vllm.model_executor.custom_op import CustomOp
|
26
26
|
|
27
27
|
from sglang.srt.layers.logits_processor import (
|
28
|
-
LogitProcessorOutput,
|
29
28
|
LogitsMetadata,
|
30
29
|
LogitsProcessor,
|
30
|
+
LogitsProcessorOutput,
|
31
31
|
)
|
32
|
+
from sglang.srt.layers.sampler import SampleOutput
|
32
33
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
33
34
|
from sglang.srt.model_executor.forward_batch_info import (
|
34
35
|
ForwardMode,
|
35
36
|
InputMetadata,
|
36
37
|
update_flashinfer_indices,
|
37
38
|
)
|
39
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
38
40
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
39
41
|
|
40
42
|
|
@@ -84,13 +86,20 @@ def set_torch_compile_config():
|
|
84
86
|
|
85
87
|
|
86
88
|
class CudaGraphRunner:
|
87
|
-
def __init__(
|
89
|
+
def __init__(
|
90
|
+
self,
|
91
|
+
model_runner,
|
92
|
+
max_batch_size_to_capture: int,
|
93
|
+
use_torch_compile: bool,
|
94
|
+
disable_padding: bool,
|
95
|
+
):
|
88
96
|
self.model_runner = model_runner
|
89
97
|
self.graphs = {}
|
90
98
|
self.input_buffers = {}
|
91
99
|
self.output_buffers = {}
|
92
100
|
self.flashinfer_handlers = {}
|
93
101
|
self.graph_memory_pool = None
|
102
|
+
self.disable_padding = disable_padding
|
94
103
|
|
95
104
|
# Common inputs
|
96
105
|
self.max_bs = max_batch_size_to_capture
|
@@ -136,13 +145,20 @@ class CudaGraphRunner:
|
|
136
145
|
self.flashinfer_kv_indices.clone(),
|
137
146
|
]
|
138
147
|
|
148
|
+
# Sampling inputs
|
149
|
+
vocab_size = model_runner.model_config.vocab_size
|
150
|
+
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
|
151
|
+
|
139
152
|
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
140
153
|
|
141
154
|
if use_torch_compile:
|
142
155
|
set_torch_compile_config()
|
143
156
|
|
144
157
|
def can_run(self, batch_size):
|
145
|
-
|
158
|
+
if self.disable_padding:
|
159
|
+
return batch_size in self.graphs
|
160
|
+
else:
|
161
|
+
return batch_size <= self.max_bs
|
146
162
|
|
147
163
|
def capture(self, batch_size_list):
|
148
164
|
self.batch_size_list = batch_size_list
|
@@ -224,6 +240,7 @@ class CudaGraphRunner:
|
|
224
240
|
def run_once():
|
225
241
|
input_metadata = InputMetadata(
|
226
242
|
forward_mode=ForwardMode.DECODE,
|
243
|
+
sampling_info=self.sampling_info[:bs],
|
227
244
|
batch_size=bs,
|
228
245
|
req_pool_indices=req_pool_indices,
|
229
246
|
seq_lens=seq_lens,
|
@@ -239,12 +256,23 @@ class CudaGraphRunner:
|
|
239
256
|
return forward(input_ids, input_metadata.positions, input_metadata)
|
240
257
|
|
241
258
|
for _ in range(2):
|
259
|
+
torch.cuda.synchronize()
|
260
|
+
self.model_runner.tp_group.barrier()
|
261
|
+
|
242
262
|
run_once()
|
243
263
|
|
264
|
+
torch.cuda.synchronize()
|
265
|
+
self.model_runner.tp_group.barrier()
|
266
|
+
|
244
267
|
torch.cuda.synchronize()
|
268
|
+
self.model_runner.tp_group.barrier()
|
269
|
+
|
245
270
|
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
|
246
271
|
out = run_once()
|
272
|
+
|
247
273
|
torch.cuda.synchronize()
|
274
|
+
self.model_runner.tp_group.barrier()
|
275
|
+
|
248
276
|
self.graph_memory_pool = graph.pool()
|
249
277
|
return graph, None, out, flashinfer_decode_wrapper
|
250
278
|
|
@@ -277,25 +305,35 @@ class CudaGraphRunner:
|
|
277
305
|
self.flashinfer_handlers[bs],
|
278
306
|
)
|
279
307
|
|
308
|
+
# Sampling inputs
|
309
|
+
self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
|
310
|
+
|
280
311
|
# Replay
|
312
|
+
torch.cuda.synchronize()
|
281
313
|
self.graphs[bs].replay()
|
282
|
-
|
314
|
+
torch.cuda.synchronize()
|
315
|
+
sample_output, logits_output = self.output_buffers[bs]
|
283
316
|
|
284
317
|
# Unpad
|
285
318
|
if bs != raw_bs:
|
286
|
-
|
287
|
-
next_token_logits=
|
319
|
+
logits_output = LogitsProcessorOutput(
|
320
|
+
next_token_logits=logits_output.next_token_logits[:raw_bs],
|
288
321
|
next_token_logprobs=None,
|
289
322
|
normalized_prompt_logprobs=None,
|
290
323
|
input_token_logprobs=None,
|
291
324
|
input_top_logprobs=None,
|
292
325
|
output_top_logprobs=None,
|
293
326
|
)
|
327
|
+
sample_output = SampleOutput(
|
328
|
+
sample_output.success[:raw_bs],
|
329
|
+
sample_output.probs[:raw_bs],
|
330
|
+
sample_output.batch_next_token_ids[:raw_bs],
|
331
|
+
)
|
294
332
|
|
295
333
|
# Extract logprobs
|
296
334
|
if batch.return_logprob:
|
297
|
-
|
298
|
-
|
335
|
+
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
|
336
|
+
logits_output.next_token_logits, dim=-1
|
299
337
|
)
|
300
338
|
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
|
301
339
|
if return_top_logprob:
|
@@ -303,8 +341,8 @@ class CudaGraphRunner:
|
|
303
341
|
forward_mode=ForwardMode.DECODE,
|
304
342
|
top_logprobs_nums=batch.top_logprobs_nums,
|
305
343
|
)
|
306
|
-
|
307
|
-
|
344
|
+
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
345
|
+
logits_output.next_token_logprobs, logits_metadata
|
308
346
|
)[1]
|
309
347
|
|
310
|
-
return
|
348
|
+
return sample_output, logits_output
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
"""
|
2
4
|
Copyright 2023-2024 SGLang Team
|
3
5
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -16,7 +18,7 @@ limitations under the License.
|
|
16
18
|
"""ModelRunner runs the forward passes of the models."""
|
17
19
|
from dataclasses import dataclass
|
18
20
|
from enum import IntEnum, auto
|
19
|
-
from typing import TYPE_CHECKING, List
|
21
|
+
from typing import TYPE_CHECKING, List
|
20
22
|
|
21
23
|
import numpy as np
|
22
24
|
import torch
|
@@ -26,6 +28,7 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
|
26
28
|
|
27
29
|
if TYPE_CHECKING:
|
28
30
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
31
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
29
32
|
|
30
33
|
|
31
34
|
class ForwardMode(IntEnum):
|
@@ -42,6 +45,7 @@ class InputMetadata:
|
|
42
45
|
"""Store all inforamtion of a forward pass."""
|
43
46
|
|
44
47
|
forward_mode: ForwardMode
|
48
|
+
sampling_info: SamplingBatchInfo
|
45
49
|
batch_size: int
|
46
50
|
req_pool_indices: torch.Tensor
|
47
51
|
seq_lens: torch.Tensor
|
@@ -61,9 +65,11 @@ class InputMetadata:
|
|
61
65
|
extend_start_loc: torch.Tensor = None
|
62
66
|
extend_no_prefix: bool = None
|
63
67
|
|
64
|
-
#
|
68
|
+
# For logprob
|
65
69
|
return_logprob: bool = False
|
66
70
|
top_logprobs_nums: List[int] = None
|
71
|
+
extend_seq_lens_cpu: List[int] = None
|
72
|
+
logprob_start_lens_cpu: List[int] = None
|
67
73
|
|
68
74
|
# For multimodal
|
69
75
|
pixel_values: List[torch.Tensor] = None
|
@@ -86,14 +92,19 @@ class InputMetadata:
|
|
86
92
|
reqs = batch.reqs
|
87
93
|
self.pixel_values = [r.pixel_values for r in reqs]
|
88
94
|
self.image_sizes = [r.image_size for r in reqs]
|
89
|
-
self.image_offsets = [
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
95
|
+
self.image_offsets = []
|
96
|
+
for r in reqs:
|
97
|
+
if isinstance(r.image_offset, list):
|
98
|
+
self.image_offsets.append(
|
99
|
+
[
|
100
|
+
(image_offset - len(r.prefix_indices))
|
101
|
+
for image_offset in r.image_offset
|
102
|
+
]
|
103
|
+
)
|
104
|
+
elif isinstance(r.image_offset, int):
|
105
|
+
self.image_offsets.append(r.image_offset - len(r.prefix_indices))
|
106
|
+
elif r.image_offset is None:
|
107
|
+
self.image_offsets.append(0)
|
97
108
|
|
98
109
|
def compute_positions(self, batch: ScheduleBatch):
|
99
110
|
position_ids_offsets = batch.position_ids_offsets
|
@@ -109,8 +120,8 @@ class InputMetadata:
|
|
109
120
|
self.positions = torch.tensor(
|
110
121
|
np.concatenate(
|
111
122
|
[
|
112
|
-
np.arange(
|
113
|
-
for req in batch.reqs
|
123
|
+
np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
|
124
|
+
for i, req in enumerate(batch.reqs)
|
114
125
|
],
|
115
126
|
axis=0,
|
116
127
|
),
|
@@ -123,7 +134,7 @@ class InputMetadata:
|
|
123
134
|
np.concatenate(
|
124
135
|
[
|
125
136
|
np.arange(
|
126
|
-
|
137
|
+
batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
127
138
|
len(req.fill_ids) + position_ids_offsets_cpu[i],
|
128
139
|
)
|
129
140
|
for i, req in enumerate(batch.reqs)
|
@@ -139,14 +150,29 @@ class InputMetadata:
|
|
139
150
|
def compute_extend_infos(self, batch: ScheduleBatch):
|
140
151
|
if self.forward_mode == ForwardMode.DECODE:
|
141
152
|
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
|
153
|
+
self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None
|
142
154
|
else:
|
143
155
|
extend_lens_cpu = [
|
144
|
-
len(r.fill_ids) -
|
156
|
+
len(r.fill_ids) - batch.prefix_lens_cpu[i]
|
157
|
+
for i, r in enumerate(batch.reqs)
|
145
158
|
]
|
146
159
|
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
147
160
|
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
148
161
|
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
149
|
-
self.extend_no_prefix = all(
|
162
|
+
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
|
163
|
+
|
164
|
+
self.extend_seq_lens_cpu = extend_lens_cpu
|
165
|
+
self.logprob_start_lens_cpu = [
|
166
|
+
(
|
167
|
+
min(
|
168
|
+
req.logprob_start_len - batch.prefix_lens_cpu[i],
|
169
|
+
extend_lens_cpu[i] - 1,
|
170
|
+
)
|
171
|
+
if req.logprob_start_len >= batch.prefix_lens_cpu[i]
|
172
|
+
else extend_lens_cpu[i] - 1 # Fake extend, actually decode
|
173
|
+
)
|
174
|
+
for i, req in enumerate(batch.reqs)
|
175
|
+
]
|
150
176
|
|
151
177
|
@classmethod
|
152
178
|
def from_schedule_batch(
|
@@ -157,6 +183,7 @@ class InputMetadata:
|
|
157
183
|
):
|
158
184
|
ret = cls(
|
159
185
|
forward_mode=forward_mode,
|
186
|
+
sampling_info=batch.sampling_info,
|
160
187
|
batch_size=batch.batch_size(),
|
161
188
|
req_pool_indices=batch.req_pool_indices,
|
162
189
|
seq_lens=batch.seq_lens,
|
@@ -167,6 +194,8 @@ class InputMetadata:
|
|
167
194
|
top_logprobs_nums=batch.top_logprobs_nums,
|
168
195
|
)
|
169
196
|
|
197
|
+
ret.sampling_info.prepare_penalties()
|
198
|
+
|
170
199
|
ret.compute_positions(batch)
|
171
200
|
|
172
201
|
ret.compute_extend_infos(batch)
|
@@ -180,14 +209,8 @@ class InputMetadata:
|
|
180
209
|
if forward_mode != ForwardMode.DECODE:
|
181
210
|
ret.init_multimuldal_info(batch)
|
182
211
|
|
183
|
-
prefix_lens = None
|
184
|
-
if forward_mode != ForwardMode.DECODE:
|
185
|
-
prefix_lens = torch.tensor(
|
186
|
-
[len(r.prefix_indices) for r in batch.reqs], device="cuda"
|
187
|
-
)
|
188
|
-
|
189
212
|
if model_runner.server_args.disable_flashinfer:
|
190
|
-
ret.init_triton_args(batch
|
213
|
+
ret.init_triton_args(batch)
|
191
214
|
|
192
215
|
flashinfer_use_ragged = False
|
193
216
|
if not model_runner.server_args.disable_flashinfer:
|
@@ -198,30 +221,35 @@ class InputMetadata:
|
|
198
221
|
):
|
199
222
|
flashinfer_use_ragged = True
|
200
223
|
ret.init_flashinfer_handlers(
|
201
|
-
model_runner,
|
224
|
+
model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged
|
202
225
|
)
|
203
226
|
|
204
227
|
return ret
|
205
228
|
|
206
|
-
def init_triton_args(self, batch: ScheduleBatch
|
229
|
+
def init_triton_args(self, batch: ScheduleBatch):
|
207
230
|
"""Init auxiliary variables for triton attention backend."""
|
208
231
|
self.triton_max_seq_len = int(torch.max(self.seq_lens))
|
209
|
-
self.triton_prefix_lens = prefix_lens
|
210
232
|
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
|
211
233
|
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
|
212
234
|
|
213
235
|
if self.forward_mode == ForwardMode.DECODE:
|
214
236
|
self.triton_max_extend_len = None
|
215
237
|
else:
|
216
|
-
|
238
|
+
self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
239
|
+
extend_seq_lens = self.seq_lens - self.triton_prefix_lens
|
217
240
|
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
|
218
241
|
|
219
242
|
def init_flashinfer_handlers(
|
220
243
|
self,
|
221
244
|
model_runner,
|
222
|
-
|
245
|
+
prefix_lens_cpu,
|
223
246
|
flashinfer_use_ragged,
|
224
247
|
):
|
248
|
+
if self.forward_mode != ForwardMode.DECODE:
|
249
|
+
prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
|
250
|
+
else:
|
251
|
+
prefix_lens = None
|
252
|
+
|
225
253
|
update_flashinfer_indices(
|
226
254
|
self.forward_mode,
|
227
255
|
model_runner,
|
@@ -294,6 +322,8 @@ def update_flashinfer_indices(
|
|
294
322
|
num_kv_heads,
|
295
323
|
head_dim,
|
296
324
|
1,
|
325
|
+
data_type=model_runner.kv_cache_dtype,
|
326
|
+
q_data_type=model_runner.dtype,
|
297
327
|
)
|
298
328
|
else:
|
299
329
|
# extend part
|
@@ -372,6 +402,8 @@ def update_flashinfer_indices(
|
|
372
402
|
num_kv_heads,
|
373
403
|
head_dim,
|
374
404
|
1,
|
405
|
+
data_type=model_runner.kv_cache_dtype,
|
406
|
+
q_data_type=model_runner.dtype,
|
375
407
|
)
|
376
408
|
else:
|
377
409
|
# extend part
|