sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_latency.py +28 -10
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/layers/attention/__init__.py +27 -5
  6. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  7. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  8. sglang/srt/layers/attention/triton_backend.py +6 -4
  9. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  10. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  11. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  12. sglang/srt/layers/sampler.py +6 -2
  13. sglang/srt/managers/detokenizer_manager.py +31 -10
  14. sglang/srt/managers/io_struct.py +4 -0
  15. sglang/srt/managers/schedule_batch.py +120 -43
  16. sglang/srt/managers/schedule_policy.py +2 -1
  17. sglang/srt/managers/scheduler.py +202 -140
  18. sglang/srt/managers/tokenizer_manager.py +5 -1
  19. sglang/srt/managers/tp_worker.py +111 -1
  20. sglang/srt/mem_cache/chunk_cache.py +8 -4
  21. sglang/srt/mem_cache/memory_pool.py +77 -4
  22. sglang/srt/mem_cache/radix_cache.py +15 -7
  23. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  24. sglang/srt/model_executor/forward_batch_info.py +16 -21
  25. sglang/srt/model_executor/model_runner.py +60 -1
  26. sglang/srt/models/baichuan.py +2 -3
  27. sglang/srt/models/chatglm.py +5 -6
  28. sglang/srt/models/commandr.py +1 -2
  29. sglang/srt/models/dbrx.py +1 -2
  30. sglang/srt/models/deepseek.py +4 -5
  31. sglang/srt/models/deepseek_v2.py +5 -6
  32. sglang/srt/models/exaone.py +1 -2
  33. sglang/srt/models/gemma.py +2 -2
  34. sglang/srt/models/gemma2.py +5 -5
  35. sglang/srt/models/gpt_bigcode.py +5 -5
  36. sglang/srt/models/grok.py +1 -2
  37. sglang/srt/models/internlm2.py +1 -2
  38. sglang/srt/models/llama.py +1 -2
  39. sglang/srt/models/llama_classification.py +1 -2
  40. sglang/srt/models/llama_reward.py +2 -3
  41. sglang/srt/models/llava.py +4 -8
  42. sglang/srt/models/llavavid.py +1 -2
  43. sglang/srt/models/minicpm.py +1 -2
  44. sglang/srt/models/minicpm3.py +5 -6
  45. sglang/srt/models/mixtral.py +1 -2
  46. sglang/srt/models/mixtral_quant.py +1 -2
  47. sglang/srt/models/olmo.py +352 -0
  48. sglang/srt/models/olmoe.py +1 -2
  49. sglang/srt/models/qwen.py +1 -2
  50. sglang/srt/models/qwen2.py +1 -2
  51. sglang/srt/models/qwen2_moe.py +4 -5
  52. sglang/srt/models/stablelm.py +1 -2
  53. sglang/srt/models/torch_native_llama.py +1 -2
  54. sglang/srt/models/xverse.py +1 -2
  55. sglang/srt/models/xverse_moe.py +4 -5
  56. sglang/srt/models/yivl.py +1 -2
  57. sglang/srt/openai_api/adapter.py +92 -49
  58. sglang/srt/openai_api/protocol.py +10 -2
  59. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  60. sglang/srt/sampling/sampling_batch_info.py +92 -58
  61. sglang/srt/sampling/sampling_params.py +2 -0
  62. sglang/srt/server.py +116 -17
  63. sglang/srt/server_args.py +121 -45
  64. sglang/srt/utils.py +11 -3
  65. sglang/test/few_shot_gsm8k.py +4 -1
  66. sglang/test/few_shot_gsm8k_engine.py +144 -0
  67. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  68. sglang/version.py +1 -1
  69. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
  70. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
  71. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  72. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  73. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  74. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -26,7 +26,9 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
26
26
  context_attention_fwd,
27
27
  )
28
28
 
29
- CUDA_CAPABILITY = torch.cuda.get_device_capability()
29
+ is_cuda_available = torch.cuda.is_available()
30
+ if is_cuda_available:
31
+ CUDA_CAPABILITY = torch.cuda.get_device_capability()
30
32
 
