sglang 0.3.6__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +2 -2
- sglang/bench_one_batch.py +2 -4
- sglang/bench_serving.py +75 -26
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +2 -2
- sglang/srt/configs/model_config.py +13 -14
- sglang/srt/constrained/__init__.py +13 -14
- sglang/srt/constrained/base_grammar_backend.py +13 -15
- sglang/srt/constrained/outlines_backend.py +13 -15
- sglang/srt/constrained/outlines_jump_forward.py +13 -15
- sglang/srt/constrained/xgrammar_backend.py +38 -57
- sglang/srt/conversation.py +13 -15
- sglang/srt/hf_transformers_utils.py +13 -15
- sglang/srt/layers/activation.py +13 -13
- sglang/srt/layers/attention/flashinfer_backend.py +13 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
- sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
- sglang/srt/layers/custom_op_util.py +13 -14
- sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
- sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
- sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
- sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
- sglang/srt/layers/fused_moe_triton/layer.py +633 -0
- sglang/srt/layers/layernorm.py +13 -15
- sglang/srt/layers/logits_processor.py +13 -15
- sglang/srt/layers/quantization/__init__.py +77 -17
- sglang/srt/layers/radix_attention.py +13 -15
- sglang/srt/layers/rotary_embedding.py +13 -13
- sglang/srt/lora/lora.py +13 -14
- sglang/srt/lora/lora_config.py +13 -14
- sglang/srt/lora/lora_manager.py +22 -24
- sglang/srt/managers/data_parallel_controller.py +25 -19
- sglang/srt/managers/detokenizer_manager.py +13 -16
- sglang/srt/managers/io_struct.py +43 -28
- sglang/srt/managers/schedule_batch.py +55 -26
- sglang/srt/managers/schedule_policy.py +13 -15
- sglang/srt/managers/scheduler.py +89 -70
- sglang/srt/managers/session_controller.py +14 -15
- sglang/srt/managers/tokenizer_manager.py +29 -22
- sglang/srt/managers/tp_worker.py +13 -15
- sglang/srt/managers/tp_worker_overlap_thread.py +13 -15
- sglang/srt/metrics/collector.py +13 -15
- sglang/srt/metrics/func_timer.py +13 -15
- sglang/srt/mm_utils.py +13 -14
- sglang/srt/model_executor/cuda_graph_runner.py +20 -19
- sglang/srt/model_executor/forward_batch_info.py +19 -17
- sglang/srt/model_executor/model_runner.py +42 -30
- sglang/srt/models/chatglm.py +15 -16
- sglang/srt/models/commandr.py +15 -16
- sglang/srt/models/dbrx.py +15 -16
- sglang/srt/models/deepseek.py +15 -15
- sglang/srt/models/deepseek_v2.py +15 -15
- sglang/srt/models/exaone.py +14 -15
- sglang/srt/models/gemma.py +14 -14
- sglang/srt/models/gemma2.py +24 -19
- sglang/srt/models/gemma2_reward.py +13 -14
- sglang/srt/models/gpt_bigcode.py +14 -14
- sglang/srt/models/grok.py +15 -15
- sglang/srt/models/internlm2.py +13 -15
- sglang/srt/models/internlm2_reward.py +13 -14
- sglang/srt/models/llama.py +21 -21
- sglang/srt/models/llama_classification.py +13 -14
- sglang/srt/models/llama_reward.py +13 -14
- sglang/srt/models/llava.py +13 -15
- sglang/srt/models/llavavid.py +13 -15
- sglang/srt/models/minicpm.py +13 -15
- sglang/srt/models/minicpm3.py +13 -15
- sglang/srt/models/mistral.py +13 -15
- sglang/srt/models/mixtral.py +15 -15
- sglang/srt/models/mixtral_quant.py +14 -14
- sglang/srt/models/olmo.py +21 -19
- sglang/srt/models/olmoe.py +23 -20
- sglang/srt/models/qwen.py +14 -14
- sglang/srt/models/qwen2.py +22 -19
- sglang/srt/models/qwen2_moe.py +17 -18
- sglang/srt/models/stablelm.py +18 -16
- sglang/srt/models/torch_native_llama.py +15 -17
- sglang/srt/models/xverse.py +13 -14
- sglang/srt/models/xverse_moe.py +15 -16
- sglang/srt/models/yivl.py +13 -15
- sglang/srt/openai_api/adapter.py +13 -15
- sglang/srt/openai_api/protocol.py +13 -15
- sglang/srt/sampling/sampling_batch_info.py +4 -1
- sglang/srt/sampling/sampling_params.py +13 -15
- sglang/srt/server.py +59 -34
- sglang/srt/server_args.py +22 -22
- sglang/srt/utils.py +196 -17
- sglang/test/few_shot_gsm8k.py +8 -4
- sglang/test/runners.py +13 -14
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +24 -15
- sglang-0.3.6.post1.dist-info/RECORD +164 -0
- sglang/srt/layers/fused_moe/__init__.py +0 -1
- sglang-0.3.6.dist-info/RECORD +0 -161
- /sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +0 -0
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""
|
17
15
|
Store information about requests and batches.
|
18
16
|
|
@@ -180,6 +178,7 @@ class Req:
|
|
180
178
|
origin_input_ids: Tuple[int],
|
181
179
|
sampling_params: SamplingParams,
|
182
180
|
lora_path: Optional[str] = None,
|
181
|
+
input_embeds: Optional[List[List[float]]] = None,
|
183
182
|
session_id: Optional[str] = None,
|
184
183
|
):
|
185
184
|
# Input and output info
|
@@ -193,6 +192,7 @@ class Req:
|
|
193
192
|
|
194
193
|
self.sampling_params = sampling_params
|
195
194
|
self.lora_path = lora_path
|
195
|
+
self.input_embeds = input_embeds
|
196
196
|
|
197
197
|
# Memory pool info
|
198
198
|
self.req_pool_idx = None
|
@@ -439,14 +439,18 @@ class ScheduleBatch:
|
|
439
439
|
token_to_kv_pool: BaseTokenToKVPool = None
|
440
440
|
tree_cache: BasePrefixCache = None
|
441
441
|
|
442
|
-
#
|
442
|
+
# Batch configs
|
443
443
|
model_config: ModelConfig = None
|
444
444
|
forward_mode: ForwardMode = None
|
445
|
+
enable_overlap: bool = False
|
446
|
+
|
447
|
+
# Sampling info
|
445
448
|
sampling_info: SamplingBatchInfo = None
|
446
449
|
next_batch_sampling_info: SamplingBatchInfo = None
|
447
450
|
|
448
451
|
# Batched arguments to model runner
|
449
452
|
input_ids: torch.Tensor = None
|
453
|
+
input_embeds: torch.Tensor = None
|
450
454
|
req_pool_indices: torch.Tensor = None
|
451
455
|
seq_lens: torch.Tensor = None
|
452
456
|
# The output locations of the KV cache
|
@@ -469,6 +473,7 @@ class ScheduleBatch:
|
|
469
473
|
extend_lens: List[int] = None
|
470
474
|
extend_num_tokens: int = None
|
471
475
|
decoding_reqs: List[Req] = None
|
476
|
+
extend_logprob_start_lens: List[int] = None
|
472
477
|
|
473
478
|
# For encoder-decoder
|
474
479
|
encoder_cached: Optional[List[bool]] = None
|
@@ -489,10 +494,11 @@ class ScheduleBatch:
|
|
489
494
|
def init_new(
|
490
495
|
cls,
|
491
496
|
reqs: List[Req],
|
492
|
-
req_to_token_pool,
|
493
|
-
token_to_kv_pool,
|
494
|
-
tree_cache,
|
495
|
-
model_config,
|
497
|
+
req_to_token_pool: ReqToTokenPool,
|
498
|
+
token_to_kv_pool: ReqToTokenPool,
|
499
|
+
tree_cache: BasePrefixCache,
|
500
|
+
model_config: ModelConfig,
|
501
|
+
enable_overlap: bool,
|
496
502
|
):
|
497
503
|
return cls(
|
498
504
|
reqs=reqs,
|
@@ -500,6 +506,7 @@ class ScheduleBatch:
|
|
500
506
|
token_to_kv_pool=token_to_kv_pool,
|
501
507
|
tree_cache=tree_cache,
|
502
508
|
model_config=model_config,
|
509
|
+
enable_overlap=enable_overlap,
|
503
510
|
return_logprob=any(req.return_logprob for req in reqs),
|
504
511
|
has_stream=any(req.stream for req in reqs),
|
505
512
|
has_grammar=any(req.grammar for req in reqs),
|
@@ -613,7 +620,7 @@ class ScheduleBatch:
|
|
613
620
|
|
614
621
|
assert len(self.out_cache_loc) == self.extend_num_tokens
|
615
622
|
|
616
|
-
def prepare_for_extend(self
|
623
|
+
def prepare_for_extend(self):
|
617
624
|
self.forward_mode = ForwardMode.EXTEND
|
618
625
|
|
619
626
|
bs = len(self.reqs)
|
@@ -627,6 +634,9 @@ class ScheduleBatch:
|
|
627
634
|
req_pool_indices = self.alloc_req_slots(bs)
|
628
635
|
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
629
636
|
|
637
|
+
input_embeds = []
|
638
|
+
|
639
|
+
pt = 0
|
630
640
|
for i, req in enumerate(reqs):
|
631
641
|
already_computed = (
|
632
642
|
req.extend_logprob_start_len + 1 + req.cached_tokens
|
@@ -645,6 +655,11 @@ class ScheduleBatch:
|
|
645
655
|
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
646
656
|
)
|
647
657
|
|
658
|
+
# If input_embeds are available, store them
|
659
|
+
if req.input_embeds is not None:
|
660
|
+
# If req.input_embeds is already a list, append its content directly
|
661
|
+
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
662
|
+
|
648
663
|
# Compute the relative logprob_start_len in an extend batch
|
649
664
|
if req.logprob_start_len >= pre_len:
|
650
665
|
extend_logprob_start_len = min(
|
@@ -667,6 +682,12 @@ class ScheduleBatch:
|
|
667
682
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
668
683
|
self.device, non_blocking=True
|
669
684
|
)
|
685
|
+
self.input_embeds = (
|
686
|
+
torch.tensor(input_embeds).to(self.device, non_blocking=True)
|
687
|
+
if input_embeds
|
688
|
+
else None
|
689
|
+
)
|
690
|
+
|
670
691
|
self.out_cache_loc = out_cache_loc
|
671
692
|
|
672
693
|
self.seq_lens_sum = sum(seq_lens)
|
@@ -707,7 +728,7 @@ class ScheduleBatch:
|
|
707
728
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
708
729
|
self,
|
709
730
|
self.model_config.vocab_size,
|
710
|
-
enable_overlap_schedule=
|
731
|
+
enable_overlap_schedule=self.enable_overlap,
|
711
732
|
)
|
712
733
|
|
713
734
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
@@ -724,16 +745,20 @@ class ScheduleBatch:
|
|
724
745
|
self.merge_batch(running_batch)
|
725
746
|
self.input_ids = input_ids
|
726
747
|
self.out_cache_loc = out_cache_loc
|
727
|
-
|
748
|
+
|
749
|
+
# For overlap scheduler, the output_ids has one step delay
|
750
|
+
delta = 0 if self.enable_overlap else -1
|
728
751
|
|
729
752
|
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
730
753
|
self.prefix_lens.extend(
|
731
754
|
[
|
732
|
-
len(r.origin_input_ids) + len(r.output_ids)
|
755
|
+
len(r.origin_input_ids) + len(r.output_ids) + delta
|
733
756
|
for r in running_batch.reqs
|
734
757
|
]
|
735
758
|
)
|
736
759
|
self.extend_lens.extend([1] * running_bs)
|
760
|
+
self.extend_num_tokens += running_bs
|
761
|
+
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
737
762
|
self.extend_logprob_start_lens.extend([0] * running_bs)
|
738
763
|
|
739
764
|
def check_decode_mem(self):
|
@@ -897,7 +922,7 @@ class ScheduleBatch:
|
|
897
922
|
self.seq_lens_sum = 0
|
898
923
|
self.extend_num_tokens = 0
|
899
924
|
|
900
|
-
def prepare_for_decode(self
|
925
|
+
def prepare_for_decode(self):
|
901
926
|
self.forward_mode = ForwardMode.DECODE
|
902
927
|
|
903
928
|
self.input_ids = self.output_ids
|
@@ -914,7 +939,7 @@ class ScheduleBatch:
|
|
914
939
|
else:
|
915
940
|
locs = self.seq_lens
|
916
941
|
|
917
|
-
if enable_overlap:
|
942
|
+
if self.enable_overlap:
|
918
943
|
# Do not use in-place operations in the overlap mode
|
919
944
|
self.req_to_token_pool.write(
|
920
945
|
(self.req_pool_indices, locs), self.out_cache_loc
|
@@ -1045,6 +1070,7 @@ class ScheduleBatch:
|
|
1045
1070
|
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
1046
1071
|
lora_paths=[req.lora_path for req in self.reqs],
|
1047
1072
|
sampling_info=self.sampling_info,
|
1073
|
+
input_embeds=self.input_embeds,
|
1048
1074
|
)
|
1049
1075
|
|
1050
1076
|
def copy(self):
|
@@ -1115,6 +1141,9 @@ class ModelWorkerBatch:
|
|
1115
1141
|
# Sampling info
|
1116
1142
|
sampling_info: SamplingBatchInfo
|
1117
1143
|
|
1144
|
+
# The input Embeds
|
1145
|
+
input_embeds: Optional[torch.tensor] = None
|
1146
|
+
|
1118
1147
|
|
1119
1148
|
@triton.jit
|
1120
1149
|
def write_req_to_token_pool_triton(
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Request scheduler policy"""
|
17
15
|
|
18
16
|
import os
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -1,21 +1,18 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
17
15
|
|
18
|
-
import dataclasses
|
19
16
|
import logging
|
20
17
|
import os
|
21
18
|
import threading
|
@@ -30,7 +27,7 @@ import torch
|
|
30
27
|
import zmq
|
31
28
|
|
32
29
|
from sglang.global_config import global_config
|
33
|
-
from sglang.srt.configs.model_config import
|
30
|
+
from sglang.srt.configs.model_config import ModelConfig
|
34
31
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
35
32
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
36
33
|
from sglang.srt.managers.io_struct import (
|
@@ -75,6 +72,7 @@ from sglang.srt.utils import (
|
|
75
72
|
configure_logger,
|
76
73
|
crash_on_warnings,
|
77
74
|
get_zmq_socket,
|
75
|
+
gpu_proc_affinity,
|
78
76
|
kill_parent_process,
|
79
77
|
set_random_seed,
|
80
78
|
suppress_other_loggers,
|
@@ -84,7 +82,7 @@ from sglang.utils import get_exception_traceback
|
|
84
82
|
logger = logging.getLogger(__name__)
|
85
83
|
|
86
84
|
# Test retract decode
|
87
|
-
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
|
85
|
+
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false").lower() == "true"
|
88
86
|
|
89
87
|
|
90
88
|
class Scheduler:
|
@@ -304,6 +302,9 @@ class Scheduler:
|
|
304
302
|
) / global_config.default_new_token_ratio_decay_steps
|
305
303
|
self.new_token_ratio = self.init_new_token_ratio
|
306
304
|
|
305
|
+
# Tells whether the current running batch is full so that we can skip
|
306
|
+
# the check of whether to prefill new requests.
|
307
|
+
# This is an optimization to reduce the overhead of the prefill check.
|
307
308
|
self.batch_is_full = False
|
308
309
|
|
309
310
|
# Init watchdog thread
|
@@ -466,6 +467,7 @@ class Scheduler:
|
|
466
467
|
self.token_to_kv_pool,
|
467
468
|
self.tree_cache,
|
468
469
|
self.model_config,
|
470
|
+
self.enable_overlap,
|
469
471
|
)
|
470
472
|
idle_batch.prepare_for_idle()
|
471
473
|
return idle_batch
|
@@ -524,14 +526,23 @@ class Scheduler:
|
|
524
526
|
recv_req: TokenizedGenerateReqInput,
|
525
527
|
):
|
526
528
|
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
|
529
|
+
# Create a new request
|
530
|
+
if recv_req.input_embeds is not None:
|
531
|
+
# Generate fake input_ids based on the length of input_embeds
|
532
|
+
seq_length = len(recv_req.input_embeds)
|
533
|
+
fake_input_ids = [1] * seq_length
|
534
|
+
recv_req.input_ids = fake_input_ids
|
535
|
+
|
527
536
|
req = Req(
|
528
537
|
recv_req.rid,
|
529
538
|
recv_req.input_text,
|
530
539
|
recv_req.input_ids,
|
531
540
|
recv_req.sampling_params,
|
532
541
|
lora_path=recv_req.lora_path,
|
542
|
+
input_embeds=recv_req.input_embeds,
|
533
543
|
)
|
534
544
|
req.tokenizer = self.tokenizer
|
545
|
+
|
535
546
|
if recv_req.session_id is not None:
|
536
547
|
req.finished_reason = FINISH_ABORT(
|
537
548
|
f"Invalid request: session id {recv_req.session_id} does not exist"
|
@@ -539,11 +550,9 @@ class Scheduler:
|
|
539
550
|
self.waiting_queue.append(req)
|
540
551
|
return
|
541
552
|
else:
|
542
|
-
#
|
553
|
+
# Create a new request from a previsou session
|
543
554
|
session = self.sessions[recv_req.session_id]
|
544
|
-
req
|
545
|
-
del self.sessions[recv_req.session_id]
|
546
|
-
self.sessions[new_session_id] = session
|
555
|
+
req = session.create_req(recv_req, self.tokenizer)
|
547
556
|
if isinstance(req.finished_reason, FINISH_ABORT):
|
548
557
|
self.waiting_queue.append(req)
|
549
558
|
return
|
@@ -723,40 +732,30 @@ class Scheduler:
|
|
723
732
|
|
724
733
|
def get_next_batch_to_run(self):
|
725
734
|
# Merge the prefill batch into the running batch
|
726
|
-
if (
|
727
|
-
self.last_batch
|
728
|
-
and not self.last_batch.forward_mode.is_decode()
|
729
|
-
and not self.last_batch.is_empty()
|
730
|
-
):
|
735
|
+
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
731
736
|
if self.being_chunked_req:
|
737
|
+
# Move the chunked request out of the batch
|
732
738
|
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
|
733
739
|
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
|
734
|
-
# Inflight request keeps its rid but will get a new req_pool_idx
|
740
|
+
# Inflight request keeps its rid but will get a new req_pool_idx
|
735
741
|
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
|
736
742
|
self.batch_is_full = False
|
743
|
+
|
737
744
|
if not self.last_batch.is_empty():
|
738
745
|
if self.running_batch is None:
|
739
746
|
self.running_batch = self.last_batch
|
740
747
|
else:
|
741
748
|
self.running_batch.merge_batch(self.last_batch)
|
742
749
|
|
743
|
-
#
|
750
|
+
# Run prefill first if possible
|
744
751
|
new_batch = self.get_new_batch_prefill()
|
745
752
|
if new_batch is not None:
|
746
753
|
return new_batch
|
747
754
|
|
748
|
-
# Check memory
|
749
|
-
if self.running_batch is None:
|
750
|
-
return
|
751
|
-
|
752
755
|
# Run decode
|
753
|
-
|
754
|
-
self.update_running_batch()
|
755
|
-
if not self.running_batch:
|
756
|
-
self.batch_is_full = False
|
756
|
+
if self.running_batch is None:
|
757
757
|
return None
|
758
|
-
|
759
|
-
self.batch_is_full = False
|
758
|
+
self.running_batch = self.update_running_batch(self.running_batch)
|
760
759
|
return self.running_batch
|
761
760
|
|
762
761
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
@@ -852,14 +851,20 @@ class Scheduler:
|
|
852
851
|
self.token_to_kv_pool,
|
853
852
|
self.tree_cache,
|
854
853
|
self.model_config,
|
854
|
+
self.enable_overlap,
|
855
855
|
)
|
856
|
-
new_batch.prepare_for_extend(
|
856
|
+
new_batch.prepare_for_extend()
|
857
857
|
|
858
858
|
# Mixed-style chunked prefill
|
859
|
-
if
|
859
|
+
if (
|
860
|
+
self.is_mixed_chunk
|
861
|
+
and self.running_batch is not None
|
862
|
+
and not (new_batch.return_logprob or self.running_batch.return_logprob)
|
863
|
+
):
|
864
|
+
# TODO (lianmin): support return_logprob + mixed chunked prefill
|
860
865
|
self.running_batch.filter_batch()
|
861
866
|
if not self.running_batch.is_empty():
|
862
|
-
self.running_batch.prepare_for_decode(
|
867
|
+
self.running_batch.prepare_for_decode()
|
863
868
|
new_batch.mix_with_running(self.running_batch)
|
864
869
|
new_batch.decoding_reqs = self.running_batch.reqs
|
865
870
|
self.running_batch = None
|
@@ -868,15 +873,16 @@ class Scheduler:
|
|
868
873
|
|
869
874
|
return new_batch
|
870
875
|
|
871
|
-
def update_running_batch(self):
|
876
|
+
def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
|
872
877
|
"""Update the current running decoding batch."""
|
873
878
|
global test_retract
|
874
|
-
|
879
|
+
|
880
|
+
initial_bs = batch.batch_size()
|
875
881
|
|
876
882
|
batch.filter_batch()
|
877
883
|
if batch.is_empty():
|
878
|
-
self.
|
879
|
-
return
|
884
|
+
self.batch_is_full = False
|
885
|
+
return None
|
880
886
|
|
881
887
|
# Check if decode out of memory
|
882
888
|
if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
|
@@ -902,11 +908,15 @@ class Scheduler:
|
|
902
908
|
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
903
909
|
self.waiting_queue.extend(jump_forward_reqs)
|
904
910
|
if batch.is_empty():
|
905
|
-
self.
|
906
|
-
return
|
911
|
+
self.batch_is_full = False
|
912
|
+
return None
|
913
|
+
|
914
|
+
if batch.batch_size() < initial_bs:
|
915
|
+
self.batch_is_full = False
|
907
916
|
|
908
917
|
# Update batch tensors
|
909
|
-
batch.prepare_for_decode(
|
918
|
+
batch.prepare_for_decode()
|
919
|
+
return batch
|
910
920
|
|
911
921
|
def run_batch(self, batch: ScheduleBatch):
|
912
922
|
"""Run a batch."""
|
@@ -981,8 +991,13 @@ class Scheduler:
|
|
981
991
|
if req.is_retracted:
|
982
992
|
continue
|
983
993
|
|
994
|
+
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
995
|
+
# Free the one delayed token for the mixed decode batch
|
996
|
+
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
997
|
+
self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
|
998
|
+
continue
|
999
|
+
|
984
1000
|
if req.is_being_chunked <= 0:
|
985
|
-
# Inflight reqs' prefill is not finished
|
986
1001
|
req.completion_tokens_wo_jump_forward += 1
|
987
1002
|
req.output_ids.append(next_token_id)
|
988
1003
|
req.check_finished()
|
@@ -992,14 +1007,15 @@ class Scheduler:
|
|
992
1007
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
993
1008
|
self.tree_cache.cache_unfinished_req(req)
|
994
1009
|
|
995
|
-
if req.grammar is not None:
|
996
|
-
req.grammar.accept_token(next_token_id)
|
997
|
-
|
998
1010
|
if req.return_logprob:
|
999
1011
|
logprob_pt += self.add_logprob_return_values(
|
1000
1012
|
i, req, logprob_pt, next_token_ids, logits_output
|
1001
1013
|
)
|
1014
|
+
|
1015
|
+
if req.grammar is not None:
|
1016
|
+
req.grammar.accept_token(next_token_id)
|
1002
1017
|
else:
|
1018
|
+
# Inflight reqs' prefill is not finished
|
1003
1019
|
req.is_being_chunked -= 1
|
1004
1020
|
|
1005
1021
|
if batch.next_batch_sampling_info:
|
@@ -1017,18 +1033,18 @@ class Scheduler:
|
|
1017
1033
|
continue
|
1018
1034
|
|
1019
1035
|
req.embedding = embeddings[i]
|
1020
|
-
if req.is_being_chunked
|
1021
|
-
|
1022
|
-
else:
|
1023
|
-
# Inflight reqs' prefill is not finished
|
1024
|
-
# dummy output token for embedding models
|
1036
|
+
if req.is_being_chunked <= 0:
|
1037
|
+
# Dummy output token for embedding models
|
1025
1038
|
req.output_ids.append(0)
|
1026
1039
|
req.check_finished()
|
1027
1040
|
|
1028
|
-
|
1029
|
-
|
1041
|
+
if req.finished():
|
1042
|
+
self.tree_cache.cache_finished_req(req)
|
1043
|
+
else:
|
1044
|
+
self.tree_cache.cache_unfinished_req(req)
|
1030
1045
|
else:
|
1031
|
-
|
1046
|
+
# Inflight reqs' prefill is not finished
|
1047
|
+
req.is_being_chunked -= 1
|
1032
1048
|
|
1033
1049
|
self.stream_output(batch.reqs)
|
1034
1050
|
|
@@ -1056,6 +1072,7 @@ class Scheduler:
|
|
1056
1072
|
continue
|
1057
1073
|
|
1058
1074
|
if self.enable_overlap and req.finished():
|
1075
|
+
# Free the one delayed token
|
1059
1076
|
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
1060
1077
|
continue
|
1061
1078
|
|
@@ -1063,9 +1080,6 @@ class Scheduler:
|
|
1063
1080
|
req.output_ids.append(next_token_id)
|
1064
1081
|
req.check_finished()
|
1065
1082
|
|
1066
|
-
if req.grammar is not None:
|
1067
|
-
req.grammar.accept_token(next_token_id)
|
1068
|
-
|
1069
1083
|
if req.finished():
|
1070
1084
|
self.tree_cache.cache_finished_req(req)
|
1071
1085
|
|
@@ -1076,6 +1090,9 @@ class Scheduler:
|
|
1076
1090
|
if req.top_logprobs_num > 0:
|
1077
1091
|
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
1078
1092
|
|
1093
|
+
if req.grammar is not None:
|
1094
|
+
req.grammar.accept_token(next_token_id)
|
1095
|
+
|
1079
1096
|
if batch.next_batch_sampling_info:
|
1080
1097
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1081
1098
|
torch.cuda.current_stream().synchronize()
|
@@ -1179,7 +1196,6 @@ class Scheduler:
|
|
1179
1196
|
output_skip_special_tokens = []
|
1180
1197
|
output_spaces_between_special_tokens = []
|
1181
1198
|
output_no_stop_trim = []
|
1182
|
-
output_session_ids = []
|
1183
1199
|
else: # embedding or reward model
|
1184
1200
|
output_embeddings = []
|
1185
1201
|
|
@@ -1207,7 +1223,6 @@ class Scheduler:
|
|
1207
1223
|
req.sampling_params.spaces_between_special_tokens
|
1208
1224
|
)
|
1209
1225
|
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
|
1210
|
-
output_session_ids.append(req.session_id)
|
1211
1226
|
|
1212
1227
|
meta_info = {
|
1213
1228
|
"prompt_tokens": len(req.origin_input_ids),
|
@@ -1258,7 +1273,6 @@ class Scheduler:
|
|
1258
1273
|
output_meta_info,
|
1259
1274
|
output_finished_reason,
|
1260
1275
|
output_no_stop_trim,
|
1261
|
-
output_session_ids,
|
1262
1276
|
)
|
1263
1277
|
)
|
1264
1278
|
else: # embedding or reward model
|
@@ -1389,9 +1403,12 @@ def run_scheduler_process(
|
|
1389
1403
|
dp_rank: Optional[int],
|
1390
1404
|
pipe_writer,
|
1391
1405
|
):
|
1406
|
+
# set cpu affinity to this gpu process
|
1407
|
+
gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
1408
|
+
|
1392
1409
|
# [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
|
1393
|
-
if dp_rank is None:
|
1394
|
-
dp_rank = int(os.
|
1410
|
+
if dp_rank is None and "DP_RANK" in os.environ:
|
1411
|
+
dp_rank = int(os.environ["DP_RANK"])
|
1395
1412
|
|
1396
1413
|
if dp_rank is None:
|
1397
1414
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
@@ -1402,7 +1419,9 @@ def run_scheduler_process(
|
|
1402
1419
|
|
1403
1420
|
try:
|
1404
1421
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
1405
|
-
pipe_writer.send(
|
1422
|
+
pipe_writer.send(
|
1423
|
+
{"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
|
1424
|
+
)
|
1406
1425
|
if scheduler.enable_overlap:
|
1407
1426
|
scheduler.event_loop_overlap()
|
1408
1427
|
else:
|
@@ -1,15 +1,14 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
# ==============================================================================
|
13
12
|
|
14
13
|
import copy
|
15
14
|
import uuid
|
@@ -27,13 +26,13 @@ class Session:
|
|
27
26
|
self.reqs: List[Req] = []
|
28
27
|
|
29
28
|
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
|
30
|
-
# renew session id
|
31
|
-
self.session_id = uuid.uuid4().hex
|
32
29
|
if req.session_rid is not None:
|
33
30
|
while len(self.reqs) > 0:
|
34
31
|
if self.reqs[-1].rid == req.session_rid:
|
35
32
|
break
|
36
33
|
self.reqs = self.reqs[:-1]
|
34
|
+
else:
|
35
|
+
self.reqs = []
|
37
36
|
if len(self.reqs) > 0:
|
38
37
|
input_ids = (
|
39
38
|
self.reqs[-1].origin_input_ids
|
@@ -59,4 +58,4 @@ class Session:
|
|
59
58
|
)
|
60
59
|
else:
|
61
60
|
self.reqs.append(new_req)
|
62
|
-
return new_req
|
61
|
+
return new_req
|