sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +48 -33
  4. sglang/bench_server_latency.py +0 -6
  5. sglang/bench_serving.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +14 -1
  7. sglang/lang/interpreter.py +16 -6
  8. sglang/lang/ir.py +20 -4
  9. sglang/srt/configs/model_config.py +11 -9
  10. sglang/srt/constrained/fsm_cache.py +9 -1
  11. sglang/srt/constrained/jump_forward.py +15 -2
  12. sglang/srt/hf_transformers_utils.py +1 -0
  13. sglang/srt/layers/activation.py +4 -4
  14. sglang/srt/layers/attention/__init__.py +49 -0
  15. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  16. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  17. sglang/srt/layers/attention/triton_backend.py +161 -0
  18. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  19. sglang/srt/layers/fused_moe/patch.py +117 -0
  20. sglang/srt/layers/layernorm.py +4 -4
  21. sglang/srt/layers/logits_processor.py +19 -15
  22. sglang/srt/layers/pooler.py +3 -3
  23. sglang/srt/layers/quantization/__init__.py +0 -2
  24. sglang/srt/layers/radix_attention.py +6 -4
  25. sglang/srt/layers/sampler.py +6 -4
  26. sglang/srt/layers/torchao_utils.py +18 -0
  27. sglang/srt/lora/lora.py +20 -21
  28. sglang/srt/lora/lora_manager.py +97 -25
  29. sglang/srt/managers/detokenizer_manager.py +31 -18
  30. sglang/srt/managers/image_processor.py +187 -0
  31. sglang/srt/managers/io_struct.py +99 -75
  32. sglang/srt/managers/schedule_batch.py +187 -68
  33. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  34. sglang/srt/managers/scheduler.py +1021 -0
  35. sglang/srt/managers/tokenizer_manager.py +120 -247
  36. sglang/srt/managers/tp_worker.py +28 -925
  37. sglang/srt/mem_cache/memory_pool.py +34 -52
  38. sglang/srt/mem_cache/radix_cache.py +5 -5
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -25
  40. sglang/srt/model_executor/forward_batch_info.py +94 -97
  41. sglang/srt/model_executor/model_runner.py +76 -78
  42. sglang/srt/models/baichuan.py +10 -10
  43. sglang/srt/models/chatglm.py +12 -12
  44. sglang/srt/models/commandr.py +10 -10
  45. sglang/srt/models/dbrx.py +12 -12
  46. sglang/srt/models/deepseek.py +10 -10
  47. sglang/srt/models/deepseek_v2.py +14 -15
  48. sglang/srt/models/exaone.py +10 -10
  49. sglang/srt/models/gemma.py +10 -10
  50. sglang/srt/models/gemma2.py +11 -11
  51. sglang/srt/models/gpt_bigcode.py +10 -10
  52. sglang/srt/models/grok.py +10 -10
  53. sglang/srt/models/internlm2.py +10 -10
  54. sglang/srt/models/llama.py +22 -10
  55. sglang/srt/models/llama_classification.py +5 -5
  56. sglang/srt/models/llama_embedding.py +4 -4
  57. sglang/srt/models/llama_reward.py +142 -0
  58. sglang/srt/models/llava.py +39 -33
  59. sglang/srt/models/llavavid.py +31 -28
  60. sglang/srt/models/minicpm.py +10 -10
  61. sglang/srt/models/minicpm3.py +14 -15
  62. sglang/srt/models/mixtral.py +10 -10
  63. sglang/srt/models/mixtral_quant.py +10 -10
  64. sglang/srt/models/olmoe.py +10 -10
  65. sglang/srt/models/qwen.py +10 -10
  66. sglang/srt/models/qwen2.py +11 -11
  67. sglang/srt/models/qwen2_moe.py +10 -10
  68. sglang/srt/models/stablelm.py +10 -10
  69. sglang/srt/models/torch_native_llama.py +506 -0
  70. sglang/srt/models/xverse.py +10 -10
  71. sglang/srt/models/xverse_moe.py +10 -10
  72. sglang/srt/openai_api/adapter.py +7 -0
  73. sglang/srt/sampling/sampling_batch_info.py +36 -27
  74. sglang/srt/sampling/sampling_params.py +3 -1
  75. sglang/srt/server.py +170 -119
  76. sglang/srt/server_args.py +54 -27
  77. sglang/srt/utils.py +101 -128
  78. sglang/test/runners.py +76 -33
  79. sglang/test/test_programs.py +38 -5
  80. sglang/test/test_utils.py +53 -9
  81. sglang/version.py +1 -1
  82. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
  83. sglang-0.3.3.dist-info/RECORD +139 -0
  84. sglang/srt/layers/attention_backend.py +0 -482
  85. sglang/srt/managers/controller_multi.py +0 -207
  86. sglang/srt/managers/controller_single.py +0 -164
  87. sglang-0.3.1.post3.dist-info/RECORD +0 -134
  88. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  89. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  90. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  92. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
