sglang 0.2.13__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +6 -0
- sglang/bench_latency.py +7 -3
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +9 -0
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/fsm_cache.py +11 -2
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +100 -1
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/fused_moe/layer.py +2 -2
- sglang/srt/layers/logits_processor.py +56 -19
- sglang/srt/layers/radix_attention.py +3 -4
- sglang/srt/layers/sampler.py +101 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +46 -166
- sglang/srt/managers/tokenizer_manager.py +192 -83
- sglang/srt/managers/tp_worker.py +118 -24
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +32 -8
- sglang/srt/model_executor/forward_batch_info.py +51 -26
- sglang/srt/model_executor/model_runner.py +201 -58
- sglang/srt/models/gemma2.py +10 -6
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +11 -1
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/qwen2.py +9 -3
- sglang/srt/openai_api/adapter.py +200 -39
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_batch_info.py +136 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +22 -0
- sglang/srt/server.py +92 -57
- sglang/srt/server_args.py +43 -15
- sglang/srt/utils.py +26 -16
- sglang/test/runners.py +22 -30
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_utils.py +36 -53
- sglang/version.py +1 -1
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +100 -27
- sglang-0.2.14.post1.dist-info/RECORD +114 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang-0.2.13.dist-info/RECORD +0 -112
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import bisect
|
19
19
|
from contextlib import contextmanager
|
20
|
+
from typing import Callable, List
|
20
21
|
|
21
22
|
import torch
|
22
23
|
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
@@ -51,12 +52,12 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
|
51
52
|
|
52
53
|
@contextmanager
|
53
54
|
def patch_model(
|
54
|
-
model: torch.nn.Module,
|
55
|
+
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
|
55
56
|
):
|
56
57
|
backup_ca_comm = None
|
57
58
|
|
58
59
|
try:
|
59
|
-
if
|
60
|
+
if enable_compile:
|
60
61
|
_to_torch(model)
|
61
62
|
monkey_patch_vllm_all_gather()
|
62
63
|
backup_ca_comm = tp_group.ca_comm
|
@@ -65,7 +66,7 @@ def patch_model(
|
|
65
66
|
else:
|
66
67
|
yield model.forward
|
67
68
|
finally:
|
68
|
-
if
|
69
|
+
if enable_compile:
|
69
70
|
_to_torch(model, reverse=True)
|
70
71
|
monkey_patch_vllm_all_gather(reverse=True)
|
71
72
|
tp_group.ca_comm = backup_ca_comm
|
@@ -84,13 +85,20 @@ def set_torch_compile_config():
|
|
84
85
|
|
85
86
|
|
86
87
|
class CudaGraphRunner:
|
87
|
-
def __init__(
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
model_runner: "ModelRunner",
|
91
|
+
max_batch_size_to_capture: int,
|
92
|
+
use_torch_compile: bool,
|
93
|
+
disable_padding: bool,
|
94
|
+
):
|
88
95
|
self.model_runner = model_runner
|
89
96
|
self.graphs = {}
|
90
97
|
self.input_buffers = {}
|
91
98
|
self.output_buffers = {}
|
92
99
|
self.flashinfer_handlers = {}
|
93
100
|
self.graph_memory_pool = None
|
101
|
+
self.disable_padding = disable_padding
|
94
102
|
|
95
103
|
# Common inputs
|
96
104
|
self.max_bs = max_batch_size_to_capture
|
@@ -141,10 +149,13 @@ class CudaGraphRunner:
|
|
141
149
|
if use_torch_compile:
|
142
150
|
set_torch_compile_config()
|
143
151
|
|
144
|
-
def can_run(self, batch_size):
|
145
|
-
|
152
|
+
def can_run(self, batch_size: int):
|
153
|
+
if self.disable_padding:
|
154
|
+
return batch_size in self.graphs
|
155
|
+
else:
|
156
|
+
return batch_size <= self.max_bs
|
146
157
|
|
147
|
-
def capture(self, batch_size_list):
|
158
|
+
def capture(self, batch_size_list: List[int]):
|
148
159
|
self.batch_size_list = batch_size_list
|
149
160
|
with graph_capture() as graph_capture_context:
|
150
161
|
self.stream = graph_capture_context.stream
|
@@ -165,7 +176,7 @@ class CudaGraphRunner:
|
|
165
176
|
self.output_buffers[bs] = output_buffers
|
166
177
|
self.flashinfer_handlers[bs] = flashinfer_handler
|
167
178
|
|
168
|
-
def capture_one_batch_size(self, bs, forward):
|
179
|
+
def capture_one_batch_size(self, bs: int, forward: Callable):
|
169
180
|
graph = torch.cuda.CUDAGraph()
|
170
181
|
stream = self.stream
|
171
182
|
|
@@ -239,12 +250,23 @@ class CudaGraphRunner:
|
|
239
250
|
return forward(input_ids, input_metadata.positions, input_metadata)
|
240
251
|
|
241
252
|
for _ in range(2):
|
253
|
+
torch.cuda.synchronize()
|
254
|
+
self.model_runner.tp_group.barrier()
|
255
|
+
|
242
256
|
run_once()
|
243
257
|
|
258
|
+
torch.cuda.synchronize()
|
259
|
+
self.model_runner.tp_group.barrier()
|
260
|
+
|
244
261
|
torch.cuda.synchronize()
|
262
|
+
self.model_runner.tp_group.barrier()
|
263
|
+
|
245
264
|
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
|
246
265
|
out = run_once()
|
266
|
+
|
247
267
|
torch.cuda.synchronize()
|
268
|
+
self.model_runner.tp_group.barrier()
|
269
|
+
|
248
270
|
self.graph_memory_pool = graph.pool()
|
249
271
|
return graph, None, out, flashinfer_decode_wrapper
|
250
272
|
|
@@ -278,7 +300,9 @@ class CudaGraphRunner:
|
|
278
300
|
)
|
279
301
|
|
280
302
|
# Replay
|
303
|
+
torch.cuda.synchronize()
|
281
304
|
self.graphs[bs].replay()
|
305
|
+
torch.cuda.synchronize()
|
282
306
|
output = self.output_buffers[bs]
|
283
307
|
|
284
308
|
# Unpad
|
@@ -61,9 +61,11 @@ class InputMetadata:
|
|
61
61
|
extend_start_loc: torch.Tensor = None
|
62
62
|
extend_no_prefix: bool = None
|
63
63
|
|
64
|
-
#
|
64
|
+
# For logprob
|
65
65
|
return_logprob: bool = False
|
66
66
|
top_logprobs_nums: List[int] = None
|
67
|
+
extend_seq_lens_cpu: List[int] = None
|
68
|
+
logprob_start_lens_cpu: List[int] = None
|
67
69
|
|
68
70
|
# For multimodal
|
69
71
|
pixel_values: List[torch.Tensor] = None
|
@@ -86,14 +88,19 @@ class InputMetadata:
|
|
86
88
|
reqs = batch.reqs
|
87
89
|
self.pixel_values = [r.pixel_values for r in reqs]
|
88
90
|
self.image_sizes = [r.image_size for r in reqs]
|
89
|
-
self.image_offsets = [
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
91
|
+
self.image_offsets = []
|
92
|
+
for r in reqs:
|
93
|
+
if isinstance(r.image_offset, list):
|
94
|
+
self.image_offsets.append(
|
95
|
+
[
|
96
|
+
(image_offset - len(r.prefix_indices))
|
97
|
+
for image_offset in r.image_offset
|
98
|
+
]
|
99
|
+
)
|
100
|
+
elif isinstance(r.image_offset, int):
|
101
|
+
self.image_offsets.append(r.image_offset - len(r.prefix_indices))
|
102
|
+
elif r.image_offset is None:
|
103
|
+
self.image_offsets.append(0)
|
97
104
|
|
98
105
|
def compute_positions(self, batch: ScheduleBatch):
|
99
106
|
position_ids_offsets = batch.position_ids_offsets
|
@@ -109,8 +116,8 @@ class InputMetadata:
|
|
109
116
|
self.positions = torch.tensor(
|
110
117
|
np.concatenate(
|
111
118
|
[
|
112
|
-
np.arange(
|
113
|
-
for req in batch.reqs
|
119
|
+
np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
|
120
|
+
for i, req in enumerate(batch.reqs)
|
114
121
|
],
|
115
122
|
axis=0,
|
116
123
|
),
|
@@ -123,7 +130,7 @@ class InputMetadata:
|
|
123
130
|
np.concatenate(
|
124
131
|
[
|
125
132
|
np.arange(
|
126
|
-
|
133
|
+
batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
127
134
|
len(req.fill_ids) + position_ids_offsets_cpu[i],
|
128
135
|
)
|
129
136
|
for i, req in enumerate(batch.reqs)
|
@@ -139,14 +146,29 @@ class InputMetadata:
|
|
139
146
|
def compute_extend_infos(self, batch: ScheduleBatch):
|
140
147
|
if self.forward_mode == ForwardMode.DECODE:
|
141
148
|
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
|
149
|
+
self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None
|
142
150
|
else:
|
143
151
|
extend_lens_cpu = [
|
144
|
-
len(r.fill_ids) -
|
152
|
+
len(r.fill_ids) - batch.prefix_lens_cpu[i]
|
153
|
+
for i, r in enumerate(batch.reqs)
|
145
154
|
]
|
146
155
|
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
147
156
|
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
148
157
|
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
149
|
-
self.extend_no_prefix = all(
|
158
|
+
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
|
159
|
+
|
160
|
+
self.extend_seq_lens_cpu = extend_lens_cpu
|
161
|
+
self.logprob_start_lens_cpu = [
|
162
|
+
(
|
163
|
+
min(
|
164
|
+
req.logprob_start_len - batch.prefix_lens_cpu[i],
|
165
|
+
extend_lens_cpu[i] - 1,
|
166
|
+
)
|
167
|
+
if req.logprob_start_len >= batch.prefix_lens_cpu[i]
|
168
|
+
else extend_lens_cpu[i] - 1 # Fake extend, actually decode
|
169
|
+
)
|
170
|
+
for i, req in enumerate(batch.reqs)
|
171
|
+
]
|
150
172
|
|
151
173
|
@classmethod
|
152
174
|
def from_schedule_batch(
|
@@ -180,14 +202,8 @@ class InputMetadata:
|
|
180
202
|
if forward_mode != ForwardMode.DECODE:
|
181
203
|
ret.init_multimuldal_info(batch)
|
182
204
|
|
183
|
-
prefix_lens = None
|
184
|
-
if forward_mode != ForwardMode.DECODE:
|
185
|
-
prefix_lens = torch.tensor(
|
186
|
-
[len(r.prefix_indices) for r in batch.reqs], device="cuda"
|
187
|
-
)
|
188
|
-
|
189
205
|
if model_runner.server_args.disable_flashinfer:
|
190
|
-
ret.init_triton_args(batch
|
206
|
+
ret.init_triton_args(batch)
|
191
207
|
|
192
208
|
flashinfer_use_ragged = False
|
193
209
|
if not model_runner.server_args.disable_flashinfer:
|
@@ -198,30 +214,35 @@ class InputMetadata:
|
|
198
214
|
):
|
199
215
|
flashinfer_use_ragged = True
|
200
216
|
ret.init_flashinfer_handlers(
|
201
|
-
model_runner,
|
217
|
+
model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged
|
202
218
|
)
|
203
219
|
|
204
220
|
return ret
|
205
221
|
|
206
|
-
def init_triton_args(self, batch: ScheduleBatch
|
222
|
+
def init_triton_args(self, batch: ScheduleBatch):
|
207
223
|
"""Init auxiliary variables for triton attention backend."""
|
208
224
|
self.triton_max_seq_len = int(torch.max(self.seq_lens))
|
209
|
-
self.triton_prefix_lens = prefix_lens
|
210
225
|
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
|
211
226
|
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
|
212
227
|
|
213
228
|
if self.forward_mode == ForwardMode.DECODE:
|
214
229
|
self.triton_max_extend_len = None
|
215
230
|
else:
|
216
|
-
|
231
|
+
self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
232
|
+
extend_seq_lens = self.seq_lens - self.triton_prefix_lens
|
217
233
|
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
|
218
234
|
|
219
235
|
def init_flashinfer_handlers(
|
220
236
|
self,
|
221
237
|
model_runner,
|
222
|
-
|
238
|
+
prefix_lens_cpu,
|
223
239
|
flashinfer_use_ragged,
|
224
240
|
):
|
241
|
+
if self.forward_mode != ForwardMode.DECODE:
|
242
|
+
prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
|
243
|
+
else:
|
244
|
+
prefix_lens = None
|
245
|
+
|
225
246
|
update_flashinfer_indices(
|
226
247
|
self.forward_mode,
|
227
248
|
model_runner,
|
@@ -294,6 +315,8 @@ def update_flashinfer_indices(
|
|
294
315
|
num_kv_heads,
|
295
316
|
head_dim,
|
296
317
|
1,
|
318
|
+
data_type=model_runner.kv_cache_dtype,
|
319
|
+
q_data_type=model_runner.dtype,
|
297
320
|
)
|
298
321
|
else:
|
299
322
|
# extend part
|
@@ -372,6 +395,8 @@ def update_flashinfer_indices(
|
|
372
395
|
num_kv_heads,
|
373
396
|
head_dim,
|
374
397
|
1,
|
398
|
+
data_type=model_runner.kv_cache_dtype,
|
399
|
+
q_data_type=model_runner.dtype,
|
375
400
|
)
|
376
401
|
else:
|
377
402
|
# extend part
|