sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_latency.py +1 -553
  4. sglang/bench_offline_throughput.py +48 -20
  5. sglang/bench_one_batch.py +472 -0
  6. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  7. sglang/bench_serving.py +125 -6
  8. sglang/check_env.py +3 -6
  9. sglang/lang/backend/base_backend.py +1 -1
  10. sglang/lang/backend/runtime_endpoint.py +2 -2
  11. sglang/srt/configs/model_config.py +13 -14
  12. sglang/srt/constrained/__init__.py +13 -14
  13. sglang/srt/constrained/base_grammar_backend.py +13 -15
  14. sglang/srt/constrained/outlines_backend.py +28 -17
  15. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  16. sglang/srt/constrained/xgrammar_backend.py +47 -58
  17. sglang/srt/conversation.py +13 -15
  18. sglang/srt/hf_transformers_utils.py +13 -15
  19. sglang/srt/layers/activation.py +16 -13
  20. sglang/srt/layers/attention/flashinfer_backend.py +106 -54
  21. sglang/srt/layers/attention/triton_backend.py +9 -7
  22. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  23. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  24. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  25. sglang/srt/layers/custom_op_util.py +25 -0
  26. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  27. sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
  28. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  29. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  30. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  31. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  32. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  33. sglang/srt/layers/layernorm.py +17 -15
  34. sglang/srt/layers/logits_processor.py +23 -25
  35. sglang/srt/layers/quantization/__init__.py +77 -17
  36. sglang/srt/layers/radix_attention.py +13 -15
  37. sglang/srt/layers/rotary_embedding.py +13 -13
  38. sglang/srt/layers/sampler.py +4 -8
  39. sglang/srt/layers/torchao_utils.py +2 -0
  40. sglang/srt/lora/lora.py +13 -14
  41. sglang/srt/lora/lora_config.py +13 -14
  42. sglang/srt/lora/lora_manager.py +22 -24
  43. sglang/srt/managers/data_parallel_controller.py +98 -27
  44. sglang/srt/managers/detokenizer_manager.py +13 -15
  45. sglang/srt/managers/io_struct.py +63 -21
  46. sglang/srt/managers/schedule_batch.py +154 -59
  47. sglang/srt/managers/schedule_policy.py +18 -16
  48. sglang/srt/managers/scheduler.py +278 -109
  49. sglang/srt/managers/session_controller.py +61 -0
  50. sglang/srt/managers/tokenizer_manager.py +63 -18
  51. sglang/srt/managers/tp_worker.py +25 -16
  52. sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
  53. sglang/srt/metrics/collector.py +13 -15
  54. sglang/srt/metrics/func_timer.py +13 -15
  55. sglang/srt/mm_utils.py +13 -14
  56. sglang/srt/model_executor/cuda_graph_runner.py +63 -25
  57. sglang/srt/model_executor/forward_batch_info.py +128 -32
  58. sglang/srt/model_executor/model_runner.py +132 -64
  59. sglang/srt/model_parallel.py +98 -0
  60. sglang/srt/models/chatglm.py +15 -16
  61. sglang/srt/models/commandr.py +15 -16
  62. sglang/srt/models/dbrx.py +15 -16
  63. sglang/srt/models/deepseek.py +15 -15
  64. sglang/srt/models/deepseek_v2.py +162 -59
  65. sglang/srt/models/exaone.py +14 -15
  66. sglang/srt/models/gemma.py +14 -14
  67. sglang/srt/models/gemma2.py +31 -25
  68. sglang/srt/models/gemma2_reward.py +13 -14
  69. sglang/srt/models/gpt_bigcode.py +14 -14
  70. sglang/srt/models/grok.py +15 -15
  71. sglang/srt/models/internlm2.py +13 -15
  72. sglang/srt/models/internlm2_reward.py +13 -14
  73. sglang/srt/models/llama.py +21 -21
  74. sglang/srt/models/llama_classification.py +13 -14
  75. sglang/srt/models/llama_reward.py +13 -14
  76. sglang/srt/models/llava.py +14 -16
  77. sglang/srt/models/llavavid.py +14 -16
  78. sglang/srt/models/minicpm.py +13 -15
  79. sglang/srt/models/minicpm3.py +13 -15
  80. sglang/srt/models/mistral.py +13 -15
  81. sglang/srt/models/mixtral.py +15 -15
  82. sglang/srt/models/mixtral_quant.py +14 -14
  83. sglang/srt/models/olmo.py +22 -20
  84. sglang/srt/models/olmoe.py +23 -20
  85. sglang/srt/models/phi3_small.py +447 -0
  86. sglang/srt/models/qwen.py +14 -14
  87. sglang/srt/models/qwen2.py +22 -19
  88. sglang/srt/models/qwen2_moe.py +17 -18
  89. sglang/srt/models/qwen2_vl.py +13 -6
  90. sglang/srt/models/stablelm.py +18 -16
  91. sglang/srt/models/torch_native_llama.py +107 -93
  92. sglang/srt/models/xverse.py +13 -14
  93. sglang/srt/models/xverse_moe.py +15 -16
  94. sglang/srt/models/yivl.py +13 -15
  95. sglang/srt/openai_api/adapter.py +19 -17
  96. sglang/srt/openai_api/protocol.py +14 -16
  97. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  98. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  99. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  100. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  101. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  102. sglang/srt/sampling/sampling_batch_info.py +61 -57
  103. sglang/srt/sampling/sampling_params.py +14 -16
  104. sglang/srt/server.py +86 -35
  105. sglang/srt/server_args.py +96 -80
  106. sglang/srt/utils.py +266 -68
  107. sglang/test/few_shot_gsm8k.py +8 -4
  108. sglang/test/runners.py +38 -20
  109. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  110. sglang/test/test_utils.py +31 -20
  111. sglang/version.py +1 -1
  112. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
  113. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
  114. sglang-0.3.6.post1.dist-info/RECORD +164 -0
  115. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
  116. sglang/srt/layers/fused_moe/__init__.py +0 -1
  117. sglang-0.3.5.post2.dist-info/RECORD +0 -156
  118. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """A scheduler that manages a tensor parallel GPU worker."""
