sglang 0.4.1.post6__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +24 -16
- sglang/bench_one_batch.py +51 -3
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +37 -28
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +15 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +449 -0
- sglang/srt/entrypoints/http_server.py +579 -0
- sglang/srt/layers/activation.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +10 -9
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +69 -0
- sglang/srt/layers/linear.py +41 -5
- sglang/srt/layers/logits_processor.py +48 -63
- sglang/srt/layers/moe/ep_moe/layer.py +4 -4
- sglang/srt/layers/moe/fused_moe_native.py +69 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
- sglang/srt/layers/moe/fused_moe_triton/layer.py +29 -5
- sglang/srt/layers/parameter.py +2 -1
- sglang/srt/layers/quantization/__init__.py +20 -23
- sglang/srt/layers/quantization/fp8.py +6 -3
- sglang/srt/layers/quantization/modelopt_quant.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/layers/rotary_embedding.py +1179 -31
- sglang/srt/layers/sampler.py +39 -1
- sglang/srt/layers/vocab_parallel_embedding.py +2 -2
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +3 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +23 -6
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +25 -2
- sglang/srt/managers/schedule_batch.py +49 -22
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +277 -178
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +206 -121
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +10 -32
- sglang/srt/metrics/collector.py +15 -6
- sglang/srt/model_executor/cuda_graph_runner.py +4 -6
- sglang/srt/model_executor/model_runner.py +37 -15
- sglang/srt/model_loader/loader.py +8 -6
- sglang/srt/model_loader/weight_utils.py +55 -2
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +3 -3
- sglang/srt/models/dbrx.py +4 -4
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +8 -8
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +6 -24
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +7 -5
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +6 -6
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +41 -4
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +3 -3
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/sampling_batch_info.py +139 -4
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +4 -1090
- sglang/srt/server_args.py +57 -14
- sglang/srt/utils.py +103 -65
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +16 -5
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +119 -115
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
sglang/srt/managers/io_struct.py
CHANGED
@@ -59,6 +59,9 @@ class GenerateReqInput:
|
|
59
59
|
return_text_in_logprobs: bool = False
|
60
60
|
# Whether to stream output.
|
61
61
|
stream: bool = False
|
62
|
+
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
63
|
+
log_metrics: bool = True
|
64
|
+
|
62
65
|
# The modalities of the image data [image, multi-images, video]
|
63
66
|
modalities: Optional[List[str]] = None
|
64
67
|
# LoRA related
|
@@ -66,6 +69,8 @@ class GenerateReqInput:
|
|
66
69
|
|
67
70
|
# Session info for continual prompting
|
68
71
|
session_params: Optional[Union[List[Dict], Dict]] = None
|
72
|
+
# Custom logit processor (serialized function)
|
73
|
+
custom_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
69
74
|
|
70
75
|
def normalize_batch_and_arguments(self):
|
71
76
|
if (
|
@@ -180,6 +185,13 @@ class GenerateReqInput:
|
|
180
185
|
else:
|
181
186
|
assert self.parallel_sample_num == 1
|
182
187
|
|
188
|
+
if self.custom_logit_processor is None:
|
189
|
+
self.custom_logit_processor = [None] * num
|
190
|
+
elif not isinstance(self.custom_logit_processor, list):
|
191
|
+
self.custom_logit_processor = [self.custom_logit_processor] * num
|
192
|
+
else:
|
193
|
+
assert self.parallel_sample_num == 1
|
194
|
+
|
183
195
|
def regenerate_rid(self):
|
184
196
|
self.rid = uuid.uuid4().hex
|
185
197
|
return self.rid
|
@@ -196,8 +208,14 @@ class GenerateReqInput:
|
|
196
208
|
top_logprobs_num=self.top_logprobs_num[i],
|
197
209
|
return_text_in_logprobs=self.return_text_in_logprobs,
|
198
210
|
stream=self.stream,
|
211
|
+
log_metrics=self.log_metrics,
|
199
212
|
modalities=self.modalities[i] if self.modalities else None,
|
200
213
|
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
214
|
+
custom_logit_processor=(
|
215
|
+
self.custom_logit_processor[i]
|
216
|
+
if self.custom_logit_processor is not None
|
217
|
+
else None
|
218
|
+
),
|
201
219
|
)
|
202
220
|
|
203
221
|
|
@@ -230,6 +248,10 @@ class TokenizedGenerateReqInput:
|
|
230
248
|
# Session info for continual prompting
|
231
249
|
session_params: Optional[SessionParams] = None
|
232
250
|
|
251
|
+
# Custom logit processor (serialized function)
|
252
|
+
# TODO (hpguo): Add an example and update doc string here
|
253
|
+
custom_logit_processor: Optional[str] = None
|
254
|
+
|
233
255
|
|
234
256
|
@dataclass
|
235
257
|
class EmbeddingReqInput:
|
@@ -243,6 +265,8 @@ class EmbeddingReqInput:
|
|
243
265
|
sampling_params: Union[List[Dict], Dict] = None
|
244
266
|
# Dummy input embeds for compatibility
|
245
267
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
268
|
+
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
269
|
+
log_metrics: bool = True
|
246
270
|
|
247
271
|
def normalize_batch_and_arguments(self):
|
248
272
|
if (self.text is None and self.input_ids is None) or (
|
@@ -340,7 +364,6 @@ class BatchTokenIDOut:
|
|
340
364
|
input_top_logprobs_idx: List[List]
|
341
365
|
output_top_logprobs_val: List[List]
|
342
366
|
output_top_logprobs_idx: List[List]
|
343
|
-
normalized_prompt_logprob: List[float]
|
344
367
|
|
345
368
|
|
346
369
|
@dataclass
|
@@ -366,7 +389,6 @@ class BatchStrOut:
|
|
366
389
|
input_top_logprobs_idx: List[List]
|
367
390
|
output_top_logprobs_val: List[List]
|
368
391
|
output_top_logprobs_idx: List[List]
|
369
|
-
normalized_prompt_logprob: List[float]
|
370
392
|
|
371
393
|
|
372
394
|
@dataclass
|
@@ -491,6 +513,7 @@ class ProfileReq(Enum):
|
|
491
513
|
@dataclass
|
492
514
|
class ConfigureLoggingReq:
|
493
515
|
log_requests: Optional[bool] = None
|
516
|
+
log_requests_level: Optional[int] = None
|
494
517
|
dump_requests_folder: Optional[str] = None
|
495
518
|
dump_requests_threshold: Optional[int] = None
|
496
519
|
|
@@ -52,7 +52,6 @@ from sglang.srt.server_args import ServerArgs
|
|
52
52
|
if TYPE_CHECKING:
|
53
53
|
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
|
54
54
|
|
55
|
-
|
56
55
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
57
56
|
|
58
57
|
# Put some global args for easy access
|
@@ -65,9 +64,9 @@ global_server_args_dict = {
|
|
65
64
|
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
66
65
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
67
66
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
67
|
+
"device": ServerArgs.device,
|
68
68
|
}
|
69
69
|
|
70
|
-
|
71
70
|
logger = logging.getLogger(__name__)
|
72
71
|
|
73
72
|
|
@@ -116,14 +115,18 @@ class FINISH_LENGTH(BaseFinishReason):
|
|
116
115
|
|
117
116
|
|
118
117
|
class FINISH_ABORT(BaseFinishReason):
|
119
|
-
def __init__(self, message="Unknown error"):
|
118
|
+
def __init__(self, message="Unknown error", status_code=None, err_type=None):
|
120
119
|
super().__init__(is_error=True)
|
121
120
|
self.message = message
|
121
|
+
self.status_code = status_code
|
122
|
+
self.err_type = err_type
|
122
123
|
|
123
124
|
def to_json(self):
|
124
125
|
return {
|
125
126
|
"type": "abort",
|
126
127
|
"message": self.message,
|
128
|
+
"status_code": self.status_code,
|
129
|
+
"err_type": self.err_type,
|
127
130
|
}
|
128
131
|
|
129
132
|
|
@@ -148,6 +151,15 @@ class ImageInputs:
|
|
148
151
|
image_grid_thws: List[Tuple[int, int, int]] = None
|
149
152
|
mrope_position_delta: Optional[torch.Tensor] = None
|
150
153
|
|
154
|
+
# MiniCPMV related
|
155
|
+
# All the images in the batch should share the same special image
|
156
|
+
# bound token ids.
|
157
|
+
im_start_id: Optional[torch.Tensor] = None
|
158
|
+
im_end_id: Optional[torch.Tensor] = None
|
159
|
+
slice_start_id: Optional[torch.Tensor] = None
|
160
|
+
slice_end_id: Optional[torch.Tensor] = None
|
161
|
+
tgt_sizes: Optional[list] = None
|
162
|
+
|
151
163
|
@staticmethod
|
152
164
|
def from_dict(obj: dict):
|
153
165
|
ret = ImageInputs(
|
@@ -167,6 +179,11 @@ class ImageInputs:
|
|
167
179
|
"aspect_ratio_ids",
|
168
180
|
"aspect_ratio_mask",
|
169
181
|
"image_grid_thws",
|
182
|
+
"im_start_id",
|
183
|
+
"im_end_id",
|
184
|
+
"slice_start_id",
|
185
|
+
"slice_end_id",
|
186
|
+
"tgt_sizes",
|
170
187
|
]
|
171
188
|
for arg in optional_args:
|
172
189
|
if arg in obj:
|
@@ -215,6 +232,7 @@ class Req:
|
|
215
232
|
lora_path: Optional[str] = None,
|
216
233
|
input_embeds: Optional[List[List[float]]] = None,
|
217
234
|
session_id: Optional[str] = None,
|
235
|
+
custom_logit_processor: Optional[str] = None,
|
218
236
|
eos_token_ids: Optional[Set[int]] = None,
|
219
237
|
):
|
220
238
|
# Input and output info
|
@@ -226,14 +244,16 @@ class Req:
|
|
226
244
|
else origin_input_ids # Before image padding
|
227
245
|
)
|
228
246
|
self.origin_input_ids = origin_input_ids
|
229
|
-
|
230
|
-
self.
|
247
|
+
# Each decode stage's output ids
|
248
|
+
self.output_ids = []
|
249
|
+
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
|
231
250
|
self.session_id = session_id
|
232
251
|
self.input_embeds = input_embeds
|
233
252
|
|
234
253
|
# Sampling info
|
235
254
|
self.sampling_params = sampling_params
|
236
255
|
self.lora_path = lora_path
|
256
|
+
self.custom_logit_processor = custom_logit_processor
|
237
257
|
|
238
258
|
# Memory pool info
|
239
259
|
self.req_pool_idx = None
|
@@ -265,6 +285,7 @@ class Req:
|
|
265
285
|
# Prefix info
|
266
286
|
self.prefix_indices = []
|
267
287
|
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
|
288
|
+
# Updated if chunked.
|
268
289
|
self.extend_input_len = 0
|
269
290
|
self.last_node = None
|
270
291
|
|
@@ -280,11 +301,10 @@ class Req:
|
|
280
301
|
self.top_logprobs_num = top_logprobs_num
|
281
302
|
|
282
303
|
# Logprobs (return value)
|
283
|
-
self.
|
284
|
-
self.
|
285
|
-
self.
|
286
|
-
self.
|
287
|
-
self.input_top_logprobs_idx = None
|
304
|
+
self.input_token_logprobs_val: Optional[List[float]] = None
|
305
|
+
self.input_token_logprobs_idx: Optional[List[int]] = None
|
306
|
+
self.input_top_logprobs_val: Optional[List[float]] = None
|
307
|
+
self.input_top_logprobs_idx: Optional[List[int]] = None
|
288
308
|
|
289
309
|
if return_logprob:
|
290
310
|
self.output_token_logprobs_val = []
|
@@ -344,9 +364,6 @@ class Req:
|
|
344
364
|
max_prefix_len = min(max_prefix_len, input_len - 1)
|
345
365
|
|
346
366
|
if self.return_logprob:
|
347
|
-
if self.normalized_prompt_logprob is None:
|
348
|
-
# Need at least two tokens to compute normalized logprob
|
349
|
-
max_prefix_len = min(max_prefix_len, input_len - 2)
|
350
367
|
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
351
368
|
|
352
369
|
max_prefix_len = max(max_prefix_len, 0)
|
@@ -578,6 +595,9 @@ class ScheduleBatch:
|
|
578
595
|
spec_algorithm: SpeculativeAlgorithm = None
|
579
596
|
spec_info: Optional[SpecInfo] = None
|
580
597
|
|
598
|
+
# Enable custom logit processor
|
599
|
+
enable_custom_logit_processor: bool = False
|
600
|
+
|
581
601
|
@classmethod
|
582
602
|
def init_new(
|
583
603
|
cls,
|
@@ -588,6 +608,7 @@ class ScheduleBatch:
|
|
588
608
|
model_config: ModelConfig,
|
589
609
|
enable_overlap: bool,
|
590
610
|
spec_algorithm: SpeculativeAlgorithm,
|
611
|
+
enable_custom_logit_processor: bool,
|
591
612
|
):
|
592
613
|
return cls(
|
593
614
|
reqs=reqs,
|
@@ -601,6 +622,7 @@ class ScheduleBatch:
|
|
601
622
|
has_grammar=any(req.grammar for req in reqs),
|
602
623
|
device=req_to_token_pool.device,
|
603
624
|
spec_algorithm=spec_algorithm,
|
625
|
+
enable_custom_logit_processor=enable_custom_logit_processor,
|
604
626
|
)
|
605
627
|
|
606
628
|
def batch_size(self):
|
@@ -656,7 +678,7 @@ class ScheduleBatch:
|
|
656
678
|
or len(req.prefix_indices) >= im.num_image_tokens
|
657
679
|
)
|
658
680
|
|
659
|
-
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.
|
681
|
+
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
|
660
682
|
self.device, non_blocking=True
|
661
683
|
)
|
662
684
|
|
@@ -690,7 +712,7 @@ class ScheduleBatch:
|
|
690
712
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
691
713
|
self.device, non_blocking=True
|
692
714
|
)
|
693
|
-
self.seq_lens = torch.tensor(seq_lens, dtype=torch.
|
715
|
+
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
694
716
|
self.device, non_blocking=True
|
695
717
|
)
|
696
718
|
|
@@ -766,10 +788,10 @@ class ScheduleBatch:
|
|
766
788
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
767
789
|
self.device, non_blocking=True
|
768
790
|
)
|
769
|
-
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.
|
791
|
+
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
770
792
|
self.device, non_blocking=True
|
771
793
|
)
|
772
|
-
self.seq_lens = torch.tensor(seq_lens, dtype=torch.
|
794
|
+
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
773
795
|
self.device, non_blocking=True
|
774
796
|
)
|
775
797
|
self.input_embeds = (
|
@@ -1002,11 +1024,16 @@ class ScheduleBatch:
|
|
1002
1024
|
def prepare_for_idle(self):
|
1003
1025
|
self.forward_mode = ForwardMode.IDLE
|
1004
1026
|
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
|
1005
|
-
self.seq_lens = torch.empty(0, dtype=torch.
|
1027
|
+
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
1006
1028
|
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
|
1007
|
-
self.req_pool_indices = torch.empty(0, dtype=torch.
|
1029
|
+
self.req_pool_indices = torch.empty(0, dtype=torch.int64, device=self.device)
|
1008
1030
|
self.seq_lens_sum = 0
|
1009
1031
|
self.extend_num_tokens = 0
|
1032
|
+
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
1033
|
+
self,
|
1034
|
+
self.model_config.vocab_size,
|
1035
|
+
enable_overlap_schedule=self.enable_overlap,
|
1036
|
+
)
|
1010
1037
|
|
1011
1038
|
def prepare_for_decode(self):
|
1012
1039
|
self.forward_mode = ForwardMode.DECODE
|
@@ -1067,7 +1094,7 @@ class ScheduleBatch:
|
|
1067
1094
|
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
1068
1095
|
|
1069
1096
|
self.reqs = [self.reqs[i] for i in keep_indices]
|
1070
|
-
new_indices = torch.tensor(keep_indices, dtype=torch.
|
1097
|
+
new_indices = torch.tensor(keep_indices, dtype=torch.int64).to(
|
1071
1098
|
self.device, non_blocking=True
|
1072
1099
|
)
|
1073
1100
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
@@ -1121,7 +1148,7 @@ class ScheduleBatch:
|
|
1121
1148
|
self.spec_info.merge_batch(other.spec_info)
|
1122
1149
|
|
1123
1150
|
def get_model_worker_batch(self):
|
1124
|
-
if self.forward_mode.
|
1151
|
+
if self.forward_mode.is_decode_or_idle():
|
1125
1152
|
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
1126
1153
|
else:
|
1127
1154
|
extend_seq_lens = self.extend_lens
|
@@ -1136,7 +1163,6 @@ class ScheduleBatch:
|
|
1136
1163
|
|
1137
1164
|
global bid
|
1138
1165
|
bid += 1
|
1139
|
-
|
1140
1166
|
return ModelWorkerBatch(
|
1141
1167
|
bid=bid,
|
1142
1168
|
forward_mode=self.forward_mode,
|
@@ -1180,6 +1206,7 @@ class ScheduleBatch:
|
|
1180
1206
|
return_logprob=self.return_logprob,
|
1181
1207
|
decoding_reqs=self.decoding_reqs,
|
1182
1208
|
spec_algorithm=self.spec_algorithm,
|
1209
|
+
enable_custom_logit_processor=self.enable_custom_logit_processor,
|
1183
1210
|
)
|
1184
1211
|
|
1185
1212
|
def __str__(self):
|
@@ -24,6 +24,7 @@ import torch
|
|
24
24
|
|
25
25
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
26
26
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
27
|
+
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool
|
27
28
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
28
29
|
|
29
30
|
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
@@ -250,23 +251,24 @@ class PrefillAdder:
|
|
250
251
|
def __init__(
|
251
252
|
self,
|
252
253
|
tree_cache: BasePrefixCache,
|
254
|
+
token_to_kv_pool: BaseTokenToKVPool,
|
253
255
|
running_batch: ScheduleBatch,
|
254
256
|
new_token_ratio: float,
|
255
|
-
rem_total_tokens: int,
|
256
257
|
rem_input_tokens: int,
|
257
258
|
rem_chunk_tokens: Optional[int],
|
258
259
|
mixed_with_decode_tokens: int = 0,
|
259
260
|
):
|
260
261
|
self.tree_cache = tree_cache
|
262
|
+
self.token_to_kv_pool = token_to_kv_pool
|
261
263
|
self.running_batch = running_batch
|
262
264
|
self.new_token_ratio = new_token_ratio
|
263
|
-
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
|
264
265
|
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
265
266
|
self.rem_chunk_tokens = rem_chunk_tokens
|
266
267
|
if self.rem_chunk_tokens is not None:
|
267
268
|
self.rem_chunk_tokens -= mixed_with_decode_tokens
|
268
269
|
|
269
|
-
self.
|
270
|
+
self.rem_total_token_offset = mixed_with_decode_tokens
|
271
|
+
self.cur_rem_token_offset = mixed_with_decode_tokens
|
270
272
|
|
271
273
|
self.req_states = None
|
272
274
|
self.can_run_list = []
|
@@ -275,8 +277,7 @@ class PrefillAdder:
|
|
275
277
|
self.log_input_tokens = 0
|
276
278
|
|
277
279
|
if running_batch is not None:
|
278
|
-
|
279
|
-
self.rem_total_tokens -= sum(
|
280
|
+
self.rem_total_token_offset += sum(
|
280
281
|
[
|
281
282
|
min(
|
282
283
|
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
@@ -287,6 +288,22 @@ class PrefillAdder:
|
|
287
288
|
]
|
288
289
|
)
|
289
290
|
|
291
|
+
@property
|
292
|
+
def rem_total_tokens(self):
|
293
|
+
return (
|
294
|
+
self.token_to_kv_pool.available_size()
|
295
|
+
+ self.tree_cache.evictable_size()
|
296
|
+
- self.rem_total_token_offset
|
297
|
+
)
|
298
|
+
|
299
|
+
@property
|
300
|
+
def cur_rem_tokens(self):
|
301
|
+
return (
|
302
|
+
self.token_to_kv_pool.available_size()
|
303
|
+
+ self.tree_cache.evictable_size()
|
304
|
+
- self.cur_rem_token_offset
|
305
|
+
)
|
306
|
+
|
290
307
|
def budget_state(self):
|
291
308
|
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
|
292
309
|
return AddReqResult.NO_TOKEN
|
@@ -301,8 +318,8 @@ class PrefillAdder:
|
|
301
318
|
def _prefill_one_req(
|
302
319
|
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
303
320
|
):
|
304
|
-
self.
|
305
|
-
self.
|
321
|
+
self.rem_total_token_offset += extend_input_len + max_new_tokens
|
322
|
+
self.cur_rem_token_offset += extend_input_len
|
306
323
|
self.rem_input_tokens -= extend_input_len
|
307
324
|
if self.rem_chunk_tokens is not None:
|
308
325
|
self.rem_chunk_tokens -= extend_input_len
|
@@ -332,12 +349,10 @@ class PrefillAdder:
|
|
332
349
|
@contextmanager
|
333
350
|
def _lock_node(self, last_node: TreeNode):
|
334
351
|
try:
|
335
|
-
|
336
|
-
self.rem_total_tokens += delta
|
352
|
+
self.tree_cache.inc_lock_ref(last_node)
|
337
353
|
yield None
|
338
354
|
finally:
|
339
|
-
|
340
|
-
self.rem_total_tokens += delta
|
355
|
+
self.tree_cache.dec_lock_ref(last_node)
|
341
356
|
|
342
357
|
def add_one_req_ignore_eos(self, req: Req):
|
343
358
|
def add_req_state(r, insert_sort=False):
|
@@ -433,7 +448,6 @@ class PrefillAdder:
|
|
433
448
|
or input_tokens <= self.rem_chunk_tokens
|
434
449
|
or (
|
435
450
|
req.return_logprob
|
436
|
-
and req.normalized_prompt_logprob is None
|
437
451
|
and req.logprob_start_len != len(req.origin_input_ids) - 1
|
438
452
|
)
|
439
453
|
):
|