sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. sglang/bench_serving.py +18 -1
  2. sglang/lang/interpreter.py +71 -1
  3. sglang/lang/ir.py +2 -0
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/chatglm.py +78 -0
  6. sglang/srt/configs/dbrx.py +279 -0
  7. sglang/srt/configs/model_config.py +1 -1
  8. sglang/srt/hf_transformers_utils.py +9 -14
  9. sglang/srt/layers/attention/__init__.py +8 -1
  10. sglang/srt/layers/attention/flashinfer_backend.py +4 -2
  11. sglang/srt/layers/linear.py +159 -55
  12. sglang/srt/layers/logits_processor.py +6 -6
  13. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
  14. sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -3
  15. sglang/srt/layers/parameter.py +431 -0
  16. sglang/srt/layers/quantization/__init__.py +3 -2
  17. sglang/srt/layers/quantization/fp8.py +1 -1
  18. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  19. sglang/srt/layers/vocab_parallel_embedding.py +1 -1
  20. sglang/srt/managers/cache_controller.py +307 -0
  21. sglang/srt/managers/data_parallel_controller.py +2 -0
  22. sglang/srt/managers/schedule_batch.py +7 -1
  23. sglang/srt/managers/scheduler.py +10 -6
  24. sglang/srt/managers/session_controller.py +1 -1
  25. sglang/srt/managers/tokenizer_manager.py +6 -2
  26. sglang/srt/mem_cache/memory_pool.py +206 -1
  27. sglang/srt/metrics/collector.py +22 -30
  28. sglang/srt/model_executor/cuda_graph_runner.py +14 -7
  29. sglang/srt/model_executor/forward_batch_info.py +20 -15
  30. sglang/srt/model_executor/model_runner.py +10 -4
  31. sglang/srt/models/chatglm.py +1 -1
  32. sglang/srt/models/dbrx.py +1 -1
  33. sglang/srt/models/grok.py +25 -16
  34. sglang/srt/models/llama.py +9 -2
  35. sglang/srt/sampling/sampling_batch_info.py +1 -0
  36. sglang/srt/server.py +11 -8
  37. sglang/srt/server_args.py +12 -1
  38. sglang/srt/speculative/eagle_utils.py +93 -85
  39. sglang/srt/speculative/eagle_worker.py +47 -33
  40. sglang/srt/utils.py +32 -5
  41. sglang/test/test_programs.py +23 -1
  42. sglang/test/test_utils.py +36 -7
  43. sglang/version.py +1 -1
  44. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +6 -7
  45. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +48 -43
  46. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
  47. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
  48. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -9,13 +9,12 @@ import triton.language as tl
9
9
  from sglang.srt.layers.attention.flashinfer_backend import (
10
10
  create_flashinfer_kv_indices_triton,
11
11
  )
12
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
12
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
13
13
  from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
14
14
  from sglang.srt.speculative.spec_info import SpecInfo
15
15
 
16
16
  if TYPE_CHECKING:
17
- from python.sglang.srt.layers.sampler import SampleOutput
18
- from python.sglang.srt.managers.schedule_batch import ScheduleBatch
17
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
19
18
  from sglang.srt.server_args import ServerArgs
20
19
 
21
20
 
