sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. sglang/bench_one_batch.py +0 -2
  2. sglang/bench_serving.py +224 -127
  3. sglang/compile_deep_gemm.py +3 -0
  4. sglang/launch_server.py +0 -14
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/falcon_h1.py +12 -58
  7. sglang/srt/configs/mamba_utils.py +117 -0
  8. sglang/srt/configs/model_config.py +68 -31
  9. sglang/srt/configs/nemotron_h.py +286 -0
  10. sglang/srt/configs/qwen3_next.py +11 -43
  11. sglang/srt/disaggregation/decode.py +7 -18
  12. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  13. sglang/srt/disaggregation/nixl/conn.py +55 -23
  14. sglang/srt/disaggregation/prefill.py +17 -32
  15. sglang/srt/entrypoints/engine.py +2 -2
  16. sglang/srt/entrypoints/grpc_request_manager.py +10 -23
  17. sglang/srt/entrypoints/grpc_server.py +220 -80
  18. sglang/srt/entrypoints/http_server.py +49 -1
  19. sglang/srt/entrypoints/openai/protocol.py +159 -31
  20. sglang/srt/entrypoints/openai/serving_chat.py +13 -71
  21. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  22. sglang/srt/environ.py +4 -0
  23. sglang/srt/function_call/function_call_parser.py +8 -6
  24. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  25. sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
  26. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
  27. sglang/srt/layers/attention/attention_registry.py +31 -22
  28. sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
  29. sglang/srt/layers/attention/flashattention_backend.py +0 -1
  30. sglang/srt/layers/attention/flashinfer_backend.py +223 -6
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
  32. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
  33. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  34. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
  35. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  36. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  37. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  38. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  39. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  40. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  41. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  42. sglang/srt/layers/attention/triton_backend.py +1 -1
  43. sglang/srt/layers/logits_processor.py +136 -6
  44. sglang/srt/layers/modelopt_utils.py +11 -0
  45. sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
  46. sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
  47. sglang/srt/layers/moe/ep_moe/layer.py +8 -286
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
  49. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  50. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  51. sglang/srt/layers/moe/utils.py +7 -1
  52. sglang/srt/layers/quantization/__init__.py +1 -1
  53. sglang/srt/layers/quantization/fp8.py +84 -18
  54. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  55. sglang/srt/layers/quantization/quark/quark.py +3 -1
  56. sglang/srt/layers/quantization/w4afp8.py +2 -16
  57. sglang/srt/lora/lora_manager.py +0 -8
  58. sglang/srt/managers/overlap_utils.py +18 -16
  59. sglang/srt/managers/schedule_batch.py +119 -90
  60. sglang/srt/managers/schedule_policy.py +1 -1
  61. sglang/srt/managers/scheduler.py +213 -126
  62. sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
  64. sglang/srt/managers/tokenizer_manager.py +270 -53
  65. sglang/srt/managers/tp_worker.py +39 -28
  66. sglang/srt/mem_cache/allocator.py +7 -2
  67. sglang/srt/mem_cache/chunk_cache.py +1 -1
  68. sglang/srt/mem_cache/memory_pool.py +162 -68
  69. sglang/srt/mem_cache/radix_cache.py +8 -3
  70. sglang/srt/mem_cache/swa_radix_cache.py +70 -14
  71. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  72. sglang/srt/model_executor/forward_batch_info.py +4 -18
  73. sglang/srt/model_executor/model_runner.py +55 -51
  74. sglang/srt/model_loader/__init__.py +1 -1
  75. sglang/srt/model_loader/loader.py +187 -6
  76. sglang/srt/model_loader/weight_utils.py +3 -0
  77. sglang/srt/models/falcon_h1.py +11 -9
  78. sglang/srt/models/gemma3_mm.py +16 -0
  79. sglang/srt/models/grok.py +5 -13
  80. sglang/srt/models/mixtral.py +1 -3
  81. sglang/srt/models/mllama4.py +11 -1
  82. sglang/srt/models/nemotron_h.py +514 -0
  83. sglang/srt/models/utils.py +5 -1
  84. sglang/srt/sampling/sampling_batch_info.py +11 -9
  85. sglang/srt/server_args.py +100 -33
  86. sglang/srt/speculative/eagle_worker.py +11 -13
  87. sglang/srt/speculative/ngram_worker.py +12 -11
  88. sglang/srt/speculative/spec_utils.py +0 -1
  89. sglang/srt/two_batch_overlap.py +1 -0
  90. sglang/srt/utils/common.py +18 -0
  91. sglang/srt/utils/hf_transformers_utils.py +2 -0
  92. sglang/test/longbench_v2/__init__.py +1 -0
  93. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  94. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  95. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  96. sglang/test/run_eval.py +40 -0
  97. sglang/test/simple_eval_longbench_v2.py +332 -0
  98. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  99. sglang/test/test_deterministic.py +18 -2
  100. sglang/test/test_deterministic_utils.py +81 -0
  101. sglang/test/test_disaggregation_utils.py +63 -0
  102. sglang/test/test_utils.py +32 -11
  103. sglang/version.py +1 -1
  104. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
  105. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
  106. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  107. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  108. sglang/test/test_block_fp8_ep.py +0 -358
  109. /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
  110. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  111. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  112. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -15,14 +15,12 @@
