sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/bench_serving.py +18 -1
  2. sglang/lang/interpreter.py +71 -1
  3. sglang/lang/ir.py +2 -0
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/chatglm.py +78 -0
  6. sglang/srt/configs/dbrx.py +279 -0
  7. sglang/srt/configs/model_config.py +16 -7
  8. sglang/srt/hf_transformers_utils.py +9 -14
  9. sglang/srt/layers/attention/__init__.py +8 -1
  10. sglang/srt/layers/attention/flashinfer_backend.py +21 -5
  11. sglang/srt/layers/linear.py +89 -47
  12. sglang/srt/layers/logits_processor.py +6 -6
  13. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
  14. sglang/srt/layers/moe/fused_moe_triton/layer.py +39 -12
  15. sglang/srt/layers/moe/topk.py +4 -2
  16. sglang/srt/layers/parameter.py +439 -0
  17. sglang/srt/layers/quantization/__init__.py +5 -2
  18. sglang/srt/layers/quantization/fp8.py +107 -53
  19. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  20. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  21. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  22. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  23. sglang/srt/layers/radix_attention.py +2 -0
  24. sglang/srt/layers/vocab_parallel_embedding.py +16 -3
  25. sglang/srt/managers/cache_controller.py +307 -0
  26. sglang/srt/managers/configure_logging.py +43 -0
  27. sglang/srt/managers/data_parallel_controller.py +2 -0
  28. sglang/srt/managers/detokenizer_manager.py +0 -2
  29. sglang/srt/managers/io_struct.py +29 -13
  30. sglang/srt/managers/schedule_batch.py +7 -1
  31. sglang/srt/managers/scheduler.py +58 -15
  32. sglang/srt/managers/session_controller.py +1 -1
  33. sglang/srt/managers/tokenizer_manager.py +109 -45
  34. sglang/srt/mem_cache/memory_pool.py +313 -53
  35. sglang/srt/metrics/collector.py +32 -35
  36. sglang/srt/model_executor/cuda_graph_runner.py +14 -7
  37. sglang/srt/model_executor/forward_batch_info.py +20 -15
  38. sglang/srt/model_executor/model_runner.py +53 -10
  39. sglang/srt/models/chatglm.py +1 -1
  40. sglang/srt/models/dbrx.py +1 -1
  41. sglang/srt/models/grok.py +25 -16
  42. sglang/srt/models/llama.py +46 -4
  43. sglang/srt/models/qwen2.py +11 -0
  44. sglang/srt/models/qwen2_eagle.py +131 -0
  45. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  46. sglang/srt/sampling/sampling_batch_info.py +15 -5
  47. sglang/srt/sampling/sampling_params.py +1 -1
  48. sglang/srt/server.py +125 -69
  49. sglang/srt/server_args.py +39 -19
  50. sglang/srt/speculative/eagle_utils.py +93 -85
  51. sglang/srt/speculative/eagle_worker.py +48 -33
  52. sglang/srt/torch_memory_saver_adapter.py +59 -0
  53. sglang/srt/utils.py +61 -5
  54. sglang/test/test_programs.py +23 -1
  55. sglang/test/test_utils.py +36 -7
  56. sglang/version.py +1 -1
  57. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +16 -15
  58. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +61 -51
  59. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +1 -1
  60. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
  61. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
@@ -9,13 +9,12 @@ import triton.language as tl
9
9
  from sglang.srt.layers.attention.flashinfer_backend import (
10
10
  create_flashinfer_kv_indices_triton,
11
11
  )
12
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
12
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
13
13
  from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
14
14
  from sglang.srt.speculative.spec_info import SpecInfo
15
15
 
16
16
  if TYPE_CHECKING:
17
- from python.sglang.srt.layers.sampler import SampleOutput
18
- from python.sglang.srt.managers.schedule_batch import ScheduleBatch
17
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
19
18
  from sglang.srt.server_args import ServerArgs
20
19
 
21
20
 