@@ -179,19 +178,9 @@ def generate_draft_decode_kv_indices(
179
178
 
180
179
 
181
180
  class EAGLEDraftInput(SpecInfo):
182
- hidden_states: torch.Tensor = None
183
- verified_id: torch.Tensor = None
184
- positions: torch.Tensor = None
185
- accept_length: torch.Tensor = None
186
- has_finished: bool = False
187
- unfinished_index: List[int] = None
188
-
189
- def init(self, server_args: ServerArgs):
181
+ def __init__(self):
190
182
  self.prev_mode = ForwardMode.DECODE
191
183
  self.sample_output = None
192
- self.topk: int = server_args.speculative_eagle_topk
193
- self.num_verify_token: int = server_args.speculative_num_draft_tokens
194
- self.spec_steps = server_args.speculative_num_steps
195
184
 
196
185
  self.scores: torch.Tensor = None
197
186
  self.score_list: List[torch.Tensor] = []
@@ -200,11 +189,20 @@ class EAGLEDraftInput(SpecInfo):
200
189
  self.parents_list: List[torch.Tensor] = []
201
190
  self.cache_list: List[torch.Tenor] = []
202
191
  self.iter = 0
203
- self.root_token: int = None
204
192
 
205
- assert self.topk <= 10, "topk should <= 10"
193
+ self.hidden_states: torch.Tensor = None
194
+ self.verified_id: torch.Tensor = None
195
+ self.positions: torch.Tensor = None
196
+ self.accept_length: torch.Tensor = None
197
+ self.has_finished: bool = False
198
+ self.unfinished_index: List[int] = None
199
+
200
+ def load_server_args(self, server_args: ServerArgs):
201
+ self.topk: int = server_args.speculative_eagle_topk
202
+ self.num_verify_token: int = server_args.speculative_num_draft_tokens
203
+ self.spec_steps = server_args.speculative_num_steps
206
204
 
207
- def prepare_for_extend(self, batch: ForwardBatch):
205
+ def prepare_for_extend(self, batch: ScheduleBatch):
208
206
  req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
209
207
  out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
210
208
  batch.out_cache_loc = out_cache_loc
@@ -226,81 +224,73 @@ class EAGLEDraftInput(SpecInfo):
226
224
 
227
225
  pt += req.extend_input_len
228
226
 
229
- seq_lens = [0] + batch.extend_lens
230
- input_ids = batch.input_ids.tolist()
231
- verified_id = batch.spec_info.verified_id.tolist()
232
- model_input_ids = []
233
- for i in range(len(seq_lens) - 1):
234
- model_input_ids.extend(
235
- input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]]
236
- )
237
- batch.input_ids = torch.tensor(
238
- model_input_ids, dtype=torch.int32, device="cuda"
239
- )
240
-
241
- def capture_for_decode(
242
- self,
243
- sample_output: SampleOutput,
244
- hidden_states: torch.Tensor,
245
- prev_mode: ForwardMode,
246
- ):
247
- self.sample_output = sample_output
248
- self.prev_mode = prev_mode
249
- self.hidden_states = hidden_states
227
+ # TODO: support batching inputs
228
+ assert len(batch.extend_lens) == 1
229
+ batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
250
230
 
251
231
  def prepare_for_decode(self, batch: ScheduleBatch):
252
- prob = self.sample_output # b * (1/topk), vocab
232
+ prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
253
233
  top = torch.topk(prob, self.topk, dim=-1)
254
- topk_index, topk_p = top.indices, top.values # b * (1/topk), topk
255
- if self.prev_mode == ForwardMode.DECODE:
234
+ topk_index, topk_p = (
235
+ top.indices,
236
+ top.values,
237
+ ) # shape: (b * top_k, top_k) or (b, top_k)
238
+
239
+ if self.prev_mode.is_decode():
256
240
  scores = torch.mul(
257
241
  self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk)
258
- ) # (b, topk) mul (b * topk ,topk) -> b, topk, topk
242
+ ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
259
243
  topk_cs = torch.topk(
260
244
  scores.flatten(start_dim=1), self.topk, dim=-1
261
245
  ) # (b, topk)
262
246
  topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
263
- self.scores = topk_cs_p
264
247
 
265
- selected_input_index = topk_cs_index.flatten() // self.topk # b* topk
248
+ selected_input_index = topk_cs_index.flatten() // self.topk + torch.arange(
249
+ 0, batch.batch_size() * self.topk, step=self.topk, device="cuda"
250
+ ).repeat_interleave(self.topk)
266
251
 
267
252
  batch.spec_info.hidden_states = batch.spec_info.hidden_states[
268
253
  selected_input_index, :
269
254
  ]
