sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_offline_throughput.py +18 -6
  3. sglang/bench_one_batch.py +13 -0
  4. sglang/bench_serving.py +8 -1
  5. sglang/check_env.py +140 -48
  6. sglang/lang/backend/runtime_endpoint.py +1 -0
  7. sglang/lang/chat_template.py +32 -0
  8. sglang/llama3_eval.py +316 -0
  9. sglang/srt/constrained/outlines_backend.py +5 -0
  10. sglang/srt/constrained/xgrammar_backend.py +9 -6
  11. sglang/srt/layers/attention/__init__.py +5 -2
  12. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  13. sglang/srt/layers/attention/flashinfer_backend.py +22 -5
  14. sglang/srt/layers/attention/torch_native_backend.py +22 -8
  15. sglang/srt/layers/attention/triton_backend.py +38 -33
  16. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  17. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  18. sglang/srt/layers/ep_moe/__init__.py +0 -0
  19. sglang/srt/layers/ep_moe/kernels.py +349 -0
  20. sglang/srt/layers/ep_moe/layer.py +665 -0
  21. sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
  22. sglang/srt/layers/fused_moe_triton/layer.py +1 -1
  23. sglang/srt/layers/logits_processor.py +133 -95
  24. sglang/srt/layers/quantization/__init__.py +2 -47
  25. sglang/srt/layers/quantization/fp8.py +607 -0
  26. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  27. sglang/srt/layers/radix_attention.py +11 -2
  28. sglang/srt/layers/sampler.py +29 -5
  29. sglang/srt/layers/torchao_utils.py +58 -45
  30. sglang/srt/managers/detokenizer_manager.py +37 -17
  31. sglang/srt/managers/io_struct.py +39 -10
  32. sglang/srt/managers/schedule_batch.py +39 -24
  33. sglang/srt/managers/schedule_policy.py +64 -5
  34. sglang/srt/managers/scheduler.py +236 -197
  35. sglang/srt/managers/tokenizer_manager.py +99 -58
  36. sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
  37. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  38. sglang/srt/mem_cache/chunk_cache.py +2 -2
  39. sglang/srt/mem_cache/memory_pool.py +5 -1
  40. sglang/srt/mem_cache/radix_cache.py +12 -2
  41. sglang/srt/model_executor/cuda_graph_runner.py +39 -11
  42. sglang/srt/model_executor/model_runner.py +24 -9
  43. sglang/srt/model_parallel.py +67 -10
  44. sglang/srt/models/commandr.py +2 -2
  45. sglang/srt/models/deepseek_v2.py +87 -7
  46. sglang/srt/models/gemma2.py +34 -0
  47. sglang/srt/models/gemma2_reward.py +0 -1
  48. sglang/srt/models/granite.py +517 -0
  49. sglang/srt/models/grok.py +72 -13
  50. sglang/srt/models/llama.py +22 -5
  51. sglang/srt/models/llama_classification.py +11 -23
  52. sglang/srt/models/llama_reward.py +0 -2
  53. sglang/srt/models/llava.py +37 -14
  54. sglang/srt/models/mixtral.py +12 -9
  55. sglang/srt/models/phi3_small.py +0 -5
  56. sglang/srt/models/qwen2.py +20 -0
  57. sglang/srt/models/qwen2_moe.py +0 -5
  58. sglang/srt/models/torch_native_llama.py +0 -5
  59. sglang/srt/openai_api/adapter.py +4 -0
  60. sglang/srt/openai_api/protocol.py +9 -4
  61. sglang/srt/sampling/sampling_batch_info.py +9 -8
  62. sglang/srt/server.py +4 -4
  63. sglang/srt/server_args.py +62 -13
  64. sglang/srt/utils.py +57 -10
  65. sglang/test/test_utils.py +3 -2
  66. sglang/utils.py +10 -3
  67. sglang/version.py +1 -1
  68. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
  69. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
  70. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
  71. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
  72. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@ from types import SimpleNamespace
25
25
  from typing import List, Optional
26
26
 
27
27
  import psutil
28
+ import setproctitle
28
29
  import torch
29
30
  import zmq
30
31
 
@@ -114,9 +115,6 @@ class Scheduler:
114
115
  self.skip_tokenizer_init = server_args.skip_tokenizer_init