@@ -179,19 +178,9 @@ def generate_draft_decode_kv_indices(
179
178
 
180
179
 
181
180
  class EAGLEDraftInput(SpecInfo):
182
- hidden_states: torch.Tensor = None
183
- verified_id: torch.Tensor = None
184
- positions: torch.Tensor = None
185
- accept_length: torch.Tensor = None
186
- has_finished: bool = False
187
- unfinished_index: List[int] = None
188
-
189
- def init(self, server_args: ServerArgs):
181
+ def __init__(self):
190
182
  self.prev_mode = ForwardMode.DECODE
191
183
  self.sample_output = None
192
- self.topk: int = server_args.speculative_eagle_topk
193
- self.num_verify_token: int = server_args.speculative_num_draft_tokens
194
- self.spec_steps = server_args.speculative_num_steps
195
184
 
196
185
  self.scores: torch.Tensor = None
197
186
  self.score_list: List[torch.Tensor] = []
@@ -200,11 +189,20 @@ class EAGLEDraftInput(SpecInfo):
200
189
  self.parents_list: List[torch.Tensor] = []
201
190
  self.cache_list: List[torch.Tenor] = []
202
191
  self.iter = 0
203
- self.root_token: int = None
204
192
 
205
- assert self.topk <= 10, "topk should <= 10"
193
+ self.hidden_states: torch.Tensor = None
194
+ self.verified_id: torch.Tensor = None
195
+ self.positions: torch.Tensor = None
196
+ self.accept_length: torch.Tensor = None
197
+ self.has_finished: bool = False
198
+ self.unfinished_index: List[int] = None
199
+
200
+ def load_server_args(self, server_args: ServerArgs):
201
+ self.topk: int = server_args.speculative_eagle_topk
202
+ self.num_verify_token: int = server_args.speculative_num_draft_tokens
203
+ self.spec_steps = server_args.speculative_num_steps
206
204
 
207
- def prepare_for_extend(self, batch: ForwardBatch):
205
+ def prepare_for_extend(self, batch: ScheduleBatch):
208
206
  req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
209
207
  out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
210
208
  batch.out_cache_loc = out_cache_loc
@@ -226,81 +224,73 @@ class EAGLEDraftInput(SpecInfo):
226
224
 
227
225
  pt += req.extend_input_len
228
226
 
229
- seq_lens = [0] + batch.extend_lens
230
- input_ids = batch.input_ids.tolist()
231
- verified_id = batch.spec_info.verified_id.tolist()
232
- model_input_ids = []
233
- for i in range(len(seq_lens) - 1):
234
- model_input_ids.extend(
235
- input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]]
236
- )
237
- batch.input_ids = torch.tensor(
238
- model_input_ids, dtype=torch.int32, device="cuda"
239
- )
240
-
241
- def capture_for_decode(
242
- self,
243
- sample_output: SampleOutput,
244
- hidden_states: torch.Tensor,
245
- prev_mode: ForwardMode,
246
- ):
247
- self.sample_output = sample_output
248
- self.prev_mode = prev_mode
249
- self.hidden_states = hidden_states
227
+ # TODO: support batching inputs
228
+ assert len(batch.extend_lens) == 1
229
+ batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
250
230
 
251
231
  def prepare_for_decode(self, batch: ScheduleBatch):
252
- prob = self.sample_output # b * (1/topk), vocab
232
+ prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
253
233
  top = torch.topk(prob, self.topk, dim=-1)
254
- topk_index, topk_p = top.indices, top.values # b * (1/topk), topk
255
- if self.prev_mode == ForwardMode.DECODE:
234
+ topk_index, topk_p = (
235
+ top.indices,
236
+ top.values,
237
+ ) # shape: (b * top_k, top_k) or (b, top_k)
238
+
239
+ if self.prev_mode.is_decode():
256
240
  scores = torch.mul(
257
241
  self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk)
258
- ) # (b, topk) mul (b * topk ,topk) -> b, topk, topk
242
+ ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
259
243
  topk_cs = torch.topk(
260
244
  scores.flatten(start_dim=1), self.topk, dim=-1
261
245
  ) # (b, topk)