31
33
 
32
34
  @triton.jit
@@ -286,12 +288,12 @@ def extend_attention_fwd(
286
288
  BLOCK_DPE = 0
287
289
  BLOCK_DV = triton.next_power_of_2(Lv)
288
290
 
289
- if CUDA_CAPABILITY[0] >= 9:
291
+ if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
290
292
  if Lq <= 256:
291
293
  BLOCK_M, BLOCK_N = (128, 64)
292
294
  else:
293
295
  BLOCK_M, BLOCK_N = (32, 64)
294
- elif CUDA_CAPABILITY[0] >= 8:
296
+ elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
295
297
  if Lq <= 128:
296
298
  BLOCK_M, BLOCK_N = (128, 128)
297
299
  elif Lq <= 256:
@@ -24,7 +24,9 @@ import torch
24
24
  import triton
25
25
  import triton.language as tl
26
26
 
27
- CUDA_CAPABILITY = torch.cuda.get_device_capability()
27
+ is_cuda_available = torch.cuda.is_available()
28
+ if is_cuda_available:
29
+ CUDA_CAPABILITY = torch.cuda.get_device_capability()
28
30
 
29
31
 
30
32
  @triton.jit
@@ -145,7 +147,7 @@ def _fwd_kernel(
145
147
 
146
148
 
147
149
  def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
148
- if CUDA_CAPABILITY[0] >= 8:
150
+ if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
149
151
  BLOCK = 128
150
152
  else:
151
153
  BLOCK = 64
@@ -21,6 +21,10 @@ logger = logging.getLogger(__name__)
21
21
 
22
22
 
23
23
  class Sampler(nn.Module):
24
+ def __init__(self):
25
+ super().__init__()
26
+ self.use_nan_detectioin = not global_server_args_dict["disable_nan_detection"]
27
+
24
28
  def forward(
25
29
  self,
26
30
  logits: Union[torch.Tensor, LogitsProcessorOutput],
@@ -36,13 +40,13 @@ class Sampler(nn.Module):
36
40
  logits = None
37
41
  del logits
38
42
 
39
- if torch.any(torch.isnan(probs)):
43
+ if self.use_nan_detectioin and torch.any(torch.isnan(probs)):
40
44
  logger.warning("Detected errors during sampling! NaN in the probability.")
41
45
  probs = torch.where(
42
46
  torch.isnan(probs), torch.full_like(probs, 1e-10), probs
43
47
  )
44
48
 
45
- if sampling_info.top_ks.max().item() <= 1:
49
+ if sampling_info.is_all_greedy:
46
50
  # Use torch.argmax if all requests use greedy sampling
47
51
  batch_next_token_ids = torch.argmax(probs, -1)
48
52
  elif global_server_args_dict["sampling_backend"] == "flashinfer":
@@ -18,7 +18,7 @@ limitations under the License.
18
18
  import dataclasses
19
19
  import logging
20
20
  from collections import OrderedDict
21
- from typing import List
21
+ from typing import List, Union
22
22
 
23
23
  import zmq
24
24
 
@@ -29,7 +29,7 @@ from sglang.srt.managers.io_struct import (
29
29
  BatchTokenIDOut,
30
30
  UpdateWeightReqOutput,
31
31
  )
32
- from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
32
+ from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
33
33
  from sglang.srt.server_args import PortArgs, ServerArgs
34
34
  from sglang.srt.utils import configure_logger, kill_parent_process
35
35
  from sglang.utils import find_printable_text, get_exception_traceback
@@ -75,6 +75,21 @@ class DetokenizerManager:
75
75
 
76
76
  self.decode_status = LimitedCapacityDict()
77
77
 
78
+ def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim):
79
+ if no_stop_trim:
80
+ return output
81
+
82
+ # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
83
+ if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
84
+ pos = output.find(finished_reason.matched)
85
+ return output[:pos] if pos != -1 else output
86
+ if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
87
+ output, list
88
+ ):
89
+ assert len(output) > 0
90
+ return output[:-1]
91
+ return output
92
+
78
93
  def event_loop(self):
79
94
  """The event loop that handles requests"""
80
95
 
@@ -122,7 +137,13 @@ class DetokenizerManager:
122
137
  s = self.decode_status[rid]
123
138
  s.decode_ids = recv_obj.decode_ids[i]
124
139
 
125
- read_ids.append(s.decode_ids[s.surr_offset :])
140
+ read_ids.append(
141
+ self.trim_eos(
142
+ s.decode_ids[s.surr_offset :],
143
+ recv_obj.finished_reason[i],
144
+ recv_obj.no_stop_trim[i],
145
+ )
146
+ )
126
147
  surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
127
148
 
128
149
  # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
@@ -152,13 +173,13 @@ class DetokenizerManager:
152
173
  else:
153
174
  new_text = find_printable_text(new_text)
154
175
 
155
- output_strs.append(s.decoded_text + new_text)
156
-
157
- # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
158
- if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
159
- pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
160
- if pos != -1:
161
- output_strs[i] = output_strs[i][:pos]
176
+ output_strs.append(
177
+ self.trim_eos(
178
+ s.decoded_text + new_text,
179
+ recv_obj.finished_reason[i],
180
+ recv_obj.no_stop_trim[i],
181
+ )
182
+ )
162
183
 
163
184
  self.send_to_tokenizer.send_pyobj(
164
185
  BatchStrOut(
@@ -56,6 +56,9 @@ class GenerateReqInput:
56
56
  # LoRA related
57
57
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
58
58
 
59
+ # Whether it is a single request or a batch request
60
+ is_single: bool = True
61
+
59
62
  def post_init(self):
60
63
  if (self.text is None and self.input_ids is None) or (
61
64
  self.text is not None and self.input_ids is not None
@@ -295,6 +298,7 @@ class BatchTokenIDOut:
295
298
  spaces_between_special_tokens: List[bool]
296
299
  meta_info: List[Dict]
297
300
  finished_reason: List[BaseFinishReason]
301
+ no_stop_trim: List[bool]
298
302
 
299
303
 
300
304
  @dataclass
@@ -53,6 +53,7 @@ global_server_args_dict = {
53
53
  "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
54
54
  "disable_mla": ServerArgs.disable_mla,
55
55
  "torchao_config": ServerArgs.torchao_config,
56
+ "disable_nan_detection": ServerArgs.disable_nan_detection,
56
57
  }
57
58
 
58
59
 
@@ -196,6 +197,9 @@ class Req:
196
197
  # this does not include the jump forward tokens.
197
198
  self.completion_tokens_wo_jump_forward = 0
198
199
 
200
+ # The number of cached tokens, that were already cached in the KV store
201
+ self.cached_tokens = 0
202
+
199
203
  # For vision inputs
200
204
  self.image_inputs: Optional[ImageInputs] = None
201
205
 
@@ -203,6 +207,7 @@ class Req:
203
207
  self.prefix_indices = []
204
208
  self.extend_input_len = 0
205
209
  self.last_node = None
210
+ self.is_inflight_req = 0
206
211
 
207
212
  # Logprobs (arguments)
208
213
  self.return_logprob = False
@@ -391,25 +396,30 @@ class Req:
391
396
  return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
392
397
 
393
398
 
399
+ bid = 0
400
+
401
+
394
402
  @dataclass
395
403
  class ScheduleBatch:
396
404
  """Store all inforamtion of a batch."""
397
405
 
398
406
  # Request, memory pool, and cache
399
407
  reqs: List[Req]
400
- req_to_token_pool: ReqToTokenPool
401
- token_to_kv_pool: BaseTokenToKVPool
402
- tree_cache: BasePrefixCache
408
+ req_to_token_pool: ReqToTokenPool = None
409
+ token_to_kv_pool: BaseTokenToKVPool = None
410
+ tree_cache: BasePrefixCache = None
403
411
 
404
412
  forward_mode: ForwardMode = None
405
413
  sampling_info: SamplingBatchInfo = None
406
414
 
407
415
  # Batched arguments to model runner
408
- input_ids: List[int] = None
409
- req_pool_indices: List[int] = None
410
- seq_lens: List[int] = None
416
+ input_ids: torch.Tensor = None
417
+ req_pool_indices: torch.Tensor = None
418
+ seq_lens: torch.Tensor = None
411
419
  out_cache_loc: torch.Tensor = None
412
420
 
421
+ output_ids: torch.Tensor = None
422
+
413
423
  # For processing logprobs
414
424
  return_logprob: bool = False
415
425
  top_logprobs_nums: Optional[List[int]] = None
@@ -419,6 +429,7 @@ class ScheduleBatch:
419
429
  extend_lens: List[int] = None
420
430
  extend_num_tokens: int = None
421
431
  running_bs: int = None
432
+ decoding_reqs: List[Req] = None
422
433
 
423
434
  # Stream
424
435
  has_stream: bool = False
@@ -492,17 +503,24 @@ class ScheduleBatch:
492
503
 
493
504
  pt = 0
494
505
  for i, req in enumerate(reqs):
506
+ already_computed = (
507
+ req.extend_logprob_start_len + 1 + req.cached_tokens
508
+ if req.extend_logprob_start_len > 0
509
+ else 0
510
+ )
511
+ req.cached_tokens += len(req.prefix_indices) - already_computed
512
+
495
513
  req.req_pool_idx = req_pool_indices[i]
496
514
  pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
497
515
  seq_lens.append(seq_len)
498
516
  assert seq_len - pre_len == req.extend_input_len
499
517
 
500
518
  if pre_len > 0:
501
- self.req_to_token_pool.req_to_token[req.req_pool_idx][
502
- :pre_len
503
- ] = req.prefix_indices
519
+ self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = (
520
+ req.prefix_indices
521
+ )
504
522
 
505
- self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
523
+ self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
506
524
  out_cache_loc[pt : pt + req.extend_input_len]
507
525
  )
508
526
 
@@ -518,10 +536,15 @@ class ScheduleBatch:
518
536
  pt += req.extend_input_len
519
537
 
520
538
  # Set fields
521
- with out_cache_loc.device:
522
- self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
523
- self.req_pool_indices = torch.tensor(req_pool_indices)
524
- self.seq_lens = torch.tensor(seq_lens)
539
+ self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
540
+ self.device, non_blocking=True
541
+ )
542
+ self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
543
+ self.device, non_blocking=True
544
+ )
545
+ self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
546
+ self.device, non_blocking=True
547
+ )
525
548
 