115
116
  self.enable_metrics = server_args.enable_metrics
116
117
 
117
- # Session info
118
- self.sessions = {}
119
-
120
118
  # Init inter-process communication
121
119
  context = zmq.Context(2)
122
120
 
@@ -259,6 +257,10 @@ class Scheduler:
259
257
  self.num_generated_tokens = 0
260
258
  self.last_decode_stats_tic = time.time()
261
259
  self.stream_interval = server_args.stream_interval
260
+ self.current_stream = torch.get_device_module(self.device).current_stream()
261
+
262
+ # Session info
263
+ self.sessions = {}
262
264
 
263
265
  # Init chunked prefill
264
266
  self.chunked_prefill_size = server_args.chunked_prefill_size
@@ -356,6 +358,7 @@ class Scheduler:
356
358
  )
357
359
 
358
360
  def watchdog_thread(self):
361
+ """A watch dog thread that will try to kill the server itself if one batch takes too long."""
359
362
  self.watchdog_last_forward_ct = 0
360
363
  self.watchdog_last_time = time.time()
361
364
 
@@ -433,61 +436,6 @@ class Scheduler:
433
436
 
434
437
  self.last_batch = batch
435
438
 
436
- def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
437
- # Check if other DP workers have running batches
438
- if local_batch is None:
439
- num_tokens = 0
440
- elif local_batch.forward_mode.is_decode():
441
- num_tokens = local_batch.batch_size()
442
- else:
443
- num_tokens = local_batch.extend_num_tokens
444
-
445
- local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
446
- global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
447
- torch.distributed.all_gather_into_tensor(
448
- global_num_tokens,
449
- local_num_tokens,
450
- group=self.tp_cpu_group,
451
- )
452
-
453
- if local_batch is None and global_num_tokens.max().item() > 0:
454
- local_batch = self.get_idle_batch()
455
-
456
- if local_batch is not None:
457
- local_batch.global_num_tokens = global_num_tokens.tolist()
458
-
459
- # Check forward mode for cuda graph
460
- if not self.server_args.disable_cuda_graph:
461
- forward_mode_state = torch.tensor(
462
- (
463
- 1
464
- if local_batch.forward_mode.is_decode()
465
- or local_batch.forward_mode.is_idle()
466
- else 0
467
- ),
468
- dtype=torch.int32,
469
- )
470
- torch.distributed.all_reduce(
471
- forward_mode_state,
472
- op=torch.distributed.ReduceOp.MIN,
473
- group=self.tp_cpu_group,
474
- )
475
- local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
476
-
477
- return local_batch
478
-
479
- def get_idle_batch(self):
480
- idle_batch = ScheduleBatch.init_new(
481
- [],
482
- self.req_to_token_pool,
483
- self.token_to_kv_pool,
484
- self.tree_cache,
485
- self.model_config,
486
- self.enable_overlap,
487
- )
488
- idle_batch.prepare_for_idle()
489
- return idle_batch
490
-
491
439
  def recv_requests(self):
492
440
  if self.tp_rank == 0 or self.server_args.enable_dp_attention:
493
441
  recv_reqs = []
@@ -567,6 +515,9 @@ class Scheduler:
567
515
  recv_req.input_text,
568
516
  recv_req.input_ids,
569
517
  recv_req.sampling_params,
518
+ return_logprob=recv_req.return_logprob,
519
+ top_logprobs_num=recv_req.top_logprobs_num,
520
+ stream=recv_req.stream,
570
521
  lora_path=recv_req.lora_path,
571
522
  input_embeds=recv_req.input_embeds,
572
523
  )
@@ -610,9 +561,6 @@ class Scheduler:
610
561
  return
611
562
 
612
563
  # Copy more attributes
613
- req.return_logprob = recv_req.return_logprob
614
- req.top_logprobs_num = recv_req.top_logprobs_num
615
- req.stream = recv_req.stream
616
564
  req.logprob_start_len = recv_req.logprob_start_len
617
565
 
618
566
  if req.logprob_start_len == -1:
@@ -765,7 +713,7 @@ class Scheduler:
765
713
  if crash_on_warnings():
766
714
  raise ValueError(msg)
767
715
 