17
15
 
18
16
  import logging
@@ -36,9 +34,12 @@ from sglang.srt.managers.io_struct import (
36
34
  AbortReq,
37
35
  BatchEmbeddingOut,
38
36
  BatchTokenIDOut,
37
+ CloseSessionReqInput,
39
38
  FlushCacheReq,
40
39
  GetMemPoolSizeReq,
41
40
  GetMemPoolSizeReqOutput,
41
+ OpenSessionReqInput,
42
+ OpenSessionReqOutput,
42
43
  ProfileReq,
43
44
  TokenizedEmbeddingReqInput,
44
45
  TokenizedGenerateReqInput,
@@ -58,16 +59,20 @@ from sglang.srt.managers.schedule_policy import (
58
59
  PrefillAdder,
59
60
  SchedulePolicy,
60
61
  )
62
+ from sglang.srt.managers.session_controller import Session
61
63
  from sglang.srt.managers.tp_worker import TpModelWorker
62
64
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
63
65
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
64
66
  from sglang.srt.mem_cache.radix_cache import RadixCache
65
67
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
68
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
66
69
  from sglang.srt.server_args import PortArgs, ServerArgs
67
70
  from sglang.srt.utils import (
68
71
  broadcast_pyobj,
69
72
  configure_logger,
73
+ crash_on_warnings,
70
74
  get_zmq_socket,
75
+ gpu_proc_affinity,
71
76
  kill_parent_process,
72
77
  set_random_seed,
73
78
  suppress_other_loggers,
@@ -76,12 +81,8 @@ from sglang.utils import get_exception_traceback
76
81
 
77
82
  logger = logging.getLogger(__name__)
78
83
 
79
-
80
- # Crash on warning if we are running CI tests
81
- crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
82
-
83
84
  # Test retract decode
84
- test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"
85
+ test_retract = os.getenv("SGLANG_TEST_RETRACT", "false").lower() == "true"
85
86
 
86
87
 
87
88
  class Scheduler:
@@ -103,14 +104,17 @@ class Scheduler:
103
104
  self.disable_jump_forward = server_args.disable_jump_forward
104
105
  self.lora_paths = server_args.lora_paths
105
106
  self.max_loras_per_batch = server_args.max_loras_per_batch
106
- self.enable_overlap = server_args.enable_overlap_schedule
107
+ self.enable_overlap = not server_args.disable_overlap_schedule
107
108
  self.skip_tokenizer_init = server_args.skip_tokenizer_init
108
109
  self.enable_metrics = server_args.enable_metrics
109
110
 
111
+ # Session info
112
+ self.sessions = {}
113
+
110
114
  # Init inter-process communication
111
115
  context = zmq.Context(2)
112
116
 
113
- if self.tp_rank == 0:
117
+ if self.tp_rank == 0 or self.server_args.enable_dp_attention:
114
118
  self.recv_from_tokenizer = get_zmq_socket(
115
119
  context, zmq.PULL, port_args.scheduler_input_ipc_name
116
120
  )
@@ -160,6 +164,14 @@ class Scheduler:
160
164
  trust_remote_code=server_args.trust_remote_code,
161
165
  )
162
166
 
167
+ # Check whether overlap can be enabled
168
+ if not self.is_generation:
169
+ self.enable_overlap = False
170
+ logger.info("Overlap scheduler is disabled for embedding models.")
171
+
172
+ if self.enable_overlap:
173
+ self.disable_jump_forward = True
174
+
163
175
  # Launch a tensor parallel worker
164
176
  if self.enable_overlap:
165
177
  TpWorkerClass = TpModelWorkerClient
@@ -223,8 +235,12 @@ class Scheduler:
223
235
 
224
236
  # Init running status
225
237
  self.waiting_queue: List[Req] = []
238
+ # The running decoding batch for continuous batching
226
239
  self.running_batch: Optional[ScheduleBatch] = None
240
+ # The current forward batch
227
241
  self.cur_batch: Optional[ScheduleBatch] = None
242
+ # The current forward batch
243
+ self.last_batch: Optional[ScheduleBatch] = None
228
244
  self.forward_ct = 0
229
245
  self.forward_ct_decode = 0
230
246
  self.num_generated_tokens = 0
@@ -286,6 +302,9 @@ class Scheduler:
286
302
  ) / global_config.default_new_token_ratio_decay_steps