526
549
  self.extend_num_tokens = extend_num_tokens
527
550
  self.out_cache_loc = out_cache_loc
@@ -531,7 +554,9 @@ class ScheduleBatch:
531
554
  self.extend_lens = [r.extend_input_len for r in reqs]
532
555
  self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
533
556
 
534
- self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
557
+ self.sampling_info = SamplingBatchInfo.from_schedule_batch(
558
+ self, vocab_size, global_server_args_dict["disable_penalizer"]
559
+ )
535
560
 
536
561
  def mix_with_running(self, running_batch: "ScheduleBatch"):
537
562
  self.forward_mode = ForwardMode.MIXED
@@ -586,9 +611,11 @@ class ScheduleBatch:
586
611
 
587
612
  retracted_reqs = []
588
613
  seq_lens_cpu = self.seq_lens.cpu().numpy()
614
+ first_iter = True
589
615
  while (
590
616
  self.token_to_kv_pool.available_size()
591
617
  < len(sorted_indices) * global_config.retract_decode_steps
618
+ or first_iter
592
619
  ):
593
620
  if len(sorted_indices) == 1:
594
621
  # Corner case: only one request left
@@ -597,6 +624,7 @@ class ScheduleBatch:
597
624
  ), "No space left for only one request"
598
625
  break