768
- def get_next_batch_to_run(self):
716
+ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
769
717
  # Merge the prefill batch into the running batch
770
718
  if self.last_batch and self.last_batch.forward_mode.is_extend():
771
719
  if self.being_chunked_req:
@@ -993,10 +941,11 @@ class Scheduler:
993
941
  self.process_batch_result_prefill(batch, result)
994
942
  elif batch.forward_mode.is_dummy_first():
995
943
  batch.next_batch_sampling_info.update_regex_vocab_mask()
996
- torch.cuda.current_stream().synchronize()
944
+ self.current_stream.synchronize()
997
945
  batch.next_batch_sampling_info.sampling_info_done.set()
998
946
 
999
947
  def process_batch_result_prefill(self, batch: ScheduleBatch, result):
948
+ skip_stream_req = None
1000
949
 
1001
950
  if self.is_generation:
1002
951
  logits_output, next_token_ids, bid = result
@@ -1033,7 +982,6 @@ class Scheduler:
1033
982
  continue
1034
983
 
1035
984
  if req.is_being_chunked <= 0:
1036
- req.completion_tokens_wo_jump_forward += 1
1037
985
  req.output_ids.append(next_token_id)
1038
986
  req.check_finished()
1039
987
 
@@ -1049,13 +997,18 @@ class Scheduler:
1049
997
 
1050
998
  if req.grammar is not None:
1051
999
  req.grammar.accept_token(next_token_id)
1000
+ req.grammar.finished = req.finished()
1052
1001
  else:
1053
1002
  # being chunked reqs' prefill is not finished
1054
1003
  req.is_being_chunked -= 1
1004
+ # There is only at most one request being currently chunked.
1005
+ # Because this request does not finish prefill,
1006
+ # we don't want to stream the request currently being chunked.
1007
+ skip_stream_req = req
1055
1008
 
1056
1009
  if batch.next_batch_sampling_info:
1057
1010
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1058
- torch.cuda.current_stream().synchronize()
1011
+ self.current_stream.synchronize()
1059
1012
  batch.next_batch_sampling_info.sampling_info_done.set()
1060
1013
 
1061
1014
  else: # embedding or reward model
@@ -1081,7 +1034,7 @@ class Scheduler:
1081
1034
  # being chunked reqs' prefill is not finished
1082
1035
  req.is_being_chunked -= 1
1083
1036
 
1084
- self.stream_output(batch.reqs)
1037
+ self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1085
1038
 
1086
1039
  def process_batch_result_decode(self, batch: ScheduleBatch, result):
1087
1040
  logits_output, next_token_ids, bid = result
@@ -1111,7 +1064,6 @@ class Scheduler:
1111
1064
  self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
1112
1065
  continue
1113
1066
 
1114
- req.completion_tokens_wo_jump_forward += 1
1115
1067
  req.output_ids.append(next_token_id)
1116
1068
  req.check_finished()
1117
1069
 
@@ -1119,21 +1071,26 @@ class Scheduler:
1119
1071
  self.tree_cache.cache_finished_req(req)
1120
1072
 
1121
1073
  if req.return_logprob:
1122
- req.output_token_logprobs.append(
1123
- (next_token_logprobs[i], next_token_id)
1124
- )
1074
+ req.output_token_logprobs_val.append(next_token_logprobs[i])
1075
+ req.output_token_logprobs_idx.append(next_token_id)
1125
1076
  if req.top_logprobs_num > 0:
1126
- req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
1077
+ req.output_top_logprobs_val.append(
1078
+ logits_output.output_top_logprobs_val[i]
1079
+ )
1080
+ req.output_top_logprobs_idx.append(
1081
+ logits_output.output_top_logprobs_idx[i]
1082
+ )
1127
1083
 
1128
1084
  if req.grammar is not None:
1129
1085
  req.grammar.accept_token(next_token_id)
1086
+ req.grammar.finished = req.finished()
1130
1087
 
1131
1088
  if batch.next_batch_sampling_info:
1132
1089
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1133
- torch.cuda.current_stream().synchronize()
1090
+ self.current_stream.synchronize()
1134
1091
  batch.next_batch_sampling_info.sampling_info_done.set()
1135
1092
 