1
  """
4
2
  Copyright 2023-2024 SGLang Team
5
3
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,7 +13,19 @@ See the License for the specific language governing permissions and
15
13
  limitations under the License.
16
14
  """
17
15
 
18
- """Meta data for requests and batches"""
16
+ """
17
+ Store information about requests and batches.
18
+
19
+ The following is the flow of data structures for a batch:
20
+
21
+ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
22
+
23
+ - ScheduleBatch is managed by `scheduler.py::Scheduler`.
24
+ It contains high-level scheduling data. Most of the data is on the CPU.
25
+ - ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
+ - ForwardBatch is managed by `model_runner.py::ModelRunner`.
27
+ It contains low-level tensor data. Most of the data consists of GPU tensors.
28
+ """
19
29
 
20
30
  import logging
21
31
  from dataclasses import dataclass
@@ -31,6 +41,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
31
41
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
32
42
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
33
43
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
44
+ from sglang.srt.sampling.sampling_params import SamplingParams
34
45
  from sglang.srt.server_args import ServerArgs
35
46
 
36
47
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
@@ -102,14 +113,50 @@ class FINISH_ABORT(BaseFinishReason):
102
113
  }
103
114
 
104
115
 
116
+ @dataclass
117
+ class ImageInputs:
118
+ """The image related inputs."""
119
+
120
+ pixel_values: torch.Tensor
121
+ image_hash: int
122
+ image_sizes: Optional[list] = None
123
+ image_offsets: Optional[list] = None
124
+ pad_values: Optional[list] = None
125
+ modalities: Optional[list] = None
126
+
127
+ image_embeds: Optional[List[torch.Tensor]] = None
128
+ aspect_ratio_ids: Optional[List[torch.Tensor]] = None
129
+ aspect_ratio_mask: Optional[List[torch.Tensor]] = None
130
+
131
+ @staticmethod
132
+ def from_dict(obj, vocab_size):
133
+ # Use image hash as fake token_ids, which is then used for prefix matching
134
+ ret = ImageInputs(
135
+ pixel_values=obj["pixel_values"],
136
+ image_hash=hash(tuple(obj["image_hashes"])),
137
+ )
138
+ image_hash = ret.image_hash
139
+ ret.pad_values = [
140
+ (image_hash) % vocab_size,
141
+ (image_hash >> 16) % vocab_size,
142
+ (image_hash >> 32) % vocab_size,
143
+ (image_hash >> 64) % vocab_size,
144
+ ]
145
+ ret.image_sizes = obj["image_sizes"]
146
+ # Only when pixel values is not None we have modalities
147
+ ret.modalities = obj["modalities"] or ["image"]
148
+ return ret
149
+
150
+
105
151
  class Req:
106
- """Store all inforamtion of a request."""
152
+ """The input and output status of a request."""
107
153
 