15
15
  """Qwen3Hybrid model configuration"""
16
16
 
17
17
  import enum
18
- import os
19
18
 
20
- import numpy as np
21
- import torch
22
19
  from transformers.configuration_utils import PretrainedConfig
23
20
  from transformers.modeling_rope_utils import rope_config_validation
24
21
  from transformers.utils import logging
25
22
 
23
+ from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
26
24
  from sglang.srt.distributed.utils import divide
27
25
  from sglang.srt.layers.dp_attention import get_attention_tp_size
28
26
 
@@ -282,45 +280,15 @@ class Qwen3NextConfig(PretrainedConfig):
282
280
  ]
283
281
 
284
282
  @property
285
- def hybrid_gdn_params(self):
286
- world_size = get_attention_tp_size()
287
- conv_dim = (
288
- self.linear_key_head_dim * self.linear_num_key_heads * 2
289
- + self.linear_value_head_dim * self.linear_num_value_heads
283
+ def mamba2_cache_params(self) -> Mamba2CacheParams:
284
+ shape = Mamba2StateShape.create(
285
+ tp_world_size=get_attention_tp_size(),
286
+ intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads,
287
+ n_groups=self.linear_num_key_heads,
288
+ num_heads=self.linear_num_value_heads,
289
+ head_dim=self.linear_value_head_dim,
290
+ state_size=self.linear_key_head_dim,
291
+ conv_kernel=self.linear_conv_kernel_dim,
290
292
  )
291
- conv_state_shape = (
292
- divide(conv_dim, world_size),
293
- self.linear_conv_kernel_dim - 1,
294
- )
295
-
296
- temporal_state_shape = (
297
- divide(self.linear_num_value_heads, world_size),
298
- self.linear_key_head_dim,
299
- self.linear_value_head_dim,
300
- )
301
- conv_dtype = torch.bfloat16
302
- dtype_map = {
303
- "float32": torch.float32,
304
- "bfloat16": torch.bfloat16,
305
- }
306
- ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
307
- mamba_layers = self.linear_layer_ids
308
- return (
309
- conv_state_shape,
310
- temporal_state_shape,
311
- conv_dtype,
312
- ssm_dtype,
313
- mamba_layers,
314
- )
315
-
316
- @property
317
- def mamba_cache_per_req(self):
318
- conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
319
- self.hybrid_gdn_params
320
- )
321
- mamba_layers_len = len(mamba_layers)
322
293
 