262
246
  topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
263
- self.scores = topk_cs_p
264
247
 
265
- selected_input_index = topk_cs_index.flatten() // self.topk # b* topk
248
+ selected_input_index = topk_cs_index.flatten() // self.topk + torch.arange(
249
+ 0, batch.batch_size() * self.topk, step=self.topk, device="cuda"
250
+ ).repeat_interleave(self.topk)
266
251
 
267
252
  batch.spec_info.hidden_states = batch.spec_info.hidden_states[
268
253
  selected_input_index, :
269
254
  ]
255
+
270
256
  topk_index = topk_index.reshape(-1, self.topk**2)
271
257
  batch.input_ids = torch.gather(
272
258
  topk_index, index=topk_cs_index, dim=1
273
259
  ).flatten()
274
- batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
275
- self.score_list.append(scores) # b, topk, topk
276
- self.token_list.append(topk_index) # b, topk*topk
260
+ batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
261
+
262
+ self.scores = topk_cs_p
263
+ self.score_list.append(scores) # (b, topk, topk)
264
+ self.token_list.append(topk_index) # (b, topk * topk)
277
265
  self.origin_score_list.append(topk_p.reshape(topk_index.shape))
278
266
  self.parents_list.append(
279
267
  topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk)
280
- ) # b, topk
281
-
282
- elif self.prev_mode in (ForwardMode.EXTEND, ForwardMode.DRAFT_EXTEND):
283
- self.scores = topk_p # b, top_k
284
- self.score_list.append(topk_p.unsqueeze(1))
285
- self.token_list.append(topk_index)
286
- self.origin_score_list.append(topk_p)
268
+ ) # shape: (b, topk)
269
+ else:
270
+ # ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND
287
271
  batch.spec_info.hidden_states = (
288
- batch.spec_info.hidden_states.repeat_interleave(self.topk, 0)
272
+ batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0)
289
273
  )
274
+
290
275
  batch.input_ids = topk_index.flatten()
291
276
  batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel())
277
+
278
+ self.scores = topk_p # shape: (b, topk)
279
+ self.score_list.append(topk_p.unsqueeze(1)) # shape: (b, 1, topk)
280
+ self.token_list.append(topk_index) # shape: (b, topk)
281
+ self.origin_score_list.append(topk_p)
292
282
  self.parents_list.append(
293
283
  torch.arange(-1, self.topk, dtype=torch.long, device="cuda")
294
284
  .unsqueeze(0)
295
285
  .repeat(self.scores.shape[0], 1)
296
- ) # b, topk+1
286
+ ) # shape: (b, topk + 1)
297
287
  self.cache_list.append(batch.out_cache_loc)
298
288
  self.positions = (
299
289
  batch.seq_lens[:, None]
300
290
  + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
301
291
  ).flatten()
302
292
 
303
- bs = batch.seq_lens.numel()
293
+ bs = len(batch.seq_lens)
304
294
  assign_req_to_token_pool[(bs,)](
305
295
  batch.req_pool_indices,
306
296
  batch.req_to_token_pool.req_to_token,
@@ -347,6 +337,7 @@ class EAGLEDraftInput(SpecInfo):
347
337
  triton.next_power_of_2(self.spec_steps + 1),
348
338
  )
349
339
 
340
+ batch.seq_lens_sum = sum(batch.seq_lens)
350
341
  batch.input_ids = self.verified_id
351
342
  self.verified_id = new_verified_id
352
343
 
@@ -419,11 +410,6 @@ class EAGLEDraftInput(SpecInfo):
419
410
  )
420
411
  return bs, kv_indices, cum_kv_seq_len
421
412
 
422
- def clear(self):
423
- self.iter = 0
424
- self.score_list.clear()
425
- self.positions = None
426
-
427
413
  def clear_draft_cache(self, batch):
428
414
  draft_cache = torch.cat(self.cache_list, dim=0)
429
415
  batch.token_to_kv_pool.free(draft_cache)