287
303
  self.new_token_ratio = self.init_new_token_ratio
288
304
 
305
+ # Tells whether the current running batch is full so that we can skip
306
+ # the check of whether to prefill new requests.
307
+ # This is an optimization to reduce the overhead of the prefill check.
289
308
  self.batch_is_full = False
290
309
 
291
310
  # Init watchdog thread
@@ -337,46 +356,34 @@ class Scheduler:
337
356
 
338
357
  kill_parent_process()
339
358
 
340
- @torch.inference_mode()
359
+ @torch.no_grad()
341
360
  def event_loop_normal(self):
342
- """A normal blocking scheduler loop."""
343
- self.last_batch = None
344
-
361
+ """A normal scheduler loop."""
345
362
  while True:
346
363
  recv_reqs = self.recv_requests()
347
364
  self.process_input_requests(recv_reqs)
348
365
 
349
366
  batch = self.get_next_batch_to_run()
367
+ if self.server_args.enable_dp_attention:
368
+ batch = self.prepare_dp_attn_batch(batch)
369
+
350
370
  self.cur_batch = batch
351
371
 
352
372
  if batch:
353
373
  result = self.run_batch(batch)
354
374
  self.process_batch_result(batch, result)
355
-
356
- # Decode multiple steps to reduce the overhead
357
- if batch.forward_mode.is_decode():
358
- for _ in range(self.server_args.num_continuous_decode_steps - 1):
359
- if not self.running_batch:
360
- break
361
- self.update_running_batch()
362
- if not self.running_batch:
363
- break
364
- result = self.run_batch(batch)
365
- self.process_batch_result(batch, result)
366
375
  else:
376
+ # Self-check and re-init some states when the server is idle
367
377
  self.check_memory()
368
378
  self.new_token_ratio = self.init_new_token_ratio
369
379
 
370
380
  self.last_batch = batch
371
381
 
372
- @torch.inference_mode()
382
+ @torch.no_grad()
373
383
  def event_loop_overlap(self):