323
- return (
324
- int(np.prod(conv_state_shape)) * conv_dtype.itemsize
325
- + int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
326
- ) * mamba_layers_len
294
+ return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)
@@ -747,11 +747,13 @@ class SchedulerDisaggregationDecodeMixin:
747
747
 
748
748
  @torch.no_grad()
749
749
  def event_loop_overlap_disagg_decode(self: Scheduler):
750
- result_queue = deque()
750
+ self.result_queue = deque()
751
751
  self.last_batch: Optional[ScheduleBatch] = None
752
752
  self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend
753
753
 
754
754
  while True:
755
+ self.launch_last_batch_sample_if_needed()
756
+
755
757
  recv_reqs = self.recv_requests()
756
758
  self.process_input_requests(recv_reqs)
757
759
  # polling and allocating kv cache
@@ -774,23 +776,13 @@ class SchedulerDisaggregationDecodeMixin:
774
776
  None, delay_process=True
775
777
  )
776
778
  if batch_:
777
- result_queue.append((batch_.copy(), result))
779
+ self.result_queue.append((batch_.copy(), result))
778
780
  last_batch_in_queue = True
779
781
  else:
780
782
  if prepare_mlp_sync_flag:
781
783
  self.prepare_mlp_sync_batch(batch)
782
784
  result = self.run_batch(batch)
783
- result_queue.append((batch.copy(), result))
784
-
785
- if (self.last_batch is None) or (not self.last_batch_in_queue):
786
- # Create a dummy first batch to start the pipeline for overlap schedule.
787
- # It is now used for triggering the sampling_info_done event.
788
- tmp_batch = ScheduleBatch(
789
- reqs=None,
790
- forward_mode=ForwardMode.DUMMY_FIRST,
791
- next_batch_sampling_info=self.tp_worker.cur_sampling_info,
792
- )
793
- self.set_next_batch_sampling_info_done(tmp_batch)
785
+ self.result_queue.append((batch.copy(), result))
794
786
  last_batch_in_queue = True
795
787
 
796
788
  elif prepare_mlp_sync_flag:
@@ -798,15 +790,12 @@ class SchedulerDisaggregationDecodeMixin:
798
790
  None, delay_process=True
799
791
  )
800
792
  if batch:
801
- result_queue.append((batch.copy(), result))
793
+ self.result_queue.append((batch.copy(), result))
802
794
  last_batch_in_queue = True
803
795
 
804
796
  # Process the results of the previous batch but skip if the last batch is extend
805
797
  if self.last_batch and self.last_batch_in_queue:
806
- tmp_batch, tmp_result = result_queue.popleft()
807
- tmp_batch.next_batch_sampling_info = (
808
- self.tp_worker.cur_sampling_info if batch else None
809
- )
798
+ tmp_batch, tmp_result = self.result_queue.popleft()
810
799
  self.process_batch_result(tmp_batch, tmp_result)
811
800
 
812
801
  queue_size = (
@@ -4,7 +4,6 @@ import time
4
4
 
5
5
  import torch
6
6
 
7
- from sglang import ServerArgs
8
7
  from sglang.srt.managers.cache_controller import HiCacheController
9
8
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
10
9
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
@@ -17,6 +16,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
17
16
  MHATokenToKVPoolHost,
18
17
  MLATokenToKVPoolHost,
19
18
  )
19
+ from sglang.srt.server_args import ServerArgs
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
@@ -319,14 +319,44 @@ class NixlKVManager(CommonKVManager):
319
319
 
320
320
  logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
321
321
  # Make descs