1136
- self.stream_output(batch.reqs)
1093
+ self.stream_output(batch.reqs, batch.return_logprob)
1137
1094
 
1138
1095
  self.token_to_kv_pool.free_group_end()
1139
1096
 
@@ -1153,9 +1110,8 @@ class Scheduler:
1153
1110
  output: LogitsProcessorOutput,
1154
1111
  ):
1155
1112
  """Attach logprobs to the return values."""
1156
- req.output_token_logprobs.append(
1157
- (output.next_token_logprobs[i], next_token_ids[i])
1158
- )
1113
+ req.output_token_logprobs_val.append(output.next_token_logprobs[i])
1114
+ req.output_token_logprobs_idx.append(next_token_ids[i])
1159
1115
 
1160
1116
  # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
1161
1117
  num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
@@ -1163,170 +1119,251 @@ class Scheduler:
1163
1119
  if req.normalized_prompt_logprob is None:
1164
1120
  req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
1165
1121
 
1166
- if req.input_token_logprobs is None:
1167
- input_token_logprobs = output.input_token_logprobs[
1122
+ if req.input_token_logprobs_val is None:
1123
+ input_token_logprobs_val = output.input_token_logprobs[
1168
1124
  pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
1169
1125
  ]
1170
- input_token_ids = req.fill_ids[
1126
+
1127
+ input_token_logprobs_idx = req.fill_ids[
1171
1128
  len(req.fill_ids)
1172
1129
  - num_input_logprobs
1173
1130
  + 1 : len(req.fill_ids)
1174
1131
  - req.last_update_decode_tokens
1175
1132
  ]
1176
-
1177
1133
  # Clip the padded hash values from image tokens.
1178
1134
  # Otherwise, it will lead to detokenization errors.
1179
- input_token_ids = [
1135
+ input_token_logprobs_idx = [
1180
1136
  x if x < self.model_config.vocab_size - 1 else 0
1181
- for x in input_token_ids
1137
+ for x in input_token_logprobs_idx
1182
1138
  ]
1183
1139
 
1184
- req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
1185
-
1186
1140
  if (
1187
1141
  req.logprob_start_len == 0
1188
1142
  ): # The first token does not have logprob, pad it.
1189
- req.input_token_logprobs = [
1190
- (None, req.fill_ids[0])
1191
- ] + req.input_token_logprobs
1143
+ input_token_logprobs_val = [None] + input_token_logprobs_val
1144
+ input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx
1145
+
1146
+ req.input_token_logprobs_val = input_token_logprobs_val
1147
+ req.input_token_logprobs_idx = input_token_logprobs_idx
1192
1148
 
1193
1149
  if req.last_update_decode_tokens != 0:
1194
1150
  # Some decode tokens are re-computed in an extend batch
1195
- req.output_token_logprobs.extend(
1196
- list(
1197
- zip(
1198
- output.input_token_logprobs[
1199
- pt
1200
- + num_input_logprobs
1201
- - 1
1202
- - req.last_update_decode_tokens : pt
1203
- + num_input_logprobs
1204
- - 1
1205
- ],
1206
- req.fill_ids[
1207
- len(req.fill_ids)
1208
- - req.last_update_decode_tokens : len(req.fill_ids)
1209
- ],
1210
- )
1211
- )
1151
+ req.output_token_logprobs_val.extend(
1152
+ output.input_token_logprobs[
1153
+ pt
1154
+ + num_input_logprobs
1155
+ - 1
1156
+ - req.last_update_decode_tokens : pt
1157
+ + num_input_logprobs
1158
+ - 1
1159
+ ],
1160
+ )
1161
+ req.output_token_logprobs_idx.extend(
1162
+ req.fill_ids[
1163
+ len(req.fill_ids)
1164
+ - req.last_update_decode_tokens : len(req.fill_ids)
1165
+ ]
1212
1166
  )
1213
1167
 
1214
1168
  if req.top_logprobs_num > 0:
1215
- if req.input_top_logprobs is None:
1216
- req.input_top_logprobs = output.input_top_logprobs[i]
1169
+ if req.input_top_logprobs_val is None:
1170
+ req.input_top_logprobs_val = output.input_top_logprobs_val[i]
1171
+ req.input_top_logprobs_idx = output.input_top_logprobs_idx[i]
1217
1172
  if req.logprob_start_len == 0:
1218
- req.input_top_logprobs = [None] + req.input_top_logprobs
1173
+ req.input_top_logprobs_val = [None] + req.input_top_logprobs_val
1174
+ req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx
1219
1175
 
1220
1176
  if req.last_update_decode_tokens != 0:
1221
- req.output_top_logprobs.extend(
1222
- output.input_top_logprobs[i][-req.last_update_decode_tokens :]
1177
+ req.output_top_logprobs_val.extend(
1178
+ output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
1223
1179
  )
1224
- req.output_top_logprobs.append(output.output_top_logprobs[i])
1180
+ req.output_top_logprobs_idx.extend(
1181
+ output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
1182
+ )
1183
+ req.output_top_logprobs_val.append(output.output_top_logprobs_val[i])
1184
+ req.output_top_logprobs_idx.append(output.output_top_logprobs_idx[i])
1225
1185
 
1226
1186
  return num_input_logprobs
1227
1187
 
1228
- def stream_output(self, reqs: List[Req]):
1188
+ def stream_output(
1189
+ self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
1190
+ ):
1229
1191
  """Stream the output to detokenizer."""
1230
- output_rids = []
1231
- output_meta_info: List[dict] = []
1232
- output_finished_reason: List[BaseFinishReason] = []
1192
+ rids = []
1193
+ finished_reasons: List[BaseFinishReason] = []
1194
+
1233
1195
  if self.is_generation:
1234
- output_vids = []
1196
+ vids = []
1235
1197
  decoded_texts = []
1236
- output_read_ids = []
1237
- output_read_offsets = []
1198
+ decode_ids_list = []
1199
+ read_offsets = []
1238
1200
  output_ids = []
1239
- output_skip_special_tokens = []
1240
- output_spaces_between_special_tokens = []
1241
- output_no_stop_trim = []
1242
- else: # embedding or reward model
1243
- output_embeddings = []
1244
1201
 
1245
- is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
1202
+ skip_special_tokens = []
1203
+ spaces_between_special_tokens = []
1204
+ no_stop_trim = []
1205
+ prompt_tokens = []
1206
+ completion_tokens = []
1207
+ cached_tokens = []
1208
+
1209
+ if return_logprob:
1210
+ input_token_logprobs_val = []
1211
+ input_token_logprobs_idx = []
1212
+ output_token_logprobs_val = []
1213
+ output_token_logprobs_idx = []
1214
+ input_top_logprobs_val = []
1215
+ input_top_logprobs_idx = []
1216
+ output_top_logprobs_val = []
1217
+ output_top_logprobs_idx = []
1218
+ normalized_prompt_logprob = []
1219
+ else:
1220
+ input_token_logprobs_val = input_token_logprobs_idx = (
1221
+ output_token_logprobs_val
1222
+ ) = output_token_logprobs_idx = input_top_logprobs_val = (
1223
+ input_top_logprobs_idx
1224
+ ) = output_top_logprobs_val = output_top_logprobs_idx = (
1225
+ normalized_prompt_logprob
1226
+ ) = None
1227
+
1228
+ for req in reqs:
1229
+ if req is skip_req:
1230
+ continue
1246
1231
 
1247
- for req in reqs:
1248
- # TODO(lianmin): revisit this for overlap + retract + stream
1249
- if req.finished() or (
1250
- req.stream and (is_stream_iter or len(req.output_ids) == 1)
1251
- ):
1252
- output_rids.append(req.rid)
1253
- output_finished_reason.append(req.finished_reason)
1254
- if self.is_generation:
1255
- output_vids.append(req.vid)
1232
+ # TODO(lianmin): revisit this for overlap + retract + stream
1233
+ if (
1234
+ req.finished()
1235
+ # If stream, follow the given stream_interval
1236
+ or (req.stream and len(req.output_ids) % self.stream_interval == 0)
1237
+ # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
1238
+ or (not req.stream and len(req.output_ids) % 50 == 0)
1239
+ ):
1240
+ rids.append(req.rid)
1241
+ finished_reasons.append(
1242
+ req.finished_reason.to_json() if req.finished_reason else None
1243
+ )
1244
+ vids.append(req.vid)
1256
1245
  decoded_texts.append(req.decoded_text)
1257
- read_ids, read_offset = req.init_incremental_detokenize()
1258
- output_read_ids.append(read_ids)
1259
- output_read_offsets.append(read_offset)
1246
+ decode_ids, read_offset = req.init_incremental_detokenize()
1247
+ decode_ids_list.append(decode_ids)
1248
+ read_offsets.append(read_offset)
1260
1249
  if self.skip_tokenizer_init:
1261
1250
  output_ids.append(req.output_ids)
1262
- output_skip_special_tokens.append(
1263
- req.sampling_params.skip_special_tokens
1264
- )
1265
- output_spaces_between_special_tokens.append(
1251
+ skip_special_tokens.append(req.sampling_params.skip_special_tokens)
1252
+ spaces_between_special_tokens.append(
1266
1253
  req.sampling_params.spaces_between_special_tokens
1267
1254
  )
1268
- output_no_stop_trim.append(req.sampling_params.no_stop_trim)
1269
-
1270
- meta_info = {
1271
- "prompt_tokens": len(req.origin_input_ids),
1272
- "completion_tokens": len(req.output_ids),
1273
- "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
1274
- "cached_tokens": req.cached_tokens,
1275
- "finish_reason": (
1276
- req.finished_reason.to_json()
1277
- if req.finished_reason is not None
1278
- else None
1279
- ),
1280
- }
1281
- if req.return_logprob:
1282
- (
1283
- meta_info["input_token_logprobs"],
1284
- meta_info["output_token_logprobs"],
1285
- meta_info["input_top_logprobs"],
1286
- meta_info["output_top_logprobs"],
1287
- meta_info["normalized_prompt_logprob"],
1288
- ) = (
1289
- req.input_token_logprobs,
1290
- req.output_token_logprobs,
1291
- req.input_top_logprobs,
1292
- req.output_top_logprobs,
1293
- req.normalized_prompt_logprob,
1294
- )
1295
- output_meta_info.append(meta_info)
1296
- else: # embedding or reward model
1297
- output_embeddings.append(req.embedding)
1298
- meta_info = {
1299
- "prompt_tokens": len(req.origin_input_ids),
1300
- }
1301
- output_meta_info.append(meta_info)
1302
-
1303
- # Send to detokenizer
1304
- if output_rids:
1305
- if self.is_generation:
1255
+ no_stop_trim.append(req.sampling_params.no_stop_trim)
1256
+
1257
+ prompt_tokens.append(len(req.origin_input_ids))
1258
+ completion_tokens.append(len(req.output_ids))
1259
+ cached_tokens.append(req.cached_tokens)
1260
+
1261
+ if return_logprob:
1262
+ input_token_logprobs_val.append(req.input_token_logprobs_val)
1263
+ input_token_logprobs_idx.append(req.input_token_logprobs_idx)
1264
+ output_token_logprobs_val.append(req.output_token_logprobs_val)
1265
+ output_token_logprobs_idx.append(req.output_token_logprobs_idx)
1266
+ input_top_logprobs_val.append(req.input_top_logprobs_val)
1267
+ input_top_logprobs_idx.append(req.input_top_logprobs_idx)
1268
+ output_top_logprobs_val.append(req.output_top_logprobs_val)
1269
+ output_top_logprobs_idx.append(req.output_top_logprobs_idx)
1270
+ normalized_prompt_logprob.append(req.normalized_prompt_logprob)
1271
+
1272
+ # Send to detokenizer
1273
+ if rids:
1306
1274
  self.send_to_detokenizer.send_pyobj(
1307
1275
  BatchTokenIDOut(
1308
- output_rids,
1309
- output_vids,
1276
+ rids,
1277
+ finished_reasons,
1278
+ vids,
1310
1279
  decoded_texts,
1311
- output_read_ids,
1312
- output_read_offsets,
1280
+ decode_ids_list,
1281
+ read_offsets,
1313
1282
  output_ids,
1314
- output_skip_special_tokens,
1315
- output_spaces_between_special_tokens,
1316
- output_meta_info,
1317
- output_finished_reason,
1318
- output_no_stop_trim,
1283
+ skip_special_tokens,
1284
+ spaces_between_special_tokens,
1285
+ no_stop_trim,
1286
+ prompt_tokens,
1287
+ completion_tokens,
1288
+ cached_tokens,
1289
+ input_token_logprobs_val,
1290
+ input_token_logprobs_idx,
1291
+ output_token_logprobs_val,
1292
+ output_token_logprobs_idx,
1293
+ input_top_logprobs_val,
1294
+ input_top_logprobs_idx,
1295
+ output_top_logprobs_val,
1296
+ output_top_logprobs_idx,
1297
+ normalized_prompt_logprob,
1319
1298
  )
1320
1299
  )
1321
- else: # embedding or reward model
1322
- self.send_to_detokenizer.send_pyobj(
1323
- BatchEmbeddingOut(
1324
- output_rids,
1325
- output_embeddings,
1326
- output_meta_info,
1327
- output_finished_reason,
1328
- )
1300
+ else: # embedding or reward model
1301
+ embeddings = []
1302
+ prompt_tokens = []
1303
+ for req in reqs:
1304
+ assert req.finished()
1305
+ rids.append(req.rid)
1306
+ finished_reasons.append(req.finished_reason.to_json())
1307
+ embeddings.append(req.embedding)
1308
+ prompt_tokens.append(len(req.origin_input_ids))
1309
+ self.send_to_detokenizer.send_pyobj(
1310
+ BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
1311
+ )
1312
+
1313
+ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1314
+ # Check if other DP workers have running batches
1315
+ if local_batch is None:
1316
+ num_tokens = 0
1317
+ elif local_batch.forward_mode.is_decode():
1318
+ num_tokens = local_batch.batch_size()
1319
+ else:
1320
+ num_tokens = local_batch.extend_num_tokens
1321
+
1322
+ local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
1323
+ global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
1324
+ torch.distributed.all_gather_into_tensor(
1325
+ global_num_tokens,
1326
+ local_num_tokens,
1327
+ group=self.tp_cpu_group,
1328
+ )
1329
+
1330
+ if local_batch is None and global_num_tokens.max().item() > 0:
1331
+ local_batch = self.get_idle_batch()
1332
+
1333
+ if local_batch is not None:
1334
+ local_batch.global_num_tokens = global_num_tokens.tolist()
1335
+
1336
+ # Check forward mode for cuda graph
1337
+ if not self.server_args.disable_cuda_graph:
1338
+ forward_mode_state = torch.tensor(
1339
+ (
1340
+ 1
1341
+ if local_batch.forward_mode.is_decode()
1342
+ or local_batch.forward_mode.is_idle()
1343
+ else 0
1344
+ ),
1345
+ dtype=torch.int32,
1329
1346
  )
1347
+ torch.distributed.all_reduce(
1348
+ forward_mode_state,
1349
+ op=torch.distributed.ReduceOp.MIN,
1350
+ group=self.tp_cpu_group,
1351
+ )
1352
+ local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
1353
+
1354
+ return local_batch
1355
+
1356
+ def get_idle_batch(self):
1357
+ idle_batch = ScheduleBatch.init_new(
1358
+ [],
1359
+ self.req_to_token_pool,
1360
+ self.token_to_kv_pool,
1361
+ self.tree_cache,
1362
+ self.model_config,
1363
+ self.enable_overlap,
1364
+ )
1365
+ idle_batch.prepare_for_idle()
1366
+ return idle_batch
1330
1367
 
1331
1368
  def move_ready_grammar_requests(self):
1332
1369
  """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
@@ -1469,9 +1506,7 @@ def run_scheduler_process(
1469
1506
  dp_rank: Optional[int],
1470
1507
  pipe_writer,
1471
1508
  ):
1472
- # set cpu affinity to this gpu process
1473
- if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
1474
- set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
1509
+ setproctitle.setproctitle("sglang::scheduler")
1475
1510
 
1476
1511
  # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
1477
1512
  if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
@@ -1482,6 +1517,10 @@ def run_scheduler_process(
1482
1517
  else:
1483
1518
  configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
1484
1519
 
1520
+ # set cpu affinity to this gpu process
1521
+ if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
1522
+ set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
1523
+
1485
1524
  suppress_other_loggers()
1486
1525
  parent_process = psutil.Process().parent()
1487
1526