255
+
270
256
  topk_index = topk_index.reshape(-1, self.topk**2)
271
257
  batch.input_ids = torch.gather(
272
258
  topk_index, index=topk_cs_index, dim=1
273
259
  ).flatten()
274
- batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
275
- self.score_list.append(scores) # b, topk, topk
276
- self.token_list.append(topk_index) # b, topk*topk
260
+ batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
261
+
262
+ self.scores = topk_cs_p
263
+ self.score_list.append(scores) # (b, topk, topk)
264
+ self.token_list.append(topk_index) # (b, topk * topk)
277
265
  self.origin_score_list.append(topk_p.reshape(topk_index.shape))
278
266
  self.parents_list.append(
279
267
  topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk)
280
- ) # b, topk
281
-
282
- elif self.prev_mode in (ForwardMode.EXTEND, ForwardMode.DRAFT_EXTEND):
283
- self.scores = topk_p # b, top_k
284
- self.score_list.append(topk_p.unsqueeze(1))
285
- self.token_list.append(topk_index)
286
- self.origin_score_list.append(topk_p)
268
+ ) # shape: (b, topk)
269
+ else:
270
+ # ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND
287
271
  batch.spec_info.hidden_states = (
288
- batch.spec_info.hidden_states.repeat_interleave(self.topk, 0)
272
+ batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0)
289
273
  )
274
+
290
275
  batch.input_ids = topk_index.flatten()
291
276
  batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel())
277
+
278
+ self.scores = topk_p # shape: (b, topk)
279
+ self.score_list.append(topk_p.unsqueeze(1)) # shape: (b, 1, topk)
280
+ self.token_list.append(topk_index) # shape: (b, topk)
281
+ self.origin_score_list.append(topk_p)
292
282
  self.parents_list.append(
293
283
  torch.arange(-1, self.topk, dtype=torch.long, device="cuda")
294
284
  .unsqueeze(0)
295
285
  .repeat(self.scores.shape[0], 1)
296
- ) # b, topk+1
286
+ ) # shape: (b, topk + 1)
297
287
  self.cache_list.append(batch.out_cache_loc)
298
288
  self.positions = (
299
289
  batch.seq_lens[:, None]
300
290
  + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
301
291
  ).flatten()
302
292
 
303
- bs = batch.seq_lens.numel()
293
+ bs = len(batch.seq_lens)
304
294
  assign_req_to_token_pool[(bs,)](
305
295
  batch.req_pool_indices,
306
296
  batch.req_to_token_pool.req_to_token,
@@ -347,6 +337,7 @@ class EAGLEDraftInput(SpecInfo):
347
337
  triton.next_power_of_2(self.spec_steps + 1),
348
338
  )
349
339
 
340
+ batch.seq_lens_sum = sum(batch.seq_lens)
350
341
  batch.input_ids = self.verified_id
351
342
  self.verified_id = new_verified_id
352
343
 
@@ -419,11 +410,6 @@ class EAGLEDraftInput(SpecInfo):
419
410
  )
420
411
  return bs, kv_indices, cum_kv_seq_len
421
412
 
422
- def clear(self):
423
- self.iter = 0
424
- self.score_list.clear()
425
- self.positions = None
426
-
427
413
  def clear_draft_cache(self, batch):
428
414
  draft_cache = torch.cat(self.cache_list, dim=0)
429
415
  batch.token_to_kv_pool.free(draft_cache)
@@ -455,12 +441,18 @@ class EAGLEDraftInput(SpecInfo):
455
441
  return kv_indices, cum_kv_seq_len, qo_indptr, None
456
442
 
457
443
  def merge_batch(self, spec_info: EAGLEDraftInput):