599
626
 
627
+ first_iter = False
600
628
  idx = sorted_indices.pop()
601
629
  req = self.reqs[idx]
602
630
  retracted_reqs.append(req)
@@ -637,7 +665,7 @@ class ScheduleBatch:
637
665
  req.last_update_decode_tokens = 0
638
666
  req.logprob_start_len = 10**9
639
667
 
640
- self.filter_batch(sorted_indices)
668
+ self.filter_batch(keep_indices=sorted_indices)
641
669
 
642
670
  # Reqs in batch are filtered
643
671
  total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
@@ -652,7 +680,7 @@ class ScheduleBatch:
652
680
 
653
681
  def check_for_jump_forward(self, pad_input_ids_func):
654
682
  jump_forward_reqs = []
655
- filter_indices = [i for i in range(len(self.reqs))]
683
+ keep_indices = set(i for i in range(len(self.reqs)))
656
684
 
657
685
  for i, req in enumerate(self.reqs):
658
686
  if req.jump_forward_map is not None:
@@ -712,63 +740,71 @@ class ScheduleBatch:
712
740
  )
713
741
 
714
742
  jump_forward_reqs.append(req)
715
- filter_indices.remove(i)
743
+ keep_indices.remove(i)
716
744
 
717
- self.filter_batch(filter_indices)
745
+ self.filter_batch(keep_indices=list(keep_indices))
718
746
 
