sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +2 -2
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +9 -7
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +48 -43
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +7 -2
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +227 -120
- sglang/srt/disaggregation/nixl/conn.py +1 -0
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +7 -1
- sglang/srt/entrypoints/engine.py +17 -2
- sglang/srt/entrypoints/http_server.py +17 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +1 -1
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +72 -71
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +3 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +76 -24
- sglang/srt/managers/schedule_policy.py +0 -3
- sglang/srt/managers/scheduler.py +113 -88
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +133 -34
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/memory_pool.py +2 -0
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +19 -14
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +23 -20
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +5 -6
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +30 -4
- sglang/srt/openai_api/protocol.py +0 -8
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +34 -4
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +6 -5
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +89 -14
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -20,7 +20,6 @@ import signal
|
|
20
20
|
import sys
|
21
21
|
import threading
|
22
22
|
import time
|
23
|
-
import warnings
|
24
23
|
from collections import defaultdict, deque
|
25
24
|
from concurrent import futures
|
26
25
|
from dataclasses import dataclass
|
@@ -121,11 +120,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|
121
120
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
122
121
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
123
122
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
124
|
-
from sglang.srt.model_executor.forward_batch_info import
|
125
|
-
ForwardBatch,
|
126
|
-
ForwardMode,
|
127
|
-
PPProxyTensors,
|
128
|
-
)
|
123
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
129
124
|
from sglang.srt.reasoning_parser import ReasoningParser
|
130
125
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
131
126
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
@@ -135,6 +130,7 @@ from sglang.srt.utils import (
|
|
135
130
|
broadcast_pyobj,
|
136
131
|
configure_logger,
|
137
132
|
crash_on_warnings,
|
133
|
+
disable_request_logging,
|
138
134
|
get_bool_env_var,
|
139
135
|
get_zmq_socket,
|
140
136
|
kill_itself_when_parent_died,
|
@@ -153,6 +149,7 @@ logger = logging.getLogger(__name__)
|
|
153
149
|
# Test retract decode for debugging purposes
|
154
150
|
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
155
151
|
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
152
|
+
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
156
153
|
|
157
154
|
|
158
155
|
@dataclass
|
@@ -163,6 +160,7 @@ class GenerationBatchResult:
|
|
163
160
|
extend_input_len_per_req: List[int]
|
164
161
|
extend_logprob_start_len_per_req: List[int]
|
165
162
|
bid: int
|
163
|
+
can_run_cuda_graph: bool
|
166
164
|
|
167
165
|
|
168
166
|
@dataclass
|
@@ -209,7 +207,8 @@ class Scheduler(
|
|
209
207
|
self.page_size = server_args.page_size
|
210
208
|
|
211
209
|
# Distributed rank info
|
212
|
-
self.
|
210
|
+
self.dp_size = server_args.dp_size
|
211
|
+
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
|
213
212
|
compute_dp_attention_world_info(
|
214
213
|
server_args.enable_dp_attention,
|
215
214
|
self.tp_rank,
|
@@ -326,13 +325,14 @@ class Scheduler(
|
|
326
325
|
set_random_seed(self.random_seed)
|
327
326
|
|
328
327
|
# Print debug info
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
328
|
+
if tp_rank == 0:
|
329
|
+
logger.info(
|
330
|
+
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
331
|
+
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
|
332
|
+
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
333
|
+
f"max_running_requests={self.max_running_requests}, "
|
334
|
+
f"context_len={self.model_config.context_len}"
|
335
|
+
)
|
336
336
|
|
337
337
|
# Init memory pool and cache
|
338
338
|
self.init_memory_pool_and_cache()
|
@@ -531,10 +531,6 @@ class Scheduler(
|
|
531
531
|
)
|
532
532
|
|
533
533
|
def init_metrics(self):
|
534
|
-
# The largest prefill length of a single request
|
535
|
-
self._largest_prefill_len: int = 0
|
536
|
-
# The largest context length (prefill + generation) of a single request
|
537
|
-
self._largest_prefill_decode_len: int = 0
|
538
534
|
self.last_gen_throughput: float = 0.0
|
539
535
|
self.last_input_throughput: float = 0.0
|
540
536
|
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
@@ -720,7 +716,7 @@ class Scheduler(
|
|
720
716
|
server_is_idle = False
|
721
717
|
result = self.run_batch(self.cur_batch)
|
722
718
|
|
723
|
-
# send the outputs to the next step
|
719
|
+
# (last rank) send the outputs to the next step
|
724
720
|
if self.pp_group.is_last_rank:
|
725
721
|
if self.cur_batch:
|
726
722
|
next_token_ids, bids[mb_id] = (
|
@@ -755,24 +751,25 @@ class Scheduler(
|
|
755
751
|
extend_input_len_per_req=None,
|
756
752
|
extend_logprob_start_len_per_req=None,
|
757
753
|
bid=bids[next_mb_id],
|
754
|
+
can_run_cuda_graph=result.can_run_cuda_graph,
|
758
755
|
)
|
759
756
|
self.process_batch_result(mbs[next_mb_id], output_result)
|
760
757
|
last_mbs[next_mb_id] = mbs[next_mb_id]
|
761
758
|
|
762
|
-
#
|
759
|
+
# (not last rank)
|
763
760
|
if not self.pp_group.is_last_rank:
|
764
761
|
if self.cur_batch:
|
765
762
|
bids[mb_id] = result.bid
|
763
|
+
# carry the outputs to the next stage
|
764
|
+
# send the outputs from the last round to let the next stage worker run post processing
|
766
765
|
if pp_outputs:
|
767
|
-
# send the outputs from the last round to let the next stage worker run post processing
|
768
766
|
self.pp_group.send_tensor_dict(
|
769
767
|
pp_outputs.tensors,
|
770
768
|
all_gather_group=self.attn_tp_group,
|
771
769
|
)
|
772
770
|
|
773
|
-
if not self.pp_group.is_last_rank:
|
774
771
|
# send out reqs to the next stage
|
775
|
-
dp_offset = self.
|
772
|
+
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
776
773
|
if self.attn_tp_rank == 0:
|
777
774
|
point_to_point_pyobj(
|
778
775
|
recv_reqs,
|
@@ -819,7 +816,7 @@ class Scheduler(
|
|
819
816
|
recv_reqs = None
|
820
817
|
else:
|
821
818
|
if self.attn_tp_rank == 0:
|
822
|
-
dp_offset = self.
|
819
|
+
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
823
820
|
recv_reqs = point_to_point_pyobj(
|
824
821
|
[],
|
825
822
|
self.pp_rank * self.tp_size + dp_offset,
|
@@ -907,19 +904,6 @@ class Scheduler(
|
|
907
904
|
fake_input_ids = [1] * seq_length
|
908
905
|
recv_req.input_ids = fake_input_ids
|
909
906
|
|
910
|
-
# Handle custom logit processor passed to the request
|
911
|
-
custom_logit_processor = recv_req.custom_logit_processor
|
912
|
-
if (
|
913
|
-
not self.server_args.enable_custom_logit_processor
|
914
|
-
and custom_logit_processor is not None
|
915
|
-
):
|
916
|
-
logger.warning(
|
917
|
-
"The SGLang server is not configured to enable custom logit processor."
|
918
|
-
"The custom logit processor passed in will be ignored."
|
919
|
-
"Please set --enable-custom-logits-processor to enable this feature."
|
920
|
-
)
|
921
|
-
custom_logit_processor = None
|
922
|
-
|
923
907
|
if recv_req.bootstrap_port is None:
|
924
908
|
# Use default bootstrap port
|
925
909
|
recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
|
@@ -935,7 +919,7 @@ class Scheduler(
|
|
935
919
|
stream=recv_req.stream,
|
936
920
|
lora_path=recv_req.lora_path,
|
937
921
|
input_embeds=recv_req.input_embeds,
|
938
|
-
custom_logit_processor=custom_logit_processor,
|
922
|
+
custom_logit_processor=recv_req.custom_logit_processor,
|
939
923
|
return_hidden_states=recv_req.return_hidden_states,
|
940
924
|
eos_token_ids=self.model_config.hf_eos_token_id,
|
941
925
|
bootstrap_host=recv_req.bootstrap_host,
|
@@ -1041,9 +1025,11 @@ class Scheduler(
|
|
1041
1025
|
elif req.sampling_params.structural_tag:
|
1042
1026
|
key = ("structural_tag", req.sampling_params.structural_tag)
|
1043
1027
|
|
1044
|
-
|
1045
|
-
|
1046
|
-
|
1028
|
+
value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
|
1029
|
+
req.grammar = value
|
1030
|
+
|
1031
|
+
if not cache_hit:
|
1032
|
+
req.grammar_key = key
|
1047
1033
|
add_to_grammar_queue = True
|
1048
1034
|
|
1049
1035
|
if add_to_grammar_queue:
|
@@ -1133,9 +1119,6 @@ class Scheduler(
|
|
1133
1119
|
self.token_to_kv_pool_allocator.available_size()
|
1134
1120
|
+ self.tree_cache.evictable_size()
|
1135
1121
|
)
|
1136
|
-
self._largest_prefill_len = max(
|
1137
|
-
self._largest_prefill_len, adder.log_input_tokens
|
1138
|
-
)
|
1139
1122
|
|
1140
1123
|
num_new_seq = len(can_run_list)
|
1141
1124
|
f = (
|
@@ -1173,7 +1156,9 @@ class Scheduler(
|
|
1173
1156
|
|
1174
1157
|
self.metrics_collector.log_stats(self.stats)
|
1175
1158
|
|
1176
|
-
def log_decode_stats(
|
1159
|
+
def log_decode_stats(
|
1160
|
+
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
1161
|
+
):
|
1177
1162
|
batch = running_batch or self.running_batch
|
1178
1163
|
|
1179
1164
|
gap_latency = time.time() - self.last_decode_stats_tic
|
@@ -1213,6 +1198,7 @@ class Scheduler(
|
|
1213
1198
|
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
1214
1199
|
|
1215
1200
|
msg += (
|
1201
|
+
f"cuda graph: {can_run_cuda_graph}, "
|
1216
1202
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
1217
1203
|
f"#queue-req: {len(self.waiting_queue)}"
|
1218
1204
|
)
|
@@ -1225,6 +1211,7 @@ class Scheduler(
|
|
1225
1211
|
self.stats.cache_hit_rate = 0.0
|
1226
1212
|
self.stats.gen_throughput = self.last_gen_throughput
|
1227
1213
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1214
|
+
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1228
1215
|
self.stats.spec_accept_length = spec_accept_length
|
1229
1216
|
self.metrics_collector.log_stats(self.stats)
|
1230
1217
|
|
@@ -1246,9 +1233,7 @@ class Scheduler(
|
|
1246
1233
|
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
1247
1234
|
f"{self.tree_cache.evictable_size()=}\n"
|
1248
1235
|
)
|
1249
|
-
|
1250
|
-
if crash_on_warnings():
|
1251
|
-
raise ValueError(msg)
|
1236
|
+
raise ValueError(msg)
|
1252
1237
|
|
1253
1238
|
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
1254
1239
|
msg = (
|
@@ -1256,9 +1241,7 @@ class Scheduler(
|
|
1256
1241
|
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
1257
1242
|
f"total_size={self.req_to_token_pool.size}\n"
|
1258
1243
|
)
|
1259
|
-
|
1260
|
-
if crash_on_warnings():
|
1261
|
-
raise ValueError(msg)
|
1244
|
+
raise ValueError(msg)
|
1262
1245
|
|
1263
1246
|
if (
|
1264
1247
|
self.enable_metrics
|
@@ -1276,6 +1259,7 @@ class Scheduler(
|
|
1276
1259
|
self.stats.token_usage = num_used / self.max_total_num_tokens
|
1277
1260
|
self.stats.gen_throughput = 0
|
1278
1261
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1262
|
+
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1279
1263
|
self.metrics_collector.log_stats(self.stats)
|
1280
1264
|
|
1281
1265
|
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
@@ -1346,7 +1330,7 @@ class Scheduler(
|
|
1346
1330
|
return None
|
1347
1331
|
|
1348
1332
|
running_bs = len(self.running_batch.reqs)
|
1349
|
-
#
|
1333
|
+
# Ignore the check if self.chunked_req is not None.
|
1350
1334
|
# In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
|
1351
1335
|
# as the space for the chunked request has just been released.
|
1352
1336
|
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
|
@@ -1540,11 +1524,11 @@ class Scheduler(
|
|
1540
1524
|
if self.spec_algorithm.is_none():
|
1541
1525
|
model_worker_batch = batch.get_model_worker_batch()
|
1542
1526
|
if self.pp_group.is_last_rank:
|
1543
|
-
logits_output, next_token_ids = (
|
1527
|
+
logits_output, next_token_ids, can_run_cuda_graph = (
|
1544
1528
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
1545
1529
|
)
|
1546
1530
|
else:
|
1547
|
-
pp_hidden_states_proxy_tensors, _ = (
|
1531
|
+
pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
|
1548
1532
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
1549
1533
|
)
|
1550
1534
|
bid = model_worker_batch.bid
|
@@ -1554,6 +1538,7 @@ class Scheduler(
|
|
1554
1538
|
next_token_ids,
|
1555
1539
|
bid,
|
1556
1540
|
num_accepted_tokens,
|
1541
|
+
can_run_cuda_graph,
|
1557
1542
|
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
1558
1543
|
self.spec_num_total_accepted_tokens += (
|
1559
1544
|
num_accepted_tokens + batch.batch_size()
|
@@ -1587,6 +1572,7 @@ class Scheduler(
|
|
1587
1572
|
extend_input_len_per_req=extend_input_len_per_req,
|
1588
1573
|
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
1589
1574
|
bid=bid,
|
1575
|
+
can_run_cuda_graph=can_run_cuda_graph,
|
1590
1576
|
)
|
1591
1577
|
else: # embedding or reward model
|
1592
1578
|
model_worker_batch = batch.get_model_worker_batch()
|
@@ -1609,14 +1595,9 @@ class Scheduler(
|
|
1609
1595
|
elif batch.forward_mode.is_idle():
|
1610
1596
|
if self.enable_overlap:
|
1611
1597
|
self.tp_worker.resolve_last_batch_result(launch_done)
|
1612
|
-
|
1613
|
-
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1614
|
-
self.current_stream.synchronize()
|
1615
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
1598
|
+
self.set_next_batch_sampling_info_done(batch)
|
1616
1599
|
elif batch.forward_mode.is_dummy_first():
|
1617
|
-
|
1618
|
-
self.current_stream.synchronize()
|
1619
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
1600
|
+
self.set_next_batch_sampling_info_done(batch)
|
1620
1601
|
|
1621
1602
|
if self.return_health_check_ct:
|
1622
1603
|
# Return some signal for the health check.
|
@@ -1630,6 +1611,7 @@ class Scheduler(
|
|
1630
1611
|
local_batch,
|
1631
1612
|
dp_size=self.server_args.dp_size,
|
1632
1613
|
attn_tp_size=self.attn_tp_size,
|
1614
|
+
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
|
1633
1615
|
tp_cpu_group=self.tp_cpu_group,
|
1634
1616
|
get_idle_batch=self.get_idle_batch,
|
1635
1617
|
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
@@ -1642,6 +1624,7 @@ class Scheduler(
|
|
1642
1624
|
local_batch: ScheduleBatch,
|
1643
1625
|
dp_size,
|
1644
1626
|
attn_tp_size: int,
|
1627
|
+
moe_dense_tp_size: Optional[int],
|
1645
1628
|
tp_cpu_group,
|
1646
1629
|
get_idle_batch,
|
1647
1630
|
disable_cuda_graph: bool,
|
@@ -1651,15 +1634,15 @@ class Scheduler(
|
|
1651
1634
|
# Check if other DP workers have running batches
|
1652
1635
|
if local_batch is None:
|
1653
1636
|
num_tokens = 0
|
1654
|
-
|
1637
|
+
num_tokens_for_logprob = 0
|
1655
1638
|
elif local_batch.forward_mode.is_decode():
|
1656
1639
|
num_tokens = local_batch.batch_size()
|
1657
1640
|
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
|
1658
1641
|
num_tokens = num_tokens * speculative_num_draft_tokens
|
1659
|
-
|
1642
|
+
num_tokens_for_logprob = num_tokens
|
1660
1643
|
else:
|
1661
1644
|
num_tokens = local_batch.extend_num_tokens
|
1662
|
-
|
1645
|
+
num_tokens_for_logprob = sum(
|
1663
1646
|
[
|
1664
1647
|
# We should have at least 1 token for sample in every case.
|
1665
1648
|
max(extend_len - logprob_start_len, 1)
|
@@ -1686,7 +1669,7 @@ class Scheduler(
|
|
1686
1669
|
[
|
1687
1670
|
num_tokens,
|
1688
1671
|
can_cuda_graph,
|
1689
|
-
|
1672
|
+
num_tokens_for_logprob,
|
1690
1673
|
is_extend_in_batch,
|
1691
1674
|
],
|
1692
1675
|
dtype=torch.int64,
|
@@ -1709,8 +1692,15 @@ class Scheduler(
|
|
1709
1692
|
local_batch = get_idle_batch()
|
1710
1693
|
|
1711
1694
|
if local_batch is not None:
|
1712
|
-
|
1713
|
-
|
1695
|
+
# TODO: handle the case when moe_dense_tp_size != 1
|
1696
|
+
if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
|
1697
|
+
local_batch.global_num_tokens = [num_tokens]
|
1698
|
+
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
|
1699
|
+
else:
|
1700
|
+
local_batch.global_num_tokens = global_num_tokens
|
1701
|
+
local_batch.global_num_tokens_for_logprob = (
|
1702
|
+
global_num_tokens_for_logprob
|
1703
|
+
)
|
1714
1704
|
|
1715
1705
|
# Check forward mode for cuda graph
|
1716
1706
|
if not disable_cuda_graph:
|
@@ -1736,11 +1726,17 @@ class Scheduler(
|
|
1736
1726
|
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
1737
1727
|
|
1738
1728
|
num_ready_reqs = 0
|
1729
|
+
num_abort_reqs = 0
|
1739
1730
|
for req in self.grammar_queue:
|
1740
1731
|
try:
|
1741
|
-
req.grammar = req.grammar.result(timeout=0.
|
1732
|
+
req.grammar = req.grammar.result(timeout=0.03)
|
1733
|
+
if req.grammar:
|
1734
|
+
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
1742
1735
|
num_ready_reqs += 1
|
1743
1736
|
except futures._base.TimeoutError:
|
1737
|
+
req.grammar_wait_ct += 1
|
1738
|
+
if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
|
1739
|
+
num_abort_reqs = 1
|
1744
1740
|
break
|
1745
1741
|
|
1746
1742
|
if self.server_args.enable_dp_attention:
|
@@ -1752,18 +1748,39 @@ class Scheduler(
|
|
1752
1748
|
|
1753
1749
|
if tp_size > 1:
|
1754
1750
|
# Sync across TP ranks to make sure they have the same number of ready requests
|
1755
|
-
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
|
1751
|
+
tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32)
|
1756
1752
|
torch.distributed.all_reduce(
|
1757
1753
|
tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
|
1758
1754
|
)
|
1759
|
-
num_ready_reqs_max = tensor.
|
1755
|
+
num_ready_reqs_max, num_abort_reqs_max = tensor.tolist()
|
1756
|
+
|
1760
1757
|
for i in range(num_ready_reqs, num_ready_reqs_max):
|
1761
|
-
|
1762
|
-
|
1758
|
+
req = self.grammar_queue[i]
|
1759
|
+
req.grammar = req.grammar.result()
|
1760
|
+
if req.grammar:
|
1761
|
+
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
1762
|
+
|
1763
|
+
for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max):
|
1764
|
+
req = self.grammar_queue[i]
|
1765
|
+
req.grammar.cancel()
|
1766
|
+
req.grammar = None
|
1767
|
+
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
|
1768
|
+
logger.error(error_msg)
|
1769
|
+
req.finished_reason = FINISH_ABORT(
|
1770
|
+
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
1771
|
+
)
|
1772
|
+
num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
|
1763
1773
|
|
1764
1774
|
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
1765
1775
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
1766
1776
|
|
1777
|
+
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
|
1778
|
+
if batch.next_batch_sampling_info:
|
1779
|
+
if batch.next_batch_sampling_info.grammars is not None:
|
1780
|
+
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1781
|
+
self.current_stream.synchronize()
|
1782
|
+
batch.next_batch_sampling_info.sampling_info_done.set()
|
1783
|
+
|
1767
1784
|
def watchdog_thread(self):
|
1768
1785
|
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
1769
1786
|
self.watchdog_last_forward_ct = 0
|
@@ -1774,24 +1791,27 @@ class Scheduler(
|
|
1774
1791
|
if self.cur_batch is not None:
|
1775
1792
|
if self.watchdog_last_forward_ct == self.forward_ct:
|
1776
1793
|
if current > self.watchdog_last_time + self.watchdog_timeout:
|
1777
|
-
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
1778
1794
|
break
|
1779
1795
|
else:
|
1780
1796
|
self.watchdog_last_forward_ct = self.forward_ct
|
1781
1797
|
self.watchdog_last_time = current
|
1782
1798
|
time.sleep(self.watchdog_timeout // 2)
|
1783
1799
|
|
1784
|
-
|
1785
|
-
|
1786
|
-
|
1787
|
-
|
1788
|
-
|
1789
|
-
|
1790
|
-
|
1791
|
-
|
1800
|
+
if not disable_request_logging():
|
1801
|
+
# Print batch size and memory pool info to check whether there are de-sync issues.
|
1802
|
+
logger.error(
|
1803
|
+
f"{self.cur_batch.batch_size()=}, "
|
1804
|
+
f"{self.cur_batch.reqs=}, "
|
1805
|
+
f"{self.token_to_kv_pool_allocator.available_size()=}, "
|
1806
|
+
f"{self.tree_cache.evictable_size()=}, "
|
1807
|
+
)
|
1808
|
+
|
1792
1809
|
pyspy_dump_schedulers()
|
1810
|
+
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
1793
1811
|
print(file=sys.stderr, flush=True)
|
1794
1812
|
print(file=sys.stdout, flush=True)
|
1813
|
+
|
1814
|
+
# Wait for some time so that the parent process can print the error.
|
1795
1815
|
time.sleep(5)
|
1796
1816
|
self.parent_process.send_signal(signal.SIGQUIT)
|
1797
1817
|
|
@@ -1923,25 +1943,30 @@ class Scheduler(
|
|
1923
1943
|
)
|
1924
1944
|
|
1925
1945
|
def abort_request(self, recv_req: AbortReq):
|
1946
|
+
# TODO(lmzheng): abort the requests in the grammar queue.
|
1947
|
+
|
1926
1948
|
# Delete requests in the waiting queue
|
1927
1949
|
to_del = []
|
1928
1950
|
for i, req in enumerate(self.waiting_queue):
|
1929
1951
|
if req.rid.startswith(recv_req.rid):
|
1930
1952
|
to_del.append(i)
|
1931
|
-
break
|
1932
1953
|
|
1933
1954
|
# Sort in reverse order to avoid index issues when deleting
|
1934
|
-
for i in
|
1955
|
+
for i in reversed(to_del):
|
1935
1956
|
req = self.waiting_queue.pop(i)
|
1957
|
+
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
1936
1958
|
logger.debug(f"Abort queued request. {req.rid=}")
|
1937
|
-
return
|
1938
1959
|
|
1939
1960
|
# Delete requests in the running batch
|
1940
|
-
|
1961
|
+
if self.cur_batch is self.running_batch or self.cur_batch is None:
|
1962
|
+
reqs = self.running_batch.reqs
|
1963
|
+
else:
|
1964
|
+
reqs = self.running_batch.reqs + self.cur_batch.reqs
|
1965
|
+
|
1966
|
+
for req in reqs:
|
1941
1967
|
if req.rid.startswith(recv_req.rid) and not req.finished():
|
1942
1968
|
logger.debug(f"Abort running request. {req.rid=}")
|
1943
1969
|
req.to_abort = True
|
1944
|
-
return
|
1945
1970
|
|
1946
1971
|
def _pause_engine(self) -> Tuple[List[Req], int]:
|
1947
1972
|
raise NotImplementedError()
|
@@ -2162,8 +2187,8 @@ class Scheduler(
|
|
2162
2187
|
|
2163
2188
|
def get_print_prefix(self):
|
2164
2189
|
prefix = ""
|
2165
|
-
if self.
|
2166
|
-
prefix += f" DP{self.
|
2190
|
+
if self.attn_dp_rank is not None:
|
2191
|
+
prefix += f" DP{self.attn_dp_rank}"
|
2167
2192
|
if self.server_args.tp_size > 1:
|
2168
2193
|
prefix += f" TP{self.tp_rank}"
|
2169
2194
|
if self.pp_size > 1:
|