sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +18 -1
- sglang/lang/interpreter.py +71 -1
- sglang/lang/ir.py +2 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/chatglm.py +78 -0
- sglang/srt/configs/dbrx.py +279 -0
- sglang/srt/configs/model_config.py +1 -1
- sglang/srt/hf_transformers_utils.py +9 -14
- sglang/srt/layers/attention/__init__.py +8 -1
- sglang/srt/layers/attention/flashinfer_backend.py +4 -2
- sglang/srt/layers/linear.py +159 -55
- sglang/srt/layers/logits_processor.py +6 -6
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -3
- sglang/srt/layers/parameter.py +431 -0
- sglang/srt/layers/quantization/__init__.py +3 -2
- sglang/srt/layers/quantization/fp8.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +174 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -1
- sglang/srt/managers/cache_controller.py +307 -0
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/schedule_batch.py +7 -1
- sglang/srt/managers/scheduler.py +10 -6
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +6 -2
- sglang/srt/mem_cache/memory_pool.py +206 -1
- sglang/srt/metrics/collector.py +22 -30
- sglang/srt/model_executor/cuda_graph_runner.py +14 -7
- sglang/srt/model_executor/forward_batch_info.py +20 -15
- sglang/srt/model_executor/model_runner.py +10 -4
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/grok.py +25 -16
- sglang/srt/models/llama.py +9 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -0
- sglang/srt/server.py +11 -8
- sglang/srt/server_args.py +12 -1
- sglang/srt/speculative/eagle_utils.py +93 -85
- sglang/srt/speculative/eagle_worker.py +47 -33
- sglang/srt/utils.py +32 -5
- sglang/test/test_programs.py +23 -1
- sglang/test/test_utils.py +36 -7
- sglang/version.py +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +6 -7
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +48 -43
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -9,13 +9,12 @@ import triton.language as tl
|
|
9
9
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
10
10
|
create_flashinfer_kv_indices_triton,
|
11
11
|
)
|
12
|
-
from sglang.srt.model_executor.forward_batch_info import
|
12
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
13
13
|
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
|
14
14
|
from sglang.srt.speculative.spec_info import SpecInfo
|
15
15
|
|
16
16
|
if TYPE_CHECKING:
|
17
|
-
from
|
18
|
-
from python.sglang.srt.managers.schedule_batch import ScheduleBatch
|
17
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
19
18
|
from sglang.srt.server_args import ServerArgs
|
20
19
|
|
21
20
|
|
@@ -179,19 +178,9 @@ def generate_draft_decode_kv_indices(
|
|
179
178
|
|
180
179
|
|
181
180
|
class EAGLEDraftInput(SpecInfo):
|
182
|
-
|
183
|
-
verified_id: torch.Tensor = None
|
184
|
-
positions: torch.Tensor = None
|
185
|
-
accept_length: torch.Tensor = None
|
186
|
-
has_finished: bool = False
|
187
|
-
unfinished_index: List[int] = None
|
188
|
-
|
189
|
-
def init(self, server_args: ServerArgs):
|
181
|
+
def __init__(self):
|
190
182
|
self.prev_mode = ForwardMode.DECODE
|
191
183
|
self.sample_output = None
|
192
|
-
self.topk: int = server_args.speculative_eagle_topk
|
193
|
-
self.num_verify_token: int = server_args.speculative_num_draft_tokens
|
194
|
-
self.spec_steps = server_args.speculative_num_steps
|
195
184
|
|
196
185
|
self.scores: torch.Tensor = None
|
197
186
|
self.score_list: List[torch.Tensor] = []
|
@@ -200,11 +189,20 @@ class EAGLEDraftInput(SpecInfo):
|
|
200
189
|
self.parents_list: List[torch.Tensor] = []
|
201
190
|
self.cache_list: List[torch.Tenor] = []
|
202
191
|
self.iter = 0
|
203
|
-
self.root_token: int = None
|
204
192
|
|
205
|
-
|
193
|
+
self.hidden_states: torch.Tensor = None
|
194
|
+
self.verified_id: torch.Tensor = None
|
195
|
+
self.positions: torch.Tensor = None
|
196
|
+
self.accept_length: torch.Tensor = None
|
197
|
+
self.has_finished: bool = False
|
198
|
+
self.unfinished_index: List[int] = None
|
199
|
+
|
200
|
+
def load_server_args(self, server_args: ServerArgs):
|
201
|
+
self.topk: int = server_args.speculative_eagle_topk
|
202
|
+
self.num_verify_token: int = server_args.speculative_num_draft_tokens
|
203
|
+
self.spec_steps = server_args.speculative_num_steps
|
206
204
|
|
207
|
-
def prepare_for_extend(self, batch:
|
205
|
+
def prepare_for_extend(self, batch: ScheduleBatch):
|
208
206
|
req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
|
209
207
|
out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
|
210
208
|
batch.out_cache_loc = out_cache_loc
|
@@ -226,81 +224,73 @@ class EAGLEDraftInput(SpecInfo):
|
|
226
224
|
|
227
225
|
pt += req.extend_input_len
|
228
226
|
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
model_input_ids = []
|
233
|
-
for i in range(len(seq_lens) - 1):
|
234
|
-
model_input_ids.extend(
|
235
|
-
input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]]
|
236
|
-
)
|
237
|
-
batch.input_ids = torch.tensor(
|
238
|
-
model_input_ids, dtype=torch.int32, device="cuda"
|
239
|
-
)
|
240
|
-
|
241
|
-
def capture_for_decode(
|
242
|
-
self,
|
243
|
-
sample_output: SampleOutput,
|
244
|
-
hidden_states: torch.Tensor,
|
245
|
-
prev_mode: ForwardMode,
|
246
|
-
):
|
247
|
-
self.sample_output = sample_output
|
248
|
-
self.prev_mode = prev_mode
|
249
|
-
self.hidden_states = hidden_states
|
227
|
+
# TODO: support batching inputs
|
228
|
+
assert len(batch.extend_lens) == 1
|
229
|
+
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
|
250
230
|
|
251
231
|
def prepare_for_decode(self, batch: ScheduleBatch):
|
252
|
-
prob = self.sample_output # b * (
|
232
|
+
prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
|
253
233
|
top = torch.topk(prob, self.topk, dim=-1)
|
254
|
-
topk_index, topk_p =
|
255
|
-
|
234
|
+
topk_index, topk_p = (
|
235
|
+
top.indices,
|
236
|
+
top.values,
|
237
|
+
) # shape: (b * top_k, top_k) or (b, top_k)
|
238
|
+
|
239
|
+
if self.prev_mode.is_decode():
|
256
240
|
scores = torch.mul(
|
257
241
|
self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk)
|
258
|
-
) # (b, topk)
|
242
|
+
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
|
259
243
|
topk_cs = torch.topk(
|
260
244
|
scores.flatten(start_dim=1), self.topk, dim=-1
|
261
245
|
) # (b, topk)
|
262
246
|
topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
|
263
|
-
self.scores = topk_cs_p
|
264
247
|
|
265
|
-
selected_input_index = topk_cs_index.flatten() // self.topk
|
248
|
+
selected_input_index = topk_cs_index.flatten() // self.topk + torch.arange(
|
249
|
+
0, batch.batch_size() * self.topk, step=self.topk, device="cuda"
|
250
|
+
).repeat_interleave(self.topk)
|
266
251
|
|
267
252
|
batch.spec_info.hidden_states = batch.spec_info.hidden_states[
|
268
253
|
selected_input_index, :
|
269
254
|
]
|
255
|
+
|
270
256
|
topk_index = topk_index.reshape(-1, self.topk**2)
|
271
257
|
batch.input_ids = torch.gather(
|
272
258
|
topk_index, index=topk_cs_index, dim=1
|
273
259
|
).flatten()
|
274
|
-
batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids
|
275
|
-
|
276
|
-
self.
|
260
|
+
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
|
261
|
+
|
262
|
+
self.scores = topk_cs_p
|
263
|
+
self.score_list.append(scores) # (b, topk, topk)
|
264
|
+
self.token_list.append(topk_index) # (b, topk * topk)
|
277
265
|
self.origin_score_list.append(topk_p.reshape(topk_index.shape))
|
278
266
|
self.parents_list.append(
|
279
267
|
topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk)
|
280
|
-
) # b, topk
|
281
|
-
|
282
|
-
|
283
|
-
self.scores = topk_p # b, top_k
|
284
|
-
self.score_list.append(topk_p.unsqueeze(1))
|
285
|
-
self.token_list.append(topk_index)
|
286
|
-
self.origin_score_list.append(topk_p)
|
268
|
+
) # shape: (b, topk)
|
269
|
+
else:
|
270
|
+
# ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND
|
287
271
|
batch.spec_info.hidden_states = (
|
288
|
-
batch.spec_info.hidden_states.repeat_interleave(self.topk, 0)
|
272
|
+
batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0)
|
289
273
|
)
|
274
|
+
|
290
275
|
batch.input_ids = topk_index.flatten()
|
291
276
|
batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel())
|
277
|
+
|
278
|
+
self.scores = topk_p # shape: (b, topk)
|
279
|
+
self.score_list.append(topk_p.unsqueeze(1)) # shape: (b, 1, topk)
|
280
|
+
self.token_list.append(topk_index) # shape: (b, topk)
|
281
|
+
self.origin_score_list.append(topk_p)
|
292
282
|
self.parents_list.append(
|
293
283
|
torch.arange(-1, self.topk, dtype=torch.long, device="cuda")
|
294
284
|
.unsqueeze(0)
|
295
285
|
.repeat(self.scores.shape[0], 1)
|
296
|
-
) # b, topk+1
|
286
|
+
) # shape: (b, topk + 1)
|
297
287
|
self.cache_list.append(batch.out_cache_loc)
|
298
288
|
self.positions = (
|
299
289
|
batch.seq_lens[:, None]
|
300
290
|
+ torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
|
301
291
|
).flatten()
|
302
292
|
|
303
|
-
bs = batch.seq_lens
|
293
|
+
bs = len(batch.seq_lens)
|
304
294
|
assign_req_to_token_pool[(bs,)](
|
305
295
|
batch.req_pool_indices,
|
306
296
|
batch.req_to_token_pool.req_to_token,
|
@@ -347,6 +337,7 @@ class EAGLEDraftInput(SpecInfo):
|
|
347
337
|
triton.next_power_of_2(self.spec_steps + 1),
|
348
338
|
)
|
349
339
|
|
340
|
+
batch.seq_lens_sum = sum(batch.seq_lens)
|
350
341
|
batch.input_ids = self.verified_id
|
351
342
|
self.verified_id = new_verified_id
|
352
343
|
|
@@ -419,11 +410,6 @@ class EAGLEDraftInput(SpecInfo):
|
|
419
410
|
)
|
420
411
|
return bs, kv_indices, cum_kv_seq_len
|
421
412
|
|
422
|
-
def clear(self):
|
423
|
-
self.iter = 0
|
424
|
-
self.score_list.clear()
|
425
|
-
self.positions = None
|
426
|
-
|
427
413
|
def clear_draft_cache(self, batch):
|
428
414
|
draft_cache = torch.cat(self.cache_list, dim=0)
|
429
415
|
batch.token_to_kv_pool.free(draft_cache)
|
@@ -455,12 +441,18 @@ class EAGLEDraftInput(SpecInfo):
|
|
455
441
|
return kv_indices, cum_kv_seq_len, qo_indptr, None
|
456
442
|
|
457
443
|
def merge_batch(self, spec_info: EAGLEDraftInput):
|
458
|
-
|
444
|
+
if self.hidden_states is None:
|
445
|
+
self.hidden_states = spec_info.hidden_states
|
446
|
+
self.verified_id = spec_info.verified_id
|
447
|
+
self.sample_output = spec_info.sample_output
|
448
|
+
self.prev_mode = spec_info.prev_mode
|
449
|
+
return
|
450
|
+
if spec_info.hidden_states is None:
|
451
|
+
return
|
459
452
|
self.hidden_states = torch.cat(
|
460
453
|
[self.hidden_states, spec_info.hidden_states], axis=0
|
461
454
|
)
|
462
455
|
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
|
463
|
-
# self.positions = torch.cat([self.positions, spec_info.positions], axis=0)
|
464
456
|
self.sample_output = torch.cat([self.sample_output, spec_info.sample_output])
|
465
457
|
|
466
458
|
|
@@ -567,11 +559,37 @@ class EagleVerifyInput(SpecInfo):
|
|
567
559
|
triton.next_power_of_2(max_draft_len),
|
568
560
|
)
|
569
561
|
|
570
|
-
accept_index = accept_index[accept_index != -1]
|
571
|
-
# extract_index = extract_index[extract_index != 0]
|
572
|
-
|
573
562
|
draft_input = EAGLEDraftInput()
|
563
|
+
new_accept_index = []
|
564
|
+
unfinished_index = []
|
565
|
+
finished_extend_len = {} # {rid:accept_length + 1}
|
566
|
+
accept_index_cpu = accept_index.tolist()
|
567
|
+
predict_cpu = predict.tolist()
|
568
|
+
# iterate every accepted token and check if req has finished after append the token
|
569
|
+
# should be checked BEFORE free kv cache slots
|
570
|
+
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
571
|
+
new_accept_index_ = []
|
572
|
+
for j, idx in enumerate(accept_index_row):
|
573
|
+
if idx == -1:
|
574
|
+
break
|
575
|
+
id = predict_cpu[idx]
|
576
|
+
# if not found_finished:
|
577
|
+
req.output_ids.append(id)
|
578
|
+
finished_extend_len[req.rid] = j + 1
|
579
|
+
req.check_finished()
|
580
|
+
if req.finished():
|
581
|
+
draft_input.has_finished = True
|
582
|
+
# set all tokens after finished token to -1 and break
|
583
|
+
accept_index[i, j + 1 :] = -1
|
584
|
+
break
|
585
|
+
else:
|
586
|
+
new_accept_index_.append(idx)
|
587
|
+
if not req.finished():
|
588
|
+
new_accept_index.extend(new_accept_index_)
|
589
|
+
unfinished_index.append(i)
|
590
|
+
accept_length = (accept_index != -1).sum(dim=1) - 1
|
574
591
|
|
592
|
+
accept_index = accept_index[accept_index != -1]
|
575
593
|
accept_length_cpu = accept_length.tolist()
|
576
594
|
verified_id = predict[accept_index]
|
577
595
|
verified_id_cpu = verified_id.tolist()
|
@@ -590,29 +608,19 @@ class EagleVerifyInput(SpecInfo):
|
|
590
608
|
triton.next_power_of_2(bs),
|
591
609
|
)
|
592
610
|
batch.seq_lens.add_(accept_length + 1)
|
593
|
-
new_accept_index = []
|
594
|
-
unfinished_index = []
|
595
|
-
finished_extend_len = {} # {rid:accept_length + 1}
|
596
|
-
# retracted_reqs, new_token_ratio = batch.retract_decode()
|
597
|
-
|
598
|
-
low = 0
|
599
|
-
for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)):
|
600
|
-
req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1])
|
601
|
-
req.check_finished()
|
602
|
-
if req.finished():
|
603
|
-
draft_input.has_finished = True
|
604
|
-
else:
|
605
|
-
new_accept_index.append(accept_index[low : low + verified_len + 1])
|
606
|
-
unfinished_index.append(i)
|
607
|
-
low += verified_len + 1
|
608
|
-
finished_extend_len[req.rid] = verified_len + 1
|
609
611
|
|
610
612
|
if len(new_accept_index) > 0:
|
611
|
-
new_accept_index = torch.
|
613
|
+
new_accept_index = torch.tensor(new_accept_index, device="cuda")
|
612
614
|
draft_input.verified_id = predict[new_accept_index]
|
613
615
|
draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
|
614
616
|
draft_input.accept_length = accept_length[unfinished_index]
|
615
617
|
draft_input.unfinished_index = unfinished_index
|
616
618
|
|
617
619
|
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
|
618
|
-
return
|
620
|
+
return (
|
621
|
+
draft_input,
|
622
|
+
logits_output,
|
623
|
+
verified_id,
|
624
|
+
finished_extend_len,
|
625
|
+
accept_length_cpu,
|
626
|
+
)
|
@@ -51,63 +51,72 @@ class EAGLEWorker(TpModelWorker):
|
|
51
51
|
batch.spec_info.prepare_for_decode(batch)
|
52
52
|
model_worker_batch = batch.get_model_worker_batch()
|
53
53
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
54
|
-
forward_batch.
|
54
|
+
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
55
55
|
logits_output = self.model_runner.forward(forward_batch)
|
56
56
|
self.capture_for_decode(logits_output, forward_batch)
|
57
57
|
|
58
58
|
def forward_draft_extend(self, batch: ScheduleBatch):
|
59
|
-
self.
|
59
|
+
self._set_mem_pool(batch, self.model_runner)
|
60
60
|
batch.spec_info.prepare_for_extend(batch)
|
61
61
|
model_worker_batch = batch.get_model_worker_batch()
|
62
62
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
63
|
-
forward_batch.
|
63
|
+
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
64
64
|
logits_output = self.model_runner.forward(forward_batch)
|
65
65
|
self.capture_for_decode(logits_output, forward_batch)
|
66
|
-
self.
|
66
|
+
self._set_mem_pool(batch, self.target_worker.model_runner)
|
67
67
|
|
68
68
|
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
|
69
69
|
if batch.forward_mode.is_decode():
|
70
|
-
|
71
|
-
self.
|
70
|
+
# Draft
|
71
|
+
self._set_mem_pool(batch, self.model_runner)
|
72
72
|
for i in range(self.server_args.speculative_num_steps):
|
73
73
|
self.forward_draft_decode(batch)
|
74
74
|
batch.spec_info.clear_draft_cache(batch)
|
75
|
-
self.
|
75
|
+
self._set_mem_pool(batch, self.target_worker.model_runner)
|
76
|
+
|
77
|
+
# Verify
|
76
78
|
(
|
77
79
|
next_draft_input,
|
78
80
|
logits_output,
|
79
81
|
verified_id,
|
80
82
|
self.finish_extend_len,
|
83
|
+
accept_length_cpu,
|
81
84
|
model_worker_batch,
|
82
85
|
) = self.verify(batch)
|
83
|
-
next_draft_input.
|
86
|
+
next_draft_input.load_server_args(self.server_args)
|
84
87
|
batch.spec_info = next_draft_input
|
85
88
|
# if it is None, means all requsets are finished
|
86
89
|
if batch.spec_info.verified_id is not None:
|
87
|
-
self.
|
88
|
-
|
89
|
-
|
90
|
+
self.forward_draft_extend_after_decode(batch)
|
91
|
+
return (
|
92
|
+
logits_output,
|
93
|
+
verified_id,
|
94
|
+
model_worker_batch,
|
95
|
+
sum(accept_length_cpu),
|
96
|
+
)
|
90
97
|
|
91
98
|
else:
|
92
|
-
|
93
|
-
|
99
|
+
# Forward with the target model and get hidden states.
|
100
|
+
# We need the full hidden states to prefill the KV cache of the draft model.
|
94
101
|
model_worker_batch = batch.get_model_worker_batch()
|
95
|
-
model_worker_batch.
|
96
|
-
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
102
|
+
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
97
103
|
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
98
104
|
model_worker_batch
|
99
105
|
)
|
100
|
-
|
101
|
-
|
106
|
+
|
107
|
+
# Forward with the draft model.
|
108
|
+
spec_info = EAGLEDraftInput()
|
109
|
+
spec_info.load_server_args(self.server_args)
|
110
|
+
spec_info.hidden_states = logits_output.hidden_states
|
111
|
+
spec_info.verified_id = next_token_ids
|
102
112
|
batch.spec_info = spec_info
|
103
113
|
self.forward_draft_extend(batch)
|
104
|
-
|
105
|
-
return logits_output, next_token_ids, model_worker_batch, spec_info
|
114
|
+
return logits_output, next_token_ids, model_worker_batch, 0
|
106
115
|
|
107
116
|
def verify(self, batch: ScheduleBatch):
|
108
117
|
verify_input = batch.spec_info.prepare_for_verify(batch)
|
109
|
-
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
110
118
|
verify_input.prepare_for_verify(batch)
|
119
|
+
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
111
120
|
batch.spec_info = verify_input
|
112
121
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
113
122
|
model_worker_batch = batch.get_model_worker_batch()
|
@@ -119,44 +128,49 @@ class EAGLEWorker(TpModelWorker):
|
|
119
128
|
batch.forward_mode = ForwardMode.DECODE
|
120
129
|
return res + (model_worker_batch,)
|
121
130
|
|
122
|
-
def
|
131
|
+
def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
|
123
132
|
batch.token_to_kv_pool = runner.token_to_kv_pool
|
124
133
|
batch.req_to_token_pool = runner.req_to_token_pool
|
125
134
|
|
126
|
-
def
|
127
|
-
self.
|
135
|
+
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
136
|
+
self._set_mem_pool(batch, self.model_runner)
|
128
137
|
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
129
138
|
if batch.spec_info.has_finished:
|
130
139
|
index = batch.spec_info.unfinished_index
|
131
140
|
seq_lens = batch.seq_lens
|
132
141
|
batch.seq_lens = batch.seq_lens[index]
|
142
|
+
|
133
143
|
batch.spec_info.prepare_extend_after_decode(batch)
|
134
144
|
model_worker_batch = batch.get_model_worker_batch()
|
135
145
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
136
|
-
forward_batch.
|
146
|
+
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
137
147
|
logits_output = self.model_runner.forward(forward_batch)
|
148
|
+
|
138
149
|
batch.spec_info.hidden_states = logits_output.hidden_states
|
139
150
|
self.capture_for_decode(logits_output, forward_batch)
|
140
151
|
batch.forward_mode = ForwardMode.DECODE
|
141
152
|
if batch.spec_info.has_finished:
|
142
153
|
batch.seq_lens = seq_lens
|
143
|
-
self.
|
154
|
+
self._set_mem_pool(batch, self.target_worker.model_runner)
|
144
155
|
|
145
|
-
def capture_for_decode(
|
146
|
-
|
147
|
-
|
156
|
+
def capture_for_decode(
|
157
|
+
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
158
|
+
):
|
148
159
|
sample_output = torch.softmax(
|
149
|
-
|
150
|
-
) # TODO: Support more sampling
|
151
|
-
forward_batch.spec_info
|
152
|
-
|
153
|
-
|
160
|
+
logits_output.next_token_logits, dim=-1
|
161
|
+
) # TODO(kavioyu): Support more sampling methods
|
162
|
+
spec_info = forward_batch.spec_info
|
163
|
+
spec_info.sample_output = sample_output
|
164
|
+
spec_info.hidden_states = logits_output.hidden_states
|
165
|
+
spec_info.prev_mode = forward_batch.forward_mode
|
154
166
|
|
155
167
|
# Don't support prefix share now.
|
156
168
|
def finish_request(self, reqs: Union[Req, List[Req]]):
|
157
169
|
if not isinstance(reqs, List):
|
158
170
|
reqs = [reqs]
|
159
171
|
for req in reqs:
|
172
|
+
if req.rid not in self.finish_extend_len:
|
173
|
+
continue
|
160
174
|
req_len = (
|
161
175
|
len(req.origin_input_ids)
|
162
176
|
+ len(req.output_ids)
|
sglang/srt/utils.py
CHANGED
@@ -335,6 +335,8 @@ def is_port_available(port):
|
|
335
335
|
return True
|
336
336
|
except socket.error:
|
337
337
|
return False
|
338
|
+
except OverflowError:
|
339
|
+
return False
|
338
340
|
|
339
341
|
|
340
342
|
def decode_video_base64(video_base64):
|
@@ -709,13 +711,14 @@ def broadcast_pyobj(
|
|
709
711
|
data: List[Any],
|
710
712
|
rank: int,
|
711
713
|
dist_group: Optional[torch.distributed.ProcessGroup] = None,
|
714
|
+
src: int = 0,
|
712
715
|
):
|
713
716
|
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
714
717
|
|
715
718
|
if rank == 0:
|
716
719
|
if len(data) == 0:
|
717
720
|
tensor_size = torch.tensor([0], dtype=torch.long)
|
718
|
-
dist.broadcast(tensor_size, src=
|
721
|
+
dist.broadcast(tensor_size, src=src, group=dist_group)
|
719
722
|
else:
|
720
723
|
serialized_data = pickle.dumps(data)
|
721
724
|
size = len(serialized_data)
|
@@ -724,19 +727,19 @@ def broadcast_pyobj(
|
|
724
727
|
)
|
725
728
|
tensor_size = torch.tensor([size], dtype=torch.long)
|
726
729
|
|
727
|
-
dist.broadcast(tensor_size, src=
|
728
|
-
dist.broadcast(tensor_data, src=
|
730
|
+
dist.broadcast(tensor_size, src=src, group=dist_group)
|
731
|
+
dist.broadcast(tensor_data, src=src, group=dist_group)
|
729
732
|
return data
|
730
733
|
else:
|
731
734
|
tensor_size = torch.tensor([0], dtype=torch.long)
|
732
|
-
dist.broadcast(tensor_size, src=
|
735
|
+
dist.broadcast(tensor_size, src=src, group=dist_group)
|
733
736
|
size = tensor_size.item()
|
734
737
|
|
735
738
|
if size == 0:
|
736
739
|
return []
|
737
740
|
|
738
741
|
tensor_data = torch.empty(size, dtype=torch.uint8)
|
739
|
-
dist.broadcast(tensor_data, src=
|
742
|
+
dist.broadcast(tensor_data, src=src, group=dist_group)
|
740
743
|
|
741
744
|
serialized_data = bytes(tensor_data.cpu().numpy())
|
742
745
|
data = pickle.loads(serialized_data)
|
@@ -1348,3 +1351,27 @@ class MultiprocessingSerializer:
|
|
1348
1351
|
@staticmethod
|
1349
1352
|
def deserialize(data):
|
1350
1353
|
return ForkingPickler.loads(data)
|
1354
|
+
|
1355
|
+
|
1356
|
+
def debug_timing(func):
|
1357
|
+
# todo: replace with a more organized instrumentation
|
1358
|
+
def wrapper(*args, **kwargs):
|
1359
|
+
if logger.isEnabledFor(logging.DEBUG):
|
1360
|
+
tic = torch.cuda.Event(enable_timing=True)
|
1361
|
+
toc = torch.cuda.Event(enable_timing=True)
|
1362
|
+
tic.record()
|
1363
|
+
result = func(*args, **kwargs)
|
1364
|
+
toc.record()
|
1365
|
+
torch.cuda.synchronize() # Ensure all CUDA operations are complete
|
1366
|
+
elapsed = tic.elapsed_time(toc)
|
1367
|
+
indices = kwargs.get("indices", args[1] if len(args) > 1 else None)
|
1368
|
+
num_tokens = len(indices) if indices is not None else 0
|
1369
|
+
throughput = num_tokens / elapsed * 1000 if elapsed > 0 else 0
|
1370
|
+
logger.debug(
|
1371
|
+
f"Transfer time: {elapsed} ms, throughput: {throughput} tokens/s"
|
1372
|
+
)
|
1373
|
+
return result
|
1374
|
+
else:
|
1375
|
+
return func(*args, **kwargs)
|
1376
|
+
|
1377
|
+
return wrapper
|
sglang/test/test_programs.py
CHANGED
@@ -509,13 +509,35 @@ def test_hellaswag_select():
|
|
509
509
|
temperature=0,
|
510
510
|
num_threads=64,
|
511
511
|
progress_bar=True,
|
512
|
+
generator_style=False,
|
512
513
|
)
|
513
|
-
preds = [
|
514
|
+
preds = []
|
515
|
+
for i, ret in enumerate(rets):
|
516
|
+
preds.append(choices[i].index(ret["answer"]))
|
514
517
|
latency = time.time() - tic
|
515
518
|
|
516
519
|
# Compute accuracy
|
517
520
|
accuracy = np.mean(np.array(preds) == np.array(labels))
|
518
521
|
|
522
|
+
# Test generator style of run_batch
|
523
|
+
tic = time.time()
|
524
|
+
rets = few_shot_hellaswag.run_batch(
|
525
|
+
arguments,
|
526
|
+
temperature=0,
|
527
|
+
num_threads=64,
|
528
|
+
progress_bar=True,
|
529
|
+
generator_style=True,
|
530
|
+
)
|
531
|
+
preds_gen = []
|
532
|
+
for i, ret in enumerate(rets):
|
533
|
+
preds_gen.append(choices[i].index(ret["answer"]))
|
534
|
+
latency_gen = time.time() - tic
|
535
|
+
|
536
|
+
# Compute accuracy
|
537
|
+
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
|
538
|
+
assert np.abs(accuracy_gen - accuracy) < 0.01
|
539
|
+
assert np.abs(latency_gen - latency) < 1
|
540
|
+
|
519
541
|
return accuracy, latency
|
520
542
|
|
521
543
|
|
sglang/test/test_utils.py
CHANGED
@@ -36,7 +36,7 @@ DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
|
|
36
36
|
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
37
37
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600
|
38
38
|
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it"
|
39
|
-
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct
|
39
|
+
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct"
|
40
40
|
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
|
41
41
|
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
42
42
|
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
|
@@ -532,6 +532,8 @@ def run_bench_serving(
|
|
532
532
|
request_rate,
|
533
533
|
other_server_args,
|
534
534
|
dataset_name="random",
|
535
|
+
dataset_path="",
|
536
|
+
tokenizer=None,
|
535
537
|
random_input_len=4096,
|
536
538
|
random_output_len=2048,
|
537
539
|
disable_stream=False,
|
@@ -553,9 +555,9 @@ def run_bench_serving(
|
|
553
555
|
host=None,
|
554
556
|
port=None,
|
555
557
|
dataset_name=dataset_name,
|
556
|
-
dataset_path=
|
558
|
+
dataset_path=dataset_path,
|
557
559
|
model=None,
|
558
|
-
tokenizer=
|
560
|
+
tokenizer=tokenizer,
|
559
561
|
num_prompts=num_prompts,
|
560
562
|
sharegpt_output_len=None,
|
561
563
|
random_input_len=random_input_len,
|
@@ -657,16 +659,16 @@ STDERR_FILENAME = "stderr.txt"
|
|
657
659
|
STDOUT_FILENAME = "stdout.txt"
|
658
660
|
|
659
661
|
|
660
|
-
def read_output(output_lines):
|
662
|
+
def read_output(output_lines: List[str], filename: str = STDERR_FILENAME):
|
661
663
|
"""Print the output in real time with another thread."""
|
662
|
-
while not os.path.exists(
|
664
|
+
while not os.path.exists(filename):
|
663
665
|
time.sleep(1)
|
664
666
|
|
665
667
|
pt = 0
|
666
668
|
while pt >= 0:
|
667
|
-
if pt > 0 and not os.path.exists(
|
669
|
+
if pt > 0 and not os.path.exists(filename):
|
668
670
|
break
|
669
|
-
lines = open(
|
671
|
+
lines = open(filename).readlines()
|
670
672
|
for line in lines[pt:]:
|
671
673
|
print(line, end="", flush=True)
|
672
674
|
output_lines.append(line)
|
@@ -747,6 +749,33 @@ def run_and_check_memory_leak(
|
|
747
749
|
assert has_abort
|
748
750
|
|
749
751
|
|
752
|
+
def run_command_and_capture_output(command, env: Optional[dict] = None):
|
753
|
+
stdout = open(STDOUT_FILENAME, "w")
|
754
|
+
stderr = open(STDERR_FILENAME, "w")
|
755
|
+
process = subprocess.Popen(
|
756
|
+
command, stdout=stdout, stderr=stderr, env=env, text=True
|
757
|
+
)
|
758
|
+
|
759
|
+
# Launch a thread to stream the output
|
760
|
+
output_lines = []
|
761
|
+
t = threading.Thread(target=read_output, args=(output_lines, STDOUT_FILENAME))
|
762
|
+
t.start()
|
763
|
+
|
764
|
+
# Join the process
|
765
|
+
process.wait()
|
766
|
+
|
767
|
+
stdout.close()
|
768
|
+
stderr.close()
|
769
|
+
if os.path.exists(STDOUT_FILENAME):
|
770
|
+
os.remove(STDOUT_FILENAME)
|
771
|
+
if os.path.exists(STDERR_FILENAME):
|
772
|
+
os.remove(STDERR_FILENAME)
|
773
|
+
kill_process_tree(process.pid)
|
774
|
+
t.join()
|
775
|
+
|
776
|
+
return output_lines
|
777
|
+
|
778
|
+
|
750
779
|
def run_mmlu_test(
|
751
780
|
disable_radix_cache=False,
|
752
781
|
enable_mixed_chunk=False,
|
sglang/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.4.1.
|
1
|
+
__version__ = "0.4.1.post5"
|