322
- num_layers = len(self.kv_args.kv_data_ptrs)
322
+ if self.is_mla_backend:
323
+ src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
324
+ self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
325
+ )
326
+ kv_item_len = self.kv_args.kv_item_lens[0]
327
+ layers_params = [
328
+ (
329
+ src_kv_ptrs[layer_id],
330
+ dst_kv_ptrs[layer_id],
331
+ kv_item_len,
332
+ )
333
+ for layer_id in range(layers_current_pp_stage)
334
+ ]
335
+ else:
336
+ src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
337
+ self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
338
+ )
339
+
340
+ kv_item_len = self.kv_args.kv_item_lens[0]
341
+ layers_params = [
342
+ (
343
+ src_k_ptrs[layer_id],
344
+ dst_k_ptrs[layer_id],
345
+ kv_item_len,
346
+ )
347
+ for layer_id in range(layers_current_pp_stage)
348
+ ] + [
349
+ (
350
+ src_v_ptrs[layer_id],
351
+ dst_v_ptrs[layer_id],
352
+ kv_item_len,
353
+ )
354
+ for layer_id in range(layers_current_pp_stage)
355
+ ]
356
+
323
357
  src_addrs = []
324
358
  dst_addrs = []
325
- for layer_id in range(num_layers):
326
- src_ptr = self.kv_args.kv_data_ptrs[layer_id]
327
- dst_ptr = dst_kv_ptrs[layer_id]
328
- item_len = self.kv_args.kv_item_lens[layer_id]
329
-
359
+ for src_ptr, dst_ptr, item_len in layers_params:
330
360
  for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
331
361
  src_addr = src_ptr + int(prefill_index[0]) * item_len
332
362
  dst_addr = dst_ptr + int(decode_index[0]) * item_len
@@ -397,6 +427,9 @@ class NixlKVManager(CommonKVManager):
397
427
  num_heads_to_send = dst_heads_per_rank
398
428
  dst_head_start_offset = 0
399
429
 
430
+ src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
431
+ self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
432
+ )
400
433
  # Create transfer descriptors
401
434
  src_addrs = []
402
435
  dst_addrs = []
@@ -404,12 +437,6 @@ class NixlKVManager(CommonKVManager):
404
437
  bytes_per_token_on_prefill = src_kv_item_len // page_size
405
438
  bytes_per_token_on_decode = dst_kv_item_len // page_size
406
439
 
407
- num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
408
- src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
409
- src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
410
- dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)]
411
- dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)]
412
-
413
440
  # Calculate precise byte offset and length for the sub-slice within the token
414
441
  src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
415
442
  dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
@@ -420,13 +447,13 @@ class NixlKVManager(CommonKVManager):
420
447
  src_k_ptrs[layer_id],
421
448
  dst_k_ptrs[layer_id],
422
449
  )
423
- for layer_id in range(len(src_k_ptrs))
450
+ for layer_id in range(layers_current_pp_stage)
424
451
  ] + [
425
452
  (
426
453
  src_v_ptrs[layer_id],
427
454
  dst_v_ptrs[layer_id],
428
455
  )
429
- for layer_id in range(len(src_v_ptrs))
456
+ for layer_id in range(layers_current_pp_stage)
430
457
  ]
431
458
 
432
459
  src_addrs = []
@@ -496,14 +523,19 @@ class NixlKVManager(CommonKVManager):
496
523
  dst_aux_index: int,
497
524
  notif: str,
498
525
  ):
499
- # Make descs
500
- aux_item_len = self.kv_args.aux_item_lens[0]
501
- prefill_aux_addr = (
502
- self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
503
- )
504
- decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
505
- src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
506
- dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
526
+ src_addrs = []
527
+ dst_addrs = []
528
+
529
+ prefill_aux_ptrs = self.kv_args.aux_data_ptrs
530
+ prefill_aux_item_lens = self.kv_args.aux_item_lens
531
+
532
+ for i, _ in enumerate(dst_aux_ptrs):
533
+ length = prefill_aux_item_lens[i]
534
+ src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
535
+ dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
536
+ src_addrs.append((src_addr, length, 0))
537
+ dst_addrs.append((dst_addr, length, 0))
538
+
507
539
  src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
508
540
  dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
509
541
  # Transfer data
@@ -576,7 +608,7 @@ class NixlKVManager(CommonKVManager):
576
608
 
577
609
  handles.append(kv_xfer_handle)
578
610
  # Only the last chunk we need to send the aux data.