@@ -455,12 +441,18 @@ class EAGLEDraftInput(SpecInfo):
455
441
  return kv_indices, cum_kv_seq_len, qo_indptr, None
456
442
 
457
443
  def merge_batch(self, spec_info: EAGLEDraftInput):
458
-
444
+ if self.hidden_states is None:
445
+ self.hidden_states = spec_info.hidden_states
446
+ self.verified_id = spec_info.verified_id
447
+ self.sample_output = spec_info.sample_output
448
+ self.prev_mode = spec_info.prev_mode
449
+ return
450
+ if spec_info.hidden_states is None:
451
+ return
459
452
  self.hidden_states = torch.cat(
460
453
  [self.hidden_states, spec_info.hidden_states], axis=0
461
454
  )
462
455
  self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
463
- # self.positions = torch.cat([self.positions, spec_info.positions], axis=0)
464
456
  self.sample_output = torch.cat([self.sample_output, spec_info.sample_output])
465
457
 
466
458
 
@@ -567,11 +559,37 @@ class EagleVerifyInput(SpecInfo):
567
559
  triton.next_power_of_2(max_draft_len),
568
560
  )
569
561
 
570
- accept_index = accept_index[accept_index != -1]
571
- # extract_index = extract_index[extract_index != 0]
572
-
573
562
  draft_input = EAGLEDraftInput()
563
+ new_accept_index = []
564
+ unfinished_index = []
565
+ finished_extend_len = {} # {rid:accept_length + 1}
566
+ accept_index_cpu = accept_index.tolist()
567
+ predict_cpu = predict.tolist()
568
+ # iterate every accepted token and check if req has finished after append the token
569
+ # should be checked BEFORE free kv cache slots
570
+ for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
571
+ new_accept_index_ = []
572
+ for j, idx in enumerate(accept_index_row):
573
+ if idx == -1:
574
+ break
575
+ id = predict_cpu[idx]
576
+ # if not found_finished:
577
+ req.output_ids.append(id)
578
+ finished_extend_len[req.rid] = j + 1
579
+ req.check_finished()
580
+ if req.finished():
581
+ draft_input.has_finished = True
582
+ # set all tokens after finished token to -1 and break
583
+ accept_index[i, j + 1 :] = -1
584
+ break
585
+ else:
586
+ new_accept_index_.append(idx)
587
+ if not req.finished():
588
+ new_accept_index.extend(new_accept_index_)
589
+ unfinished_index.append(i)
590
+ accept_length = (accept_index != -1).sum(dim=1) - 1
574
591
 
592
+ accept_index = accept_index[accept_index != -1]
575
593
  accept_length_cpu = accept_length.tolist()
576
594
  verified_id = predict[accept_index]
577
595
  verified_id_cpu = verified_id.tolist()
@@ -590,29 +608,19 @@ class EagleVerifyInput(SpecInfo):
590
608
  triton.next_power_of_2(bs),
591
609
  )
592
610
  batch.seq_lens.add_(accept_length + 1)
593
- new_accept_index = []
594
- unfinished_index = []
595
- finished_extend_len = {} # {rid:accept_length + 1}
596
- # retracted_reqs, new_token_ratio = batch.retract_decode()
597
-
598
- low = 0
599
- for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)):
600
- req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1])
601
- req.check_finished()
602
- if req.finished():
603
- draft_input.has_finished = True
604
- else:
605
- new_accept_index.append(accept_index[low : low + verified_len + 1])
606
- unfinished_index.append(i)
607
- low += verified_len + 1
608
- finished_extend_len[req.rid] = verified_len + 1
609
611
 
610
612
  if len(new_accept_index) > 0:
611
- new_accept_index = torch.cat(new_accept_index, dim=0)
613
+ new_accept_index = torch.tensor(new_accept_index, device="cuda")
612
614
  draft_input.verified_id = predict[new_accept_index]
613
615
  draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
614
616
  draft_input.accept_length = accept_length[unfinished_index]
615
617
  draft_input.unfinished_index = unfinished_index
616
618
 
617
619
  logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
