sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. sglang/bench_one_batch.py +0 -2
  2. sglang/bench_serving.py +224 -127
  3. sglang/compile_deep_gemm.py +3 -0
  4. sglang/launch_server.py +0 -14
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/falcon_h1.py +12 -58
  7. sglang/srt/configs/mamba_utils.py +117 -0
  8. sglang/srt/configs/model_config.py +68 -31
  9. sglang/srt/configs/nemotron_h.py +286 -0
  10. sglang/srt/configs/qwen3_next.py +11 -43
  11. sglang/srt/disaggregation/decode.py +7 -18
  12. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  13. sglang/srt/disaggregation/nixl/conn.py +55 -23
  14. sglang/srt/disaggregation/prefill.py +17 -32
  15. sglang/srt/entrypoints/engine.py +2 -2
  16. sglang/srt/entrypoints/grpc_request_manager.py +10 -23
  17. sglang/srt/entrypoints/grpc_server.py +220 -80
  18. sglang/srt/entrypoints/http_server.py +49 -1
  19. sglang/srt/entrypoints/openai/protocol.py +159 -31
  20. sglang/srt/entrypoints/openai/serving_chat.py +13 -71
  21. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  22. sglang/srt/environ.py +4 -0
  23. sglang/srt/function_call/function_call_parser.py +8 -6
  24. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  25. sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
  26. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
  27. sglang/srt/layers/attention/attention_registry.py +31 -22
  28. sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
  29. sglang/srt/layers/attention/flashattention_backend.py +0 -1
  30. sglang/srt/layers/attention/flashinfer_backend.py +223 -6
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
  32. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
  33. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  34. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
  35. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  36. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  37. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  38. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  39. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  40. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  41. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  42. sglang/srt/layers/attention/triton_backend.py +1 -1
  43. sglang/srt/layers/logits_processor.py +136 -6
  44. sglang/srt/layers/modelopt_utils.py +11 -0
  45. sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
  46. sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
  47. sglang/srt/layers/moe/ep_moe/layer.py +8 -286
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
  49. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  50. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  51. sglang/srt/layers/moe/utils.py +7 -1
  52. sglang/srt/layers/quantization/__init__.py +1 -1
  53. sglang/srt/layers/quantization/fp8.py +84 -18
  54. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  55. sglang/srt/layers/quantization/quark/quark.py +3 -1
  56. sglang/srt/layers/quantization/w4afp8.py +2 -16
  57. sglang/srt/lora/lora_manager.py +0 -8
  58. sglang/srt/managers/overlap_utils.py +18 -16
  59. sglang/srt/managers/schedule_batch.py +119 -90
  60. sglang/srt/managers/schedule_policy.py +1 -1
  61. sglang/srt/managers/scheduler.py +213 -126
  62. sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
  64. sglang/srt/managers/tokenizer_manager.py +270 -53
  65. sglang/srt/managers/tp_worker.py +39 -28
  66. sglang/srt/mem_cache/allocator.py +7 -2
  67. sglang/srt/mem_cache/chunk_cache.py +1 -1
  68. sglang/srt/mem_cache/memory_pool.py +162 -68
  69. sglang/srt/mem_cache/radix_cache.py +8 -3
  70. sglang/srt/mem_cache/swa_radix_cache.py +70 -14
  71. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  72. sglang/srt/model_executor/forward_batch_info.py +4 -18
  73. sglang/srt/model_executor/model_runner.py +55 -51
  74. sglang/srt/model_loader/__init__.py +1 -1
  75. sglang/srt/model_loader/loader.py +187 -6
  76. sglang/srt/model_loader/weight_utils.py +3 -0
  77. sglang/srt/models/falcon_h1.py +11 -9
  78. sglang/srt/models/gemma3_mm.py +16 -0
  79. sglang/srt/models/grok.py +5 -13
  80. sglang/srt/models/mixtral.py +1 -3
  81. sglang/srt/models/mllama4.py +11 -1
  82. sglang/srt/models/nemotron_h.py +514 -0
  83. sglang/srt/models/utils.py +5 -1
  84. sglang/srt/sampling/sampling_batch_info.py +11 -9
  85. sglang/srt/server_args.py +100 -33
  86. sglang/srt/speculative/eagle_worker.py +11 -13
  87. sglang/srt/speculative/ngram_worker.py +12 -11
  88. sglang/srt/speculative/spec_utils.py +0 -1
  89. sglang/srt/two_batch_overlap.py +1 -0
  90. sglang/srt/utils/common.py +18 -0
  91. sglang/srt/utils/hf_transformers_utils.py +2 -0
  92. sglang/test/longbench_v2/__init__.py +1 -0
  93. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  94. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  95. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  96. sglang/test/run_eval.py +40 -0
  97. sglang/test/simple_eval_longbench_v2.py +332 -0
  98. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  99. sglang/test/test_deterministic.py +18 -2
  100. sglang/test/test_deterministic_utils.py +81 -0
  101. sglang/test/test_disaggregation_utils.py +63 -0
  102. sglang/test/test_utils.py +32 -11
  103. sglang/version.py +1 -1
  104. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
  105. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
  106. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  107. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  108. sglang/test/test_block_fp8_ep.py +0 -358
  109. /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
  110. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  111. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  112. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -174,7 +174,7 @@ class SchedulePolicy:
174
174
  self.waiting_queue_radix_tree.reset()
175
175
 
176
176
  for r in waiting_queue:
177
- prefix_ids = r.adjust_max_prefix_ids()
177
+ prefix_ids = r.origin_input_ids + r.output_ids
178
178
  extra_key = r.extra_key
179
179
 
180
180
  # NOTE: the prefix_indices must always be aligned with last_node
@@ -25,12 +25,14 @@ from concurrent import futures
25
25
  from dataclasses import dataclass
26
26
  from http import HTTPStatus
27
27
  from types import SimpleNamespace
28
- from typing import Dict, List, Optional, Tuple, Union
28
+ from typing import Deque, Dict, List, Optional, Tuple, Union
29
29
 
30
30
  import psutil
31
31
  import setproctitle
32
32
  import torch
33
33
  import zmq
34
+ from torch.cuda import Stream as CudaStream
35
+ from torch.cuda import StreamContext as CudaStreamContext
34
36
  from torch.distributed import barrier
35
37
 
36
38
  from sglang.global_config import global_config
@@ -112,8 +114,10 @@ from sglang.srt.managers.io_struct import (
112
114
  UpdateWeightsFromTensorReqInput,
113
115
  )
114
116
  from sglang.srt.managers.mm_utils import init_embedding_cache
117
+ from sglang.srt.managers.overlap_utils import FutureIndices, FutureMap
115
118
  from sglang.srt.managers.schedule_batch import (
116
119
  FINISH_ABORT,
120
+ ModelWorkerBatch,
117
121
  MultimodalInputs,
118
122
  Req,
119
123
  RequestStage,
@@ -139,15 +143,13 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
139
143
  SchedulerUpdateWeightsMixin,
140
144
  )
141
145
  from sglang.srt.managers.session_controller import Session
142
- from sglang.srt.managers.tp_worker import TpModelWorker
143
- from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
144
146
  from sglang.srt.managers.utils import validate_input_length
145
147
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
146
148
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
147
149
  from sglang.srt.mem_cache.radix_cache import RadixCache
148
150
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
149
151
  from sglang.srt.model_executor.forward_batch_info import (
150
- ForwardBatchOutput,
152
+ ForwardBatch,
151
153
  ForwardMode,
152
154
  PPProxyTensors,
153
155
  )
@@ -201,40 +203,48 @@ GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
201
203
 
202
204
  @dataclass
203
205
  class GenerationBatchResult:
204
- logits_output: Optional[LogitsProcessorOutput]
205
- pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
206
- next_token_ids: Optional[List[int]]
207
- can_run_cuda_graph: bool
206
+ logits_output: Optional[LogitsProcessorOutput] = None
207
+ pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None
208
+ next_token_ids: Optional[torch.Tensor] = None
209
+ num_accepted_tokens: Optional[int] = None
210
+ can_run_cuda_graph: bool = False
208
211
 
209
212
  # For output processing
210
- extend_input_len_per_req: List[int]
211
- extend_logprob_start_len_per_req: List[int]
212
-
213
- @classmethod
214
- def from_forward_batch_output(
215
- cls,
216
- forward_batch_output: ForwardBatchOutput,
217
- extend_input_len_per_req: List[int],
218
- extend_logprob_start_len_per_req: List[int],
219
- ):
220
- # TODO(lsyin): remove this workaround logic and try to unify output classes
221
-
222
- return cls(
223
- logits_output=forward_batch_output.logits_output,
224
- pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
225
- next_token_ids=forward_batch_output.next_token_ids,
226
- extend_input_len_per_req=extend_input_len_per_req,
227
- extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
228
- can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
229
- )
213
+ extend_input_len_per_req: Optional[List[int]] = None
214
+ extend_logprob_start_len_per_req: Optional[List[int]] = None
215
+
216
+ # For overlap scheduling
217
+ copy_done: Optional[torch.cuda.Event] = None
218
+ delay_sample_launch: bool = False
219
+ forward_batch: Optional[ForwardBatch] = None
220
+ future_indices: Optional[FutureIndices] = None
221
+
222
+ def copy_to_cpu(self, return_logprob: bool = False):
223
+ """Copy tensors to CPU in overlap scheduling.
224
+ Only the tensors which are needed for processing results are copied,
225
+ e.g., next_token_ids, logits outputs
226
+ """
227
+ if return_logprob:
228
+ if self.logits_output.next_token_logits is not None:
229
+ self.logits_output.next_token_logits = (
230
+ self.logits_output.next_token_logits.to("cpu", non_blocking=True)
231
+ )
232
+ if self.logits_output.input_token_logprobs is not None:
233
+ self.logits_output.input_token_logprobs = (
234
+ self.logits_output.input_token_logprobs.to("cpu", non_blocking=True)
235
+ )
236
+ if self.logits_output.hidden_states is not None:
237
+ self.logits_output.hidden_states = self.logits_output.hidden_states.to(
238
+ "cpu", non_blocking=True
239
+ )
240
+ self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
241
+ self.copy_done.record()
230
242
 
231
243
  @classmethod
232
244
  def from_pp_proxy(
233
245
  cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
234
246
  ):
235
- # TODO(lsyin): also simplify this logic
236
- # Current PP implementation in scheduler is not compatible with ForwardBatchOutput
237
- # Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
247
+ # TODO(lsyin): refactor PP and avoid using dict
238
248
  proxy_dict = next_pp_outputs.tensors
