sglang 0.3.6__py3-none-any.whl → 0.3.6.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +2 -2
- sglang/bench_one_batch.py +4 -7
- sglang/bench_one_batch_server.py +2 -2
- sglang/bench_serving.py +75 -26
- sglang/check_env.py +7 -1
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +2 -2
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +0 -3
- sglang/srt/configs/model_config.py +15 -20
- sglang/srt/constrained/__init__.py +13 -14
- sglang/srt/constrained/base_grammar_backend.py +13 -15
- sglang/srt/constrained/outlines_backend.py +13 -15
- sglang/srt/constrained/outlines_jump_forward.py +13 -15
- sglang/srt/constrained/xgrammar_backend.py +38 -57
- sglang/srt/conversation.py +13 -15
- sglang/srt/hf_transformers_utils.py +13 -15
- sglang/srt/layers/activation.py +13 -13
- sglang/srt/layers/attention/flashinfer_backend.py +14 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
- sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
- sglang/srt/layers/custom_op_util.py +13 -14
- sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
- sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
- sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
- sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
- sglang/srt/layers/fused_moe_triton/layer.py +633 -0
- sglang/srt/layers/layernorm.py +13 -15
- sglang/srt/layers/logits_processor.py +13 -15
- sglang/srt/layers/quantization/__init__.py +77 -17
- sglang/srt/layers/radix_attention.py +13 -15
- sglang/srt/layers/rotary_embedding.py +13 -13
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora.py +13 -14
- sglang/srt/lora/lora_config.py +13 -14
- sglang/srt/lora/lora_manager.py +22 -24
- sglang/srt/managers/data_parallel_controller.py +25 -19
- sglang/srt/managers/detokenizer_manager.py +13 -18
- sglang/srt/managers/image_processor.py +6 -9
- sglang/srt/managers/io_struct.py +43 -28
- sglang/srt/managers/schedule_batch.py +92 -27
- sglang/srt/managers/schedule_policy.py +13 -15
- sglang/srt/managers/scheduler.py +94 -72
- sglang/srt/managers/session_controller.py +29 -19
- sglang/srt/managers/tokenizer_manager.py +29 -22
- sglang/srt/managers/tp_worker.py +13 -15
- sglang/srt/managers/tp_worker_overlap_thread.py +13 -15
- sglang/srt/metrics/collector.py +13 -15
- sglang/srt/metrics/func_timer.py +13 -15
- sglang/srt/mm_utils.py +13 -14
- sglang/srt/model_executor/cuda_graph_runner.py +20 -19
- sglang/srt/model_executor/forward_batch_info.py +19 -17
- sglang/srt/model_executor/model_runner.py +42 -30
- sglang/srt/models/chatglm.py +15 -16
- sglang/srt/models/commandr.py +15 -16
- sglang/srt/models/dbrx.py +15 -16
- sglang/srt/models/deepseek.py +15 -15
- sglang/srt/models/deepseek_v2.py +15 -15
- sglang/srt/models/exaone.py +14 -15
- sglang/srt/models/gemma.py +14 -14
- sglang/srt/models/gemma2.py +24 -19
- sglang/srt/models/gemma2_reward.py +13 -14
- sglang/srt/models/gpt_bigcode.py +14 -14
- sglang/srt/models/grok.py +15 -15
- sglang/srt/models/internlm2.py +13 -15
- sglang/srt/models/internlm2_reward.py +13 -14
- sglang/srt/models/llama.py +21 -21
- sglang/srt/models/llama_classification.py +13 -14
- sglang/srt/models/llama_reward.py +13 -14
- sglang/srt/models/llava.py +20 -16
- sglang/srt/models/llavavid.py +13 -15
- sglang/srt/models/minicpm.py +13 -15
- sglang/srt/models/minicpm3.py +13 -15
- sglang/srt/models/mistral.py +13 -15
- sglang/srt/models/mixtral.py +15 -15
- sglang/srt/models/mixtral_quant.py +14 -14
- sglang/srt/models/olmo.py +21 -19
- sglang/srt/models/olmoe.py +23 -20
- sglang/srt/models/qwen.py +14 -14
- sglang/srt/models/qwen2.py +22 -19
- sglang/srt/models/qwen2_moe.py +17 -18
- sglang/srt/models/stablelm.py +18 -16
- sglang/srt/models/torch_native_llama.py +15 -17
- sglang/srt/models/xverse.py +13 -14
- sglang/srt/models/xverse_moe.py +15 -16
- sglang/srt/models/yivl.py +13 -15
- sglang/srt/openai_api/adapter.py +13 -15
- sglang/srt/openai_api/protocol.py +13 -15
- sglang/srt/sampling/sampling_batch_info.py +4 -1
- sglang/srt/sampling/sampling_params.py +13 -15
- sglang/srt/server.py +60 -34
- sglang/srt/server_args.py +22 -22
- sglang/srt/utils.py +208 -19
- sglang/test/few_shot_gsm8k.py +8 -4
- sglang/test/runners.py +13 -14
- sglang/test/test_utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.3.6.dist-info → sglang-0.3.6.post2.dist-info}/LICENSE +1 -1
- {sglang-0.3.6.dist-info → sglang-0.3.6.post2.dist-info}/METADATA +25 -15
- sglang-0.3.6.post2.dist-info/RECORD +164 -0
- sglang/srt/layers/fused_moe/__init__.py +0 -1
- sglang-0.3.6.dist-info/RECORD +0 -161
- /sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +0 -0
- {sglang-0.3.6.dist-info → sglang-0.3.6.post2.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.dist-info → sglang-0.3.6.post2.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""
|
17
15
|
Store information about requests and batches.
|
18
16
|
|
@@ -33,6 +31,7 @@ import dataclasses
|
|
33
31
|
import logging
|
34
32
|
from typing import List, Optional, Tuple, Union
|
35
33
|
|
34
|
+
import numpy as np
|
36
35
|
import torch
|
37
36
|
import triton
|
38
37
|
import triton.language as tl
|
@@ -169,6 +168,30 @@ class ImageInputs:
|
|
169
168
|
|
170
169
|
return ret
|
171
170
|
|
171
|
+
def merge(self, other, vocab_size):
|
172
|
+
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
|
173
|
+
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
|
174
|
+
self.image_hashes += other.image_hashes
|
175
|
+
|
176
|
+
self.pad_values = [
|
177
|
+
(self.image_hashes) % vocab_size,
|
178
|
+
(self.image_hashes >> 16) % vocab_size,
|
179
|
+
(self.image_hashes >> 32) % vocab_size,
|
180
|
+
(self.image_hashes >> 64) % vocab_size,
|
181
|
+
]
|
182
|
+
|
183
|
+
optional_args = [
|
184
|
+
"image_sizes",
|
185
|
+
"image_offsets",
|
186
|
+
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
|
187
|
+
"aspect_ratio_ids",
|
188
|
+
"aspect_ratio_mask",
|
189
|
+
"image_grid_thws",
|
190
|
+
]
|
191
|
+
for arg in optional_args:
|
192
|
+
if getattr(self, arg, None) is not None:
|
193
|
+
setattr(self, arg, getattr(self, arg) + getattr(other, arg))
|
194
|
+
|
172
195
|
|
173
196
|
class Req:
|
174
197
|
"""The input and output status of a request."""
|
@@ -179,13 +202,19 @@ class Req:
|
|
179
202
|
origin_input_text: str,
|
180
203
|
origin_input_ids: Tuple[int],
|
181
204
|
sampling_params: SamplingParams,
|
205
|
+
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
182
206
|
lora_path: Optional[str] = None,
|
207
|
+
input_embeds: Optional[List[List[float]]] = None,
|
183
208
|
session_id: Optional[str] = None,
|
184
209
|
):
|
185
210
|
# Input and output info
|
186
211
|
self.rid = rid
|
187
212
|
self.origin_input_text = origin_input_text
|
188
|
-
self.origin_input_ids_unpadded =
|
213
|
+
self.origin_input_ids_unpadded = (
|
214
|
+
origin_input_ids_unpadded
|
215
|
+
if origin_input_ids_unpadded
|
216
|
+
else origin_input_ids # Before image padding
|
217
|
+
)
|
189
218
|
self.origin_input_ids = origin_input_ids
|
190
219
|
self.output_ids = [] # Each decode stage's output ids
|
191
220
|
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
@@ -193,6 +222,7 @@ class Req:
|
|
193
222
|
|
194
223
|
self.sampling_params = sampling_params
|
195
224
|
self.lora_path = lora_path
|
225
|
+
self.input_embeds = input_embeds
|
196
226
|
|
197
227
|
# Memory pool info
|
198
228
|
self.req_pool_idx = None
|
@@ -260,6 +290,12 @@ class Req:
|
|
260
290
|
# The number of cached tokens, that were already cached in the KV cache
|
261
291
|
self.cached_tokens = 0
|
262
292
|
|
293
|
+
def extend_image_inputs(self, image_inputs, vocab_size):
|
294
|
+
if self.image_inputs is None:
|
295
|
+
self.image_inputs = image_inputs
|
296
|
+
else:
|
297
|
+
self.image_inputs.merge(image_inputs, vocab_size)
|
298
|
+
|
263
299
|
# whether request reached finished condition
|
264
300
|
def finished(self) -> bool:
|
265
301
|
return self.finished_reason is not None
|
@@ -439,14 +475,18 @@ class ScheduleBatch:
|
|
439
475
|
token_to_kv_pool: BaseTokenToKVPool = None
|
440
476
|
tree_cache: BasePrefixCache = None
|
441
477
|
|
442
|
-
#
|
478
|
+
# Batch configs
|
443
479
|
model_config: ModelConfig = None
|
444
480
|
forward_mode: ForwardMode = None
|
481
|
+
enable_overlap: bool = False
|
482
|
+
|
483
|
+
# Sampling info
|
445
484
|
sampling_info: SamplingBatchInfo = None
|
446
485
|
next_batch_sampling_info: SamplingBatchInfo = None
|
447
486
|
|
448
487
|
# Batched arguments to model runner
|
449
488
|
input_ids: torch.Tensor = None
|
489
|
+
input_embeds: torch.Tensor = None
|
450
490
|
req_pool_indices: torch.Tensor = None
|
451
491
|
seq_lens: torch.Tensor = None
|
452
492
|
# The output locations of the KV cache
|
@@ -469,6 +509,7 @@ class ScheduleBatch:
|
|
469
509
|
extend_lens: List[int] = None
|
470
510
|
extend_num_tokens: int = None
|
471
511
|
decoding_reqs: List[Req] = None
|
512
|
+
extend_logprob_start_lens: List[int] = None
|
472
513
|
|
473
514
|
# For encoder-decoder
|
474
515
|
encoder_cached: Optional[List[bool]] = None
|
@@ -489,10 +530,11 @@ class ScheduleBatch:
|
|
489
530
|
def init_new(
|
490
531
|
cls,
|
491
532
|
reqs: List[Req],
|
492
|
-
req_to_token_pool,
|
493
|
-
token_to_kv_pool,
|
494
|
-
tree_cache,
|
495
|
-
model_config,
|
533
|
+
req_to_token_pool: ReqToTokenPool,
|
534
|
+
token_to_kv_pool: ReqToTokenPool,
|
535
|
+
tree_cache: BasePrefixCache,
|
536
|
+
model_config: ModelConfig,
|
537
|
+
enable_overlap: bool,
|
496
538
|
):
|
497
539
|
return cls(
|
498
540
|
reqs=reqs,
|
@@ -500,6 +542,7 @@ class ScheduleBatch:
|
|
500
542
|
token_to_kv_pool=token_to_kv_pool,
|
501
543
|
tree_cache=tree_cache,
|
502
544
|
model_config=model_config,
|
545
|
+
enable_overlap=enable_overlap,
|
503
546
|
return_logprob=any(req.return_logprob for req in reqs),
|
504
547
|
has_stream=any(req.stream for req in reqs),
|
505
548
|
has_grammar=any(req.grammar for req in reqs),
|
@@ -613,7 +656,7 @@ class ScheduleBatch:
|
|
613
656
|
|
614
657
|
assert len(self.out_cache_loc) == self.extend_num_tokens
|
615
658
|
|
616
|
-
def prepare_for_extend(self
|
659
|
+
def prepare_for_extend(self):
|
617
660
|
self.forward_mode = ForwardMode.EXTEND
|
618
661
|
|
619
662
|
bs = len(self.reqs)
|
@@ -627,6 +670,9 @@ class ScheduleBatch:
|
|
627
670
|
req_pool_indices = self.alloc_req_slots(bs)
|
628
671
|
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
629
672
|
|
673
|
+
input_embeds = []
|
674
|
+
|
675
|
+
pt = 0
|
630
676
|
for i, req in enumerate(reqs):
|
631
677
|
already_computed = (
|
632
678
|
req.extend_logprob_start_len + 1 + req.cached_tokens
|
@@ -645,6 +691,11 @@ class ScheduleBatch:
|
|
645
691
|
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
646
692
|
)
|
647
693
|
|
694
|
+
# If input_embeds are available, store them
|
695
|
+
if req.input_embeds is not None:
|
696
|
+
# If req.input_embeds is already a list, append its content directly
|
697
|
+
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
698
|
+
|
648
699
|
# Compute the relative logprob_start_len in an extend batch
|
649
700
|
if req.logprob_start_len >= pre_len:
|
650
701
|
extend_logprob_start_len = min(
|
@@ -667,6 +718,12 @@ class ScheduleBatch:
|
|
667
718
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
668
719
|
self.device, non_blocking=True
|
669
720
|
)
|
721
|
+
self.input_embeds = (
|
722
|
+
torch.tensor(input_embeds).to(self.device, non_blocking=True)
|
723
|
+
if input_embeds
|
724
|
+
else None
|
725
|
+
)
|
726
|
+
|
670
727
|
self.out_cache_loc = out_cache_loc
|
671
728
|
|
672
729
|
self.seq_lens_sum = sum(seq_lens)
|
@@ -707,7 +764,7 @@ class ScheduleBatch:
|
|
707
764
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
708
765
|
self,
|
709
766
|
self.model_config.vocab_size,
|
710
|
-
enable_overlap_schedule=
|
767
|
+
enable_overlap_schedule=self.enable_overlap,
|
711
768
|
)
|
712
769
|
|
713
770
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
@@ -724,16 +781,20 @@ class ScheduleBatch:
|
|
724
781
|
self.merge_batch(running_batch)
|
725
782
|
self.input_ids = input_ids
|
726
783
|
self.out_cache_loc = out_cache_loc
|
727
|
-
|
784
|
+
|
785
|
+
# For overlap scheduler, the output_ids has one step delay
|
786
|
+
delta = 0 if self.enable_overlap else -1
|
728
787
|
|
729
788
|
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
730
789
|
self.prefix_lens.extend(
|
731
790
|
[
|
732
|
-
len(r.origin_input_ids) + len(r.output_ids)
|
791
|
+
len(r.origin_input_ids) + len(r.output_ids) + delta
|
733
792
|
for r in running_batch.reqs
|
734
793
|
]
|
735
794
|
)
|
736
795
|
self.extend_lens.extend([1] * running_bs)
|
796
|
+
self.extend_num_tokens += running_bs
|
797
|
+
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
737
798
|
self.extend_logprob_start_lens.extend([0] * running_bs)
|
738
799
|
|
739
800
|
def check_decode_mem(self):
|
@@ -897,7 +958,7 @@ class ScheduleBatch:
|
|
897
958
|
self.seq_lens_sum = 0
|
898
959
|
self.extend_num_tokens = 0
|
899
960
|
|
900
|
-
def prepare_for_decode(self
|
961
|
+
def prepare_for_decode(self):
|
901
962
|
self.forward_mode = ForwardMode.DECODE
|
902
963
|
|
903
964
|
self.input_ids = self.output_ids
|
@@ -914,7 +975,7 @@ class ScheduleBatch:
|
|
914
975
|
else:
|
915
976
|
locs = self.seq_lens
|
916
977
|
|
917
|
-
if enable_overlap:
|
978
|
+
if self.enable_overlap:
|
918
979
|
# Do not use in-place operations in the overlap mode
|
919
980
|
self.req_to_token_pool.write(
|
920
981
|
(self.req_pool_indices, locs), self.out_cache_loc
|
@@ -1045,6 +1106,7 @@ class ScheduleBatch:
|
|
1045
1106
|
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
1046
1107
|
lora_paths=[req.lora_path for req in self.reqs],
|
1047
1108
|
sampling_info=self.sampling_info,
|
1109
|
+
input_embeds=self.input_embeds,
|
1048
1110
|
)
|
1049
1111
|
|
1050
1112
|
def copy(self):
|
@@ -1115,6 +1177,9 @@ class ModelWorkerBatch:
|
|
1115
1177
|
# Sampling info
|
1116
1178
|
sampling_info: SamplingBatchInfo
|
1117
1179
|
|
1180
|
+
# The input Embeds
|
1181
|
+
input_embeds: Optional[torch.tensor] = None
|
1182
|
+
|
1118
1183
|
|
1119
1184
|
@triton.jit
|
1120
1185
|
def write_req_to_token_pool_triton(
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Request scheduler policy"""
|
17
15
|
|
18
16
|
import os
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -1,21 +1,18 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
17
15
|
|
18
|
-
import dataclasses
|
19
16
|
import logging
|
20
17
|
import os
|
21
18
|
import threading
|
@@ -30,7 +27,7 @@ import torch
|
|
30
27
|
import zmq
|
31
28
|
|
32
29
|
from sglang.global_config import global_config
|
33
|
-
from sglang.srt.configs.model_config import
|
30
|
+
from sglang.srt.configs.model_config import ModelConfig
|
34
31
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
35
32
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
36
33
|
from sglang.srt.managers.io_struct import (
|
@@ -74,8 +71,10 @@ from sglang.srt.utils import (
|
|
74
71
|
broadcast_pyobj,
|
75
72
|
configure_logger,
|
76
73
|
crash_on_warnings,
|
74
|
+
get_bool_env_var,
|
77
75
|
get_zmq_socket,
|
78
76
|
kill_parent_process,
|
77
|
+
set_gpu_proc_affinity,
|
79
78
|
set_random_seed,
|
80
79
|
suppress_other_loggers,
|
81
80
|
)
|
@@ -84,7 +83,7 @@ from sglang.utils import get_exception_traceback
|
|
84
83
|
logger = logging.getLogger(__name__)
|
85
84
|
|
86
85
|
# Test retract decode
|
87
|
-
test_retract =
|
86
|
+
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
|
88
87
|
|
89
88
|
|
90
89
|
class Scheduler:
|
@@ -304,6 +303,9 @@ class Scheduler:
|
|
304
303
|
) / global_config.default_new_token_ratio_decay_steps
|
305
304
|
self.new_token_ratio = self.init_new_token_ratio
|
306
305
|
|
306
|
+
# Tells whether the current running batch is full so that we can skip
|
307
|
+
# the check of whether to prefill new requests.
|
308
|
+
# This is an optimization to reduce the overhead of the prefill check.
|
307
309
|
self.batch_is_full = False
|
308
310
|
|
309
311
|
# Init watchdog thread
|
@@ -466,6 +468,7 @@ class Scheduler:
|
|
466
468
|
self.token_to_kv_pool,
|
467
469
|
self.tree_cache,
|
468
470
|
self.model_config,
|
471
|
+
self.enable_overlap,
|
469
472
|
)
|
470
473
|
idle_batch.prepare_for_idle()
|
471
474
|
return idle_batch
|
@@ -524,14 +527,23 @@ class Scheduler:
|
|
524
527
|
recv_req: TokenizedGenerateReqInput,
|
525
528
|
):
|
526
529
|
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
|
530
|
+
# Create a new request
|
531
|
+
if recv_req.input_embeds is not None:
|
532
|
+
# Generate fake input_ids based on the length of input_embeds
|
533
|
+
seq_length = len(recv_req.input_embeds)
|
534
|
+
fake_input_ids = [1] * seq_length
|
535
|
+
recv_req.input_ids = fake_input_ids
|
536
|
+
|
527
537
|
req = Req(
|
528
538
|
recv_req.rid,
|
529
539
|
recv_req.input_text,
|
530
540
|
recv_req.input_ids,
|
531
541
|
recv_req.sampling_params,
|
532
542
|
lora_path=recv_req.lora_path,
|
543
|
+
input_embeds=recv_req.input_embeds,
|
533
544
|
)
|
534
545
|
req.tokenizer = self.tokenizer
|
546
|
+
|
535
547
|
if recv_req.session_id is not None:
|
536
548
|
req.finished_reason = FINISH_ABORT(
|
537
549
|
f"Invalid request: session id {recv_req.session_id} does not exist"
|
@@ -539,23 +551,22 @@ class Scheduler:
|
|
539
551
|
self.waiting_queue.append(req)
|
540
552
|
return
|
541
553
|
else:
|
542
|
-
#
|
554
|
+
# Create a new request from a previsou session
|
543
555
|
session = self.sessions[recv_req.session_id]
|
544
|
-
req
|
545
|
-
del self.sessions[recv_req.session_id]
|
546
|
-
self.sessions[new_session_id] = session
|
556
|
+
req = session.create_req(recv_req, self.tokenizer)
|
547
557
|
if isinstance(req.finished_reason, FINISH_ABORT):
|
548
558
|
self.waiting_queue.append(req)
|
549
559
|
return
|
550
560
|
|
551
561
|
# Image inputs
|
552
562
|
if recv_req.image_inputs is not None:
|
553
|
-
|
563
|
+
image_inputs = ImageInputs.from_dict(
|
554
564
|
recv_req.image_inputs, self.model_config.vocab_size
|
555
565
|
)
|
556
566
|
req.origin_input_ids = self.pad_input_ids_func(
|
557
|
-
req.
|
567
|
+
req.origin_input_ids, image_inputs
|
558
568
|
)
|
569
|
+
req.extend_image_inputs(image_inputs, self.model_config.vocab_size)
|
559
570
|
|
560
571
|
if len(req.origin_input_ids) > self.max_req_input_len:
|
561
572
|
req.finished_reason = FINISH_ABORT(
|
@@ -723,40 +734,30 @@ class Scheduler:
|
|
723
734
|
|
724
735
|
def get_next_batch_to_run(self):
|
725
736
|
# Merge the prefill batch into the running batch
|
726
|
-
if (
|
727
|
-
self.last_batch
|
728
|
-
and not self.last_batch.forward_mode.is_decode()
|
729
|
-
and not self.last_batch.is_empty()
|
730
|
-
):
|
737
|
+
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
731
738
|
if self.being_chunked_req:
|
739
|
+
# Move the chunked request out of the batch
|
732
740
|
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
|
733
741
|
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
|
734
|
-
# Inflight request keeps its rid but will get a new req_pool_idx
|
742
|
+
# Inflight request keeps its rid but will get a new req_pool_idx
|
735
743
|
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
|
736
744
|
self.batch_is_full = False
|
745
|
+
|
737
746
|
if not self.last_batch.is_empty():
|
738
747
|
if self.running_batch is None:
|
739
748
|
self.running_batch = self.last_batch
|
740
749
|
else:
|
741
750
|
self.running_batch.merge_batch(self.last_batch)
|
742
751
|
|
743
|
-
#
|
752
|
+
# Run prefill first if possible
|
744
753
|
new_batch = self.get_new_batch_prefill()
|
745
754
|
if new_batch is not None:
|
746
755
|
return new_batch
|
747
756
|
|
748
|
-
# Check memory
|
749
|
-
if self.running_batch is None:
|
750
|
-
return
|
751
|
-
|
752
757
|
# Run decode
|
753
|
-
|
754
|
-
self.update_running_batch()
|
755
|
-
if not self.running_batch:
|
756
|
-
self.batch_is_full = False
|
758
|
+
if self.running_batch is None:
|
757
759
|
return None
|
758
|
-
|
759
|
-
self.batch_is_full = False
|
760
|
+
self.running_batch = self.update_running_batch(self.running_batch)
|
760
761
|
return self.running_batch
|
761
762
|
|
762
763
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
@@ -852,14 +853,20 @@ class Scheduler:
|
|
852
853
|
self.token_to_kv_pool,
|
853
854
|
self.tree_cache,
|
854
855
|
self.model_config,
|
856
|
+
self.enable_overlap,
|
855
857
|
)
|
856
|
-
new_batch.prepare_for_extend(
|
858
|
+
new_batch.prepare_for_extend()
|
857
859
|
|
858
860
|
# Mixed-style chunked prefill
|
859
|
-
if
|
861
|
+
if (
|
862
|
+
self.is_mixed_chunk
|
863
|
+
and self.running_batch is not None
|
864
|
+
and not (new_batch.return_logprob or self.running_batch.return_logprob)
|
865
|
+
):
|
866
|
+
# TODO (lianmin): support return_logprob + mixed chunked prefill
|
860
867
|
self.running_batch.filter_batch()
|
861
868
|
if not self.running_batch.is_empty():
|
862
|
-
self.running_batch.prepare_for_decode(
|
869
|
+
self.running_batch.prepare_for_decode()
|
863
870
|
new_batch.mix_with_running(self.running_batch)
|
864
871
|
new_batch.decoding_reqs = self.running_batch.reqs
|
865
872
|
self.running_batch = None
|
@@ -868,15 +875,16 @@ class Scheduler:
|
|
868
875
|
|
869
876
|
return new_batch
|
870
877
|
|
871
|
-
def update_running_batch(self):
|
878
|
+
def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
|
872
879
|
"""Update the current running decoding batch."""
|
873
880
|
global test_retract
|
874
|
-
|
881
|
+
|
882
|
+
initial_bs = batch.batch_size()
|
875
883
|
|
876
884
|
batch.filter_batch()
|
877
885
|
if batch.is_empty():
|
878
|
-
self.
|
879
|
-
return
|
886
|
+
self.batch_is_full = False
|
887
|
+
return None
|
880
888
|
|
881
889
|
# Check if decode out of memory
|
882
890
|
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
|
@@ -902,11 +910,15 @@ class Scheduler:
|
|
902
910
|
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
903
911
|
self.waiting_queue.extend(jump_forward_reqs)
|
904
912
|
if batch.is_empty():
|
905
|
-
self.
|
906
|
-
return
|
913
|
+
self.batch_is_full = False
|
914
|
+
return None
|
915
|
+
|
916
|
+
if batch.batch_size() < initial_bs:
|
917
|
+
self.batch_is_full = False
|
907
918
|
|
908
919
|
# Update batch tensors
|
909
|
-
batch.prepare_for_decode(
|
920
|
+
batch.prepare_for_decode()
|
921
|
+
return batch
|
910
922
|
|
911
923
|
def run_batch(self, batch: ScheduleBatch):
|
912
924
|
"""Run a batch."""
|
@@ -981,8 +993,13 @@ class Scheduler:
|
|
981
993
|
if req.is_retracted:
|
982
994
|
continue
|
983
995
|
|
996
|
+
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
997
|
+
# Free the one delayed token for the mixed decode batch
|
998
|
+
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
999
|
+
self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
|
1000
|
+
continue
|
1001
|
+
|
984
1002
|
if req.is_being_chunked <= 0:
|
985
|
-
# Inflight reqs' prefill is not finished
|
986
1003
|
req.completion_tokens_wo_jump_forward += 1
|
987
1004
|
req.output_ids.append(next_token_id)
|
988
1005
|
req.check_finished()
|
@@ -992,14 +1009,15 @@ class Scheduler:
|
|
992
1009
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
993
1010
|
self.tree_cache.cache_unfinished_req(req)
|
994
1011
|
|
995
|
-
if req.grammar is not None:
|
996
|
-
req.grammar.accept_token(next_token_id)
|
997
|
-
|
998
1012
|
if req.return_logprob:
|
999
1013
|
logprob_pt += self.add_logprob_return_values(
|
1000
1014
|
i, req, logprob_pt, next_token_ids, logits_output
|
1001
1015
|
)
|
1016
|
+
|
1017
|
+
if req.grammar is not None:
|
1018
|
+
req.grammar.accept_token(next_token_id)
|
1002
1019
|
else:
|
1020
|
+
# Inflight reqs' prefill is not finished
|
1003
1021
|
req.is_being_chunked -= 1
|
1004
1022
|
|
1005
1023
|
if batch.next_batch_sampling_info:
|
@@ -1017,18 +1035,18 @@ class Scheduler:
|
|
1017
1035
|
continue
|
1018
1036
|
|
1019
1037
|
req.embedding = embeddings[i]
|
1020
|
-
if req.is_being_chunked
|
1021
|
-
|
1022
|
-
else:
|
1023
|
-
# Inflight reqs' prefill is not finished
|
1024
|
-
# dummy output token for embedding models
|
1038
|
+
if req.is_being_chunked <= 0:
|
1039
|
+
# Dummy output token for embedding models
|
1025
1040
|
req.output_ids.append(0)
|
1026
1041
|
req.check_finished()
|
1027
1042
|
|
1028
|
-
|
1029
|
-
|
1043
|
+
if req.finished():
|
1044
|
+
self.tree_cache.cache_finished_req(req)
|
1045
|
+
else:
|
1046
|
+
self.tree_cache.cache_unfinished_req(req)
|
1030
1047
|
else:
|
1031
|
-
|
1048
|
+
# Inflight reqs' prefill is not finished
|
1049
|
+
req.is_being_chunked -= 1
|
1032
1050
|
|
1033
1051
|
self.stream_output(batch.reqs)
|
1034
1052
|
|
@@ -1056,6 +1074,7 @@ class Scheduler:
|
|
1056
1074
|
continue
|
1057
1075
|
|
1058
1076
|
if self.enable_overlap and req.finished():
|
1077
|
+
# Free the one delayed token
|
1059
1078
|
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
1060
1079
|
continue
|
1061
1080
|
|
@@ -1063,9 +1082,6 @@ class Scheduler:
|
|
1063
1082
|
req.output_ids.append(next_token_id)
|
1064
1083
|
req.check_finished()
|
1065
1084
|
|
1066
|
-
if req.grammar is not None:
|
1067
|
-
req.grammar.accept_token(next_token_id)
|
1068
|
-
|
1069
1085
|
if req.finished():
|
1070
1086
|
self.tree_cache.cache_finished_req(req)
|
1071
1087
|
|
@@ -1076,6 +1092,9 @@ class Scheduler:
|
|
1076
1092
|
if req.top_logprobs_num > 0:
|
1077
1093
|
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
1078
1094
|
|
1095
|
+
if req.grammar is not None:
|
1096
|
+
req.grammar.accept_token(next_token_id)
|
1097
|
+
|
1079
1098
|
if batch.next_batch_sampling_info:
|
1080
1099
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1081
1100
|
torch.cuda.current_stream().synchronize()
|
@@ -1179,7 +1198,6 @@ class Scheduler:
|
|
1179
1198
|
output_skip_special_tokens = []
|
1180
1199
|
output_spaces_between_special_tokens = []
|
1181
1200
|
output_no_stop_trim = []
|
1182
|
-
output_session_ids = []
|
1183
1201
|
else: # embedding or reward model
|
1184
1202
|
output_embeddings = []
|
1185
1203
|
|
@@ -1207,7 +1225,6 @@ class Scheduler:
|
|
1207
1225
|
req.sampling_params.spaces_between_special_tokens
|
1208
1226
|
)
|
1209
1227
|
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
|
1210
|
-
output_session_ids.append(req.session_id)
|
1211
1228
|
|
1212
1229
|
meta_info = {
|
1213
1230
|
"prompt_tokens": len(req.origin_input_ids),
|
@@ -1258,7 +1275,6 @@ class Scheduler:
|
|
1258
1275
|
output_meta_info,
|
1259
1276
|
output_finished_reason,
|
1260
1277
|
output_no_stop_trim,
|
1261
|
-
output_session_ids,
|
1262
1278
|
)
|
1263
1279
|
)
|
1264
1280
|
else: # embedding or reward model
|
@@ -1389,9 +1405,13 @@ def run_scheduler_process(
|
|
1389
1405
|
dp_rank: Optional[int],
|
1390
1406
|
pipe_writer,
|
1391
1407
|
):
|
1408
|
+
# set cpu affinity to this gpu process
|
1409
|
+
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
1410
|
+
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
1411
|
+
|
1392
1412
|
# [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
|
1393
|
-
if dp_rank is None:
|
1394
|
-
dp_rank = int(os.
|
1413
|
+
if dp_rank is None and "DP_RANK" in os.environ:
|
1414
|
+
dp_rank = int(os.environ["DP_RANK"])
|
1395
1415
|
|
1396
1416
|
if dp_rank is None:
|
1397
1417
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
@@ -1402,7 +1422,9 @@ def run_scheduler_process(
|
|
1402
1422
|
|
1403
1423
|
try:
|
1404
1424
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
1405
|
-
pipe_writer.send(
|
1425
|
+
pipe_writer.send(
|
1426
|
+
{"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
|
1427
|
+
)
|
1406
1428
|
if scheduler.enable_overlap:
|
1407
1429
|
scheduler.event_loop_overlap()
|
1408
1430
|
else:
|