sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +18 -1
- sglang/lang/interpreter.py +71 -1
- sglang/lang/ir.py +2 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/chatglm.py +78 -0
- sglang/srt/configs/dbrx.py +279 -0
- sglang/srt/configs/model_config.py +16 -7
- sglang/srt/hf_transformers_utils.py +9 -14
- sglang/srt/layers/attention/__init__.py +8 -1
- sglang/srt/layers/attention/flashinfer_backend.py +21 -5
- sglang/srt/layers/linear.py +89 -47
- sglang/srt/layers/logits_processor.py +6 -6
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +39 -12
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +439 -0
- sglang/srt/layers/quantization/__init__.py +5 -2
- sglang/srt/layers/quantization/fp8.py +107 -53
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +174 -0
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/vocab_parallel_embedding.py +16 -3
- sglang/srt/managers/cache_controller.py +307 -0
- sglang/srt/managers/configure_logging.py +43 -0
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/detokenizer_manager.py +0 -2
- sglang/srt/managers/io_struct.py +29 -13
- sglang/srt/managers/schedule_batch.py +7 -1
- sglang/srt/managers/scheduler.py +58 -15
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +109 -45
- sglang/srt/mem_cache/memory_pool.py +313 -53
- sglang/srt/metrics/collector.py +32 -35
- sglang/srt/model_executor/cuda_graph_runner.py +14 -7
- sglang/srt/model_executor/forward_batch_info.py +20 -15
- sglang/srt/model_executor/model_runner.py +53 -10
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/grok.py +25 -16
- sglang/srt/models/llama.py +46 -4
- sglang/srt/models/qwen2.py +11 -0
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +15 -5
- sglang/srt/sampling/sampling_params.py +1 -1
- sglang/srt/server.py +125 -69
- sglang/srt/server_args.py +39 -19
- sglang/srt/speculative/eagle_utils.py +93 -85
- sglang/srt/speculative/eagle_worker.py +48 -33
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +61 -5
- sglang/test/test_programs.py +23 -1
- sglang/test/test_utils.py +36 -7
- sglang/version.py +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +16 -15
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +61 -51
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
@@ -9,13 +9,12 @@ import triton.language as tl
|
|
9
9
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
10
10
|
create_flashinfer_kv_indices_triton,
|
11
11
|
)
|
12
|
-
from sglang.srt.model_executor.forward_batch_info import
|
12
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
13
13
|
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
|
14
14
|
from sglang.srt.speculative.spec_info import SpecInfo
|
15
15
|
|
16
16
|
if TYPE_CHECKING:
|
17
|
-
from
|
18
|
-
from python.sglang.srt.managers.schedule_batch import ScheduleBatch
|
17
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
19
18
|
from sglang.srt.server_args import ServerArgs
|
20
19
|
|
21
20
|
|
@@ -179,19 +178,9 @@ def generate_draft_decode_kv_indices(
|
|
179
178
|
|
180
179
|
|
181
180
|
class EAGLEDraftInput(SpecInfo):
|
182
|
-
|
183
|
-
verified_id: torch.Tensor = None
|
184
|
-
positions: torch.Tensor = None
|
185
|
-
accept_length: torch.Tensor = None
|
186
|
-
has_finished: bool = False
|
187
|
-
unfinished_index: List[int] = None
|
188
|
-
|
189
|
-
def init(self, server_args: ServerArgs):
|
181
|
+
def __init__(self):
|
190
182
|
self.prev_mode = ForwardMode.DECODE
|
191
183
|
self.sample_output = None
|
192
|
-
self.topk: int = server_args.speculative_eagle_topk
|
193
|
-
self.num_verify_token: int = server_args.speculative_num_draft_tokens
|
194
|
-
self.spec_steps = server_args.speculative_num_steps
|
195
184
|
|
196
185
|
self.scores: torch.Tensor = None
|
197
186
|
self.score_list: List[torch.Tensor] = []
|
@@ -200,11 +189,20 @@ class EAGLEDraftInput(SpecInfo):
|
|
200
189
|
self.parents_list: List[torch.Tensor] = []
|
201
190
|
self.cache_list: List[torch.Tenor] = []
|
202
191
|
self.iter = 0
|
203
|
-
self.root_token: int = None
|
204
192
|
|
205
|
-
|
193
|
+
self.hidden_states: torch.Tensor = None
|
194
|
+
self.verified_id: torch.Tensor = None
|
195
|
+
self.positions: torch.Tensor = None
|
196
|
+
self.accept_length: torch.Tensor = None
|
197
|
+
self.has_finished: bool = False
|
198
|
+
self.unfinished_index: List[int] = None
|
199
|
+
|
200
|
+
def load_server_args(self, server_args: ServerArgs):
|
201
|
+
self.topk: int = server_args.speculative_eagle_topk
|
202
|
+
self.num_verify_token: int = server_args.speculative_num_draft_tokens
|
203
|
+
self.spec_steps = server_args.speculative_num_steps
|
206
204
|
|
207
|
-
def prepare_for_extend(self, batch:
|
205
|
+
def prepare_for_extend(self, batch: ScheduleBatch):
|
208
206
|
req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
|
209
207
|
out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
|
210
208
|
batch.out_cache_loc = out_cache_loc
|
@@ -226,81 +224,73 @@ class EAGLEDraftInput(SpecInfo):
|
|
226
224
|
|
227
225
|
pt += req.extend_input_len
|
228
226
|
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
model_input_ids = []
|
233
|
-
for i in range(len(seq_lens) - 1):
|
234
|
-
model_input_ids.extend(
|
235
|
-
input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]]
|
236
|
-
)
|
237
|
-
batch.input_ids = torch.tensor(
|
238
|
-
model_input_ids, dtype=torch.int32, device="cuda"
|
239
|
-
)
|
240
|
-
|
241
|
-
def capture_for_decode(
|
242
|
-
self,
|
243
|
-
sample_output: SampleOutput,
|
244
|
-
hidden_states: torch.Tensor,
|
245
|
-
prev_mode: ForwardMode,
|
246
|
-
):
|
247
|
-
self.sample_output = sample_output
|
248
|
-
self.prev_mode = prev_mode
|
249
|
-
self.hidden_states = hidden_states
|
227
|
+
# TODO: support batching inputs
|
228
|
+
assert len(batch.extend_lens) == 1
|
229
|
+
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
|
250
230
|
|
251
231
|
def prepare_for_decode(self, batch: ScheduleBatch):
|
252
|
-
prob = self.sample_output # b * (
|
232
|
+
prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
|
253
233
|
top = torch.topk(prob, self.topk, dim=-1)
|
254
|
-
topk_index, topk_p =
|
255
|
-
|
234
|
+
topk_index, topk_p = (
|
235
|
+
top.indices,
|
236
|
+
top.values,
|
237
|
+
) # shape: (b * top_k, top_k) or (b, top_k)
|
238
|
+
|
239
|
+
if self.prev_mode.is_decode():
|
256
240
|
scores = torch.mul(
|
257
241
|
self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk)
|
258
|
-
) # (b, topk)
|
242
|
+
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
|
259
243
|
topk_cs = torch.topk(
|
260
244
|
scores.flatten(start_dim=1), self.topk, dim=-1
|
261
245
|
) # (b, topk)
|
262
246
|
topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
|
263
|
-
self.scores = topk_cs_p
|
264
247
|
|
265
|
-
selected_input_index = topk_cs_index.flatten() // self.topk
|
248
|
+
selected_input_index = topk_cs_index.flatten() // self.topk + torch.arange(
|
249
|
+
0, batch.batch_size() * self.topk, step=self.topk, device="cuda"
|
250
|
+
).repeat_interleave(self.topk)
|
266
251
|
|
267
252
|
batch.spec_info.hidden_states = batch.spec_info.hidden_states[
|
268
253
|
selected_input_index, :
|
269
254
|
]
|
255
|
+
|
270
256
|
topk_index = topk_index.reshape(-1, self.topk**2)
|
271
257
|
batch.input_ids = torch.gather(
|
272
258
|
topk_index, index=topk_cs_index, dim=1
|
273
259
|
).flatten()
|
274
|
-
batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids
|
275
|
-
|
276
|
-
self.
|
260
|
+
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
|
261
|
+
|
262
|
+
self.scores = topk_cs_p
|
263
|
+
self.score_list.append(scores) # (b, topk, topk)
|
264
|
+
self.token_list.append(topk_index) # (b, topk * topk)
|
277
265
|
self.origin_score_list.append(topk_p.reshape(topk_index.shape))
|
278
266
|
self.parents_list.append(
|
279
267
|
topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk)
|
280
|
-
) # b, topk
|
281
|
-
|
282
|
-
|
283
|
-
self.scores = topk_p # b, top_k
|
284
|
-
self.score_list.append(topk_p.unsqueeze(1))
|
285
|
-
self.token_list.append(topk_index)
|
286
|
-
self.origin_score_list.append(topk_p)
|
268
|
+
) # shape: (b, topk)
|
269
|
+
else:
|
270
|
+
# ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND
|
287
271
|
batch.spec_info.hidden_states = (
|
288
|
-
batch.spec_info.hidden_states.repeat_interleave(self.topk, 0)
|
272
|
+
batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0)
|
289
273
|
)
|
274
|
+
|
290
275
|
batch.input_ids = topk_index.flatten()
|
291
276
|
batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel())
|
277
|
+
|
278
|
+
self.scores = topk_p # shape: (b, topk)
|
279
|
+
self.score_list.append(topk_p.unsqueeze(1)) # shape: (b, 1, topk)
|
280
|
+
self.token_list.append(topk_index) # shape: (b, topk)
|
281
|
+
self.origin_score_list.append(topk_p)
|
292
282
|
self.parents_list.append(
|
293
283
|
torch.arange(-1, self.topk, dtype=torch.long, device="cuda")
|
294
284
|
.unsqueeze(0)
|
295
285
|
.repeat(self.scores.shape[0], 1)
|
296
|
-
) # b, topk+1
|
286
|
+
) # shape: (b, topk + 1)
|
297
287
|
self.cache_list.append(batch.out_cache_loc)
|
298
288
|
self.positions = (
|
299
289
|
batch.seq_lens[:, None]
|
300
290
|
+ torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
|
301
291
|
).flatten()
|
302
292
|
|
303
|
-
bs = batch.seq_lens
|
293
|
+
bs = len(batch.seq_lens)
|
304
294
|
assign_req_to_token_pool[(bs,)](
|
305
295
|
batch.req_pool_indices,
|
306
296
|
batch.req_to_token_pool.req_to_token,
|
@@ -347,6 +337,7 @@ class EAGLEDraftInput(SpecInfo):
|
|
347
337
|
triton.next_power_of_2(self.spec_steps + 1),
|
348
338
|
)
|
349
339
|
|
340
|
+
batch.seq_lens_sum = sum(batch.seq_lens)
|
350
341
|
batch.input_ids = self.verified_id
|
351
342
|
self.verified_id = new_verified_id
|
352
343
|
|
@@ -419,11 +410,6 @@ class EAGLEDraftInput(SpecInfo):
|
|
419
410
|
)
|
420
411
|
return bs, kv_indices, cum_kv_seq_len
|
421
412
|
|
422
|
-
def clear(self):
|
423
|
-
self.iter = 0
|
424
|
-
self.score_list.clear()
|
425
|
-
self.positions = None
|
426
|
-
|
427
413
|
def clear_draft_cache(self, batch):
|
428
414
|
draft_cache = torch.cat(self.cache_list, dim=0)
|
429
415
|
batch.token_to_kv_pool.free(draft_cache)
|
@@ -455,12 +441,18 @@ class EAGLEDraftInput(SpecInfo):
|
|
455
441
|
return kv_indices, cum_kv_seq_len, qo_indptr, None
|
456
442
|
|
457
443
|
def merge_batch(self, spec_info: EAGLEDraftInput):
|
458
|
-
|
444
|
+
if self.hidden_states is None:
|
445
|
+
self.hidden_states = spec_info.hidden_states
|
446
|
+
self.verified_id = spec_info.verified_id
|
447
|
+
self.sample_output = spec_info.sample_output
|
448
|
+
self.prev_mode = spec_info.prev_mode
|
449
|
+
return
|
450
|
+
if spec_info.hidden_states is None:
|
451
|
+
return
|
459
452
|
self.hidden_states = torch.cat(
|
460
453
|
[self.hidden_states, spec_info.hidden_states], axis=0
|
461
454
|
)
|
462
455
|
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
|
463
|
-
# self.positions = torch.cat([self.positions, spec_info.positions], axis=0)
|
464
456
|
self.sample_output = torch.cat([self.sample_output, spec_info.sample_output])
|
465
457
|
|
466
458
|
|
@@ -567,11 +559,37 @@ class EagleVerifyInput(SpecInfo):
|
|
567
559
|
triton.next_power_of_2(max_draft_len),
|
568
560
|
)
|
569
561
|
|
570
|
-
accept_index = accept_index[accept_index != -1]
|
571
|
-
# extract_index = extract_index[extract_index != 0]
|
572
|
-
|
573
562
|
draft_input = EAGLEDraftInput()
|
563
|
+
new_accept_index = []
|
564
|
+
unfinished_index = []
|
565
|
+
finished_extend_len = {} # {rid:accept_length + 1}
|
566
|
+
accept_index_cpu = accept_index.tolist()
|
567
|
+
predict_cpu = predict.tolist()
|
568
|
+
# iterate every accepted token and check if req has finished after append the token
|
569
|
+
# should be checked BEFORE free kv cache slots
|
570
|
+
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
571
|
+
new_accept_index_ = []
|
572
|
+
for j, idx in enumerate(accept_index_row):
|
573
|
+
if idx == -1:
|
574
|
+
break
|
575
|
+
id = predict_cpu[idx]
|
576
|
+
# if not found_finished:
|
577
|
+
req.output_ids.append(id)
|
578
|
+
finished_extend_len[req.rid] = j + 1
|
579
|
+
req.check_finished()
|
580
|
+
if req.finished():
|
581
|
+
draft_input.has_finished = True
|
582
|
+
# set all tokens after finished token to -1 and break
|
583
|
+
accept_index[i, j + 1 :] = -1
|
584
|
+
break
|
585
|
+
else:
|
586
|
+
new_accept_index_.append(idx)
|
587
|
+
if not req.finished():
|
588
|
+
new_accept_index.extend(new_accept_index_)
|
589
|
+
unfinished_index.append(i)
|
590
|
+
accept_length = (accept_index != -1).sum(dim=1) - 1
|
574
591
|
|
592
|
+
accept_index = accept_index[accept_index != -1]
|
575
593
|
accept_length_cpu = accept_length.tolist()
|
576
594
|
verified_id = predict[accept_index]
|
577
595
|
verified_id_cpu = verified_id.tolist()
|
@@ -590,29 +608,19 @@ class EagleVerifyInput(SpecInfo):
|
|
590
608
|
triton.next_power_of_2(bs),
|
591
609
|
)
|
592
610
|
batch.seq_lens.add_(accept_length + 1)
|
593
|
-
new_accept_index = []
|
594
|
-
unfinished_index = []
|
595
|
-
finished_extend_len = {} # {rid:accept_length + 1}
|
596
|
-
# retracted_reqs, new_token_ratio = batch.retract_decode()
|
597
|
-
|
598
|
-
low = 0
|
599
|
-
for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)):
|
600
|
-
req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1])
|
601
|
-
req.check_finished()
|
602
|
-
if req.finished():
|
603
|
-
draft_input.has_finished = True
|
604
|
-
else:
|
605
|
-
new_accept_index.append(accept_index[low : low + verified_len + 1])
|
606
|
-
unfinished_index.append(i)
|
607
|
-
low += verified_len + 1
|
608
|
-
finished_extend_len[req.rid] = verified_len + 1
|
609
611
|
|
610
612
|
if len(new_accept_index) > 0:
|
611
|
-
new_accept_index = torch.
|
613
|
+
new_accept_index = torch.tensor(new_accept_index, device="cuda")
|
612
614
|
draft_input.verified_id = predict[new_accept_index]
|
613
615
|
draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
|
614
616
|
draft_input.accept_length = accept_length[unfinished_index]
|
615
617
|
draft_input.unfinished_index = unfinished_index
|
616
618
|
|
617
619
|
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
|
618
|
-
return
|
620
|
+
return (
|
621
|
+
draft_input,
|
622
|
+
logits_output,
|
623
|
+
verified_id,
|
624
|
+
finished_extend_len,
|
625
|
+
accept_length_cpu,
|
626
|
+
)
|
@@ -40,6 +40,7 @@ class EAGLEWorker(TpModelWorker):
|
|
40
40
|
)
|
41
41
|
self.target_worker = target_worker
|
42
42
|
self.server_args = server_args
|
43
|
+
self.finish_extend_len = []
|
43
44
|
|
44
45
|
# Share the embedding and lm_head
|
45
46
|
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
@@ -51,63 +52,72 @@ class EAGLEWorker(TpModelWorker):
|
|
51
52
|
batch.spec_info.prepare_for_decode(batch)
|
52
53
|
model_worker_batch = batch.get_model_worker_batch()
|
53
54
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
54
|
-
forward_batch.
|
55
|
+
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
55
56
|
logits_output = self.model_runner.forward(forward_batch)
|
56
57
|
self.capture_for_decode(logits_output, forward_batch)
|
57
58
|
|
58
59
|
def forward_draft_extend(self, batch: ScheduleBatch):
|
59
|
-
self.
|
60
|
+
self._set_mem_pool(batch, self.model_runner)
|
60
61
|
batch.spec_info.prepare_for_extend(batch)
|
61
62
|
model_worker_batch = batch.get_model_worker_batch()
|
62
63
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
63
|
-
forward_batch.
|
64
|
+
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
64
65
|
logits_output = self.model_runner.forward(forward_batch)
|
65
66
|
self.capture_for_decode(logits_output, forward_batch)
|
66
|
-
self.
|
67
|
+
self._set_mem_pool(batch, self.target_worker.model_runner)
|
67
68
|
|
68
69
|
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
|
69
70
|
if batch.forward_mode.is_decode():
|
70
|
-
|
71
|
-
self.
|
71
|
+
# Draft
|
72
|
+
self._set_mem_pool(batch, self.model_runner)
|
72
73
|
for i in range(self.server_args.speculative_num_steps):
|
73
74
|
self.forward_draft_decode(batch)
|
74
75
|
batch.spec_info.clear_draft_cache(batch)
|
75
|
-
self.
|
76
|
+
self._set_mem_pool(batch, self.target_worker.model_runner)
|
77
|
+
|
78
|
+
# Verify
|
76
79
|
(
|
77
80
|
next_draft_input,
|
78
81
|
logits_output,
|
79
82
|
verified_id,
|
80
83
|
self.finish_extend_len,
|
84
|
+
accept_length_cpu,
|
81
85
|
model_worker_batch,
|
82
86
|
) = self.verify(batch)
|
83
|
-
next_draft_input.
|
87
|
+
next_draft_input.load_server_args(self.server_args)
|
84
88
|
batch.spec_info = next_draft_input
|
85
89
|
# if it is None, means all requsets are finished
|
86
90
|
if batch.spec_info.verified_id is not None:
|
87
|
-
self.
|
88
|
-
|
89
|
-
|
91
|
+
self.forward_draft_extend_after_decode(batch)
|
92
|
+
return (
|
93
|
+
logits_output,
|
94
|
+
verified_id,
|
95
|
+
model_worker_batch,
|
96
|
+
sum(accept_length_cpu),
|
97
|
+
)
|
90
98
|
|
91
99
|
else:
|
92
|
-
|
93
|
-
|
100
|
+
# Forward with the target model and get hidden states.
|
101
|
+
# We need the full hidden states to prefill the KV cache of the draft model.
|
94
102
|
model_worker_batch = batch.get_model_worker_batch()
|
95
|
-
model_worker_batch.
|
96
|
-
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
103
|
+
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
97
104
|
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
98
105
|
model_worker_batch
|
99
106
|
)
|
100
|
-
|
101
|
-
|
107
|
+
|
108
|
+
# Forward with the draft model.
|
109
|
+
spec_info = EAGLEDraftInput()
|
110
|
+
spec_info.load_server_args(self.server_args)
|
111
|
+
spec_info.hidden_states = logits_output.hidden_states
|
112
|
+
spec_info.verified_id = next_token_ids
|
102
113
|
batch.spec_info = spec_info
|
103
114
|
self.forward_draft_extend(batch)
|
104
|
-
|
105
|
-
return logits_output, next_token_ids, model_worker_batch, spec_info
|
115
|
+
return logits_output, next_token_ids, model_worker_batch, 0
|
106
116
|
|
107
117
|
def verify(self, batch: ScheduleBatch):
|
108
118
|
verify_input = batch.spec_info.prepare_for_verify(batch)
|
109
|
-
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
110
119
|
verify_input.prepare_for_verify(batch)
|
120
|
+
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
111
121
|
batch.spec_info = verify_input
|
112
122
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
113
123
|
model_worker_batch = batch.get_model_worker_batch()
|
@@ -119,44 +129,49 @@ class EAGLEWorker(TpModelWorker):
|
|
119
129
|
batch.forward_mode = ForwardMode.DECODE
|
120
130
|
return res + (model_worker_batch,)
|
121
131
|
|
122
|
-
def
|
132
|
+
def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
|
123
133
|
batch.token_to_kv_pool = runner.token_to_kv_pool
|
124
134
|
batch.req_to_token_pool = runner.req_to_token_pool
|
125
135
|
|
126
|
-
def
|
127
|
-
self.
|
136
|
+
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
137
|
+
self._set_mem_pool(batch, self.model_runner)
|
128
138
|
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
129
139
|
if batch.spec_info.has_finished:
|
130
140
|
index = batch.spec_info.unfinished_index
|
131
141
|
seq_lens = batch.seq_lens
|
132
142
|
batch.seq_lens = batch.seq_lens[index]
|
143
|
+
|
133
144
|
batch.spec_info.prepare_extend_after_decode(batch)
|
134
145
|
model_worker_batch = batch.get_model_worker_batch()
|
135
146
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
136
|
-
forward_batch.
|
147
|
+
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
137
148
|
logits_output = self.model_runner.forward(forward_batch)
|
149
|
+
|
138
150
|
batch.spec_info.hidden_states = logits_output.hidden_states
|
139
151
|
self.capture_for_decode(logits_output, forward_batch)
|
140
152
|
batch.forward_mode = ForwardMode.DECODE
|
141
153
|
if batch.spec_info.has_finished:
|
142
154
|
batch.seq_lens = seq_lens
|
143
|
-
self.
|
155
|
+
self._set_mem_pool(batch, self.target_worker.model_runner)
|
144
156
|
|
145
|
-
def capture_for_decode(
|
146
|
-
|
147
|
-
|
157
|
+
def capture_for_decode(
|
158
|
+
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
159
|
+
):
|
148
160
|
sample_output = torch.softmax(
|
149
|
-
|
150
|
-
) # TODO: Support more sampling
|
151
|
-
forward_batch.spec_info
|
152
|
-
|
153
|
-
|
161
|
+
logits_output.next_token_logits, dim=-1
|
162
|
+
) # TODO(kavioyu): Support more sampling methods
|
163
|
+
spec_info = forward_batch.spec_info
|
164
|
+
spec_info.sample_output = sample_output
|
165
|
+
spec_info.hidden_states = logits_output.hidden_states
|
166
|
+
spec_info.prev_mode = forward_batch.forward_mode
|
154
167
|
|
155
168
|
# Don't support prefix share now.
|
156
169
|
def finish_request(self, reqs: Union[Req, List[Req]]):
|
157
170
|
if not isinstance(reqs, List):
|
158
171
|
reqs = [reqs]
|
159
172
|
for req in reqs:
|
173
|
+
if req.rid not in self.finish_extend_len:
|
174
|
+
continue
|
160
175
|
req_len = (
|
161
176
|
len(req.origin_input_ids)
|
162
177
|
+ len(req.output_ids)
|
@@ -0,0 +1,59 @@
|
|
1
|
+
from abc import ABC
|
2
|
+
from contextlib import contextmanager
|
3
|
+
|
4
|
+
try:
|
5
|
+
import torch_memory_saver
|
6
|
+
|
7
|
+
_primary_memory_saver = torch_memory_saver.TorchMemorySaver()
|
8
|
+
except ImportError:
|
9
|
+
pass
|
10
|
+
|
11
|
+
|
12
|
+
class TorchMemorySaverAdapter(ABC):
|
13
|
+
@staticmethod
|
14
|
+
def create(enable: bool):
|
15
|
+
return (
|
16
|
+
_TorchMemorySaverAdapterReal() if enable else _TorchMemorySaverAdapterNoop()
|
17
|
+
)
|
18
|
+
|
19
|
+
def configure_subprocess(self):
|
20
|
+
raise NotImplementedError
|
21
|
+
|
22
|
+
def region(self):
|
23
|
+
raise NotImplementedError
|
24
|
+
|
25
|
+
def pause(self):
|
26
|
+
raise NotImplementedError
|
27
|
+
|
28
|
+
def resume(self):
|
29
|
+
raise NotImplementedError
|
30
|
+
|
31
|
+
|
32
|
+
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
|
33
|
+
def configure_subprocess(self):
|
34
|
+
return torch_memory_saver.configure_subprocess()
|
35
|
+
|
36
|
+
def region(self):
|
37
|
+
return _primary_memory_saver.region()
|
38
|
+
|
39
|
+
def pause(self):
|
40
|
+
return _primary_memory_saver.pause()
|
41
|
+
|
42
|
+
def resume(self):
|
43
|
+
return _primary_memory_saver.resume()
|
44
|
+
|
45
|
+
|
46
|
+
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
47
|
+
@contextmanager
|
48
|
+
def configure_subprocess(self):
|
49
|
+
yield
|
50
|
+
|
51
|
+
@contextmanager
|
52
|
+
def region(self):
|
53
|
+
yield
|
54
|
+
|
55
|
+
def pause(self):
|
56
|
+
pass
|
57
|
+
|
58
|
+
def resume(self):
|
59
|
+
pass
|
sglang/srt/utils.py
CHANGED
@@ -97,6 +97,10 @@ def is_flashinfer_available():
|
|
97
97
|
return torch.cuda.is_available() and torch.version.cuda
|
98
98
|
|
99
99
|
|
100
|
+
def is_cuda_available():
|
101
|
+
return torch.cuda.is_available() and torch.version.cuda
|
102
|
+
|
103
|
+
|
100
104
|
def is_ipv6(address):
|
101
105
|
try:
|
102
106
|
ipaddress.IPv6Address(address)
|
@@ -335,6 +339,8 @@ def is_port_available(port):
|
|
335
339
|
return True
|
336
340
|
except socket.error:
|
337
341
|
return False
|
342
|
+
except OverflowError:
|
343
|
+
return False
|
338
344
|
|
339
345
|
|
340
346
|
def decode_video_base64(video_base64):
|
@@ -709,13 +715,14 @@ def broadcast_pyobj(
|
|
709
715
|
data: List[Any],
|
710
716
|
rank: int,
|
711
717
|
dist_group: Optional[torch.distributed.ProcessGroup] = None,
|
718
|
+
src: int = 0,
|
712
719
|
):
|
713
720
|
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
714
721
|
|
715
722
|
if rank == 0:
|
716
723
|
if len(data) == 0:
|
717
724
|
tensor_size = torch.tensor([0], dtype=torch.long)
|
718
|
-
dist.broadcast(tensor_size, src=
|
725
|
+
dist.broadcast(tensor_size, src=src, group=dist_group)
|
719
726
|
else:
|
720
727
|
serialized_data = pickle.dumps(data)
|
721
728
|
size = len(serialized_data)
|
@@ -724,19 +731,19 @@ def broadcast_pyobj(
|
|
724
731
|
)
|
725
732
|
tensor_size = torch.tensor([size], dtype=torch.long)
|
726
733
|
|
727
|
-
dist.broadcast(tensor_size, src=
|
728
|
-
dist.broadcast(tensor_data, src=
|
734
|
+
dist.broadcast(tensor_size, src=src, group=dist_group)
|
735
|
+
dist.broadcast(tensor_data, src=src, group=dist_group)
|
729
736
|
return data
|
730
737
|
else:
|
731
738
|
tensor_size = torch.tensor([0], dtype=torch.long)
|
732
|
-
dist.broadcast(tensor_size, src=
|
739
|
+
dist.broadcast(tensor_size, src=src, group=dist_group)
|
733
740
|
size = tensor_size.item()
|
734
741
|
|
735
742
|
if size == 0:
|
736
743
|
return []
|
737
744
|
|
738
745
|
tensor_data = torch.empty(size, dtype=torch.uint8)
|
739
|
-
dist.broadcast(tensor_data, src=
|
746
|
+
dist.broadcast(tensor_data, src=src, group=dist_group)
|
740
747
|
|
741
748
|
serialized_data = bytes(tensor_data.cpu().numpy())
|
742
749
|
data = pickle.loads(serialized_data)
|
@@ -1337,6 +1344,25 @@ def parse_tool_response(text, tools, **kwargs):
|
|
1337
1344
|
return text, call_info_list
|
1338
1345
|
|
1339
1346
|
|
1347
|
+
def permute_weight(x: torch.Tensor) -> torch.Tensor:
|
1348
|
+
b_ = x.shape[0]
|
1349
|
+
n_ = x.shape[1]
|
1350
|
+
k_ = x.shape[2]
|
1351
|
+
|
1352
|
+
x_ = x
|
1353
|
+
if x.dtype == torch.bfloat16 or x.dtype == torch.float16:
|
1354
|
+
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 32), 4, 8)
|
1355
|
+
elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
|
1356
|
+
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
|
1357
|
+
else:
|
1358
|
+
return x_
|
1359
|
+
|
1360
|
+
x_ = x_.permute(0, 1, 3, 4, 2, 5)
|
1361
|
+
x_ = x_.contiguous()
|
1362
|
+
x_ = x_.view(*x.shape)
|
1363
|
+
return x_
|
1364
|
+
|
1365
|
+
|
1340
1366
|
class MultiprocessingSerializer:
|
1341
1367
|
@staticmethod
|
1342
1368
|
def serialize(obj):
|
@@ -1348,3 +1374,33 @@ class MultiprocessingSerializer:
|
|
1348
1374
|
@staticmethod
|
1349
1375
|
def deserialize(data):
|
1350
1376
|
return ForkingPickler.loads(data)
|
1377
|
+
|
1378
|
+
|
1379
|
+
def debug_timing(func):
|
1380
|
+
# todo: replace with a more organized instrumentation
|
1381
|
+
def wrapper(*args, **kwargs):
|
1382
|
+
if logger.isEnabledFor(logging.DEBUG):
|
1383
|
+
tic = torch.cuda.Event(enable_timing=True)
|
1384
|
+
toc = torch.cuda.Event(enable_timing=True)
|
1385
|
+
tic.record()
|
1386
|
+
result = func(*args, **kwargs)
|
1387
|
+
toc.record()
|
1388
|
+
torch.cuda.synchronize() # Ensure all CUDA operations are complete
|
1389
|
+
elapsed = tic.elapsed_time(toc)
|
1390
|
+
indices = kwargs.get("indices", args[1] if len(args) > 1 else None)
|
1391
|
+
num_tokens = len(indices) if indices is not None else 0
|
1392
|
+
throughput = num_tokens / elapsed * 1000 if elapsed > 0 else 0
|
1393
|
+
logger.debug(
|
1394
|
+
f"Transfer time: {elapsed} ms, throughput: {throughput} tokens/s"
|
1395
|
+
)
|
1396
|
+
return result
|
1397
|
+
else:
|
1398
|
+
return func(*args, **kwargs)
|
1399
|
+
|
1400
|
+
return wrapper
|
1401
|
+
|
1402
|
+
|
1403
|
+
def nullable_str(val: str):
|
1404
|
+
if not val or val == "None":
|
1405
|
+
return None
|
1406
|
+
return val
|
sglang/test/test_programs.py
CHANGED
@@ -509,13 +509,35 @@ def test_hellaswag_select():
|
|
509
509
|
temperature=0,
|
510
510
|
num_threads=64,
|
511
511
|
progress_bar=True,
|
512
|
+
generator_style=False,
|
512
513
|
)
|
513
|
-
preds = [
|
514
|
+
preds = []
|
515
|
+
for i, ret in enumerate(rets):
|
516
|
+
preds.append(choices[i].index(ret["answer"]))
|
514
517
|
latency = time.time() - tic
|
515
518
|
|
516
519
|
# Compute accuracy
|
517
520
|
accuracy = np.mean(np.array(preds) == np.array(labels))
|
518
521
|
|
522
|
+
# Test generator style of run_batch
|
523
|
+
tic = time.time()
|
524
|
+
rets = few_shot_hellaswag.run_batch(
|
525
|
+
arguments,
|
526
|
+
temperature=0,
|
527
|
+
num_threads=64,
|
528
|
+
progress_bar=True,
|
529
|
+
generator_style=True,
|
530
|
+
)
|
531
|
+
preds_gen = []
|
532
|
+
for i, ret in enumerate(rets):
|
533
|
+
preds_gen.append(choices[i].index(ret["answer"]))
|
534
|
+
latency_gen = time.time() - tic
|
535
|
+
|
536
|
+
# Compute accuracy
|
537
|
+
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
|
538
|
+
assert np.abs(accuracy_gen - accuracy) < 0.01
|
539
|
+
assert np.abs(latency_gen - latency) < 1
|
540
|
+
|
519
541
|
return accuracy, latency
|
520
542
|
|
521
543
|
|