239
249
  return cls(
240
250
  logits_output=logits_output,
@@ -263,6 +273,48 @@ class Scheduler(
263
273
  ):
264
274
  """A scheduler that manages a tensor parallel GPU worker."""
265
275
 
276
+ def launch_draft_worker(
277
+ self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
278
+ ):
279
+ if self.spec_algorithm.is_eagle():
280
+ from sglang.srt.speculative.eagle_worker import EAGLEWorker
281
+
282
+ self.draft_worker = EAGLEWorker(
283
+ gpu_id=gpu_id,
284
+ tp_rank=tp_rank,
285
+ moe_ep_rank=moe_ep_rank,
286
+ server_args=server_args,
287
+ nccl_port=port_args.nccl_port,
288
+ target_worker=self.tp_worker,
289
+ dp_rank=dp_rank,
290
+ )
291
+ elif self.spec_algorithm.is_standalone():
292
+ from sglang.srt.speculative.standalone_worker import StandaloneWorker
293
+
294
+ self.draft_worker = StandaloneWorker(
295
+ gpu_id=gpu_id,
296
+ tp_rank=tp_rank,
297
+ moe_ep_rank=moe_ep_rank,
298
+ server_args=server_args,
299
+ nccl_port=port_args.nccl_port,
300
+ target_worker=self.tp_worker,
301
+ dp_rank=dp_rank,
302
+ )
303
+ elif self.spec_algorithm.is_ngram():
304
+ from sglang.srt.speculative.ngram_worker import NGRAMWorker
305
+
306
+ self.draft_worker = NGRAMWorker(
307
+ gpu_id=gpu_id,
308
+ tp_rank=tp_rank,
309
+ moe_ep_rank=moe_ep_rank,
310
+ server_args=server_args,
311
+ nccl_port=port_args.nccl_port,
312
+ target_worker=self.tp_worker,
313
+ dp_rank=dp_rank,
314
+ )
315
+ else:
316
+ self.draft_worker = None
317
+
266
318
  def __init__(
267
319
  self,
268
320
  server_args: ServerArgs,
@@ -388,12 +440,10 @@ class Scheduler(
388
440
  logger.info("Overlap scheduler is disabled for embedding models.")
389
441
 
390
442
  # Launch a tensor parallel worker
391
- if self.enable_overlap:
392
- TpWorkerClass = TpModelWorkerClient
393
- else:
394
- TpWorkerClass = TpModelWorker
395
443
 
396
- self.tp_worker = TpWorkerClass(
444
+ from sglang.srt.managers.tp_worker import TpModelWorker
445
+
446
+ self.tp_worker = TpModelWorker(
397
447
  server_args=server_args,
398
448
  gpu_id=gpu_id,
399
449
  tp_rank=tp_rank,
@@ -404,44 +454,9 @@ class Scheduler(
404
454
  )
405
455
 
406
456
  # Launch a draft worker for speculative decoding
407
- if self.spec_algorithm.is_eagle():
408
- from sglang.srt.speculative.eagle_worker import EAGLEWorker
409
-
410
- self.draft_worker = EAGLEWorker(
411
- gpu_id=gpu_id,
412
- tp_rank=tp_rank,
413
- moe_ep_rank=moe_ep_rank,
414
- server_args=server_args,
415
- nccl_port=port_args.nccl_port,
416
- target_worker=self.tp_worker,
417
- dp_rank=dp_rank,
418
- )
419
- elif self.spec_algorithm.is_standalone():
420
- from sglang.srt.speculative.standalone_worker import StandaloneWorker
421
-
422
- self.draft_worker = StandaloneWorker(
423
- gpu_id=gpu_id,
424
- tp_rank=tp_rank,
425
- moe_ep_rank=moe_ep_rank,
426
- server_args=server_args,
427
- nccl_port=port_args.nccl_port,
428
- target_worker=self.tp_worker,
429
- dp_rank=dp_rank,
430
- )
431
- elif self.spec_algorithm.is_ngram():
432
- from sglang.srt.speculative.ngram_worker import NGRAMWorker
433
-
434
- self.draft_worker = NGRAMWorker(
435
- gpu_id=gpu_id,
436
- tp_rank=tp_rank,
437
- moe_ep_rank=moe_ep_rank,
438
- server_args=server_args,
439
- nccl_port=port_args.nccl_port,
440
- target_worker=self.tp_worker,
441
- dp_rank=dp_rank,
442
- )
443
- else:
444
- self.draft_worker = None
457
+ self.launch_draft_worker(
458
+ gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
459
+ )
445
460
 
446
461
  # Dispatch the model worker
447
462
  if self.spec_algorithm.is_none():
@@ -464,8 +479,8 @@ class Scheduler(
464
479
  _,
465
480
  _,
466
481
  ) = self.tp_worker.get_worker_info()
467
- if global_server_args_dict["max_micro_batch_size"] is None:
468
- global_server_args_dict["max_micro_batch_size"] = max(
482
+ if global_server_args_dict["pp_max_micro_batch_size"] is None:
483
+ global_server_args_dict["pp_max_micro_batch_size"] = max(
469
484
  self.max_running_requests // server_args.pp_size, 1
470
485
  )
471
486
 
@@ -525,9 +540,11 @@ class Scheduler(
525
540
  self.kv_transfer_speed_gb_s: float = 0.0
526
541
  self.kv_transfer_latency_ms: float = 0.0
527
542
  self.sessions: Dict[str, Session] = {}
528
- self.current_stream = torch.get_device_module(self.device).current_stream()
543
+ self.default_stream: CudaStream = torch.get_device_module(
544
+ self.device
545
+ ).current_stream()
529
546
  if self.device == "cpu":
530
- self.current_stream.synchronize = lambda: None # No-op for CPU
547
+ self.default_stream.synchronize = lambda: None # No-op for CPU
531
548
  self.forward_sleep_time = None
532
549
 
533
550
  # Init chunked prefill
@@ -618,6 +635,9 @@ class Scheduler(
618
635
  # Init prefill kv split size when deterministic inference is enabled with various attention backends
619
636
  self.init_deterministic_inference_config()
620
637
 
638
+ # Init overlap
639
+ self.init_overlap()
640
+
621
641
  # Init request dispatcher
622
642
  self._request_dispatcher = TypeBasedDispatcher(
623
643
  [
@@ -777,6 +797,7 @@ class Scheduler(
777
797
  sliding_window_size=self.sliding_window_size,
778
798
  page_size=self.page_size,
779
799
  disable=server_args.disable_radix_cache,
800
+ is_eagle=self.spec_algorithm.is_eagle(),
780
801
  )
781
802
  elif server_args.enable_lmcache:
782
803
  from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
@@ -931,6 +952,32 @@ class Scheduler(
931
952
  # The prefill requests that are in the middle of kv sending
932
953
  self.disagg_prefill_inflight_queue: List[Req] = []
933
954
 
955
+ def init_overlap(self):
956
+ if not self.enable_overlap:
957
+ return
958
+
959
+ self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream()
960
+ self.forward_stream_ctx: CudaStreamContext = torch.get_device_module(
961
+ self.device
962
+ ).stream(self.forward_stream)
963
+ self.copy_stream: CudaStream = torch.get_device_module(self.device).Stream()
964
+ self.copy_stream_ctx: CudaStreamContext = torch.get_device_module(
965
+ self.device
966
+ ).stream(self.copy_stream)
967
+
968
+ self.future_map = FutureMap(self.max_running_requests, self.device)
969
+ self.batch_record_buf = [None] * 2
970
+ self.batch_record_ct = 0
971
+
972
+ def record_batch_in_overlap(self, model_worker_batch: ModelWorkerBatch):
973
+ # FIXME(lsyin): hacky way to keep a reference to avoid GPU tensors being freed by torch GC
974
+ # NOTE: More Reliable: record all tensors into the forward stream
975
+ # NOTE: - for all future tensors, we shall always read from future map
976
+ # - for all non-future tensors (produced only by schedule stream),
977
+ # we shall keep its reference not being release during all the forwarding pass
978
+ self.batch_record_ct = (self.batch_record_ct + 1) % 2
979
+ self.batch_record_buf[self.batch_record_ct] = model_worker_batch
980
+
934
981
  def init_moe_config(self):
935
982
  if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
936
983
  initialize_moe_config(self.server_args)
@@ -957,9 +1004,11 @@ class Scheduler(
957
1004
  @DynamicGradMode()
958
1005
  def event_loop_overlap(self):
959
1006
  """A scheduler loop that overlaps the CPU processing and GPU computation."""
960
- self.result_queue = deque()
1007
+ self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque()
961
1008
 
962
1009
  while True:
1010
+ self.launch_last_batch_sample_if_needed()
1011
+
963
1012
  recv_reqs = self.recv_requests()
964
1013
  self.process_input_requests(recv_reqs)
965
1014
 
@@ -967,30 +1016,13 @@ class Scheduler(
967
1016
  self.cur_batch = batch
968
1017
 
969
1018
  if batch:
970
- batch.launch_done = threading.Event()
971
1019
  result = self.run_batch(batch)
972
1020
  self.result_queue.append((batch.copy(), result))
973
1021
 
974
- if self.last_batch is None:
975
- # Create a dummy first batch to start the pipeline for overlap schedule.
976
- # It is now used for triggering the sampling_info_done event.
977
- tmp_batch = ScheduleBatch(
978
- reqs=None,
979
- forward_mode=ForwardMode.DUMMY_FIRST,
980
- next_batch_sampling_info=self.tp_worker.cur_sampling_info,
981
- )
982
- self.process_batch_result(tmp_batch, None, batch.launch_done)
983
-
984
1022
  if self.last_batch:
985
1023
  # Process the results of the last batch
986
1024
  tmp_batch, tmp_result = self.result_queue.popleft()
987
- tmp_batch.next_batch_sampling_info = (
988
- self.tp_worker.cur_sampling_info if batch else None
989
- )
990
- # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
991
- self.process_batch_result(
992
- tmp_batch, tmp_result, batch.launch_done if batch else None
993
- )
1025
+ self.process_batch_result(tmp_batch, tmp_result)
994
1026
  elif batch is None:
995
1027
  # When the server is idle, do self-check and re-init some states
996
1028
  self.self_check_during_idle()
@@ -1745,7 +1777,7 @@ class Scheduler(
1745
1777
  chunked_req_to_exclude.add(self.chunked_req)
1746
1778
  self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
1747
1779
  # chunked request keeps its rid but will get a new req_pool_idx
1748
- if self.tp_worker.worker.model_runner.is_hybrid_gdn:
1780
+ if self.tp_worker.worker.model_runner.mambaish_config is not None:
1749
1781
  self.req_to_token_pool.free(
1750
1782
  self.chunked_req.req_pool_idx, free_mamba_cache=False
1751
1783
  )
@@ -1802,7 +1834,7 @@ class Scheduler(
1802
1834
  return ret
1803
1835
 
1804
1836
  def get_num_allocatable_reqs(self, running_bs):
1805
- res = global_server_args_dict["max_micro_batch_size"] - running_bs
1837
+ res = global_server_args_dict["pp_max_micro_batch_size"] - running_bs
1806
1838
  if self.pp_size > 1:
1807
1839
  res = min(res, self.req_to_token_pool.available_size())
1808
1840
  return res
@@ -2055,18 +2087,59 @@ class Scheduler(
2055
2087
  # FIXME(lsyin): remove this if and finally unify the abstraction
2056
2088
  batch_or_worker_batch = batch.get_model_worker_batch()
2057
2089
 
2058
- forward_batch_output = self.model_worker.forward_batch_generation(
2059
- batch_or_worker_batch
2060
- )
2090
+ if self.enable_overlap:
2091
+ # FIXME: remove this assert
2092
+ assert isinstance(batch_or_worker_batch, ModelWorkerBatch)
2093
+ model_worker_batch = batch_or_worker_batch
2094
+ self.record_batch_in_overlap(model_worker_batch)
2095
+
2096
+ # Sampling info will be modified during forward
2097
+ model_worker_batch.sampling_info = (
2098
+ model_worker_batch.sampling_info.copy_for_forward()
2099
+ )
2100
+
2101
+ bs = len(model_worker_batch.seq_lens)
2102
+ future_indices = self.future_map.alloc_future_indices(bs)
2103
+
2104
+ with self.forward_stream_ctx:
2105
+ self.forward_stream.wait_stream(self.default_stream)
2106
+ self.future_map.resolve_future(model_worker_batch)
2107
+ if batch.sampling_info.grammars is not None:
2108
+ model_worker_batch.delay_sample_launch = True
2109
+ batch_result = self.model_worker.forward_batch_generation(
2110
+ batch_or_worker_batch
2111
+ )
2112
+ # FIXME(lsyin): maybe move this to forward_batch_generation
2113
+ batch_result.copy_done = torch.get_device_module(
2114
+ self.device
2115
+ ).Event()
2116
+ if not model_worker_batch.delay_sample_launch:
2117
+ self.future_map.store_to_map(
2118
+ future_indices, batch_result.next_token_ids
2119
+ )
2120
+ batch_result.copy_to_cpu()
2121
+ else:
2122
+ batch_result.future_indices = future_indices
2123
+
2124
+ # FIXME(lsyin): move this assignment elsewhere
2125
+ maybe_future_next_token_ids = -future_indices.indices
2126
+ else:
2127
+ batch_result = self.model_worker.forward_batch_generation(
2128
+ batch_or_worker_batch
2129
+ )
2130
+ maybe_future_next_token_ids = batch_result.next_token_ids
2061
2131
 
2062
2132
  if not self.spec_algorithm.is_none():
2063
2133
  # TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
2064
- self.udpate_spec_metrics(
2065
- batch.batch_size(), forward_batch_output.num_accepted_tokens
2134
+ self.update_spec_metrics(
2135
+ batch.batch_size(), batch_result.num_accepted_tokens
2066
2136
  )
2067
2137
 
2068
- # update batch's output ids
2069
- batch.output_ids = forward_batch_output.next_token_ids
2138
+ # NOTE: maybe_future_next_token_ids is used in ScheduleBatch,
2139
+ # which can probably be replaced by future_indices later [TODO(lsyin)].
2140
+ # we shall still keep the original outputs, e.g. next_token_ids
2141
+ # in the GenerationBatchOutput for processing after copy_done.
2142
+ batch.output_ids = maybe_future_next_token_ids
2070
2143
 
2071
2144
  # These 2 values are needed for processing the output, but the values can be
2072
2145
  # modified by overlap schedule. So we have to copy them here so that
@@ -2083,39 +2156,60 @@ class Scheduler(
2083
2156
  else:
2084
2157
  extend_logprob_start_len_per_req = None
2085
2158
 
2086
- return GenerationBatchResult.from_forward_batch_output(
2087
- forward_batch_output=forward_batch_output,
2088
- extend_input_len_per_req=extend_input_len_per_req,
2089
- extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
2159
+ batch_result.extend_input_len_per_req = extend_input_len_per_req
2160
+ batch_result.extend_logprob_start_len_per_req = (
2161
+ extend_logprob_start_len_per_req
2090
2162
  )
2163
+ return batch_result
2091
2164
  else: # embedding or reward model
2092
2165
  model_worker_batch = batch.get_model_worker_batch()
2093
2166
  embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
2094
2167
  ret = EmbeddingBatchResult(embeddings=embeddings)
2095
2168
  return ret
2096
2169
 
2170
+ def launch_last_batch_sample_if_needed(
2171
+ self,
2172
+ ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
2173
+ if len(self.result_queue) == 0:
2174
+ return
2175
+
2176
+ tmp_batch, tmp_result = self.result_queue.popleft()
2177
+
2178
+ tmp_result: GenerationBatchResult
2179
+ if not tmp_result.delay_sample_launch:
2180
+ self.result_queue.appendleft((tmp_batch, tmp_result))
2181
+ return
2182
+
2183
+ with self.forward_stream_ctx:
2184
+ self.forward_stream.wait_stream(self.default_stream)
2185
+ tmp_result.next_token_ids = self.model_worker.model_runner.sample(
2186
+ tmp_result.logits_output,
2187
+ tmp_result.forward_batch,
2188
+ )
2189
+ future_indices = tmp_result.future_indices
2190
+ self.future_map.store_to_map(future_indices, tmp_result.next_token_ids)
2191
+ tmp_result.copy_to_cpu()
2192
+ self.result_queue.appendleft((tmp_batch, tmp_result))
2193
+
2097
2194
  def process_batch_result(
2098
2195
  self,
2099
2196
  batch: ScheduleBatch,
2100
2197
  result: Union[GenerationBatchResult, EmbeddingBatchResult],
2101
- launch_done: Optional[threading.Event] = None,
2102
2198
  ):
2103
2199
  if batch.forward_mode.is_decode():
2104
- self.process_batch_result_decode(batch, result, launch_done)
2200
+ self.process_batch_result_decode(batch, result)
2105
2201
  if self.enable_trace:
2106
2202
  trace_slice_batch("decode loop", batch.reqs)
2107
2203
 
2108
2204
  elif batch.forward_mode.is_extend():
2109
- self.process_batch_result_prefill(batch, result, launch_done)
2205
+ self.process_batch_result_prefill(batch, result)
2110
2206
  if self.enable_trace:
2111
2207
  trace_slice_batch("prefill", batch.reqs)
2112
2208
 
2113
2209
  elif batch.forward_mode.is_idle():
2114
2210
  if self.enable_overlap:
2115
- self.tp_worker.resolve_last_batch_result(launch_done)
2116
- self.set_next_batch_sampling_info_done(batch)
2117
- elif batch.forward_mode.is_dummy_first():
2118
- self.set_next_batch_sampling_info_done(batch)
2211
+ if result.copy_done is not None:
2212
+ result.copy_done.synchronize()
2119
2213
 
2120
2214
  self.maybe_send_health_check_signal()
2121
2215
 
@@ -2325,13 +2419,6 @@ class Scheduler(
2325
2419
  self._add_request_to_queue(req)
2326
2420
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
2327
2421
 
2328
- def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
2329
- if batch.next_batch_sampling_info:
2330
- if batch.next_batch_sampling_info.grammars is not None:
2331
- batch.next_batch_sampling_info.update_regex_vocab_mask()
2332
- self.current_stream.synchronize()
2333
- batch.next_batch_sampling_info.sampling_info_done.set()
2334
-
2335
2422
  def watchdog_thread(self):
2336
2423
  """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
2337
2424
  self.watchdog_last_forward_ct = 0
@@ -2510,7 +2597,7 @@ class Scheduler(
2510
2597
  server_args_dict = recv_req.server_args
2511
2598
  args_allow_update = set(
2512
2599
  [
2513
- "max_micro_batch_size",
2600
+ "pp_max_micro_batch_size",
2514
2601
  "speculative_accept_threshold_single",
2515
2602
  "speculative_accept_threshold_acc",
2516
2603
  ]
@@ -2521,7 +2608,7 @@ class Scheduler(
2521
2608
  logging.warning(f"Updating {k} is not supported.")
2522
2609
  if_success = False
2523
2610
  break
2524
- elif k == "max_micro_batch_size" and (
2611
+ elif k == "pp_max_micro_batch_size" and (
2525
2612
  v > self.max_running_requests // self.pp_size or v < 1
2526
2613
  ):
2527
2614
  logging.warning(
@@ -69,7 +69,7 @@ class SchedulerMetricsMixin:
69
69
  kv_events_config, self.attn_dp_rank
70
70
  )
71
71
 
72
- def udpate_spec_metrics(self, bs: int, num_accepted_tokens: int):
72
+ def update_spec_metrics(self, bs: int, num_accepted_tokens: int):
73
73
  self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
74
74
  self.spec_num_total_forward_ct += bs
75
75
  self.num_generated_tokens += num_accepted_tokens