108
154
  def __init__(
109
155
  self,
110
156
  rid: str,
111
157
  origin_input_text: str,
112
158
  origin_input_ids: Tuple[int],
159
+ sampling_params: SamplingParams,
113
160
  lora_path: Optional[str] = None,
114
161
  ):
115
162
  # Input and output info
@@ -119,6 +166,8 @@ class Req:
119
166
  self.origin_input_ids = origin_input_ids
120
167
  self.output_ids = [] # Each decode stage's output ids
121
168
  self.fill_ids = None # fill_ids = origin_input_ids + output_ids
169
+
170
+ self.sampling_params = sampling_params
122
171
  self.lora_path = lora_path
123
172
 
124
173
  # Memory info
@@ -127,6 +176,7 @@ class Req:
127
176
  # Check finish
128
177
  self.tokenizer = None
129
178
  self.finished_reason = None
179
+ self.stream = False
130
180
 
131
181
  # For incremental decoding
132
182
  # ----- | --------- read_ids -------|
@@ -147,21 +197,13 @@ class Req:
147
197
  self.completion_tokens_wo_jump_forward = 0
148
198
 
149
199
  # For vision inputs
150
- self.pixel_values = None
151
- self.image_sizes = None
152
- self.image_offsets = None
153
- self.pad_value = None
154
- self.modalities = None
200
+ self.image_inputs: Optional[ImageInputs] = None
155
201
 
156
202
  # Prefix info
157
203
  self.prefix_indices = []
158
204
  self.extend_input_len = 0
159
205
  self.last_node = None
160
206
 
161
- # Sampling parameters
162
- self.sampling_params = None
163
- self.stream = False
164
-
165
207
  # Logprobs (arguments)
166
208
  self.return_logprob = False
167
209
  self.logprob_start_len = 0
@@ -363,28 +405,32 @@ class ScheduleBatch:
363
405
  sampling_info: SamplingBatchInfo = None
364
406
 
365
407
  # Batched arguments to model runner
366
- input_ids: torch.Tensor = None
367
- req_pool_indices: torch.Tensor = None
368
- seq_lens: torch.Tensor = None
369
- position_ids_offsets: torch.Tensor = None
408
+ input_ids: List[int] = None
409
+ req_pool_indices: List[int] = None
410
+ seq_lens: List[int] = None
370
411
  out_cache_loc: torch.Tensor = None
371
- extend_num_tokens: int = None
372
-
373
- # For mixed chunekd prefill
374
- prefix_lens_cpu: List[int] = None
375
- running_bs: int = None
376
412
 
377
413
  # For processing logprobs
378
414
  return_logprob: bool = False
379
- top_logprobs_nums: List[int] = None
415
+ top_logprobs_nums: Optional[List[int]] = None
416
+
417
+ # For extend and mixed chunekd prefill
418
+ prefix_lens: List[int] = None
419
+ extend_lens: List[int] = None
420
+ extend_num_tokens: int = None
421
+ running_bs: int = None
380
422
 
381
423
  # Stream
382
424
  has_stream: bool = False
383
425
 
426
+ # Has regex
427
+ has_regex: bool = False
428
+
384
429
  @classmethod
385
430
  def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
386
431
  return_logprob = any(req.return_logprob for req in reqs)
387
432
  has_stream = any(req.stream for req in reqs)
433
+ has_regex = any(req.regex_fsm for req in reqs)
388
434
 
389
435
  return cls(
390
436
  reqs=reqs,
@@ -393,6 +439,7 @@ class ScheduleBatch:
393
439
  tree_cache=tree_cache,
394
440
  return_logprob=return_logprob,
395
441
  has_stream=has_stream,
442
+ has_regex=has_regex,
396
443
  )
397
444
 
398
445
  def batch_size(self):
@@ -429,19 +476,19 @@ class ScheduleBatch:
429
476
  def prepare_for_extend(self, vocab_size: int):