374
384
  """A scheduler loop that overlaps the CPU processing and GPU computation."""
375
385
  result_queue = deque()
376
386
 
377
- self.last_batch = None
378
- self.running_batch = None
379
-
380
387
  while True:
381
388
  recv_reqs = self.recv_requests()
382
389
  self.process_input_requests(recv_reqs)
@@ -387,17 +394,86 @@ class Scheduler:
387
394
  result = self.run_batch(batch)
388
395
  result_queue.append((batch.copy(), result))
389
396
 
397
+ if self.last_batch is None:
398
+ # A dummy first batch to start the pipeline for overlap scheduler.
399
+ # It is now used for triggering the sampling_info_done event.
400
+ tmp_batch = ScheduleBatch(
401
+ reqs=None,
402
+ forward_mode=ForwardMode.DUMMY_FIRST,
403
+ next_batch_sampling_info=self.tp_worker.cur_sampling_info,
404
+ )
405
+ self.process_batch_result(tmp_batch, None)
406
+
390
407
  if self.last_batch:
391
408
  tmp_batch, tmp_result = result_queue.popleft()
409
+ tmp_batch.next_batch_sampling_info = (
410
+ self.tp_worker.cur_sampling_info if batch else None
411
+ )
392
412
  self.process_batch_result(tmp_batch, tmp_result)
393
413
  elif batch is None:
414
+ # Self-check and re-init some states when the server is idle
394
415
  self.check_memory()
395
416
  self.new_token_ratio = self.init_new_token_ratio
396
417
 
397
418
  self.last_batch = batch
398
419
 
420
+ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
421
+ # Check if other DP workers have running batches
422
+ if local_batch is None:
423
+ num_tokens = 0
424
+ elif local_batch.forward_mode.is_decode():
425
+ num_tokens = local_batch.batch_size()
426
+ else:
427
+ num_tokens = local_batch.extend_num_tokens
428
+
429
+ local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
430
+ global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
431
+ torch.distributed.all_gather_into_tensor(
432
+ global_num_tokens,
433
+ local_num_tokens,
434
+ group=self.tp_cpu_group,
435
+ )
436
+
437
+ if local_batch is None and global_num_tokens.max().item() > 0:
438
+ local_batch = self.get_idle_batch()
439
+
440
+ if local_batch is not None:
441
+ local_batch.global_num_tokens = global_num_tokens.tolist()
442
+
443
+ # Check forward mode for cuda graph
444
+ if not self.server_args.disable_cuda_graph:
445
+ forward_mode_state = torch.tensor(
446
+ (
447
+ 1
448
+ if local_batch.forward_mode.is_decode()
449
+ or local_batch.forward_mode.is_idle()
450
+ else 0
451
+ ),
452
+ dtype=torch.int32,
453
+ )
454
+ torch.distributed.all_reduce(
455
+ forward_mode_state,
456
+ op=torch.distributed.ReduceOp.MIN,
457
+ group=self.tp_cpu_group,
458
+ )
459
+ local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
460
+
461
+ return local_batch
462
+
463
+ def get_idle_batch(self):
464
+ idle_batch = ScheduleBatch.init_new(
465
+ [],
466
+ self.req_to_token_pool,
467
+ self.token_to_kv_pool,
468
+ self.tree_cache,
469
+ self.model_config,
470
+ self.enable_overlap,
471
+ )
472
+ idle_batch.prepare_for_idle()
473
+ return idle_batch
474
+
399
475
  def recv_requests(self):
400
- if self.tp_rank == 0:
476
+ if self.tp_rank == 0 or self.server_args.enable_dp_attention:
401
477
  recv_reqs = []
402
478
 
403
479
  while True:
@@ -409,7 +485,7 @@ class Scheduler:
409
485
  else:
410
486
  recv_reqs = None
411
487
 
412
- if self.tp_size != 1:
488
+ if self.tp_size != 1 and not self.server_args.enable_dp_attention:
413
489
  recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
414
490
  return recv_reqs
415
491
 
@@ -433,6 +509,11 @@ class Scheduler:
433
509
  self.start_profile()
434
510
  else:
435
511
  self.stop_profile()