719
747
  return jump_forward_reqs
720
748
 
721
- def prepare_for_decode(self, input_ids=None):
749
+ def prepare_for_decode(self):
722
750
  self.forward_mode = ForwardMode.DECODE
723
751
 
724
- if input_ids is None:
725
- input_ids = [
726
- r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
727
- for r in self.reqs
728
- ]
729
-
730
- self.input_ids = torch.tensor(
731
- input_ids, dtype=torch.int32, device=self.seq_lens.device
732
- )
733
- self.seq_lens.add_(1)
752
+ self.input_ids = self.output_ids
753
+ self.output_ids = None
754
+ if self.sampling_info.penalizer_orchestrator:
755
+ self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
756
+ self.input_ids
757
+ )
734
758
 
735
759
  # Alloc mem
736
760
  bs = len(self.reqs)
737
761
  self.out_cache_loc = self.alloc_token_slots(bs)
738
762
 
739
- self.req_to_token_pool.req_to_token[
740
- self.req_pool_indices, self.seq_lens - 1
741
- ] = self.out_cache_loc
763
+ self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
764
+ self.out_cache_loc
765
+ )
766
+ self.seq_lens.add_(1)
742
767
 
743
- def filter_batch(self, unfinished_indices: List[int]):
744
- if unfinished_indices is None or len(unfinished_indices) == 0:
768
+ def filter_batch(
769
+ self,
770
+ current_inflight_req: Optional[Req] = None,
771
+ keep_indices: Optional[List[int]] = None,
772
+ ):
773
+ if keep_indices is None:
774
+ keep_indices = [
775
+ i
776
+ for i in range(len(self.reqs))
777
+ if not self.reqs[i].finished()
778
+ and self.reqs[i] is not current_inflight_req
779
+ ]
780
+
781
+ if keep_indices is None or len(keep_indices) == 0:
745
782
  # Filter out all requests
746
783
  self.reqs = []
747
784
  return
748
785
 
749
- if len(unfinished_indices) == len(self.reqs):
786
+ if len(keep_indices) == len(self.reqs):
750
787
  # No need to filter
751
788
  return
752
789
 
753
- self.reqs = [self.reqs[i] for i in unfinished_indices]
754
- new_indices = torch.tensor(
755
- unfinished_indices, dtype=torch.int32, device=self.seq_lens.device
790
+ self.reqs = [self.reqs[i] for i in keep_indices]
791
+ new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
792
+ self.device, non_blocking=True
756
793
  )
757
794
  self.req_pool_indices = self.req_pool_indices[new_indices]
758
795
  self.seq_lens = self.seq_lens[new_indices]
759
796
  self.out_cache_loc = None
797
+ self.output_ids = self.output_ids[new_indices]
760
798
  self.return_logprob = any(req.return_logprob for req in self.reqs)
761
799
  if self.return_logprob:
762
- self.top_logprobs_nums = [
763
- self.top_logprobs_nums[i] for i in unfinished_indices
764
- ]
800
+ self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
765
801
  else:
766
802
  self.top_logprobs_nums = None
767
803
 
768
804
  self.has_stream = any(req.stream for req in self.reqs)
769
805
  self.has_regex = any(req.regex_fsm for req in self.reqs)
770
806
 
771
- self.sampling_info.filter_batch(unfinished_indices, new_indices)
807
+ self.sampling_info.filter_batch(keep_indices, new_indices)
772
808
 
773
809
  def merge_batch(self, other: "ScheduleBatch"):
774
810
  # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
@@ -781,6 +817,8 @@ class ScheduleBatch:
781
817
  )
782
818
  self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
783
819
  self.out_cache_loc = None
820
+ if self.output_ids is not None:
821
+ self.output_ids = torch.concat([self.output_ids, other.output_ids])
784
822
  if self.return_logprob and other.return_logprob:
785
823
  self.top_logprobs_nums.extend(other.top_logprobs_nums)
786
824
  elif self.return_logprob:
@@ -813,7 +851,11 @@ class ScheduleBatch:
813
851
  else:
814
852
  self.sampling_info.regex_fsms = None
815
853
 
854
+ global bid
855
+ bid += 1
856
+
816
857
  return ModelWorkerBatch(
858
+ bid=bid,
817
859
  forward_mode=self.forward_mode,
818
860
  input_ids=self.input_ids,
819
861
  req_pool_indices=self.req_pool_indices,
@@ -829,9 +871,26 @@ class ScheduleBatch:
829
871
  sampling_info=self.sampling_info,
830
872
  )
831
873
 
874
+ def copy(self):
875
+ return ScheduleBatch(
876
+ reqs=self.reqs,
877
+ forward_mode=self.forward_mode,
878
+ out_cache_loc=self.out_cache_loc,
879
+ return_logprob=self.return_logprob,
880
+ decoding_reqs=self.decoding_reqs,
881
+ )
882
+
883
+ def __str__(self):
884
+ return (
885
+ f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
886
+ f"#req={(len(self.reqs))})"
887
+ )
888
+
832
889
 
833
890
  @dataclass
834
891
  class ModelWorkerBatch:
892
+ # The batch id
893
+ bid: int
835
894
  # The forward mode
836
895
  forward_mode: ForwardMode
837
896
  # The input ids
@@ -860,3 +919,21 @@ class ModelWorkerBatch:
860
919
 
861
920
  # Sampling info
862
921
  sampling_info: SamplingBatchInfo
922
+
923
+ def copy(self):
924
+ return ModelWorkerBatch(
925
+ bid=self.bid,
926
+ forward_mode=self.forward_mode,
927
+ input_ids=self.input_ids.clone(),
928
+ req_pool_indices=self.req_pool_indices,
929
+ seq_lens=self.seq_lens.clone(),
930
+ out_cache_loc=self.out_cache_loc,
931
+ return_logprob=self.return_logprob,
932
+ top_logprobs_nums=self.top_logprobs_nums,
933
+ extend_seq_lens=self.extend_seq_lens,
934
+ extend_prefix_lens=self.extend_prefix_lens,
935
+ extend_logprob_start_lens=self.extend_logprob_start_lens,
936
+ image_inputs=self.image_inputs,
937
+ lora_paths=self.lora_paths,
938
+ sampling_info=self.sampling_info.copy(),
939
+ )
@@ -45,12 +45,13 @@ class SchedulePolicy:
45
45
  def calc_priority(self, waiting_queue: List[Req]):
46
46
  # Compute matched prefix length
47
47
  prefix_computed = False
48
- if self.policy in ["lpm", "dfs-weight"]:
48
+ if self.policy == "lpm" or self.policy == "dfs-weight":
49
49
  for r in waiting_queue:
50
50
  # NOTE: the prefix_indices must always be aligned with last_node
51
51
  r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
52
52
  rid=r.rid, key=r.adjust_max_prefix_ids()
53
53
  )
54
+
54
55
  prefix_computed = True
55
56
 
56
57
  if self.policy == "lpm":