430
477
  self.forward_mode = ForwardMode.EXTEND
431
478
 
432
- bs = self.batch_size()
479
+ bs = len(self.reqs)
433
480
  reqs = self.reqs
434
481
  input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
435
482
  extend_num_tokens = sum(len(ids) for ids in input_ids)
436
483
  seq_lens = []
437
484
 
438
485
  # Allocate memory
439
- req_pool_indices_cpu = self.alloc_req_slots(bs)
486
+ req_pool_indices = self.alloc_req_slots(bs)
440
487
  out_cache_loc = self.alloc_token_slots(extend_num_tokens)
441
488
 
442
489
  pt = 0
443
490
  for i, req in enumerate(reqs):
444
- req.req_pool_idx = req_pool_indices_cpu[i]
491
+ req.req_pool_idx = req_pool_indices[i]
445
492
  pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
446
493
  seq_lens.append(seq_len)
447
494
  assert seq_len - pre_len == req.extend_input_len
@@ -467,18 +514,19 @@ class ScheduleBatch:
467
514
  pt += req.extend_input_len
468
515
 
469
516
  # Set fields
470
- with torch.device("cuda"):
517
+ with out_cache_loc.device:
471
518
  self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
472
- self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
473
- self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
474
- self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
519
+ self.req_pool_indices = torch.tensor(req_pool_indices)
520
+ self.seq_lens = torch.tensor(seq_lens)
475
521
 
476
522
  self.extend_num_tokens = extend_num_tokens
477
523
  self.out_cache_loc = out_cache_loc
478
- self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
479
- self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
480
- self.extend_lens_cpu = [r.extend_input_len for r in reqs]
481
- self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
524
+ if self.return_logprob:
525
+ self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
526
+ self.prefix_lens = [len(r.prefix_indices) for r in reqs]
527
+ self.extend_lens = [r.extend_input_len for r in reqs]
528
+ self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
529
+
482
530
  self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
483
531
 
484
532
  def mix_with_running(self, running_batch: "ScheduleBatch"):
@@ -493,23 +541,23 @@ class ScheduleBatch:
493
541
  out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
494
542
  extend_num_tokens = self.extend_num_tokens + running_bs
495
543
 
496
- self.merge(running_batch)
544
+ self.merge_batch(running_batch)
497
545
  self.input_ids = input_ids
498
546
  self.out_cache_loc = out_cache_loc
499
547
  self.extend_num_tokens = extend_num_tokens
500
548
 
501
549
  # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
