sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_offline_throughput.py +18 -6
- sglang/bench_one_batch.py +13 -0
- sglang/bench_serving.py +8 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/constrained/outlines_backend.py +5 -0
- sglang/srt/constrained/xgrammar_backend.py +9 -6
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +22 -5
- sglang/srt/layers/attention/torch_native_backend.py +22 -8
- sglang/srt/layers/attention/triton_backend.py +38 -33
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +665 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
- sglang/srt/layers/fused_moe_triton/layer.py +1 -1
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/quantization/__init__.py +2 -47
- sglang/srt/layers/quantization/fp8.py +607 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +11 -2
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/torchao_utils.py +58 -45
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +39 -24
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +236 -197
- sglang/srt/managers/tokenizer_manager.py +99 -58
- sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -11
- sglang/srt/model_executor/model_runner.py +24 -9
- sglang/srt/model_parallel.py +67 -10
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/deepseek_v2.py +87 -7
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +72 -13
- sglang/srt/models/llama.py +22 -5
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +12 -9
- sglang/srt/models/phi3_small.py +0 -5
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +0 -5
- sglang/srt/models/torch_native_llama.py +0 -5
- sglang/srt/openai_api/adapter.py +4 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +4 -4
- sglang/srt/server_args.py +62 -13
- sglang/srt/utils.py +57 -10
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -25,6 +25,7 @@ from types import SimpleNamespace
|
|
25
25
|
from typing import List, Optional
|
26
26
|
|
27
27
|
import psutil
|
28
|
+
import setproctitle
|
28
29
|
import torch
|
29
30
|
import zmq
|
30
31
|
|
@@ -114,9 +115,6 @@ class Scheduler:
|
|
114
115
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
115
116
|
self.enable_metrics = server_args.enable_metrics
|
116
117
|
|
117
|
-
# Session info
|
118
|
-
self.sessions = {}
|
119
|
-
|
120
118
|
# Init inter-process communication
|
121
119
|
context = zmq.Context(2)
|
122
120
|
|
@@ -259,6 +257,10 @@ class Scheduler:
|
|
259
257
|
self.num_generated_tokens = 0
|
260
258
|
self.last_decode_stats_tic = time.time()
|
261
259
|
self.stream_interval = server_args.stream_interval
|
260
|
+
self.current_stream = torch.get_device_module(self.device).current_stream()
|
261
|
+
|
262
|
+
# Session info
|
263
|
+
self.sessions = {}
|
262
264
|
|
263
265
|
# Init chunked prefill
|
264
266
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
@@ -356,6 +358,7 @@ class Scheduler:
|
|
356
358
|
)
|
357
359
|
|
358
360
|
def watchdog_thread(self):
|
361
|
+
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
|
359
362
|
self.watchdog_last_forward_ct = 0
|
360
363
|
self.watchdog_last_time = time.time()
|
361
364
|
|
@@ -433,61 +436,6 @@ class Scheduler:
|
|
433
436
|
|
434
437
|
self.last_batch = batch
|
435
438
|
|
436
|
-
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
437
|
-
# Check if other DP workers have running batches
|
438
|
-
if local_batch is None:
|
439
|
-
num_tokens = 0
|
440
|
-
elif local_batch.forward_mode.is_decode():
|
441
|
-
num_tokens = local_batch.batch_size()
|
442
|
-
else:
|
443
|
-
num_tokens = local_batch.extend_num_tokens
|
444
|
-
|
445
|
-
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
|
446
|
-
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
|
447
|
-
torch.distributed.all_gather_into_tensor(
|
448
|
-
global_num_tokens,
|
449
|
-
local_num_tokens,
|
450
|
-
group=self.tp_cpu_group,
|
451
|
-
)
|
452
|
-
|
453
|
-
if local_batch is None and global_num_tokens.max().item() > 0:
|
454
|
-
local_batch = self.get_idle_batch()
|
455
|
-
|
456
|
-
if local_batch is not None:
|
457
|
-
local_batch.global_num_tokens = global_num_tokens.tolist()
|
458
|
-
|
459
|
-
# Check forward mode for cuda graph
|
460
|
-
if not self.server_args.disable_cuda_graph:
|
461
|
-
forward_mode_state = torch.tensor(
|
462
|
-
(
|
463
|
-
1
|
464
|
-
if local_batch.forward_mode.is_decode()
|
465
|
-
or local_batch.forward_mode.is_idle()
|
466
|
-
else 0
|
467
|
-
),
|
468
|
-
dtype=torch.int32,
|
469
|
-
)
|
470
|
-
torch.distributed.all_reduce(
|
471
|
-
forward_mode_state,
|
472
|
-
op=torch.distributed.ReduceOp.MIN,
|
473
|
-
group=self.tp_cpu_group,
|
474
|
-
)
|
475
|
-
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
476
|
-
|
477
|
-
return local_batch
|
478
|
-
|
479
|
-
def get_idle_batch(self):
|
480
|
-
idle_batch = ScheduleBatch.init_new(
|
481
|
-
[],
|
482
|
-
self.req_to_token_pool,
|
483
|
-
self.token_to_kv_pool,
|
484
|
-
self.tree_cache,
|
485
|
-
self.model_config,
|
486
|
-
self.enable_overlap,
|
487
|
-
)
|
488
|
-
idle_batch.prepare_for_idle()
|
489
|
-
return idle_batch
|
490
|
-
|
491
439
|
def recv_requests(self):
|
492
440
|
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
493
441
|
recv_reqs = []
|
@@ -567,6 +515,9 @@ class Scheduler:
|
|
567
515
|
recv_req.input_text,
|
568
516
|
recv_req.input_ids,
|
569
517
|
recv_req.sampling_params,
|
518
|
+
return_logprob=recv_req.return_logprob,
|
519
|
+
top_logprobs_num=recv_req.top_logprobs_num,
|
520
|
+
stream=recv_req.stream,
|
570
521
|
lora_path=recv_req.lora_path,
|
571
522
|
input_embeds=recv_req.input_embeds,
|
572
523
|
)
|
@@ -610,9 +561,6 @@ class Scheduler:
|
|
610
561
|
return
|
611
562
|
|
612
563
|
# Copy more attributes
|
613
|
-
req.return_logprob = recv_req.return_logprob
|
614
|
-
req.top_logprobs_num = recv_req.top_logprobs_num
|
615
|
-
req.stream = recv_req.stream
|
616
564
|
req.logprob_start_len = recv_req.logprob_start_len
|
617
565
|
|
618
566
|
if req.logprob_start_len == -1:
|
@@ -765,7 +713,7 @@ class Scheduler:
|
|
765
713
|
if crash_on_warnings():
|
766
714
|
raise ValueError(msg)
|
767
715
|
|
768
|
-
def get_next_batch_to_run(self):
|
716
|
+
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
769
717
|
# Merge the prefill batch into the running batch
|
770
718
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
771
719
|
if self.being_chunked_req:
|
@@ -993,10 +941,11 @@ class Scheduler:
|
|
993
941
|
self.process_batch_result_prefill(batch, result)
|
994
942
|
elif batch.forward_mode.is_dummy_first():
|
995
943
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
996
|
-
|
944
|
+
self.current_stream.synchronize()
|
997
945
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
998
946
|
|
999
947
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
948
|
+
skip_stream_req = None
|
1000
949
|
|
1001
950
|
if self.is_generation:
|
1002
951
|
logits_output, next_token_ids, bid = result
|
@@ -1033,7 +982,6 @@ class Scheduler:
|
|
1033
982
|
continue
|
1034
983
|
|
1035
984
|
if req.is_being_chunked <= 0:
|
1036
|
-
req.completion_tokens_wo_jump_forward += 1
|
1037
985
|
req.output_ids.append(next_token_id)
|
1038
986
|
req.check_finished()
|
1039
987
|
|
@@ -1049,13 +997,18 @@ class Scheduler:
|
|
1049
997
|
|
1050
998
|
if req.grammar is not None:
|
1051
999
|
req.grammar.accept_token(next_token_id)
|
1000
|
+
req.grammar.finished = req.finished()
|
1052
1001
|
else:
|
1053
1002
|
# being chunked reqs' prefill is not finished
|
1054
1003
|
req.is_being_chunked -= 1
|
1004
|
+
# There is only at most one request being currently chunked.
|
1005
|
+
# Because this request does not finish prefill,
|
1006
|
+
# we don't want to stream the request currently being chunked.
|
1007
|
+
skip_stream_req = req
|
1055
1008
|
|
1056
1009
|
if batch.next_batch_sampling_info:
|
1057
1010
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1058
|
-
|
1011
|
+
self.current_stream.synchronize()
|
1059
1012
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
1060
1013
|
|
1061
1014
|
else: # embedding or reward model
|
@@ -1081,7 +1034,7 @@ class Scheduler:
|
|
1081
1034
|
# being chunked reqs' prefill is not finished
|
1082
1035
|
req.is_being_chunked -= 1
|
1083
1036
|
|
1084
|
-
self.stream_output(batch.reqs)
|
1037
|
+
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
1085
1038
|
|
1086
1039
|
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
1087
1040
|
logits_output, next_token_ids, bid = result
|
@@ -1111,7 +1064,6 @@ class Scheduler:
|
|
1111
1064
|
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
1112
1065
|
continue
|
1113
1066
|
|
1114
|
-
req.completion_tokens_wo_jump_forward += 1
|
1115
1067
|
req.output_ids.append(next_token_id)
|
1116
1068
|
req.check_finished()
|
1117
1069
|
|
@@ -1119,21 +1071,26 @@ class Scheduler:
|
|
1119
1071
|
self.tree_cache.cache_finished_req(req)
|
1120
1072
|
|
1121
1073
|
if req.return_logprob:
|
1122
|
-
req.
|
1123
|
-
|
1124
|
-
)
|
1074
|
+
req.output_token_logprobs_val.append(next_token_logprobs[i])
|
1075
|
+
req.output_token_logprobs_idx.append(next_token_id)
|
1125
1076
|
if req.top_logprobs_num > 0:
|
1126
|
-
req.
|
1077
|
+
req.output_top_logprobs_val.append(
|
1078
|
+
logits_output.output_top_logprobs_val[i]
|
1079
|
+
)
|
1080
|
+
req.output_top_logprobs_idx.append(
|
1081
|
+
logits_output.output_top_logprobs_idx[i]
|
1082
|
+
)
|
1127
1083
|
|
1128
1084
|
if req.grammar is not None:
|
1129
1085
|
req.grammar.accept_token(next_token_id)
|
1086
|
+
req.grammar.finished = req.finished()
|
1130
1087
|
|
1131
1088
|
if batch.next_batch_sampling_info:
|
1132
1089
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1133
|
-
|
1090
|
+
self.current_stream.synchronize()
|
1134
1091
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
1135
1092
|
|
1136
|
-
self.stream_output(batch.reqs)
|
1093
|
+
self.stream_output(batch.reqs, batch.return_logprob)
|
1137
1094
|
|
1138
1095
|
self.token_to_kv_pool.free_group_end()
|
1139
1096
|
|
@@ -1153,9 +1110,8 @@ class Scheduler:
|
|
1153
1110
|
output: LogitsProcessorOutput,
|
1154
1111
|
):
|
1155
1112
|
"""Attach logprobs to the return values."""
|
1156
|
-
req.
|
1157
|
-
|
1158
|
-
)
|
1113
|
+
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
|
1114
|
+
req.output_token_logprobs_idx.append(next_token_ids[i])
|
1159
1115
|
|
1160
1116
|
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
1161
1117
|
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
|
@@ -1163,170 +1119,251 @@ class Scheduler:
|
|
1163
1119
|
if req.normalized_prompt_logprob is None:
|
1164
1120
|
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
1165
1121
|
|
1166
|
-
if req.
|
1167
|
-
|
1122
|
+
if req.input_token_logprobs_val is None:
|
1123
|
+
input_token_logprobs_val = output.input_token_logprobs[
|
1168
1124
|
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
|
1169
1125
|
]
|
1170
|
-
|
1126
|
+
|
1127
|
+
input_token_logprobs_idx = req.fill_ids[
|
1171
1128
|
len(req.fill_ids)
|
1172
1129
|
- num_input_logprobs
|
1173
1130
|
+ 1 : len(req.fill_ids)
|
1174
1131
|
- req.last_update_decode_tokens
|
1175
1132
|
]
|
1176
|
-
|
1177
1133
|
# Clip the padded hash values from image tokens.
|
1178
1134
|
# Otherwise, it will lead to detokenization errors.
|
1179
|
-
|
1135
|
+
input_token_logprobs_idx = [
|
1180
1136
|
x if x < self.model_config.vocab_size - 1 else 0
|
1181
|
-
for x in
|
1137
|
+
for x in input_token_logprobs_idx
|
1182
1138
|
]
|
1183
1139
|
|
1184
|
-
req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
|
1185
|
-
|
1186
1140
|
if (
|
1187
1141
|
req.logprob_start_len == 0
|
1188
1142
|
): # The first token does not have logprob, pad it.
|
1189
|
-
|
1190
|
-
|
1191
|
-
|
1143
|
+
input_token_logprobs_val = [None] + input_token_logprobs_val
|
1144
|
+
input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx
|
1145
|
+
|
1146
|
+
req.input_token_logprobs_val = input_token_logprobs_val
|
1147
|
+
req.input_token_logprobs_idx = input_token_logprobs_idx
|
1192
1148
|
|
1193
1149
|
if req.last_update_decode_tokens != 0:
|
1194
1150
|
# Some decode tokens are re-computed in an extend batch
|
1195
|
-
req.
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1204
|
-
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1210
|
-
)
|
1211
|
-
)
|
1151
|
+
req.output_token_logprobs_val.extend(
|
1152
|
+
output.input_token_logprobs[
|
1153
|
+
pt
|
1154
|
+
+ num_input_logprobs
|
1155
|
+
- 1
|
1156
|
+
- req.last_update_decode_tokens : pt
|
1157
|
+
+ num_input_logprobs
|
1158
|
+
- 1
|
1159
|
+
],
|
1160
|
+
)
|
1161
|
+
req.output_token_logprobs_idx.extend(
|
1162
|
+
req.fill_ids[
|
1163
|
+
len(req.fill_ids)
|
1164
|
+
- req.last_update_decode_tokens : len(req.fill_ids)
|
1165
|
+
]
|
1212
1166
|
)
|
1213
1167
|
|
1214
1168
|
if req.top_logprobs_num > 0:
|
1215
|
-
if req.
|
1216
|
-
req.
|
1169
|
+
if req.input_top_logprobs_val is None:
|
1170
|
+
req.input_top_logprobs_val = output.input_top_logprobs_val[i]
|
1171
|
+
req.input_top_logprobs_idx = output.input_top_logprobs_idx[i]
|
1217
1172
|
if req.logprob_start_len == 0:
|
1218
|
-
req.
|
1173
|
+
req.input_top_logprobs_val = [None] + req.input_top_logprobs_val
|
1174
|
+
req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx
|
1219
1175
|
|
1220
1176
|
if req.last_update_decode_tokens != 0:
|
1221
|
-
req.
|
1222
|
-
output.
|
1177
|
+
req.output_top_logprobs_val.extend(
|
1178
|
+
output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
|
1223
1179
|
)
|
1224
|
-
|
1180
|
+
req.output_top_logprobs_idx.extend(
|
1181
|
+
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
|
1182
|
+
)
|
1183
|
+
req.output_top_logprobs_val.append(output.output_top_logprobs_val[i])
|
1184
|
+
req.output_top_logprobs_idx.append(output.output_top_logprobs_idx[i])
|
1225
1185
|
|
1226
1186
|
return num_input_logprobs
|
1227
1187
|
|
1228
|
-
def stream_output(
|
1188
|
+
def stream_output(
|
1189
|
+
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
|
1190
|
+
):
|
1229
1191
|
"""Stream the output to detokenizer."""
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1192
|
+
rids = []
|
1193
|
+
finished_reasons: List[BaseFinishReason] = []
|
1194
|
+
|
1233
1195
|
if self.is_generation:
|
1234
|
-
|
1196
|
+
vids = []
|
1235
1197
|
decoded_texts = []
|
1236
|
-
|
1237
|
-
|
1198
|
+
decode_ids_list = []
|
1199
|
+
read_offsets = []
|
1238
1200
|
output_ids = []
|
1239
|
-
output_skip_special_tokens = []
|
1240
|
-
output_spaces_between_special_tokens = []
|
1241
|
-
output_no_stop_trim = []
|
1242
|
-
else: # embedding or reward model
|
1243
|
-
output_embeddings = []
|
1244
1201
|
|
1245
|
-
|
1202
|
+
skip_special_tokens = []
|
1203
|
+
spaces_between_special_tokens = []
|
1204
|
+
no_stop_trim = []
|
1205
|
+
prompt_tokens = []
|
1206
|
+
completion_tokens = []
|
1207
|
+
cached_tokens = []
|
1208
|
+
|
1209
|
+
if return_logprob:
|
1210
|
+
input_token_logprobs_val = []
|
1211
|
+
input_token_logprobs_idx = []
|
1212
|
+
output_token_logprobs_val = []
|
1213
|
+
output_token_logprobs_idx = []
|
1214
|
+
input_top_logprobs_val = []
|
1215
|
+
input_top_logprobs_idx = []
|
1216
|
+
output_top_logprobs_val = []
|
1217
|
+
output_top_logprobs_idx = []
|
1218
|
+
normalized_prompt_logprob = []
|
1219
|
+
else:
|
1220
|
+
input_token_logprobs_val = input_token_logprobs_idx = (
|
1221
|
+
output_token_logprobs_val
|
1222
|
+
) = output_token_logprobs_idx = input_top_logprobs_val = (
|
1223
|
+
input_top_logprobs_idx
|
1224
|
+
) = output_top_logprobs_val = output_top_logprobs_idx = (
|
1225
|
+
normalized_prompt_logprob
|
1226
|
+
) = None
|
1227
|
+
|
1228
|
+
for req in reqs:
|
1229
|
+
if req is skip_req:
|
1230
|
+
continue
|
1246
1231
|
|
1247
|
-
|
1248
|
-
|
1249
|
-
|
1250
|
-
|
1251
|
-
|
1252
|
-
|
1253
|
-
|
1254
|
-
|
1255
|
-
|
1232
|
+
# TODO(lianmin): revisit this for overlap + retract + stream
|
1233
|
+
if (
|
1234
|
+
req.finished()
|
1235
|
+
# If stream, follow the given stream_interval
|
1236
|
+
or (req.stream and len(req.output_ids) % self.stream_interval == 0)
|
1237
|
+
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
|
1238
|
+
or (not req.stream and len(req.output_ids) % 50 == 0)
|
1239
|
+
):
|
1240
|
+
rids.append(req.rid)
|
1241
|
+
finished_reasons.append(
|
1242
|
+
req.finished_reason.to_json() if req.finished_reason else None
|
1243
|
+
)
|
1244
|
+
vids.append(req.vid)
|
1256
1245
|
decoded_texts.append(req.decoded_text)
|
1257
|
-
|
1258
|
-
|
1259
|
-
|
1246
|
+
decode_ids, read_offset = req.init_incremental_detokenize()
|
1247
|
+
decode_ids_list.append(decode_ids)
|
1248
|
+
read_offsets.append(read_offset)
|
1260
1249
|
if self.skip_tokenizer_init:
|
1261
1250
|
output_ids.append(req.output_ids)
|
1262
|
-
|
1263
|
-
|
1264
|
-
)
|
1265
|
-
output_spaces_between_special_tokens.append(
|
1251
|
+
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
1252
|
+
spaces_between_special_tokens.append(
|
1266
1253
|
req.sampling_params.spaces_between_special_tokens
|
1267
1254
|
)
|
1268
|
-
|
1269
|
-
|
1270
|
-
|
1271
|
-
|
1272
|
-
|
1273
|
-
|
1274
|
-
|
1275
|
-
|
1276
|
-
|
1277
|
-
|
1278
|
-
|
1279
|
-
)
|
1280
|
-
|
1281
|
-
|
1282
|
-
(
|
1283
|
-
|
1284
|
-
|
1285
|
-
|
1286
|
-
|
1287
|
-
meta_info["normalized_prompt_logprob"],
|
1288
|
-
) = (
|
1289
|
-
req.input_token_logprobs,
|
1290
|
-
req.output_token_logprobs,
|
1291
|
-
req.input_top_logprobs,
|
1292
|
-
req.output_top_logprobs,
|
1293
|
-
req.normalized_prompt_logprob,
|
1294
|
-
)
|
1295
|
-
output_meta_info.append(meta_info)
|
1296
|
-
else: # embedding or reward model
|
1297
|
-
output_embeddings.append(req.embedding)
|
1298
|
-
meta_info = {
|
1299
|
-
"prompt_tokens": len(req.origin_input_ids),
|
1300
|
-
}
|
1301
|
-
output_meta_info.append(meta_info)
|
1302
|
-
|
1303
|
-
# Send to detokenizer
|
1304
|
-
if output_rids:
|
1305
|
-
if self.is_generation:
|
1255
|
+
no_stop_trim.append(req.sampling_params.no_stop_trim)
|
1256
|
+
|
1257
|
+
prompt_tokens.append(len(req.origin_input_ids))
|
1258
|
+
completion_tokens.append(len(req.output_ids))
|
1259
|
+
cached_tokens.append(req.cached_tokens)
|
1260
|
+
|
1261
|
+
if return_logprob:
|
1262
|
+
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
1263
|
+
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
1264
|
+
output_token_logprobs_val.append(req.output_token_logprobs_val)
|
1265
|
+
output_token_logprobs_idx.append(req.output_token_logprobs_idx)
|
1266
|
+
input_top_logprobs_val.append(req.input_top_logprobs_val)
|
1267
|
+
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
1268
|
+
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
1269
|
+
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
1270
|
+
normalized_prompt_logprob.append(req.normalized_prompt_logprob)
|
1271
|
+
|
1272
|
+
# Send to detokenizer
|
1273
|
+
if rids:
|
1306
1274
|
self.send_to_detokenizer.send_pyobj(
|
1307
1275
|
BatchTokenIDOut(
|
1308
|
-
|
1309
|
-
|
1276
|
+
rids,
|
1277
|
+
finished_reasons,
|
1278
|
+
vids,
|
1310
1279
|
decoded_texts,
|
1311
|
-
|
1312
|
-
|
1280
|
+
decode_ids_list,
|
1281
|
+
read_offsets,
|
1313
1282
|
output_ids,
|
1314
|
-
|
1315
|
-
|
1316
|
-
|
1317
|
-
|
1318
|
-
|
1283
|
+
skip_special_tokens,
|
1284
|
+
spaces_between_special_tokens,
|
1285
|
+
no_stop_trim,
|
1286
|
+
prompt_tokens,
|
1287
|
+
completion_tokens,
|
1288
|
+
cached_tokens,
|
1289
|
+
input_token_logprobs_val,
|
1290
|
+
input_token_logprobs_idx,
|
1291
|
+
output_token_logprobs_val,
|
1292
|
+
output_token_logprobs_idx,
|
1293
|
+
input_top_logprobs_val,
|
1294
|
+
input_top_logprobs_idx,
|
1295
|
+
output_top_logprobs_val,
|
1296
|
+
output_top_logprobs_idx,
|
1297
|
+
normalized_prompt_logprob,
|
1319
1298
|
)
|
1320
1299
|
)
|
1321
|
-
|
1322
|
-
|
1323
|
-
|
1324
|
-
|
1325
|
-
|
1326
|
-
|
1327
|
-
|
1328
|
-
|
1300
|
+
else: # embedding or reward model
|
1301
|
+
embeddings = []
|
1302
|
+
prompt_tokens = []
|
1303
|
+
for req in reqs:
|
1304
|
+
assert req.finished()
|
1305
|
+
rids.append(req.rid)
|
1306
|
+
finished_reasons.append(req.finished_reason.to_json())
|
1307
|
+
embeddings.append(req.embedding)
|
1308
|
+
prompt_tokens.append(len(req.origin_input_ids))
|
1309
|
+
self.send_to_detokenizer.send_pyobj(
|
1310
|
+
BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
|
1311
|
+
)
|
1312
|
+
|
1313
|
+
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
1314
|
+
# Check if other DP workers have running batches
|
1315
|
+
if local_batch is None:
|
1316
|
+
num_tokens = 0
|
1317
|
+
elif local_batch.forward_mode.is_decode():
|
1318
|
+
num_tokens = local_batch.batch_size()
|
1319
|
+
else:
|
1320
|
+
num_tokens = local_batch.extend_num_tokens
|
1321
|
+
|
1322
|
+
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
|
1323
|
+
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
|
1324
|
+
torch.distributed.all_gather_into_tensor(
|
1325
|
+
global_num_tokens,
|
1326
|
+
local_num_tokens,
|
1327
|
+
group=self.tp_cpu_group,
|
1328
|
+
)
|
1329
|
+
|
1330
|
+
if local_batch is None and global_num_tokens.max().item() > 0:
|
1331
|
+
local_batch = self.get_idle_batch()
|
1332
|
+
|
1333
|
+
if local_batch is not None:
|
1334
|
+
local_batch.global_num_tokens = global_num_tokens.tolist()
|
1335
|
+
|
1336
|
+
# Check forward mode for cuda graph
|
1337
|
+
if not self.server_args.disable_cuda_graph:
|
1338
|
+
forward_mode_state = torch.tensor(
|
1339
|
+
(
|
1340
|
+
1
|
1341
|
+
if local_batch.forward_mode.is_decode()
|
1342
|
+
or local_batch.forward_mode.is_idle()
|
1343
|
+
else 0
|
1344
|
+
),
|
1345
|
+
dtype=torch.int32,
|
1329
1346
|
)
|
1347
|
+
torch.distributed.all_reduce(
|
1348
|
+
forward_mode_state,
|
1349
|
+
op=torch.distributed.ReduceOp.MIN,
|
1350
|
+
group=self.tp_cpu_group,
|
1351
|
+
)
|
1352
|
+
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
1353
|
+
|
1354
|
+
return local_batch
|
1355
|
+
|
1356
|
+
def get_idle_batch(self):
|
1357
|
+
idle_batch = ScheduleBatch.init_new(
|
1358
|
+
[],
|
1359
|
+
self.req_to_token_pool,
|
1360
|
+
self.token_to_kv_pool,
|
1361
|
+
self.tree_cache,
|
1362
|
+
self.model_config,
|
1363
|
+
self.enable_overlap,
|
1364
|
+
)
|
1365
|
+
idle_batch.prepare_for_idle()
|
1366
|
+
return idle_batch
|
1330
1367
|
|
1331
1368
|
def move_ready_grammar_requests(self):
|
1332
1369
|
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
@@ -1469,9 +1506,7 @@ def run_scheduler_process(
|
|
1469
1506
|
dp_rank: Optional[int],
|
1470
1507
|
pipe_writer,
|
1471
1508
|
):
|
1472
|
-
|
1473
|
-
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
1474
|
-
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
1509
|
+
setproctitle.setproctitle("sglang::scheduler")
|
1475
1510
|
|
1476
1511
|
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
1477
1512
|
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
@@ -1482,6 +1517,10 @@ def run_scheduler_process(
|
|
1482
1517
|
else:
|
1483
1518
|
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
1484
1519
|
|
1520
|
+
# set cpu affinity to this gpu process
|
1521
|
+
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
1522
|
+
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
1523
|
+
|
1485
1524
|
suppress_other_loggers()
|
1486
1525
|
parent_process = psutil.Process().parent()
|
1487
1526
|
|