458
-
444
+ if self.hidden_states is None:
445
+ self.hidden_states = spec_info.hidden_states
446
+ self.verified_id = spec_info.verified_id
447
+ self.sample_output = spec_info.sample_output
448
+ self.prev_mode = spec_info.prev_mode
449
+ return
450
+ if spec_info.hidden_states is None:
451
+ return
459
452
  self.hidden_states = torch.cat(
460
453
  [self.hidden_states, spec_info.hidden_states], axis=0
461
454
  )
462
455
  self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
463
- # self.positions = torch.cat([self.positions, spec_info.positions], axis=0)
464
456
  self.sample_output = torch.cat([self.sample_output, spec_info.sample_output])
465
457
 
466
458
 
@@ -567,11 +559,37 @@ class EagleVerifyInput(SpecInfo):
567
559
  triton.next_power_of_2(max_draft_len),
568
560
  )
569
561
 
570
- accept_index = accept_index[accept_index != -1]
571
- # extract_index = extract_index[extract_index != 0]
572
-
573
562
  draft_input = EAGLEDraftInput()
563
+ new_accept_index = []
564
+ unfinished_index = []
565
+ finished_extend_len = {} # {rid:accept_length + 1}
566
+ accept_index_cpu = accept_index.tolist()
567
+ predict_cpu = predict.tolist()
568
+ # iterate every accepted token and check if req has finished after append the token
569
+ # should be checked BEFORE free kv cache slots
570
+ for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
571
+ new_accept_index_ = []
572
+ for j, idx in enumerate(accept_index_row):
573
+ if idx == -1:
574
+ break
575
+ id = predict_cpu[idx]
576
+ # if not found_finished:
577
+ req.output_ids.append(id)
578
+ finished_extend_len[req.rid] = j + 1
579
+ req.check_finished()
580
+ if req.finished():
581
+ draft_input.has_finished = True
582
+ # set all tokens after finished token to -1 and break
583
+ accept_index[i, j + 1 :] = -1
584
+ break
585
+ else:
586
+ new_accept_index_.append(idx)
587
+ if not req.finished():
588
+ new_accept_index.extend(new_accept_index_)
589
+ unfinished_index.append(i)
590
+ accept_length = (accept_index != -1).sum(dim=1) - 1
574
591
 
592
+ accept_index = accept_index[accept_index != -1]
575
593
  accept_length_cpu = accept_length.tolist()
576
594
  verified_id = predict[accept_index]
577
595
  verified_id_cpu = verified_id.tolist()
@@ -590,29 +608,19 @@ class EagleVerifyInput(SpecInfo):
590
608
  triton.next_power_of_2(bs),
591
609
  )
592
610
  batch.seq_lens.add_(accept_length + 1)
593
- new_accept_index = []
594
- unfinished_index = []
595
- finished_extend_len = {} # {rid:accept_length + 1}
596
- # retracted_reqs, new_token_ratio = batch.retract_decode()
597
-
598
- low = 0
599
- for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)):
600
- req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1])
601
- req.check_finished()
602
- if req.finished():
603
- draft_input.has_finished = True
604
- else:
605
- new_accept_index.append(accept_index[low : low + verified_len + 1])
606
- unfinished_index.append(i)
607
- low += verified_len + 1
608
- finished_extend_len[req.rid] = verified_len + 1
609
611
 
610
612
  if len(new_accept_index) > 0:
611
- new_accept_index = torch.cat(new_accept_index, dim=0)
613
+ new_accept_index = torch.tensor(new_accept_index, device="cuda")
612
614
  draft_input.verified_id = predict[new_accept_index]
613
615
  draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
614
616
  draft_input.accept_length = accept_length[unfinished_index]
615
617
  draft_input.unfinished_index = unfinished_index
616
618
 
617
619
  logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
618
- return draft_input, logits_output, verified_id, finished_extend_len
620
+ return (
621
+ draft_input,
622
+ logits_output,
623
+ verified_id,
624
+ finished_extend_len,
625
+ accept_length_cpu,
626
+ )
@@ -51,63 +51,72 @@ class EAGLEWorker(TpModelWorker):
51
51
  batch.spec_info.prepare_for_decode(batch)