502
- self.prefix_lens_cpu.extend(
550
+ self.prefix_lens.extend(
503
551
  [
504
552
  len(r.origin_input_ids) + len(r.output_ids) - 1
505
553
  for r in running_batch.reqs
506
554
  ]
507
555
  )
508
- self.extend_lens_cpu.extend([1] * running_bs)
509
- self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
556
+ self.extend_lens.extend([1] * running_bs)
557
+ self.extend_logprob_start_lens.extend([0] * running_bs)
510
558
 
511
559
  def check_decode_mem(self):
512
- bs = self.batch_size()
560
+ bs = len(self.reqs)
513
561
  if self.token_to_kv_pool.available_size() >= bs:
514
562
  return True
515
563
 
@@ -598,7 +646,7 @@ class ScheduleBatch:
598
646
 
599
647
  return retracted_reqs, new_estimate_ratio
600
648
 
601
- def check_for_jump_forward(self, model_runner):
649
+ def check_for_jump_forward(self, pad_input_ids_func):
602
650
  jump_forward_reqs = []
603
651
  filter_indices = [i for i in range(len(self.reqs))]
604
652
 
@@ -654,15 +702,9 @@ class ScheduleBatch:
654
702
  self.tree_cache.cache_finished_req(req, cur_all_ids)
655
703
 
656
704
  # re-applying image padding
657
- if req.pixel_values is not None:
658
- (
659
- req.origin_input_ids,
660
- req.image_offsets,
661
- ) = model_runner.model.pad_input_ids(
662
- req.origin_input_ids_unpadded,
663
- req.pad_value,
664
- req.pixel_values,
665
- req.image_sizes,
705
+ if req.image_inputs is not None:
706
+ req.origin_input_ids = pad_input_ids_func(
707
+ req.origin_input_ids_unpadded, req.image_inputs
666
708
  )
667
709
 
668
710
  jump_forward_reqs.append(req)
@@ -680,14 +722,14 @@ class ScheduleBatch:
680
722
  r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
681
723
  for r in self.reqs
682
724
  ]
683
- else:
684
- self.sampling_info.penalizer_orchestrator.cumulate_input_tokens(input_ids)
685
725
 
686
- self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
726
+ self.input_ids = torch.tensor(
727
+ input_ids, dtype=torch.int32, device=self.seq_lens.device
728
+ )
687
729
  self.seq_lens.add_(1)
688
730
 
689
731
  # Alloc mem
690
- bs = self.batch_size()
732
+ bs = len(self.reqs)
691
733
  self.out_cache_loc = self.alloc_token_slots(bs)
692
734
 
693
735
  self.req_to_token_pool.req_to_token[
@@ -705,33 +747,110 @@ class ScheduleBatch:
705
747
  return
706
748
 
707
749
  self.reqs = [self.reqs[i] for i in unfinished_indices]
708
- new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
709
- self.seq_lens = self.seq_lens[new_indices]
710
- self.input_ids = None
750
+ new_indices = torch.tensor(
751
+ unfinished_indices, dtype=torch.int32, device=self.seq_lens.device
752
+ )
711
753
  self.req_pool_indices = self.req_pool_indices[new_indices]
712
- self.position_ids_offsets = self.position_ids_offsets[new_indices]
754
+ self.seq_lens = self.seq_lens[new_indices]
713
755
  self.out_cache_loc = None
714
- self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
715
756
  self.return_logprob = any(req.return_logprob for req in self.reqs)
757
+ if self.return_logprob:
758
+ self.top_logprobs_nums = [
759
+ self.top_logprobs_nums[i] for i in unfinished_indices
760
+ ]
761
+ else:
762
+ self.top_logprobs_nums = None
763
+
716
764
  self.has_stream = any(req.stream for req in self.reqs)
765
+ self.has_regex = any(req.regex_fsm for req in self.reqs)
717
766
 
718
- self.sampling_info.filter(unfinished_indices, new_indices)
767
+ self.sampling_info.filter_batch(unfinished_indices, new_indices)
719
768
 
720
- def merge(self, other: "ScheduleBatch"):
769
+ def merge_batch(self, other: "ScheduleBatch"):
721
770
  # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
722
771
  # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
723
772
  # needs to be called with pre-merged Batch.reqs.
724
- self.sampling_info.merge(other.sampling_info)
773
+ self.sampling_info.merge_batch(other.sampling_info)
725
774
 
726
- self.reqs.extend(other.reqs)
727
775
  self.req_pool_indices = torch.concat(
728
776
  [self.req_pool_indices, other.req_pool_indices]
729
777
  )
730
778
  self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
731
- self.position_ids_offsets = torch.concat(
732
- [self.position_ids_offsets, other.position_ids_offsets]
733
- )
734
779
  self.out_cache_loc = None
735
- self.top_logprobs_nums.extend(other.top_logprobs_nums)
736
- self.return_logprob = any(req.return_logprob for req in self.reqs)
737
- self.has_stream = any(req.stream for req in self.reqs)
780
+ if self.return_logprob and other.return_logprob:
781
+ self.top_logprobs_nums.extend(other.top_logprobs_nums)
782
+ elif self.return_logprob:
783
+ self.top_logprobs_nums.extend([0] * len(other.reqs))
784
+ elif other.return_logprob:
785
+ self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
786
+ self.reqs.extend(other.reqs)
787
+
788
+ self.return_logprob = self.return_logprob or other.return_logprob
789
+ self.has_stream = self.has_stream or other.has_stream
790
+ self.has_regex = self.has_regex or other.has_regex
791
+
792
+ def get_model_worker_batch(self):
793
+ if self.forward_mode.is_decode():
794
+ extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = (
795
+ image_inputs
796
+ ) = None
797
+ else:
798
+ extend_seq_lens = self.extend_lens
799
+ extend_prefix_lens = self.prefix_lens
800
+ extend_logprob_start_lens = self.extend_logprob_start_lens
801
+ image_inputs = [r.image_inputs for r in self.reqs]
802
+
803
+ lora_paths = [req.lora_path for req in self.reqs]
804
+ if self.has_regex:
805
+ self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
806
+ self.sampling_info.regex_fsm_states = [
807
+ req.regex_fsm_state for req in self.reqs
808
+ ]
809
+
810
+ return ModelWorkerBatch(
811
+ forward_mode=self.forward_mode,
812
+ input_ids=self.input_ids,
813
+ req_pool_indices=self.req_pool_indices,
814
+ seq_lens=self.seq_lens,
815
+ out_cache_loc=self.out_cache_loc,
816
+ return_logprob=self.return_logprob,
817
+ top_logprobs_nums=self.top_logprobs_nums,
818
+ extend_seq_lens=extend_seq_lens,
819
+ extend_prefix_lens=extend_prefix_lens,
820
+ extend_logprob_start_lens=extend_logprob_start_lens,
821
+ image_inputs=image_inputs,
822
+ lora_paths=lora_paths,
823
+ sampling_info=self.sampling_info,
824
+ )
825
+
826
+
827
+ @dataclass
828
+ class ModelWorkerBatch:
829
+ # The forward mode
830
+ forward_mode: ForwardMode
831
+ # The input ids
832
+ input_ids: torch.Tensor
833
+ # The indices of requests in the req_to_token_pool
834
+ req_pool_indices: torch.Tensor
835
+ # The sequence length
836
+ seq_lens: torch.Tensor
837
+ # The indices of output tokens in the token_to_kv_pool
838
+ out_cache_loc: torch.Tensor
839
+
840
+ # For logprob
841
+ return_logprob: bool
842
+ top_logprobs_nums: Optional[List[int]]
843
+
844
+ # For extend
845
+ extend_seq_lens: Optional[List[int]]
846
+ extend_prefix_lens: Optional[List[int]]
847
+ extend_logprob_start_lens: Optional[List[int]]
848
+
849
+ # For multimodal
850
+ image_inputs: Optional[List[ImageInputs]]
851
+
852
+ # For LoRA
853
+ lora_paths: Optional[List[str]]
854
+
855
+ # Sampling info
856
+ sampling_info: SamplingBatchInfo
@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
- """Request policy scheduler"""
16
+ """Request scheduler policy"""
17
17
 
18
18
  import os
19
19
  import random
20
20
  from collections import defaultdict
21
21
  from contextlib import contextmanager
22
+ from enum import Enum, auto
22
23
  from typing import Dict, List, Optional
23
24
 
24
25
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
@@ -32,7 +33,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
32
33
  CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
33
34
 
34
35
 
35
- class PolicyScheduler:
36
+ class SchedulePolicy:
36
37
  def __init__(self, policy: str, tree_cache: BasePrefixCache):
37
38
  if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
38
39
  # LPM and DFS-weight is meaningless when the tree cache is disabled.
@@ -104,6 +105,12 @@ class PolicyScheduler:
104
105
  q.extend(last_node_to_reqs[cur_node])
105
106
 
106
107
 
108
+ class AddReqResult(Enum):
109
+ CONTINUE = auto() # Continue to add requests
110
+ NO_TOKEN = auto() # No token left
111
+ OTHER = auto() # Other reasons to stop adding requests
112
+
113
+
107
114
  class PrefillAdder:
108
115
  def __init__(
109
116
  self,
@@ -145,17 +152,16 @@ class PrefillAdder:
145
152
  ]
146
153
  )
147
154
 
148
- def no_remaining_tokens(self):
149
- return (
150
- self.rem_total_tokens <= 0
151
- or self.rem_input_tokens <= 0
152
- or (
153
- self.rem_chunk_tokens <= 0
154
- if self.rem_chunk_tokens is not None
155
- else False
156
- )
157
- or self.cur_rem_tokens <= 0
158
- )
155
+ def budget_state(self):
156
+ if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
157
+ return AddReqResult.NO_TOKEN
158
+
159
+ if self.rem_input_tokens <= 0 or (
160
+ self.rem_chunk_tokens is not None and self.rem_chunk_tokens <= 0
161
+ ):
162
+ return AddReqResult.OTHER
163
+
164
+ return AddReqResult.CONTINUE
159
165
 
160
166
  def _prefill_one_req(
161
167
  self, prefix_len: int, extend_input_len: int, max_new_tokens: int
@@ -212,6 +218,7 @@ class PrefillAdder:
212
218
  if not insert_sort:
213
219
  self.req_states.append((tokens_left, tokens_occupied))
214
220
  else:
221
+ i = 0
215
222
  for i in range(len(self.req_states)):
216
223
  if tokens_left <= self.req_states[i][0]:
217
224
  break
@@ -239,10 +246,13 @@ class PrefillAdder:
239
246
  )
240
247
  bs = len(self.req_states) - i
241
248
  if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
242
- return False
249
+ return AddReqResult.NO_TOKEN
243
250
  tokens_freed += tokens_occupied
244
251
 
245
- if req.extend_input_len <= self.rem_chunk_tokens:
252
+ if (
253
+ self.rem_chunk_tokens is None
254
+ or req.extend_input_len <= self.rem_chunk_tokens
255
+ ):
246
256
  self.can_run_list.append(req)
247
257
  self._prefill_one_req(
248
258
  0,
@@ -258,7 +268,7 @@ class PrefillAdder:
258
268
  self.new_inflight_req = req
259
269
  self._prefill_one_req(0, trunc_len, 0)
260
270
 
261
- return True
271
+ return self.budget_state()
262
272
 
263
273
  def add_one_req(self, req: Req):
264
274
  if req.sampling_params.ignore_eos and self.tree_cache.disable:
@@ -271,14 +281,14 @@ class PrefillAdder:
271
281
  prefix_len = len(req.prefix_indices)
272
282
 
273
283
  if total_tokens >= self.rem_total_tokens:
274
- return False
284
+ return AddReqResult.NO_TOKEN
275
285
 
276
286
  if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
277
- return False
287
+ return AddReqResult.OTHER
278
288
 
279
289
  with self._lock_node(req.last_node):
280
290
  if total_tokens > self.rem_total_tokens:
281
- return False
291
+ return AddReqResult.NO_TOKEN
282
292
 
283
293
  if (
284
294
  self.rem_chunk_tokens is None
@@ -297,7 +307,7 @@ class PrefillAdder:
297
307
  # Chunked prefill
298
308
  trunc_len = self.rem_chunk_tokens
299
309
  if trunc_len == 0:
300
- return False
310
+ return AddReqResult.OTHER
301
311
 
302
312
  req.extend_input_len = trunc_len
303
313
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
@@ -306,4 +316,4 @@ class PrefillAdder:
306
316
  self.tree_cache.inc_lock_ref(req.last_node)
307
317
  self._prefill_one_req(prefix_len, trunc_len, 0)
308
318
 
309
- return True and not self.no_remaining_tokens()
319
+ return self.budget_state()