512
+ elif isinstance(recv_req, OpenSessionReqInput):
513
+ session_id = self.open_session(recv_req)
514
+ self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
515
+ elif isinstance(recv_req, CloseSessionReqInput):
516
+ self.close_session(recv_req)
436
517
  elif isinstance(recv_req, GetMemPoolSizeReq):
437
518
  self.send_to_tokenizer.send_pyobj(
438
519
  GetMemPoolSizeReqOutput(self.max_total_num_tokens)
@@ -444,14 +525,37 @@ class Scheduler:
444
525
  self,
445
526
  recv_req: TokenizedGenerateReqInput,
446
527
  ):
447
- req = Req(
448
- recv_req.rid,
449
- recv_req.input_text,
450
- recv_req.input_ids,
451
- recv_req.sampling_params,
452
- lora_path=recv_req.lora_path,
453
- )
454
- req.tokenizer = self.tokenizer
528
+ if recv_req.session_id is None or recv_req.session_id not in self.sessions:
529
+ # Create a new request
530
+ if recv_req.input_embeds is not None:
531
+ # Generate fake input_ids based on the length of input_embeds
532
+ seq_length = len(recv_req.input_embeds)
533
+ fake_input_ids = [1] * seq_length
534
+ recv_req.input_ids = fake_input_ids
535
+
536
+ req = Req(
537
+ recv_req.rid,
538
+ recv_req.input_text,
539
+ recv_req.input_ids,
540
+ recv_req.sampling_params,
541
+ lora_path=recv_req.lora_path,
542
+ input_embeds=recv_req.input_embeds,
543
+ )
544
+ req.tokenizer = self.tokenizer
545
+
546
+ if recv_req.session_id is not None:
547
+ req.finished_reason = FINISH_ABORT(
548
+ f"Invalid request: session id {recv_req.session_id} does not exist"
549
+ )
550
+ self.waiting_queue.append(req)
551
+ return
552
+ else:
553
+ # Create a new request from a previsou session
554
+ session = self.sessions[recv_req.session_id]
555
+ req = session.create_req(recv_req, self.tokenizer)
556
+ if isinstance(req.finished_reason, FINISH_ABORT):
557
+ self.waiting_queue.append(req)
558
+ return
455
559
 
456
560
  # Image inputs
457
561
  if recv_req.image_inputs is not None:
@@ -462,6 +566,15 @@ class Scheduler:
462
566
  req.origin_input_ids_unpadded, req.image_inputs
463
567
  )
464
568
 
569
+ if len(req.origin_input_ids) > self.max_req_input_len:
570
+ req.finished_reason = FINISH_ABORT(
571
+ "Image request length is longer than the KV cache pool size or "
572
+ "the max context length aborting because you cannot truncate the image embeds"
573
+ )
574
+ req.sampling_params.max_new_tokens = 0
575
+ self.waiting_queue.append(req)
576
+ return
577
+
465
578
  req.return_logprob = recv_req.return_logprob
466
579
  req.top_logprobs_num = recv_req.top_logprobs_num
467
580
  req.stream = recv_req.stream
@@ -599,58 +712,50 @@ class Scheduler:
599
712
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
600
713
  )
601
714
  if available_size != self.max_total_num_tokens:
602
- warnings.warn(
603
- "Warning: "
604
- f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
715
+ msg = (
605
716
  "KV cache pool leak detected!"
717
+ f"{available_size=}, {self.max_total_num_tokens=}\n"
606
718
  )
607
- exit(1) if crash_on_warning else None
719
+ warnings.warn(msg)
720
+ if crash_on_warnings():
721
+ raise ValueError(msg)
608
722
 
609
723
  if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
610
- warnings.warn(
611
- "Warning: "
612
- f"available req slots={len(self.req_to_token_pool.free_slots)}, "
613
- f"total slots={self.req_to_token_pool.size}\n"
724
+ msg = (
614
725
  "Memory pool leak detected!"
726
+ f"available_size={len(self.req_to_token_pool.free_slots)}, "
727
+ f"total_size={self.req_to_token_pool.size}\n"
615
728
  )
616
- exit(1) if crash_on_warning else None
729
+ warnings.warn(msg)
730
+ if crash_on_warnings():
731
+ raise ValueError(msg)
617
732
 
618
733
  def get_next_batch_to_run(self):
619
734
  # Merge the prefill batch into the running batch
620
- if (
621
- self.last_batch
622
- and not self.last_batch.forward_mode.is_decode()
623
- and not self.last_batch.is_empty()
624
- ):
735
+ if self.last_batch and self.last_batch.forward_mode.is_extend():
625
736
  if self.being_chunked_req:
737
+ # Move the chunked request out of the batch
626
738
  self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
627
739
  self.tree_cache.cache_unfinished_req(self.being_chunked_req)
628
- # Inflight request keeps its rid but will get a new req_pool_idx.
740
+ # Inflight request keeps its rid but will get a new req_pool_idx
629
741
  self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
630
742
  self.batch_is_full = False
743
+
631
744
  if not self.last_batch.is_empty():
632
745
  if self.running_batch is None:
633
746
  self.running_batch = self.last_batch
634
747
  else:
635
748
  self.running_batch.merge_batch(self.last_batch)
636
749
 
637
- # Prefill first
750
+ # Run prefill first if possible
638
751
  new_batch = self.get_new_batch_prefill()
639
752
  if new_batch is not None:
640
753
  return new_batch
641
754
 
642
- # Check memory
643
- if self.running_batch is None:
644
- return
645
-
646
755
  # Run decode
647
- before_bs = self.running_batch.batch_size()
648
- self.update_running_batch()
649
- if not self.running_batch:
650
- self.batch_is_full = False
756
+ if self.running_batch is None:
651
757
  return None
652
- if before_bs != self.running_batch.batch_size():
653
- self.batch_is_full = False
758
+ self.running_batch = self.update_running_batch(self.running_batch)
654
759
  return self.running_batch
655
760
 
656
761
  def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
@@ -746,14 +851,20 @@ class Scheduler:
746
851
  self.token_to_kv_pool,
747
852
  self.tree_cache,
748
853
  self.model_config,
854
+ self.enable_overlap,
749
855
  )
750
856
  new_batch.prepare_for_extend()
751
857
 
752
858
  # Mixed-style chunked prefill
753
- if self.is_mixed_chunk and self.running_batch is not None:
859
+ if (
860
+ self.is_mixed_chunk
861
+ and self.running_batch is not None
862
+ and not (new_batch.return_logprob or self.running_batch.return_logprob)
863
+ ):
864
+ # TODO (lianmin): support return_logprob + mixed chunked prefill
754
865
  self.running_batch.filter_batch()
755
866
  if not self.running_batch.is_empty():
756
- self.running_batch.prepare_for_decode(self.enable_overlap)
867
+ self.running_batch.prepare_for_decode()
757
868
  new_batch.mix_with_running(self.running_batch)
758
869
  new_batch.decoding_reqs = self.running_batch.reqs
759
870
  self.running_batch = None
@@ -762,15 +873,16 @@ class Scheduler:
762
873
 
763
874
  return new_batch
764
875
 
765
- def update_running_batch(self):
876
+ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
766
877
  """Update the current running decoding batch."""
767
878
  global test_retract
768
- batch = self.running_batch
879
+
880
+ initial_bs = batch.batch_size()
769
881
 
770
882
  batch.filter_batch()
771
883
  if batch.is_empty():
772
- self.running_batch = None
773
- return
884
+ self.batch_is_full = False
885
+ return None
774
886
 
775
887
  # Check if decode out of memory
776
888
  if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
@@ -796,11 +908,15 @@ class Scheduler:
796
908
  jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
797
909
  self.waiting_queue.extend(jump_forward_reqs)
798
910
  if batch.is_empty():
799
- self.running_batch = None
800
- return
911
+ self.batch_is_full = False
912
+ return None
913
+
914
+ if batch.batch_size() < initial_bs:
915
+ self.batch_is_full = False
801
916
 
802
917
  # Update batch tensors
803
- batch.prepare_for_decode(self.enable_overlap)
918
+ batch.prepare_for_decode()
919
+ return batch
804
920
 
805
921
  def run_batch(self, batch: ScheduleBatch):
806
922
  """Run a batch."""
@@ -812,6 +928,10 @@ class Scheduler:
812
928
  logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
813
929
  model_worker_batch
814
930
  )