52
52
  model_worker_batch = batch.get_model_worker_batch()
53
53
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
54
- forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
54
+ forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
55
55
  logits_output = self.model_runner.forward(forward_batch)
56
56
  self.capture_for_decode(logits_output, forward_batch)
57
57
 
58
58
  def forward_draft_extend(self, batch: ScheduleBatch):
59
- self._swap_mem_pool(batch, self.model_runner)
59
+ self._set_mem_pool(batch, self.model_runner)
60
60
  batch.spec_info.prepare_for_extend(batch)
61
61
  model_worker_batch = batch.get_model_worker_batch()
62
62
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
63
- forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
63
+ forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
64
64
  logits_output = self.model_runner.forward(forward_batch)
65
65
  self.capture_for_decode(logits_output, forward_batch)
66
- self._swap_mem_pool(batch, self.target_worker.model_runner)
66
+ self._set_mem_pool(batch, self.target_worker.model_runner)
67
67
 
68
68
  def forward_batch_speculative_generation(self, batch: ScheduleBatch):
69
69
  if batch.forward_mode.is_decode():
70
- prev_spec_info = batch.spec_info
71
- self._swap_mem_pool(batch, self.model_runner)
70
+ # Draft
71
+ self._set_mem_pool(batch, self.model_runner)
72
72
  for i in range(self.server_args.speculative_num_steps):
73
73
  self.forward_draft_decode(batch)
74
74
  batch.spec_info.clear_draft_cache(batch)
75
- self._swap_mem_pool(batch, self.target_worker.model_runner)
75
+ self._set_mem_pool(batch, self.target_worker.model_runner)
76
+
77
+ # Verify
76
78
  (
77
79
  next_draft_input,
78
80
  logits_output,
79
81
  verified_id,
80
82
  self.finish_extend_len,
83
+ accept_length_cpu,
81
84
  model_worker_batch,
82
85
  ) = self.verify(batch)
83
- next_draft_input.init(self.server_args)
86
+ next_draft_input.load_server_args(self.server_args)
84
87
  batch.spec_info = next_draft_input
85
88
  # if it is None, means all requsets are finished
86
89
  if batch.spec_info.verified_id is not None:
87
- self.forward_extend_after_decode(batch)
88
- batch.spec_info = prev_spec_info
89
- return logits_output, verified_id, model_worker_batch, next_draft_input
90
+ self.forward_draft_extend_after_decode(batch)
91
+ return (
92
+ logits_output,
93
+ verified_id,
94
+ model_worker_batch,
95
+ sum(accept_length_cpu),
96
+ )
90
97
 
91
98
  else:
92
- spec_info = EAGLEDraftInput()
93
- spec_info.init(self.server_args)
99
+ # Forward with the target model and get hidden states.
100
+ # We need the full hidden states to prefill the KV cache of the draft model.
94
101
  model_worker_batch = batch.get_model_worker_batch()
95
- model_worker_batch.spec_info = spec_info
96
- spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
102
+ model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
97
103
  logits_output, next_token_ids = self.target_worker.forward_batch_generation(
98
104
  model_worker_batch
99
105
  )
100
- model_worker_batch.spec_info.verified_id = next_token_ids
101
- model_worker_batch.spec_info.hidden_states = logits_output.hidden_states
106
+
107
+ # Forward with the draft model.
108
+ spec_info = EAGLEDraftInput()
109
+ spec_info.load_server_args(self.server_args)
110
+ spec_info.hidden_states = logits_output.hidden_states
111
+ spec_info.verified_id = next_token_ids
102
112
  batch.spec_info = spec_info
103
113
  self.forward_draft_extend(batch)
104
- batch.spec_info = None
105
- return logits_output, next_token_ids, model_worker_batch, spec_info
114
+ return logits_output, next_token_ids, model_worker_batch, 0
106
115
 