618
- return draft_input, logits_output, verified_id, finished_extend_len
620
+ return (
621
+ draft_input,
622
+ logits_output,
623
+ verified_id,
624
+ finished_extend_len,
625
+ accept_length_cpu,
626
+ )
@@ -40,6 +40,7 @@ class EAGLEWorker(TpModelWorker):
40
40
  )
41
41
  self.target_worker = target_worker
42
42
  self.server_args = server_args
43
+ self.finish_extend_len = []
43
44
 
44
45
  # Share the embedding and lm_head
45
46
  embed, head = self.target_worker.model_runner.model.get_embed_and_head()
@@ -51,63 +52,72 @@ class EAGLEWorker(TpModelWorker):
51
52
  batch.spec_info.prepare_for_decode(batch)
52
53
  model_worker_batch = batch.get_model_worker_batch()
53
54
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
54
- forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
55
+ forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
55
56
  logits_output = self.model_runner.forward(forward_batch)
56
57
  self.capture_for_decode(logits_output, forward_batch)
57
58
 
58
59
  def forward_draft_extend(self, batch: ScheduleBatch):
59
- self._swap_mem_pool(batch, self.model_runner)
60
+ self._set_mem_pool(batch, self.model_runner)
60
61
  batch.spec_info.prepare_for_extend(batch)
61
62
  model_worker_batch = batch.get_model_worker_batch()
62
63
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
63
- forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
64
+ forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
64
65
  logits_output = self.model_runner.forward(forward_batch)
65
66
  self.capture_for_decode(logits_output, forward_batch)
66
- self._swap_mem_pool(batch, self.target_worker.model_runner)
67
+ self._set_mem_pool(batch, self.target_worker.model_runner)
67
68
 
68
69
  def forward_batch_speculative_generation(self, batch: ScheduleBatch):
69
70
  if batch.forward_mode.is_decode():
70
- prev_spec_info = batch.spec_info
71
- self._swap_mem_pool(batch, self.model_runner)
71
+ # Draft
72
+ self._set_mem_pool(batch, self.model_runner)
72
73
  for i in range(self.server_args.speculative_num_steps):
73
74
  self.forward_draft_decode(batch)
74
75
  batch.spec_info.clear_draft_cache(batch)
75
- self._swap_mem_pool(batch, self.target_worker.model_runner)
76
+ self._set_mem_pool(batch, self.target_worker.model_runner)
77
+
78
+ # Verify
76
79
  (
77
80
  next_draft_input,
78
81
  logits_output,
79
82
  verified_id,
80
83
  self.finish_extend_len,
84
+ accept_length_cpu,
81
85
  model_worker_batch,
82
86
  ) = self.verify(batch)
83
- next_draft_input.init(self.server_args)
87
+ next_draft_input.load_server_args(self.server_args)
84
88
  batch.spec_info = next_draft_input
85
89
  # if it is None, means all requsets are finished
86
90
  if batch.spec_info.verified_id is not None:
87
- self.forward_extend_after_decode(batch)
88
- batch.spec_info = prev_spec_info
89
- return logits_output, verified_id, model_worker_batch, next_draft_input
91
+ self.forward_draft_extend_after_decode(batch)
92
+ return (
93
+ logits_output,
94
+ verified_id,
95
+ model_worker_batch,
96
+ sum(accept_length_cpu),
97
+ )
90
98
 
91
99
  else:
92
- spec_info = EAGLEDraftInput()
93
- spec_info.init(self.server_args)
100
+ # Forward with the target model and get hidden states.
101
+ # We need the full hidden states to prefill the KV cache of the draft model.
94
102
  model_worker_batch = batch.get_model_worker_batch()
95
- model_worker_batch.spec_info = spec_info
96
- spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
103
+ model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
97
104
  logits_output, next_token_ids = self.target_worker.forward_batch_generation(
98
105
  model_worker_batch
99
106
  )
