sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +48 -33
- sglang/bench_server_latency.py +0 -6
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/hf_transformers_utils.py +1 -0
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/patch.py +117 -0
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +187 -68
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -247
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/mem_cache/radix_cache.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +25 -25
- sglang/srt/model_executor/forward_batch_info.py +94 -97
- sglang/srt/model_executor/model_runner.py +76 -78
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +22 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/openai_api/adapter.py +7 -0
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +76 -33
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +53 -9
- sglang/version.py +1 -1
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -482
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.1.post3.dist-info/RECORD +0 -134
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,3 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
1
|
"""
|
4
2
|
Copyright 2023-2024 SGLang Team
|
5
3
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -15,7 +13,19 @@ See the License for the specific language governing permissions and
|
|
15
13
|
limitations under the License.
|
16
14
|
"""
|
17
15
|
|
18
|
-
"""
|
16
|
+
"""
|
17
|
+
Store information about requests and batches.
|
18
|
+
|
19
|
+
The following is the flow of data structures for a batch:
|
20
|
+
|
21
|
+
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
22
|
+
|
23
|
+
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
|
24
|
+
It contains high-level scheduling data. Most of the data is on the CPU.
|
25
|
+
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
|
26
|
+
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
27
|
+
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
28
|
+
"""
|
19
29
|
|
20
30
|
import logging
|
21
31
|
from dataclasses import dataclass
|
@@ -31,6 +41,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|
31
41
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
32
42
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
33
43
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
44
|
+
from sglang.srt.sampling.sampling_params import SamplingParams
|
34
45
|
from sglang.srt.server_args import ServerArgs
|
35
46
|
|
36
47
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
@@ -102,14 +113,50 @@ class FINISH_ABORT(BaseFinishReason):
|
|
102
113
|
}
|
103
114
|
|
104
115
|
|
116
|
+
@dataclass
|
117
|
+
class ImageInputs:
|
118
|
+
"""The image related inputs."""
|
119
|
+
|
120
|
+
pixel_values: torch.Tensor
|
121
|
+
image_hash: int
|
122
|
+
image_sizes: Optional[list] = None
|
123
|
+
image_offsets: Optional[list] = None
|
124
|
+
pad_values: Optional[list] = None
|
125
|
+
modalities: Optional[list] = None
|
126
|
+
|
127
|
+
image_embeds: Optional[List[torch.Tensor]] = None
|
128
|
+
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
129
|
+
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
130
|
+
|
131
|
+
@staticmethod
|
132
|
+
def from_dict(obj, vocab_size):
|
133
|
+
# Use image hash as fake token_ids, which is then used for prefix matching
|
134
|
+
ret = ImageInputs(
|
135
|
+
pixel_values=obj["pixel_values"],
|
136
|
+
image_hash=hash(tuple(obj["image_hashes"])),
|
137
|
+
)
|
138
|
+
image_hash = ret.image_hash
|
139
|
+
ret.pad_values = [
|
140
|
+
(image_hash) % vocab_size,
|
141
|
+
(image_hash >> 16) % vocab_size,
|
142
|
+
(image_hash >> 32) % vocab_size,
|
143
|
+
(image_hash >> 64) % vocab_size,
|
144
|
+
]
|
145
|
+
ret.image_sizes = obj["image_sizes"]
|
146
|
+
# Only when pixel values is not None we have modalities
|
147
|
+
ret.modalities = obj["modalities"] or ["image"]
|
148
|
+
return ret
|
149
|
+
|
150
|
+
|
105
151
|
class Req:
|
106
|
-
"""
|
152
|
+
"""The input and output status of a request."""
|
107
153
|
|
108
154
|
def __init__(
|
109
155
|
self,
|
110
156
|
rid: str,
|
111
157
|
origin_input_text: str,
|
112
158
|
origin_input_ids: Tuple[int],
|
159
|
+
sampling_params: SamplingParams,
|
113
160
|
lora_path: Optional[str] = None,
|
114
161
|
):
|
115
162
|
# Input and output info
|
@@ -119,6 +166,8 @@ class Req:
|
|
119
166
|
self.origin_input_ids = origin_input_ids
|
120
167
|
self.output_ids = [] # Each decode stage's output ids
|
121
168
|
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
169
|
+
|
170
|
+
self.sampling_params = sampling_params
|
122
171
|
self.lora_path = lora_path
|
123
172
|
|
124
173
|
# Memory info
|
@@ -127,6 +176,7 @@ class Req:
|
|
127
176
|
# Check finish
|
128
177
|
self.tokenizer = None
|
129
178
|
self.finished_reason = None
|
179
|
+
self.stream = False
|
130
180
|
|
131
181
|
# For incremental decoding
|
132
182
|
# ----- | --------- read_ids -------|
|
@@ -147,21 +197,13 @@ class Req:
|
|
147
197
|
self.completion_tokens_wo_jump_forward = 0
|
148
198
|
|
149
199
|
# For vision inputs
|
150
|
-
self.
|
151
|
-
self.image_sizes = None
|
152
|
-
self.image_offsets = None
|
153
|
-
self.pad_value = None
|
154
|
-
self.modalities = None
|
200
|
+
self.image_inputs: Optional[ImageInputs] = None
|
155
201
|
|
156
202
|
# Prefix info
|
157
203
|
self.prefix_indices = []
|
158
204
|
self.extend_input_len = 0
|
159
205
|
self.last_node = None
|
160
206
|
|
161
|
-
# Sampling parameters
|
162
|
-
self.sampling_params = None
|
163
|
-
self.stream = False
|
164
|
-
|
165
207
|
# Logprobs (arguments)
|
166
208
|
self.return_logprob = False
|
167
209
|
self.logprob_start_len = 0
|
@@ -363,28 +405,32 @@ class ScheduleBatch:
|
|
363
405
|
sampling_info: SamplingBatchInfo = None
|
364
406
|
|
365
407
|
# Batched arguments to model runner
|
366
|
-
input_ids:
|
367
|
-
req_pool_indices:
|
368
|
-
seq_lens:
|
369
|
-
position_ids_offsets: torch.Tensor = None
|
408
|
+
input_ids: List[int] = None
|
409
|
+
req_pool_indices: List[int] = None
|
410
|
+
seq_lens: List[int] = None
|
370
411
|
out_cache_loc: torch.Tensor = None
|
371
|
-
extend_num_tokens: int = None
|
372
|
-
|
373
|
-
# For mixed chunekd prefill
|
374
|
-
prefix_lens_cpu: List[int] = None
|
375
|
-
running_bs: int = None
|
376
412
|
|
377
413
|
# For processing logprobs
|
378
414
|
return_logprob: bool = False
|
379
|
-
top_logprobs_nums: List[int] = None
|
415
|
+
top_logprobs_nums: Optional[List[int]] = None
|
416
|
+
|
417
|
+
# For extend and mixed chunekd prefill
|
418
|
+
prefix_lens: List[int] = None
|
419
|
+
extend_lens: List[int] = None
|
420
|
+
extend_num_tokens: int = None
|
421
|
+
running_bs: int = None
|
380
422
|
|
381
423
|
# Stream
|
382
424
|
has_stream: bool = False
|
383
425
|
|
426
|
+
# Has regex
|
427
|
+
has_regex: bool = False
|
428
|
+
|
384
429
|
@classmethod
|
385
430
|
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
386
431
|
return_logprob = any(req.return_logprob for req in reqs)
|
387
432
|
has_stream = any(req.stream for req in reqs)
|
433
|
+
has_regex = any(req.regex_fsm for req in reqs)
|
388
434
|
|
389
435
|
return cls(
|
390
436
|
reqs=reqs,
|
@@ -393,6 +439,7 @@ class ScheduleBatch:
|
|
393
439
|
tree_cache=tree_cache,
|
394
440
|
return_logprob=return_logprob,
|
395
441
|
has_stream=has_stream,
|
442
|
+
has_regex=has_regex,
|
396
443
|
)
|
397
444
|
|
398
445
|
def batch_size(self):
|
@@ -429,19 +476,19 @@ class ScheduleBatch:
|
|
429
476
|
def prepare_for_extend(self, vocab_size: int):
|
430
477
|
self.forward_mode = ForwardMode.EXTEND
|
431
478
|
|
432
|
-
bs = self.
|
479
|
+
bs = len(self.reqs)
|
433
480
|
reqs = self.reqs
|
434
481
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
435
482
|
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
436
483
|
seq_lens = []
|
437
484
|
|
438
485
|
# Allocate memory
|
439
|
-
|
486
|
+
req_pool_indices = self.alloc_req_slots(bs)
|
440
487
|
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
441
488
|
|
442
489
|
pt = 0
|
443
490
|
for i, req in enumerate(reqs):
|
444
|
-
req.req_pool_idx =
|
491
|
+
req.req_pool_idx = req_pool_indices[i]
|
445
492
|
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
446
493
|
seq_lens.append(seq_len)
|
447
494
|
assert seq_len - pre_len == req.extend_input_len
|
@@ -467,18 +514,19 @@ class ScheduleBatch:
|
|
467
514
|
pt += req.extend_input_len
|
468
515
|
|
469
516
|
# Set fields
|
470
|
-
with
|
517
|
+
with out_cache_loc.device:
|
471
518
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
|
472
|
-
self.req_pool_indices = torch.tensor(
|
473
|
-
self.seq_lens = torch.tensor(seq_lens
|
474
|
-
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
|
519
|
+
self.req_pool_indices = torch.tensor(req_pool_indices)
|
520
|
+
self.seq_lens = torch.tensor(seq_lens)
|
475
521
|
|
476
522
|
self.extend_num_tokens = extend_num_tokens
|
477
523
|
self.out_cache_loc = out_cache_loc
|
478
|
-
self.
|
479
|
-
|
480
|
-
self.
|
481
|
-
self.
|
524
|
+
if self.return_logprob:
|
525
|
+
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
526
|
+
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
527
|
+
self.extend_lens = [r.extend_input_len for r in reqs]
|
528
|
+
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
529
|
+
|
482
530
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
483
531
|
|
484
532
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
@@ -493,23 +541,23 @@ class ScheduleBatch:
|
|
493
541
|
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
494
542
|
extend_num_tokens = self.extend_num_tokens + running_bs
|
495
543
|
|
496
|
-
self.
|
544
|
+
self.merge_batch(running_batch)
|
497
545
|
self.input_ids = input_ids
|
498
546
|
self.out_cache_loc = out_cache_loc
|
499
547
|
self.extend_num_tokens = extend_num_tokens
|
500
548
|
|
501
549
|
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
502
|
-
self.
|
550
|
+
self.prefix_lens.extend(
|
503
551
|
[
|
504
552
|
len(r.origin_input_ids) + len(r.output_ids) - 1
|
505
553
|
for r in running_batch.reqs
|
506
554
|
]
|
507
555
|
)
|
508
|
-
self.
|
509
|
-
self.
|
556
|
+
self.extend_lens.extend([1] * running_bs)
|
557
|
+
self.extend_logprob_start_lens.extend([0] * running_bs)
|
510
558
|
|
511
559
|
def check_decode_mem(self):
|
512
|
-
bs = self.
|
560
|
+
bs = len(self.reqs)
|
513
561
|
if self.token_to_kv_pool.available_size() >= bs:
|
514
562
|
return True
|
515
563
|
|
@@ -598,7 +646,7 @@ class ScheduleBatch:
|
|
598
646
|
|
599
647
|
return retracted_reqs, new_estimate_ratio
|
600
648
|
|
601
|
-
def check_for_jump_forward(self,
|
649
|
+
def check_for_jump_forward(self, pad_input_ids_func):
|
602
650
|
jump_forward_reqs = []
|
603
651
|
filter_indices = [i for i in range(len(self.reqs))]
|
604
652
|
|
@@ -654,15 +702,9 @@ class ScheduleBatch:
|
|
654
702
|
self.tree_cache.cache_finished_req(req, cur_all_ids)
|
655
703
|
|
656
704
|
# re-applying image padding
|
657
|
-
if req.
|
658
|
-
(
|
659
|
-
req.
|
660
|
-
req.image_offsets,
|
661
|
-
) = model_runner.model.pad_input_ids(
|
662
|
-
req.origin_input_ids_unpadded,
|
663
|
-
req.pad_value,
|
664
|
-
req.pixel_values,
|
665
|
-
req.image_sizes,
|
705
|
+
if req.image_inputs is not None:
|
706
|
+
req.origin_input_ids = pad_input_ids_func(
|
707
|
+
req.origin_input_ids_unpadded, req.image_inputs
|
666
708
|
)
|
667
709
|
|
668
710
|
jump_forward_reqs.append(req)
|
@@ -680,14 +722,14 @@ class ScheduleBatch:
|
|
680
722
|
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
|
681
723
|
for r in self.reqs
|
682
724
|
]
|
683
|
-
else:
|
684
|
-
self.sampling_info.penalizer_orchestrator.cumulate_input_tokens(input_ids)
|
685
725
|
|
686
|
-
self.input_ids = torch.tensor(
|
726
|
+
self.input_ids = torch.tensor(
|
727
|
+
input_ids, dtype=torch.int32, device=self.seq_lens.device
|
728
|
+
)
|
687
729
|
self.seq_lens.add_(1)
|
688
730
|
|
689
731
|
# Alloc mem
|
690
|
-
bs = self.
|
732
|
+
bs = len(self.reqs)
|
691
733
|
self.out_cache_loc = self.alloc_token_slots(bs)
|
692
734
|
|
693
735
|
self.req_to_token_pool.req_to_token[
|
@@ -705,33 +747,110 @@ class ScheduleBatch:
|
|
705
747
|
return
|
706
748
|
|
707
749
|
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
708
|
-
new_indices = torch.tensor(
|
709
|
-
|
710
|
-
|
750
|
+
new_indices = torch.tensor(
|
751
|
+
unfinished_indices, dtype=torch.int32, device=self.seq_lens.device
|
752
|
+
)
|
711
753
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
712
|
-
self.
|
754
|
+
self.seq_lens = self.seq_lens[new_indices]
|
713
755
|
self.out_cache_loc = None
|
714
|
-
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
715
756
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
757
|
+
if self.return_logprob:
|
758
|
+
self.top_logprobs_nums = [
|
759
|
+
self.top_logprobs_nums[i] for i in unfinished_indices
|
760
|
+
]
|
761
|
+
else:
|
762
|
+
self.top_logprobs_nums = None
|
763
|
+
|
716
764
|
self.has_stream = any(req.stream for req in self.reqs)
|
765
|
+
self.has_regex = any(req.regex_fsm for req in self.reqs)
|
717
766
|
|
718
|
-
self.sampling_info.
|
767
|
+
self.sampling_info.filter_batch(unfinished_indices, new_indices)
|
719
768
|
|
720
|
-
def
|
769
|
+
def merge_batch(self, other: "ScheduleBatch"):
|
721
770
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
722
771
|
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
723
772
|
# needs to be called with pre-merged Batch.reqs.
|
724
|
-
self.sampling_info.
|
773
|
+
self.sampling_info.merge_batch(other.sampling_info)
|
725
774
|
|
726
|
-
self.reqs.extend(other.reqs)
|
727
775
|
self.req_pool_indices = torch.concat(
|
728
776
|
[self.req_pool_indices, other.req_pool_indices]
|
729
777
|
)
|
730
778
|
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
731
|
-
self.position_ids_offsets = torch.concat(
|
732
|
-
[self.position_ids_offsets, other.position_ids_offsets]
|
733
|
-
)
|
734
779
|
self.out_cache_loc = None
|
735
|
-
self.
|
736
|
-
|
737
|
-
|
780
|
+
if self.return_logprob and other.return_logprob:
|
781
|
+
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
782
|
+
elif self.return_logprob:
|
783
|
+
self.top_logprobs_nums.extend([0] * len(other.reqs))
|
784
|
+
elif other.return_logprob:
|
785
|
+
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
786
|
+
self.reqs.extend(other.reqs)
|
787
|
+
|
788
|
+
self.return_logprob = self.return_logprob or other.return_logprob
|
789
|
+
self.has_stream = self.has_stream or other.has_stream
|
790
|
+
self.has_regex = self.has_regex or other.has_regex
|
791
|
+
|
792
|
+
def get_model_worker_batch(self):
|
793
|
+
if self.forward_mode.is_decode():
|
794
|
+
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = (
|
795
|
+
image_inputs
|
796
|
+
) = None
|
797
|
+
else:
|
798
|
+
extend_seq_lens = self.extend_lens
|
799
|
+
extend_prefix_lens = self.prefix_lens
|
800
|
+
extend_logprob_start_lens = self.extend_logprob_start_lens
|
801
|
+
image_inputs = [r.image_inputs for r in self.reqs]
|
802
|
+
|
803
|
+
lora_paths = [req.lora_path for req in self.reqs]
|
804
|
+
if self.has_regex:
|
805
|
+
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
|
806
|
+
self.sampling_info.regex_fsm_states = [
|
807
|
+
req.regex_fsm_state for req in self.reqs
|
808
|
+
]
|
809
|
+
|
810
|
+
return ModelWorkerBatch(
|
811
|
+
forward_mode=self.forward_mode,
|
812
|
+
input_ids=self.input_ids,
|
813
|
+
req_pool_indices=self.req_pool_indices,
|
814
|
+
seq_lens=self.seq_lens,
|
815
|
+
out_cache_loc=self.out_cache_loc,
|
816
|
+
return_logprob=self.return_logprob,
|
817
|
+
top_logprobs_nums=self.top_logprobs_nums,
|
818
|
+
extend_seq_lens=extend_seq_lens,
|
819
|
+
extend_prefix_lens=extend_prefix_lens,
|
820
|
+
extend_logprob_start_lens=extend_logprob_start_lens,
|
821
|
+
image_inputs=image_inputs,
|
822
|
+
lora_paths=lora_paths,
|
823
|
+
sampling_info=self.sampling_info,
|
824
|
+
)
|
825
|
+
|
826
|
+
|
827
|
+
@dataclass
|
828
|
+
class ModelWorkerBatch:
|
829
|
+
# The forward mode
|
830
|
+
forward_mode: ForwardMode
|
831
|
+
# The input ids
|
832
|
+
input_ids: torch.Tensor
|
833
|
+
# The indices of requests in the req_to_token_pool
|
834
|
+
req_pool_indices: torch.Tensor
|
835
|
+
# The sequence length
|
836
|
+
seq_lens: torch.Tensor
|
837
|
+
# The indices of output tokens in the token_to_kv_pool
|
838
|
+
out_cache_loc: torch.Tensor
|
839
|
+
|
840
|
+
# For logprob
|
841
|
+
return_logprob: bool
|
842
|
+
top_logprobs_nums: Optional[List[int]]
|
843
|
+
|
844
|
+
# For extend
|
845
|
+
extend_seq_lens: Optional[List[int]]
|
846
|
+
extend_prefix_lens: Optional[List[int]]
|
847
|
+
extend_logprob_start_lens: Optional[List[int]]
|
848
|
+
|
849
|
+
# For multimodal
|
850
|
+
image_inputs: Optional[List[ImageInputs]]
|
851
|
+
|
852
|
+
# For LoRA
|
853
|
+
lora_paths: Optional[List[str]]
|
854
|
+
|
855
|
+
# Sampling info
|
856
|
+
sampling_info: SamplingBatchInfo
|
@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
-
"""Request policy
|
16
|
+
"""Request scheduler policy"""
|
17
17
|
|
18
18
|
import os
|
19
19
|
import random
|
20
20
|
from collections import defaultdict
|
21
21
|
from contextlib import contextmanager
|
22
|
+
from enum import Enum, auto
|
22
23
|
from typing import Dict, List, Optional
|
23
24
|
|
24
25
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
@@ -32,7 +33,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
|
|
32
33
|
CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
|
33
34
|
|
34
35
|
|
35
|
-
class
|
36
|
+
class SchedulePolicy:
|
36
37
|
def __init__(self, policy: str, tree_cache: BasePrefixCache):
|
37
38
|
if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
|
38
39
|
# LPM and DFS-weight is meaningless when the tree cache is disabled.
|
@@ -104,6 +105,12 @@ class PolicyScheduler:
|
|
104
105
|
q.extend(last_node_to_reqs[cur_node])
|
105
106
|
|
106
107
|
|
108
|
+
class AddReqResult(Enum):
|
109
|
+
CONTINUE = auto() # Continue to add requests
|
110
|
+
NO_TOKEN = auto() # No token left
|
111
|
+
OTHER = auto() # Other reasons to stop adding requests
|
112
|
+
|
113
|
+
|
107
114
|
class PrefillAdder:
|
108
115
|
def __init__(
|
109
116
|
self,
|
@@ -145,17 +152,16 @@ class PrefillAdder:
|
|
145
152
|
]
|
146
153
|
)
|
147
154
|
|
148
|
-
def
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
)
|
155
|
+
def budget_state(self):
|
156
|
+
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
|
157
|
+
return AddReqResult.NO_TOKEN
|
158
|
+
|
159
|
+
if self.rem_input_tokens <= 0 or (
|
160
|
+
self.rem_chunk_tokens is not None and self.rem_chunk_tokens <= 0
|
161
|
+
):
|
162
|
+
return AddReqResult.OTHER
|
163
|
+
|
164
|
+
return AddReqResult.CONTINUE
|
159
165
|
|
160
166
|
def _prefill_one_req(
|
161
167
|
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
@@ -212,6 +218,7 @@ class PrefillAdder:
|
|
212
218
|
if not insert_sort:
|
213
219
|
self.req_states.append((tokens_left, tokens_occupied))
|
214
220
|
else:
|
221
|
+
i = 0
|
215
222
|
for i in range(len(self.req_states)):
|
216
223
|
if tokens_left <= self.req_states[i][0]:
|
217
224
|
break
|
@@ -239,10 +246,13 @@ class PrefillAdder:
|
|
239
246
|
)
|
240
247
|
bs = len(self.req_states) - i
|
241
248
|
if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
|
242
|
-
return
|
249
|
+
return AddReqResult.NO_TOKEN
|
243
250
|
tokens_freed += tokens_occupied
|
244
251
|
|
245
|
-
if
|
252
|
+
if (
|
253
|
+
self.rem_chunk_tokens is None
|
254
|
+
or req.extend_input_len <= self.rem_chunk_tokens
|
255
|
+
):
|
246
256
|
self.can_run_list.append(req)
|
247
257
|
self._prefill_one_req(
|
248
258
|
0,
|
@@ -258,7 +268,7 @@ class PrefillAdder:
|
|
258
268
|
self.new_inflight_req = req
|
259
269
|
self._prefill_one_req(0, trunc_len, 0)
|
260
270
|
|
261
|
-
return
|
271
|
+
return self.budget_state()
|
262
272
|
|
263
273
|
def add_one_req(self, req: Req):
|
264
274
|
if req.sampling_params.ignore_eos and self.tree_cache.disable:
|
@@ -271,14 +281,14 @@ class PrefillAdder:
|
|
271
281
|
prefix_len = len(req.prefix_indices)
|
272
282
|
|
273
283
|
if total_tokens >= self.rem_total_tokens:
|
274
|
-
return
|
284
|
+
return AddReqResult.NO_TOKEN
|
275
285
|
|
276
286
|
if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
|
277
|
-
return
|
287
|
+
return AddReqResult.OTHER
|
278
288
|
|
279
289
|
with self._lock_node(req.last_node):
|
280
290
|
if total_tokens > self.rem_total_tokens:
|
281
|
-
return
|
291
|
+
return AddReqResult.NO_TOKEN
|
282
292
|
|
283
293
|
if (
|
284
294
|
self.rem_chunk_tokens is None
|
@@ -297,7 +307,7 @@ class PrefillAdder:
|
|
297
307
|
# Chunked prefill
|
298
308
|
trunc_len = self.rem_chunk_tokens
|
299
309
|
if trunc_len == 0:
|
300
|
-
return
|
310
|
+
return AddReqResult.OTHER
|
301
311
|
|
302
312
|
req.extend_input_len = trunc_len
|
303
313
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
@@ -306,4 +316,4 @@ class PrefillAdder:
|
|
306
316
|
self.tree_cache.inc_lock_ref(req.last_node)
|
307
317
|
self._prefill_one_req(prefix_len, trunc_len, 0)
|
308
318
|
|
309
|
-
return
|
319
|
+
return self.budget_state()
|