579
- if is_last:
611
+ if is_last and self.pp_group.is_last_rank:
580
612
  assert aux_index is not None
581
613
  aux_xfer_handle = self.send_aux(
582
614
  req.agent_name,
@@ -321,6 +321,8 @@ class SchedulerDisaggregationPrefillMixin:
321
321
  self.result_queue = deque()
322
322
 
323
323
  while True:
324
+ self.launch_last_batch_sample_if_needed()
325
+
324
326
  recv_reqs = self.recv_requests()
325
327
  self.process_input_requests(recv_reqs)
326
328
  self.waiting_queue.extend(
@@ -336,21 +338,8 @@ class SchedulerDisaggregationPrefillMixin:
336
338
  result = self.run_batch(batch)
337
339
  self.result_queue.append((batch.copy(), result))
338
340
 
339
- if self.last_batch is None:
340
- # Create a dummy first batch to start the pipeline for overlap schedule.
341
- # It is now used for triggering the sampling_info_done event.
342
- tmp_batch = ScheduleBatch(
343
- reqs=None,
344
- forward_mode=ForwardMode.DUMMY_FIRST,
345
- next_batch_sampling_info=self.tp_worker.cur_sampling_info,
346
- )
347
- self.set_next_batch_sampling_info_done(tmp_batch)
348
-
349
341
  if self.last_batch:
350
342
  tmp_batch, tmp_result = self.result_queue.popleft()
351
- tmp_batch.next_batch_sampling_info = (
352
- self.tp_worker.cur_sampling_info if batch else None
353
- )
354
343
  self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
355
344
 
356
345
  if len(self.disagg_prefill_inflight_queue) > 0:
@@ -368,7 +357,6 @@ class SchedulerDisaggregationPrefillMixin:
368
357
  self: Scheduler,
369
358
  batch: ScheduleBatch,
370
359
  result: GenerationBatchResult,
371
- launch_done: Optional[threading.Event] = None,
372
360
  ) -> None:
373
361
  """
374
362
  Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
@@ -379,31 +367,30 @@ class SchedulerDisaggregationPrefillMixin:
379
367
  next_token_ids,
380
368
  extend_input_len_per_req,
381
369
  extend_logprob_start_len_per_req,
370
+ copy_done,
382
371
  ) = (
383
372
  result.logits_output,
384
373
  result.next_token_ids,
385
374
  result.extend_input_len_per_req,
386
375
  result.extend_logprob_start_len_per_req,
376
+ result.copy_done,
387
377
  )
388
378
 
379
+ if copy_done is not None:
380
+ copy_done.synchronize()
381
+
389
382
  logprob_pt = 0
390
383
  # Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
391
- if self.enable_overlap:
392
- # wait
393
- logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
394
- launch_done
395
- )
396
- else:
397
- next_token_ids = result.next_token_ids.tolist()
398
- if batch.return_logprob:
399
- if logits_output.next_token_logprobs is not None:
400
- logits_output.next_token_logprobs = (
401
- logits_output.next_token_logprobs.tolist()
402
- )
403
- if logits_output.input_token_logprobs is not None:
404
- logits_output.input_token_logprobs = tuple(
405
- logits_output.input_token_logprobs.tolist()
406
- )
384
+ next_token_ids = result.next_token_ids.tolist()
385
+ if batch.return_logprob:
386
+ if logits_output.next_token_logprobs is not None:
387
+ logits_output.next_token_logprobs = (
388
+ logits_output.next_token_logprobs.tolist()
389
+ )
390
+ if logits_output.input_token_logprobs is not None:
391
+ logits_output.input_token_logprobs = tuple(
392
+ logits_output.input_token_logprobs.tolist()
393
+ )
407
394
 
408
395
  hidden_state_offset = 0
409
396
  for i, (req, next_token_id) in enumerate(
@@ -491,8 +478,6 @@ class SchedulerDisaggregationPrefillMixin:
491
478
  if self.enable_overlap:
492
479
  self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
493
480
 
494
- # We need to remove the sync in the following function for overlap schedule.
495
- self.set_next_batch_sampling_info_done(batch)
496
481
  self.maybe_send_health_check_signal()
497
482
 
498
483
  def process_disagg_prefill_inflight_queue(
@@ -703,7 +703,7 @@ def _set_envs_and_config(server_args: ServerArgs):
703
703
  if server_args.attention_backend == "flashinfer":
704
704
  assert_pkg_version(
705
705
  "flashinfer_python",
706
- "0.4.0rc3",
706
+ "0.4.0",
707
707
  "Please uninstall the old version and "
708
708
  "reinstall the latest version by following the instructions "
709
709
  "at https://docs.flashinfer.ai/installation.html.",
@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
711
711
  if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
712
712
  assert_pkg_version(
713
713
  "sgl-kernel",
714
- "0.3.14",
714
+ "0.3.15",
715
715
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
716
716
  )
717
717
 
@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import (
27
27
  TokenizedEmbeddingReqInput,
28
28
  TokenizedGenerateReqInput,
29
29
  )
30
+ from sglang.srt.managers.scheduler import is_health_check_generate_req
30
31
  from sglang.srt.server_args import PortArgs, ServerArgs
31
32
  from sglang.srt.utils import get_zmq_socket, kill_process_tree
32
33
  from sglang.utils import get_exception_traceback
@@ -263,8 +264,8 @@ class GrpcRequestManager:
263
264
  response = await task
264
265
 
265
266
  # Add index for client-side ordering
266
- if isinstance(response, dict) and "meta_info" in response:
267
- response_rid = response["meta_info"].get("id", "")
267
+ if isinstance(response, dict):
268
+ response_rid = response.get("request_id", "")
268
269
  if response_rid in rid_to_index:
269
270
  response["index"] = rid_to_index[response_rid]
270
271
 
@@ -338,12 +339,9 @@ class GrpcRequestManager:
338
339
  break
339
340
 
340
341
  except asyncio.TimeoutError:
341
- # Timeout waiting for response - abort and cleanup
342
- logger.warning(
343
- f"Timeout waiting for response for request {request_id}"
344
- )
345
- await self.abort_request(request_id)
346
- return
342
+ # Timeout is for periodic client cancellation check
343
+ # Continue waiting for scheduler response
344
+ continue
347
345
 
348
346
  finally:
349
347
  # Always clean up request state when exiting
@@ -397,9 +395,7 @@ class GrpcRequestManager:
397
395
  # Wait for result in background
398
396
  async def wait_for_result():
399
397
  try:
400
- # Wait for completion
401
398
  await state.event.wait()
402
- # Get result from queue
403
399
  result = await state.out_queue.get()
404
400
  future.set_result(result)
405
401
  except Exception as e:
@@ -414,6 +410,10 @@ class GrpcRequestManager:
414
410
 
415
411
  async def abort_request(self, request_id: str) -> bool:
416
412
  """Abort a running request."""
413
+ # Skip aborting health check requests (they clean themselves up)
414
+ if request_id.startswith("HEALTH_CHECK"):
415
+ return False
416
+
417
417
  if request_id not in self.rid_to_state:
418
418
  return False
419
419
 
@@ -437,19 +437,6 @@ class GrpcRequestManager:
437
437
 
438
438
  return True
439
439
 
440
- async def pause_generation(self):
441
- """Pause generation processing."""
442
- async with self.is_pause_cond:
443
- self.is_pause = True
444
- logger.info("Generation paused")
445
-
446
- async def resume_generation(self):
447
- """Resume generation processing."""
448
- async with self.is_pause_cond:
449
- self.is_pause = False
450
- self.is_pause_cond.notify_all()
451
- logger.info("Generation resumed")
452
-
453
440
  async def handle_loop(self):
454
441
  """
455
442
  Main event loop - processes outputs from scheduler.