100
- model_worker_batch.spec_info.verified_id = next_token_ids
101
- model_worker_batch.spec_info.hidden_states = logits_output.hidden_states
107
+
108
+ # Forward with the draft model.
109
+ spec_info = EAGLEDraftInput()
110
+ spec_info.load_server_args(self.server_args)
111
+ spec_info.hidden_states = logits_output.hidden_states
112
+ spec_info.verified_id = next_token_ids
102
113
  batch.spec_info = spec_info
103
114
  self.forward_draft_extend(batch)
104
- batch.spec_info = None
105
- return logits_output, next_token_ids, model_worker_batch, spec_info
115
+ return logits_output, next_token_ids, model_worker_batch, 0
106
116
 
107
117
  def verify(self, batch: ScheduleBatch):
108
118
  verify_input = batch.spec_info.prepare_for_verify(batch)
109
- batch.forward_mode = ForwardMode.TARGET_VERIFY
110
119
  verify_input.prepare_for_verify(batch)
120
+ batch.forward_mode = ForwardMode.TARGET_VERIFY
111
121
  batch.spec_info = verify_input
112
122
  batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
113
123
  model_worker_batch = batch.get_model_worker_batch()
@@ -119,44 +129,49 @@ class EAGLEWorker(TpModelWorker):
119
129
  batch.forward_mode = ForwardMode.DECODE
120
130
  return res + (model_worker_batch,)
121
131
 
122
- def _swap_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
132
+ def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
123
133
  batch.token_to_kv_pool = runner.token_to_kv_pool
124
134
  batch.req_to_token_pool = runner.req_to_token_pool
125
135
 
126
- def forward_extend_after_decode(self, batch: ScheduleBatch):
127
- self._swap_mem_pool(batch, self.model_runner)
136
+ def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
137
+ self._set_mem_pool(batch, self.model_runner)
128
138
  batch.forward_mode = ForwardMode.DRAFT_EXTEND
129
139
  if batch.spec_info.has_finished:
130
140
  index = batch.spec_info.unfinished_index
131
141
  seq_lens = batch.seq_lens
132
142
  batch.seq_lens = batch.seq_lens[index]
143
+
133
144
  batch.spec_info.prepare_extend_after_decode(batch)
134
145
  model_worker_batch = batch.get_model_worker_batch()
135
146
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
136
- forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
147
+ forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
137
148
  logits_output = self.model_runner.forward(forward_batch)
149
+
138
150
  batch.spec_info.hidden_states = logits_output.hidden_states
139
151
  self.capture_for_decode(logits_output, forward_batch)
140
152
  batch.forward_mode = ForwardMode.DECODE
141
153
  if batch.spec_info.has_finished:
142
154
  batch.seq_lens = seq_lens
143
- self._swap_mem_pool(batch, self.target_worker.model_runner)
155
+ self._set_mem_pool(batch, self.target_worker.model_runner)
144
156
 
145
- def capture_for_decode(self, logits_output, forward_batch):
146
- if isinstance(logits_output, LogitsProcessorOutput):
147
- logits = logits_output.next_token_logits
157
+ def capture_for_decode(
158
+ self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
159
+ ):
148
160
  sample_output = torch.softmax(
149
- logits, dim=-1
150
- ) # TODO: Support more sampling method @kavioyu
151
- forward_batch.spec_info.capture_for_decode(
152
- sample_output, logits_output.hidden_states, forward_batch.forward_mode
153
- )
161
+ logits_output.next_token_logits, dim=-1
162
+ ) # TODO(kavioyu): Support more sampling methods
163
+ spec_info = forward_batch.spec_info
164
+ spec_info.sample_output = sample_output
165
+ spec_info.hidden_states = logits_output.hidden_states
166
+ spec_info.prev_mode = forward_batch.forward_mode
154
167
 
155
168
  # Don't support prefix share now.
156
169
  def finish_request(self, reqs: Union[Req, List[Req]]):
157
170
  if not isinstance(reqs, List):
158
171
  reqs = [reqs]
159
172
  for req in reqs:
173
+ if req.rid not in self.finish_extend_len:
174
+ continue
160
175
  req_len = (
161
176
  len(req.origin_input_ids)
162
177
  + len(req.output_ids)
@@ -0,0 +1,59 @@
1
+ from abc import ABC
2
+ from contextlib import contextmanager
3
+
4
+ try:
5
+ import torch_memory_saver
6
+
7
+ _primary_memory_saver = torch_memory_saver.TorchMemorySaver()
8
+ except ImportError:
9
+ pass
10
+
11
+
12
+ class TorchMemorySaverAdapter(ABC):
13
+ @staticmethod
14
+ def create(enable: bool):
15
+ return (
16
+ _TorchMemorySaverAdapterReal() if enable else _TorchMemorySaverAdapterNoop()
17
+ )
18
+
19
+ def configure_subprocess(self):
20
+ raise NotImplementedError
21
+
22
+ def region(self):
23
+ raise NotImplementedError
24
+
25
+ def pause(self):
26
+ raise NotImplementedError
27
+
28
+ def resume(self):
29
+ raise NotImplementedError
30
+
31
+
32
+ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
33
+ def configure_subprocess(self):
34
+ return torch_memory_saver.configure_subprocess()
35
+
36
+ def region(self):
37
+ return _primary_memory_saver.region()
38
+
39
+ def pause(self):
40
+ return _primary_memory_saver.pause()
41
+
42
+ def resume(self):
43
+ return _primary_memory_saver.resume()
44
+
45
+
46
+ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
47
+ @contextmanager
48
+ def configure_subprocess(self):
49
+ yield
50
+
51
+ @contextmanager
52
+ def region(self):
53
+ yield
54
+
55
+ def pause(self):
56
+ pass
57
+
58
+ def resume(self):
59
+ pass
sglang/srt/utils.py CHANGED
@@ -97,6 +97,10 @@ def is_flashinfer_available():
97
97
  return torch.cuda.is_available() and torch.version.cuda
98
98
 
99
99
 
100
+ def is_cuda_available():
101
+ return torch.cuda.is_available() and torch.version.cuda
102
+
103
+
100
104
  def is_ipv6(address):
101
105
  try:
102
106
  ipaddress.IPv6Address(address)
@@ -335,6 +339,8 @@ def is_port_available(port):
335
339
  return True
336
340
  except socket.error:
337
341
  return False
342
+ except OverflowError:
343
+ return False
338
344
 
339
345
 
340
346
  def decode_video_base64(video_base64):
@@ -709,13 +715,14 @@ def broadcast_pyobj(
709
715
  data: List[Any],
710
716
  rank: int,
711
717
  dist_group: Optional[torch.distributed.ProcessGroup] = None,
718
+ src: int = 0,
712
719
  ):
713
720
  """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
714
721
 
715
722
  if rank == 0:
716
723
  if len(data) == 0:
717
724
  tensor_size = torch.tensor([0], dtype=torch.long)
718
- dist.broadcast(tensor_size, src=0, group=dist_group)
725
+ dist.broadcast(tensor_size, src=src, group=dist_group)
719
726
  else:
720
727
  serialized_data = pickle.dumps(data)
721
728
  size = len(serialized_data)
@@ -724,19 +731,19 @@ def broadcast_pyobj(
724
731
  )
725
732
  tensor_size = torch.tensor([size], dtype=torch.long)
726
733
 
727
- dist.broadcast(tensor_size, src=0, group=dist_group)
728
- dist.broadcast(tensor_data, src=0, group=dist_group)
734
+ dist.broadcast(tensor_size, src=src, group=dist_group)
735
+ dist.broadcast(tensor_data, src=src, group=dist_group)
729
736
  return data
730
737
  else:
731
738
  tensor_size = torch.tensor([0], dtype=torch.long)
732
- dist.broadcast(tensor_size, src=0, group=dist_group)
739
+ dist.broadcast(tensor_size, src=src, group=dist_group)
733
740
  size = tensor_size.item()
734
741
 
735
742
  if size == 0:
736
743
  return []
737
744
 
738
745
  tensor_data = torch.empty(size, dtype=torch.uint8)
739
- dist.broadcast(tensor_data, src=0, group=dist_group)
746
+ dist.broadcast(tensor_data, src=src, group=dist_group)
740
747
 
741
748
  serialized_data = bytes(tensor_data.cpu().numpy())
742
749
  data = pickle.loads(serialized_data)
@@ -1337,6 +1344,25 @@ def parse_tool_response(text, tools, **kwargs):
1337
1344
  return text, call_info_list
1338
1345
 
1339
1346
 
1347
+ def permute_weight(x: torch.Tensor) -> torch.Tensor:
1348
+ b_ = x.shape[0]
1349
+ n_ = x.shape[1]
1350
+ k_ = x.shape[2]
1351
+
1352
+ x_ = x
1353
+ if x.dtype == torch.bfloat16 or x.dtype == torch.float16:
1354
+ x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 32), 4, 8)
1355
+ elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
1356
+ x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
1357
+ else:
1358
+ return x_
1359
+
1360
+ x_ = x_.permute(0, 1, 3, 4, 2, 5)
1361
+ x_ = x_.contiguous()
1362
+ x_ = x_.view(*x.shape)
1363
+ return x_
1364
+
1365
+
1340
1366
  class MultiprocessingSerializer:
1341
1367
  @staticmethod
1342
1368
  def serialize(obj):
@@ -1348,3 +1374,33 @@ class MultiprocessingSerializer:
1348
1374
  @staticmethod
1349
1375
  def deserialize(data):
1350
1376
  return ForkingPickler.loads(data)
1377
+
1378
+
1379
+ def debug_timing(func):
1380
+ # todo: replace with a more organized instrumentation
1381
+ def wrapper(*args, **kwargs):
1382
+ if logger.isEnabledFor(logging.DEBUG):
1383
+ tic = torch.cuda.Event(enable_timing=True)
1384
+ toc = torch.cuda.Event(enable_timing=True)
1385
+ tic.record()
1386
+ result = func(*args, **kwargs)
1387
+ toc.record()
1388
+ torch.cuda.synchronize() # Ensure all CUDA operations are complete
1389
+ elapsed = tic.elapsed_time(toc)
1390
+ indices = kwargs.get("indices", args[1] if len(args) > 1 else None)
1391
+ num_tokens = len(indices) if indices is not None else 0
1392
+ throughput = num_tokens / elapsed * 1000 if elapsed > 0 else 0
1393
+ logger.debug(
1394
+ f"Transfer time: {elapsed} ms, throughput: {throughput} tokens/s"
1395
+ )
1396
+ return result
1397
+ else:
1398
+ return func(*args, **kwargs)
1399
+
1400
+ return wrapper
1401
+
1402
+
1403
+ def nullable_str(val: str):
1404
+ if not val or val == "None":
1405
+ return None
1406
+ return val
@@ -509,13 +509,35 @@ def test_hellaswag_select():
509
509
  temperature=0,
510
510
  num_threads=64,
511
511
  progress_bar=True,
512
+ generator_style=False,
512
513
  )
513
- preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
514
+ preds = []
515
+ for i, ret in enumerate(rets):
516
+ preds.append(choices[i].index(ret["answer"]))
514
517
  latency = time.time() - tic
515
518
 
516
519
  # Compute accuracy
517
520
  accuracy = np.mean(np.array(preds) == np.array(labels))
518
521
 
522
+ # Test generator style of run_batch
523
+ tic = time.time()
524
+ rets = few_shot_hellaswag.run_batch(
525
+ arguments,
526
+ temperature=0,
527
+ num_threads=64,
528
+ progress_bar=True,
529
+ generator_style=True,
530
+ )
531
+ preds_gen = []
532
+ for i, ret in enumerate(rets):
533
+ preds_gen.append(choices[i].index(ret["answer"]))
534
+ latency_gen = time.time() - tic
535
+
536
+ # Compute accuracy
537
+ accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
538
+ assert np.abs(accuracy_gen - accuracy) < 0.01
539
+ assert np.abs(latency_gen - latency) < 1
540
+
519
541
  return accuracy, latency
520
542
 
521
543