931
+ elif batch.forward_mode.is_idle():
932
+ model_worker_batch = batch.get_model_worker_batch()
933
+ self.tp_worker.forward_batch_idle(model_worker_batch)
934
+ return
815
935
  else:
816
936
  logits_output = None
817
937
  if self.skip_tokenizer_init:
@@ -834,8 +954,12 @@ class Scheduler:
834
954
  self.process_batch_result_decode(batch, result)
835
955
  if batch.is_empty():
836
956
  self.running_batch = None
837
- else:
957
+ elif batch.forward_mode.is_extend():
838
958
  self.process_batch_result_prefill(batch, result)
959
+ elif batch.forward_mode.is_dummy_first():
960
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
961
+ torch.cuda.current_stream().synchronize()
962
+ batch.next_batch_sampling_info.sampling_info_done.set()
839
963
 
840
964
  def process_batch_result_prefill(self, batch: ScheduleBatch, result):
841
965
 
@@ -843,7 +967,7 @@ class Scheduler:
843
967
  logits_output, next_token_ids, bid = result
844
968
 
845
969
  if self.enable_overlap:
846
- logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
970
+ logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
847
971
  else:
848
972
  # Move next_token_ids and logprobs to cpu
849
973
  if batch.return_logprob:
@@ -863,14 +987,19 @@ class Scheduler:
863
987
 
864
988
  # Check finish conditions
865
989
  logprob_pt = 0
866
- for i, req in enumerate(batch.reqs):
990
+ for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
867
991
  if req.is_retracted:
868
992
  continue
869
993
 
994
+ if self.is_mixed_chunk and self.enable_overlap and req.finished():
995
+ # Free the one delayed token for the mixed decode batch
996
+ j = len(batch.out_cache_loc) - len(batch.reqs) + i
997
+ self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
998
+ continue
999
+
870
1000
  if req.is_being_chunked <= 0:
871
- # Inflight reqs' prefill is not finished
872
1001
  req.completion_tokens_wo_jump_forward += 1
873
- req.output_ids.append(next_token_ids[i])
1002
+ req.output_ids.append(next_token_id)
874
1003
  req.check_finished()
875
1004
 
876
1005
  if req.finished():
@@ -878,16 +1007,22 @@ class Scheduler:
878
1007
  elif not batch.decoding_reqs or req not in batch.decoding_reqs:
879
1008
  self.tree_cache.cache_unfinished_req(req)
880
1009
 
881
- if req.grammar is not None:
882
- req.grammar.accept_token(next_token_ids[i])
883
-
884
1010
  if req.return_logprob:
885
1011
  logprob_pt += self.add_logprob_return_values(
886
1012
  i, req, logprob_pt, next_token_ids, logits_output
887
1013
  )
1014
+
1015
+ if req.grammar is not None:
1016
+ req.grammar.accept_token(next_token_id)
888
1017
  else:
1018
+ # Inflight reqs' prefill is not finished
889
1019
  req.is_being_chunked -= 1
890
1020
 
1021
+ if batch.next_batch_sampling_info:
1022
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
1023
+ torch.cuda.current_stream().synchronize()
1024
+ batch.next_batch_sampling_info.sampling_info_done.set()
1025
+
891
1026
  else: # embedding or reward model
892
1027
  embeddings, bid = result
893
1028
  embeddings = embeddings.tolist()
@@ -898,18 +1033,18 @@ class Scheduler:
898
1033
  continue
899
1034
 
900
1035
  req.embedding = embeddings[i]
901
- if req.is_being_chunked > 0:
902
- req.is_being_chunked -= 1
903
- else:
904
- # Inflight reqs' prefill is not finished
905
- # dummy output token for embedding models
1036
+ if req.is_being_chunked <= 0:
1037
+ # Dummy output token for embedding models
906
1038
  req.output_ids.append(0)
907
1039
  req.check_finished()
908
1040
 
909
- if req.finished():
910
- self.tree_cache.cache_finished_req(req)
1041
+ if req.finished():
1042
+ self.tree_cache.cache_finished_req(req)
1043
+ else:
1044
+ self.tree_cache.cache_unfinished_req(req)
911
1045
  else:
912
- self.tree_cache.cache_unfinished_req(req)
1046
+ # Inflight reqs' prefill is not finished
1047
+ req.is_being_chunked -= 1
913
1048
 
914
1049
  self.stream_output(batch.reqs)
915
1050
 
@@ -918,7 +1053,7 @@ class Scheduler:
918
1053
  self.num_generated_tokens += len(batch.reqs)
919
1054
 
920
1055
  if self.enable_overlap:
921
- logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
1056
+ logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
922
1057
  next_token_logprobs = logits_output.next_token_logprobs
923
1058
  else:
924
1059
  # Move next_token_ids and logprobs to cpu
@@ -936,7 +1071,8 @@ class Scheduler:
936
1071
  if req.is_retracted:
937
1072
  continue
938
1073
 
939
- if self.server_args.enable_overlap_schedule and (req.finished()):
1074
+ if self.enable_overlap and req.finished():
1075
+ # Free the one delayed token
940
1076
  self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
941
1077
  continue
942
1078
 
@@ -944,9 +1080,6 @@ class Scheduler:
944
1080
  req.output_ids.append(next_token_id)
945
1081
  req.check_finished()
946
1082
 
947
- if req.grammar is not None:
948
- req.grammar.accept_token(next_token_id)
949
-
950
1083
  if req.finished():
951
1084
  self.tree_cache.cache_finished_req(req)
952
1085
 
@@ -957,6 +1090,14 @@ class Scheduler:
957
1090
  if req.top_logprobs_num > 0:
958
1091
  req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
959
1092
 
1093
+ if req.grammar is not None:
1094
+ req.grammar.accept_token(next_token_id)
1095
+
1096
+ if batch.next_batch_sampling_info:
1097
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
1098
+ torch.cuda.current_stream().synchronize()
1099
+ batch.next_batch_sampling_info.sampling_info_done.set()
1100
+
960
1101
  self.stream_output(batch.reqs)
961
1102
 
962
1103
  self.token_to_kv_pool.free_group_end()
@@ -1234,6 +1375,25 @@ class Scheduler:
1234
1375
  )
1235
1376
  logger.info("Profiler is done")
1236
1377
 
1378
+ def open_session(self, recv_req: OpenSessionReqInput) -> str:
1379
+ # handle error
1380
+ session_id = recv_req.session_id
1381
+ if session_id in self.sessions:
1382
+ logger.warning(f"session id {session_id} already exist, cannot open.")
1383
+ else:
1384
+ self.sessions[session_id] = Session(
1385
+ recv_req.capacity_of_str_len, session_id
1386
+ )
1387
+ return session_id
1388
+
1389
+ def close_session(self, recv_req: CloseSessionReqInput):
1390
+ # handle error
1391
+ session_id = recv_req.session_id
1392
+ if session_id not in self.sessions:
1393
+ logger.warning(f"session id {session_id} does not exist, cannot delete.")
1394
+ else:
1395
+ del self.sessions[session_id]
1396
+
1237
1397
 
1238
1398
  def run_scheduler_process(
1239
1399
  server_args: ServerArgs,
@@ -1243,6 +1403,13 @@ def run_scheduler_process(
1243
1403
  dp_rank: Optional[int],
1244
1404
  pipe_writer,
1245
1405
  ):
1406
+ # set cpu affinity to this gpu process
1407
+ gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
1408
+
1409
+ # [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
1410
+ if dp_rank is None and "DP_RANK" in os.environ:
1411
+ dp_rank = int(os.environ["DP_RANK"])
1412
+
1246
1413
  if dp_rank is None:
1247
1414
  configure_logger(server_args, prefix=f" TP{tp_rank}")
1248
1415
  else:
@@ -1252,8 +1419,10 @@ def run_scheduler_process(
1252
1419
 
1253
1420
  try:
1254
1421
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1255
- pipe_writer.send("ready")
1256
- if server_args.enable_overlap_schedule:
1422
+ pipe_writer.send(
1423
+ {"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
1424
+ )
1425
+ if scheduler.enable_overlap:
1257
1426
  scheduler.event_loop_overlap()
1258
1427
  else:
1259
1428
  scheduler.event_loop_normal()