107
116
  def verify(self, batch: ScheduleBatch):
108
117
  verify_input = batch.spec_info.prepare_for_verify(batch)
109
- batch.forward_mode = ForwardMode.TARGET_VERIFY
110
118
  verify_input.prepare_for_verify(batch)
119
+ batch.forward_mode = ForwardMode.TARGET_VERIFY
111
120
  batch.spec_info = verify_input
112
121
  batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
113
122
  model_worker_batch = batch.get_model_worker_batch()
@@ -119,44 +128,49 @@ class EAGLEWorker(TpModelWorker):
119
128
  batch.forward_mode = ForwardMode.DECODE
120
129
  return res + (model_worker_batch,)
121
130
 
122
- def _swap_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
131
+ def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
123
132
  batch.token_to_kv_pool = runner.token_to_kv_pool
124
133
  batch.req_to_token_pool = runner.req_to_token_pool
125
134
 
126
- def forward_extend_after_decode(self, batch: ScheduleBatch):
127
- self._swap_mem_pool(batch, self.model_runner)
135
+ def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
136
+ self._set_mem_pool(batch, self.model_runner)
128
137
  batch.forward_mode = ForwardMode.DRAFT_EXTEND
129
138
  if batch.spec_info.has_finished:
130
139
  index = batch.spec_info.unfinished_index
131
140
  seq_lens = batch.seq_lens
132
141
  batch.seq_lens = batch.seq_lens[index]
142
+
133
143
  batch.spec_info.prepare_extend_after_decode(batch)
134
144
  model_worker_batch = batch.get_model_worker_batch()
135
145
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
136
- forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
146
+ forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
137
147
  logits_output = self.model_runner.forward(forward_batch)
148
+
138
149
  batch.spec_info.hidden_states = logits_output.hidden_states
139
150
  self.capture_for_decode(logits_output, forward_batch)
140
151
  batch.forward_mode = ForwardMode.DECODE
141
152
  if batch.spec_info.has_finished:
142
153
  batch.seq_lens = seq_lens
143
- self._swap_mem_pool(batch, self.target_worker.model_runner)
154
+ self._set_mem_pool(batch, self.target_worker.model_runner)
144
155
 
145
- def capture_for_decode(self, logits_output, forward_batch):
146
- if isinstance(logits_output, LogitsProcessorOutput):
147
- logits = logits_output.next_token_logits
156
+ def capture_for_decode(
157
+ self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
158
+ ):
148
159
  sample_output = torch.softmax(
149
- logits, dim=-1
150
- ) # TODO: Support more sampling method @kavioyu
151
- forward_batch.spec_info.capture_for_decode(
152
- sample_output, logits_output.hidden_states, forward_batch.forward_mode
153
- )
160
+ logits_output.next_token_logits, dim=-1
161
+ ) # TODO(kavioyu): Support more sampling methods
162
+ spec_info = forward_batch.spec_info
163
+ spec_info.sample_output = sample_output
164
+ spec_info.hidden_states = logits_output.hidden_states
165
+ spec_info.prev_mode = forward_batch.forward_mode
154
166
 
155
167
  # Don't support prefix share now.
156
168
  def finish_request(self, reqs: Union[Req, List[Req]]):
157
169
  if not isinstance(reqs, List):
158
170
  reqs = [reqs]
159
171
  for req in reqs:
172
+ if req.rid not in self.finish_extend_len:
173
+ continue
160
174
  req_len = (
161
175
  len(req.origin_input_ids)
162
176
  + len(req.output_ids)
sglang/srt/utils.py CHANGED
@@ -335,6 +335,8 @@ def is_port_available(port):
335
335
  return True
336
336
  except socket.error:
337
337
  return False
338
+ except OverflowError:
339
+ return False
338
340
 
339
341
 
340
342
  def decode_video_base64(video_base64):
@@ -709,13 +711,14 @@ def broadcast_pyobj(
709
711
  data: List[Any],
710
712
  rank: int,
711
713
  dist_group: Optional[torch.distributed.ProcessGroup] = None,
714
+ src: int = 0,
712
715
  ):
713
716
  """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
714
717
 
715
718
  if rank == 0:
716
719
  if len(data) == 0:
717
720
  tensor_size = torch.tensor([0], dtype=torch.long)
718
- dist.broadcast(tensor_size, src=0, group=dist_group)
721
+ dist.broadcast(tensor_size, src=src, group=dist_group)
719
722
  else:
720
723
  serialized_data = pickle.dumps(data)
721
724
  size = len(serialized_data)
@@ -724,19 +727,19 @@ def broadcast_pyobj(
724
727
  )
725
728
  tensor_size = torch.tensor([size], dtype=torch.long)
726
729
 
727
- dist.broadcast(tensor_size, src=0, group=dist_group)
728
- dist.broadcast(tensor_data, src=0, group=dist_group)
730
+ dist.broadcast(tensor_size, src=src, group=dist_group)
731
+ dist.broadcast(tensor_data, src=src, group=dist_group)
729
732
  return data
730
733
  else:
731
734
  tensor_size = torch.tensor([0], dtype=torch.long)
732
- dist.broadcast(tensor_size, src=0, group=dist_group)
735
+ dist.broadcast(tensor_size, src=src, group=dist_group)
733
736
  size = tensor_size.item()
734
737
 
735
738
  if size == 0:
736
739
  return []
737
740
 
738
741
  tensor_data = torch.empty(size, dtype=torch.uint8)
739
- dist.broadcast(tensor_data, src=0, group=dist_group)
742
+ dist.broadcast(tensor_data, src=src, group=dist_group)
740
743
 
741
744
  serialized_data = bytes(tensor_data.cpu().numpy())
742
745
  data = pickle.loads(serialized_data)
@@ -1348,3 +1351,27 @@ class MultiprocessingSerializer:
1348
1351
  @staticmethod
1349
1352
  def deserialize(data):
1350
1353
  return ForkingPickler.loads(data)
1354
+
1355
+
1356
+ def debug_timing(func):
1357
+ # todo: replace with a more organized instrumentation
1358
+ def wrapper(*args, **kwargs):
1359
+ if logger.isEnabledFor(logging.DEBUG):
1360
+ tic = torch.cuda.Event(enable_timing=True)
1361
+ toc = torch.cuda.Event(enable_timing=True)
1362
+ tic.record()
1363
+ result = func(*args, **kwargs)
1364
+ toc.record()
1365
+ torch.cuda.synchronize() # Ensure all CUDA operations are complete
1366
+ elapsed = tic.elapsed_time(toc)
1367
+ indices = kwargs.get("indices", args[1] if len(args) > 1 else None)
1368
+ num_tokens = len(indices) if indices is not None else 0
1369
+ throughput = num_tokens / elapsed * 1000 if elapsed > 0 else 0
1370
+ logger.debug(
1371
+ f"Transfer time: {elapsed} ms, throughput: {throughput} tokens/s"
1372
+ )
1373
+ return result
1374
+ else:
1375
+ return func(*args, **kwargs)
1376
+
1377
+ return wrapper
@@ -509,13 +509,35 @@ def test_hellaswag_select():
509
509
  temperature=0,
510
510
  num_threads=64,
511
511
  progress_bar=True,
512
+ generator_style=False,
512
513
  )
513
- preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
514
+ preds = []
515
+ for i, ret in enumerate(rets):
516
+ preds.append(choices[i].index(ret["answer"]))
514
517
  latency = time.time() - tic
515
518
 
516
519
  # Compute accuracy
517
520
  accuracy = np.mean(np.array(preds) == np.array(labels))
518
521
 
522
+ # Test generator style of run_batch
523
+ tic = time.time()
524
+ rets = few_shot_hellaswag.run_batch(
525
+ arguments,
526
+ temperature=0,
527
+ num_threads=64,
528
+ progress_bar=True,
529
+ generator_style=True,
530
+ )
531
+ preds_gen = []
532
+ for i, ret in enumerate(rets):
533
+ preds_gen.append(choices[i].index(ret["answer"]))
534
+ latency_gen = time.time() - tic
535
+
536
+ # Compute accuracy
537
+ accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
538
+ assert np.abs(accuracy_gen - accuracy) < 0.01
539
+ assert np.abs(latency_gen - latency) < 1
540
+
519
541
  return accuracy, latency
520
542
 
521
543
 
sglang/test/test_utils.py CHANGED
@@ -36,7 +36,7 @@ DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
36
36
  DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
37
37
  DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600
38
38
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it"
39
- DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
39
+ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct"
40
40
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
41
41
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
42
42
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
@@ -532,6 +532,8 @@ def run_bench_serving(
532
532
  request_rate,
533
533
  other_server_args,
534
534
  dataset_name="random",
535
+ dataset_path="",
536
+ tokenizer=None,
535
537
  random_input_len=4096,
536
538
  random_output_len=2048,
537
539
  disable_stream=False,
@@ -553,9 +555,9 @@ def run_bench_serving(
553
555
  host=None,
554
556
  port=None,
555
557
  dataset_name=dataset_name,
556
- dataset_path="",
558
+ dataset_path=dataset_path,
557
559
  model=None,
558
- tokenizer=None,
560
+ tokenizer=tokenizer,
559
561
  num_prompts=num_prompts,
560
562
  sharegpt_output_len=None,
561
563
  random_input_len=random_input_len,
@@ -657,16 +659,16 @@ STDERR_FILENAME = "stderr.txt"
657
659
  STDOUT_FILENAME = "stdout.txt"
658
660
 
659
661
 
660
- def read_output(output_lines):
662
+ def read_output(output_lines: List[str], filename: str = STDERR_FILENAME):
661
663
  """Print the output in real time with another thread."""
662
- while not os.path.exists(STDERR_FILENAME):
664
+ while not os.path.exists(filename):
663
665
  time.sleep(1)
664
666
 
665
667
  pt = 0
666
668
  while pt >= 0:
667
- if pt > 0 and not os.path.exists(STDERR_FILENAME):
669
+ if pt > 0 and not os.path.exists(filename):
668
670
  break
669
- lines = open(STDERR_FILENAME).readlines()
671
+ lines = open(filename).readlines()
670
672
  for line in lines[pt:]:
671
673
  print(line, end="", flush=True)
672
674
  output_lines.append(line)
@@ -747,6 +749,33 @@ def run_and_check_memory_leak(
747
749
  assert has_abort
748
750
 
749
751
 
752
+ def run_command_and_capture_output(command, env: Optional[dict] = None):
753
+ stdout = open(STDOUT_FILENAME, "w")
754
+ stderr = open(STDERR_FILENAME, "w")
755
+ process = subprocess.Popen(
756
+ command, stdout=stdout, stderr=stderr, env=env, text=True
757
+ )
758
+
759
+ # Launch a thread to stream the output
760
+ output_lines = []
761
+ t = threading.Thread(target=read_output, args=(output_lines, STDOUT_FILENAME))
762
+ t.start()
763
+
764
+ # Join the process
765
+ process.wait()
766
+
767
+ stdout.close()
768
+ stderr.close()
769
+ if os.path.exists(STDOUT_FILENAME):
770
+ os.remove(STDOUT_FILENAME)
771
+ if os.path.exists(STDERR_FILENAME):
772
+ os.remove(STDERR_FILENAME)
773
+ kill_process_tree(process.pid)
774
+ t.join()
775
+
776
+ return output_lines
777
+
778
+
750
779
  def run_mmlu_test(
751
780
  disable_radix_cache=False,
752
781
  enable_mixed_chunk=False,
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.1.post4"
1
+ __version__ = "0.4.1.post5"