sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +35 -0
  5. sglang/srt/conversation.py +9 -117
  6. sglang/srt/disaggregation/base/conn.py +5 -2
  7. sglang/srt/disaggregation/decode.py +6 -1
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
  9. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  10. sglang/srt/disaggregation/prefill.py +3 -0
  11. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  12. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  13. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  14. sglang/srt/distributed/parallel_state.py +22 -9
  15. sglang/srt/entrypoints/context.py +244 -0
  16. sglang/srt/entrypoints/engine.py +8 -5
  17. sglang/srt/entrypoints/harmony_utils.py +370 -0
  18. sglang/srt/entrypoints/http_server.py +106 -15
  19. sglang/srt/entrypoints/openai/protocol.py +227 -1
  20. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  21. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  22. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  23. sglang/srt/entrypoints/tool.py +87 -0
  24. sglang/srt/eplb/expert_distribution.py +4 -2
  25. sglang/srt/eplb/expert_location.py +5 -1
  26. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  27. sglang/srt/hf_transformers_utils.py +55 -13
  28. sglang/srt/jinja_template_utils.py +8 -1
  29. sglang/srt/layers/attention/aiter_backend.py +5 -8
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  31. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  32. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  33. sglang/srt/layers/attention/triton_backend.py +85 -14
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  35. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  36. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
  38. sglang/srt/layers/attention/vision.py +40 -15
  39. sglang/srt/layers/communicator.py +35 -8
  40. sglang/srt/layers/dp_attention.py +12 -0
  41. sglang/srt/layers/linear.py +9 -8
  42. sglang/srt/layers/logits_processor.py +9 -1
  43. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  44. sglang/srt/layers/moe/ep_moe/layer.py +87 -107
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
  48. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
  49. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  50. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  51. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  52. sglang/srt/layers/moe/topk.py +12 -3
  53. sglang/srt/layers/moe/utils.py +59 -0
  54. sglang/srt/layers/quantization/__init__.py +22 -0
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  56. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  57. sglang/srt/layers/quantization/fp4.py +557 -0
  58. sglang/srt/layers/quantization/fp8.py +8 -7
  59. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  60. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  61. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  62. sglang/srt/layers/quantization/mxfp4.py +651 -0
  63. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  64. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  65. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  66. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  67. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  68. sglang/srt/layers/quantization/quark/utils.py +107 -0
  69. sglang/srt/layers/quantization/unquant.py +60 -6
  70. sglang/srt/layers/quantization/w4afp8.py +1 -1
  71. sglang/srt/layers/rotary_embedding.py +225 -1
  72. sglang/srt/layers/utils.py +9 -0
  73. sglang/srt/layers/vocab_parallel_embedding.py +15 -4
  74. sglang/srt/lora/lora_manager.py +70 -14
  75. sglang/srt/lora/lora_registry.py +10 -2
  76. sglang/srt/lora/mem_pool.py +43 -5
  77. sglang/srt/managers/cache_controller.py +61 -32
  78. sglang/srt/managers/data_parallel_controller.py +52 -2
  79. sglang/srt/managers/detokenizer_manager.py +1 -1
  80. sglang/srt/managers/io_struct.py +21 -4
  81. sglang/srt/managers/mm_utils.py +5 -11
  82. sglang/srt/managers/schedule_batch.py +30 -8
  83. sglang/srt/managers/schedule_policy.py +3 -1
  84. sglang/srt/managers/scheduler.py +170 -18
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  86. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  87. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  88. sglang/srt/managers/template_manager.py +59 -22
  89. sglang/srt/managers/tokenizer_manager.py +137 -67
  90. sglang/srt/managers/tp_worker.py +3 -0
  91. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  92. sglang/srt/managers/utils.py +45 -1
  93. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  94. sglang/srt/mem_cache/hicache_storage.py +13 -21
  95. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  96. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  97. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  98. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  99. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  100. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  101. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  102. sglang/srt/model_executor/forward_batch_info.py +48 -17
  103. sglang/srt/model_executor/model_runner.py +24 -2
  104. sglang/srt/model_loader/weight_utils.py +10 -0
  105. sglang/srt/models/bailing_moe.py +425 -0
  106. sglang/srt/models/deepseek_v2.py +95 -50
  107. sglang/srt/models/ernie4.py +426 -0
  108. sglang/srt/models/ernie4_eagle.py +203 -0
  109. sglang/srt/models/gemma3n_mm.py +39 -0
  110. sglang/srt/models/glm4_moe.py +102 -27
  111. sglang/srt/models/gpt_oss.py +1134 -0
  112. sglang/srt/models/grok.py +3 -3
  113. sglang/srt/models/llama4.py +13 -2
  114. sglang/srt/models/mixtral.py +3 -3
  115. sglang/srt/models/mllama4.py +428 -19
  116. sglang/srt/models/qwen2.py +6 -0
  117. sglang/srt/models/qwen2_moe.py +7 -4
  118. sglang/srt/models/qwen3_moe.py +39 -14
  119. sglang/srt/models/step3_vl.py +10 -1
  120. sglang/srt/models/transformers.py +2 -5
  121. sglang/srt/multimodal/processors/base_processor.py +4 -3
  122. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  123. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  124. sglang/srt/operations_strategy.py +1 -1
  125. sglang/srt/reasoning_parser.py +18 -39
  126. sglang/srt/server_args.py +218 -23
  127. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  128. sglang/srt/two_batch_overlap.py +163 -9
  129. sglang/srt/utils.py +41 -26
  130. sglang/srt/weight_sync/utils.py +1 -1
  131. sglang/test/runners.py +4 -4
  132. sglang/test/test_utils.py +4 -4
  133. sglang/version.py +1 -1
  134. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
  135. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
  136. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  137. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  138. /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
  139. /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
  140. /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
  141. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  142. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  143. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -64,6 +64,7 @@ from sglang.srt.hf_transformers_utils import (
64
64
  )
65
65
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
66
66
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
67
+ from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
67
68
  from sglang.srt.managers.io_struct import (
68
69
  AbortReq,
69
70
  CloseSessionReqInput,
@@ -119,13 +120,14 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
119
120
  SchedulerOutputProcessorMixin,
120
121
  )
121
122
  from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
123
+ from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
122
124
  from sglang.srt.managers.scheduler_update_weights_mixin import (
123
125
  SchedulerUpdateWeightsMixin,
124
126
  )
125
127
  from sglang.srt.managers.session_controller import Session
126
128
  from sglang.srt.managers.tp_worker import TpModelWorker
127
129
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
128
- from sglang.srt.managers.utils import validate_input_length
130
+ from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
129
131
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
130
132
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
131
133
  from sglang.srt.mem_cache.radix_cache import RadixCache
@@ -137,7 +139,6 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
137
139
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
138
140
  from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
139
141
  from sglang.srt.utils import (
140
- DeepEPMode,
141
142
  DynamicGradMode,
142
143
  broadcast_pyobj,
143
144
  configure_gc_logger,
@@ -203,6 +204,7 @@ class Scheduler(
203
204
  moe_ep_rank: int,
204
205
  pp_rank: int,
205
206
  dp_rank: Optional[int],
207
+ dp_balance_meta: Optional[DPBalanceMeta] = None,
206
208
  ):
207
209
  # Parse args
208
210
  self.server_args = server_args
@@ -471,8 +473,10 @@ class Scheduler(
471
473
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
472
474
  enable=server_args.enable_memory_saver
473
475
  )
476
+ self.offload_tags = set()
474
477
  self.init_profier()
475
478
 
479
+ self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
476
480
  self.input_blocker = (
477
481
  SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
478
482
  if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
@@ -522,6 +526,15 @@ class Scheduler(
522
526
  ]
523
527
  )
524
528
 
529
+ self.balance_meta = dp_balance_meta
530
+ if (
531
+ server_args.enable_dp_attention
532
+ and server_args.load_balance_method == "minimum_tokens"
533
+ ):
534
+ assert dp_balance_meta is not None
535
+
536
+ self.recv_dp_balance_id_this_term = []
537
+
525
538
  def init_tokenizer(self):
526
539
  server_args = self.server_args
527
540
 
@@ -569,7 +582,23 @@ class Scheduler(
569
582
  page_size=self.page_size,
570
583
  )
571
584
  else:
572
- if self.enable_hierarchical_cache:
585
+ if os.environ.get("SGLANG_EXPERIMENTAL_CPP_RADIX_TREE") == "1":
586
+ # lazy import to avoid JIT overhead
587
+ from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp
588
+
589
+ self.tree_cache = RadixCacheCpp(
590
+ disable=False,
591
+ use_hicache=self.enable_hierarchical_cache,
592
+ req_to_token_pool=self.req_to_token_pool,
593
+ token_to_kv_pool=self.token_to_kv_pool_allocator,
594
+ tp_cache_group=self.tp_cpu_group,
595
+ page_size=self.page_size,
596
+ hicache_ratio=server_args.hicache_ratio,
597
+ hicache_size=server_args.hicache_size,
598
+ hicache_write_policy=server_args.hicache_write_policy,
599
+ enable_kv_cache_events=self.enable_kv_cache_events,
600
+ )
601
+ elif self.enable_hierarchical_cache:
573
602
  self.tree_cache = HiRadixCache(
574
603
  req_to_token_pool=self.req_to_token_pool,
575
604
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
@@ -590,6 +619,7 @@ class Scheduler(
590
619
  ),
591
620
  hicache_mem_layout=server_args.hicache_mem_layout,
592
621
  hicache_storage_backend=server_args.hicache_storage_backend,
622
+ hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
593
623
  )
594
624
  self.tp_worker.register_hicache_layer_transfer_counter(
595
625
  self.tree_cache.cache_controller.layer_done_counter
@@ -920,6 +950,14 @@ class Scheduler(
920
950
 
921
951
  def recv_requests(self) -> List[Req]:
922
952
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
953
+
954
+ if self.recv_skipper is not None:
955
+ last_forward_mode = (
956
+ self.last_batch.forward_mode if self.last_batch is not None else None
957
+ )
958
+ if not self.recv_skipper.handle(last_forward_mode):
959
+ return []
960
+
923
961
  if self.pp_rank == 0:
924
962
  if self.attn_tp_rank == 0:
925
963
  recv_reqs = []
@@ -1003,7 +1041,9 @@ class Scheduler(
1003
1041
  for recv_req in recv_reqs:
1004
1042
  # If it is a health check generation request and there are running requests, ignore it.
1005
1043
  if is_health_check_generate_req(recv_req) and (
1006
- self.chunked_req is not None or not self.running_batch.is_empty()
1044
+ self.chunked_req is not None
1045
+ or not self.running_batch.is_empty()
1046
+ or len(self.offload_tags) > 0
1007
1047
  ):
1008
1048
  self.return_health_check_ct += 1
1009
1049
  continue
@@ -1033,6 +1073,12 @@ class Scheduler(
1033
1073
  self,
1034
1074
  recv_req: TokenizedGenerateReqInput,
1035
1075
  ):
1076
+ if (
1077
+ self.server_args.enable_dp_attention
1078
+ and self.server_args.load_balance_method == "minimum_tokens"
1079
+ ):
1080
+ self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
1081
+
1036
1082
  # Create a new request
1037
1083
  if (
1038
1084
  recv_req.session_params is None
@@ -1058,7 +1104,7 @@ class Scheduler(
1058
1104
  top_logprobs_num=recv_req.top_logprobs_num,
1059
1105
  token_ids_logprob=recv_req.token_ids_logprob,
1060
1106
  stream=recv_req.stream,
1061
- lora_path=recv_req.lora_path,
1107
+ lora_id=recv_req.lora_id,
1062
1108
  input_embeds=recv_req.input_embeds,
1063
1109
  custom_logit_processor=recv_req.custom_logit_processor,
1064
1110
  return_hidden_states=recv_req.return_hidden_states,
@@ -1443,6 +1489,11 @@ class Scheduler(
1443
1489
 
1444
1490
  # Handle DP attention
1445
1491
  if need_dp_attn_preparation:
1492
+ if (
1493
+ self.server_args.load_balance_method == "minimum_tokens"
1494
+ and self.forward_ct % 40 == 0
1495
+ ):
1496
+ self.handle_dp_balance_data(ret)
1446
1497
  ret = self.prepare_mlp_sync_batch(ret)
1447
1498
 
1448
1499
  return ret
@@ -1497,18 +1548,15 @@ class Scheduler(
1497
1548
  self.chunked_req = adder.add_chunked_req(self.chunked_req)
1498
1549
 
1499
1550
  if self.enable_lora:
1500
- lora_set = set([req.lora_path for req in self.running_batch.reqs])
1551
+ lora_set = set([req.lora_id for req in self.running_batch.reqs])
1501
1552
 
1502
1553
  # Get requests from the waiting queue to a new prefill batch
1503
1554
  for req in self.waiting_queue:
1504
- if (
1505
- self.enable_lora
1506
- and len(
1507
- lora_set
1508
- | set([req.lora_path for req in adder.can_run_list])
1509
- | set([req.lora_path])
1510
- )
1511
- > self.max_loras_per_batch
1555
+
1556
+ if self.enable_lora and not self.tp_worker.can_run_lora_batch(
1557
+ lora_set
1558
+ | set([req.lora_id for req in adder.can_run_list])
1559
+ | set([req.lora_id])
1512
1560
  ):
1513
1561
  self.running_batch.batch_is_full = True
1514
1562
  break
@@ -1525,7 +1573,10 @@ class Scheduler(
1525
1573
  break
1526
1574
 
1527
1575
  if self.enable_hicache_storage:
1528
- self.tree_cache.check_prefetch_progress(req.rid)
1576
+ prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
1577
+ if not prefetch_done:
1578
+ # skip staging requests that are ongoing prefetch
1579
+ continue
1529
1580
 
1530
1581
  req.init_next_round_input(self.tree_cache)
1531
1582
  res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
@@ -1744,6 +1795,9 @@ class Scheduler(
1744
1795
  elif batch.forward_mode.is_dummy_first():
1745
1796
  self.set_next_batch_sampling_info_done(batch)
1746
1797
 
1798
+ self.maybe_send_health_check_signal()
1799
+
1800
+ def maybe_send_health_check_signal(self):
1747
1801
  if self.return_health_check_ct:
1748
1802
  # Return some signal for the health check.
1749
1803
  # This is used to prevent the health check signal being blocked by long context prefill.
@@ -1762,12 +1816,94 @@ class Scheduler(
1762
1816
  spec_algorithm=self.spec_algorithm,
1763
1817
  speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1764
1818
  enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
1765
- enable_deepep_moe=self.server_args.enable_deepep_moe,
1766
- deepep_mode=DeepEPMode[self.server_args.deepep_mode],
1819
+ enable_deepep_moe=MoeA2ABackend(
1820
+ self.server_args.moe_a2a_backend
1821
+ ).is_deepep(),
1822
+ deepep_mode=DeepEPMode(self.server_args.deepep_mode),
1767
1823
  require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
1768
1824
  disable_overlap_schedule=self.server_args.disable_overlap_schedule,
1769
1825
  )
1770
1826
 
1827
+ def handle_dp_balance_data(self, local_batch: ScheduleBatch):
1828
+ def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
1829
+ """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
1830
+ recv_list = self.recv_dp_balance_id_this_term
1831
+ assert len(recv_list) <= 511, (
1832
+ "The number of requests received this round is too large. "
1833
+ "Please increase gather_tensor_size and onfly_info_size."
1834
+ )
1835
+ # The maximum size of the tensor used for gathering data from all workers.
1836
+ gather_tensor_size = 512
1837
+
1838
+ # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
1839
+ recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
1840
+ recv_tensor[0] = holding_tokens_list
1841
+ recv_tensor[1] = len(
1842
+ recv_list
1843
+ ) # The first element is the length of the list.
1844
+ recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
1845
+ recv_list, dtype=torch.int32
1846
+ )
1847
+
1848
+ if self.tp_rank == 0:
1849
+ gathered_list = [
1850
+ torch.zeros(gather_tensor_size, dtype=torch.int32)
1851
+ for _ in range(self.balance_meta.num_workers)
1852
+ ]
1853
+ else:
1854
+ gathered_list = None
1855
+
1856
+ torch.distributed.gather(
1857
+ recv_tensor, gathered_list, group=self.tp_cpu_group
1858
+ )
1859
+
1860
+ gathered_id_list_per_worker = None
1861
+ if self.tp_rank == 0:
1862
+ gathered_id_list_per_worker = []
1863
+ holding_tokens_list = []
1864
+ for tensor in gathered_list:
1865
+ holding_tokens_list.append(tensor[0].item())
1866
+ list_length = tensor[1].item()
1867
+ gathered_id_list_per_worker.append(
1868
+ tensor[2 : list_length + 2].tolist()
1869
+ )
1870
+
1871
+ return gathered_id_list_per_worker, holding_tokens_list
1872
+
1873
+ def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
1874
+ meta = self.balance_meta
1875
+
1876
+ with meta.mutex:
1877
+ onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
1878
+ assert len(new_recv_rid_lists) == len(
1879
+ onfly_list
1880
+ ), "num_worker not equal"
1881
+ # 1.Check if the rid received by each worker this round is present in onfly.
1882
+ # If it is, remove the corresponding onfly item.
1883
+ worker_id = 0
1884
+ for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
1885
+ for new_recv_rid in new_recv_rids:
1886
+ assert (
1887
+ new_recv_rid in on_fly_reqs
1888
+ ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
1889
+ del on_fly_reqs[new_recv_rid]
1890
+ worker_id += 1
1891
+ # 2. Atomically write local_tokens and onfly into shm under the mutex
1892
+ meta.set_shared_onfly_info(onfly_list)
1893
+ meta.set_shared_local_tokens(local_tokens)
1894
+
1895
+ holding_tokens = self.get_load()
1896
+
1897
+ new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
1898
+ holding_tokens
1899
+ )
1900
+
1901
+ self.recv_dp_balance_id_this_term.clear()
1902
+ if self.tp_rank == 0: # only first worker write info
1903
+ write_shared_dp_balance_info(
1904
+ new_recv_dp_balance_id_list, holding_token_list
1905
+ )
1906
+
1771
1907
  @staticmethod
1772
1908
  def prepare_mlp_sync_batch_raw(
1773
1909
  local_batch: ScheduleBatch,
@@ -2344,11 +2480,19 @@ class IdleSleeper:
2344
2480
 
2345
2481
  def __init__(self, sockets):
2346
2482
  self.poller = zmq.Poller()
2483
+ self.last_empty_time = time.time()
2347
2484
  for s in sockets:
2348
2485
  self.poller.register(s, zmq.POLLIN)
2349
2486
 
2350
2487
  def maybe_sleep(self):
2351
2488
  self.poller.poll(1000)
2489
+ if (
2490
+ global_config.torch_empty_cache_interval > 0
2491
+ and time.time() - self.last_empty_time
2492
+ > global_config.torch_empty_cache_interval
2493
+ ):
2494
+ self.last_empty_time = time.time()
2495
+ torch.cuda.empty_cache()
2352
2496
 
2353
2497
 
2354
2498
  def is_health_check_generate_req(recv_req):
@@ -2368,6 +2512,7 @@ def run_scheduler_process(
2368
2512
  pp_rank: int,
2369
2513
  dp_rank: Optional[int],
2370
2514
  pipe_writer,
2515
+ balance_meta: Optional[DPBalanceMeta] = None,
2371
2516
  ):
2372
2517
  # Generate the prefix
2373
2518
  prefix = ""
@@ -2401,7 +2546,14 @@ def run_scheduler_process(
2401
2546
  # Create a scheduler and run the event loop
2402
2547
  try:
2403
2548
  scheduler = Scheduler(
2404
- server_args, port_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank
2549
+ server_args,
2550
+ port_args,
2551
+ gpu_id,
2552
+ tp_rank,
2553
+ moe_ep_rank,
2554
+ pp_rank,
2555
+ dp_rank,
2556
+ dp_balance_meta=balance_meta,
2405
2557
  )
2406
2558
  pipe_writer.send(
2407
2559
  {
@@ -571,8 +571,7 @@ class SchedulerOutputProcessorMixin:
571
571
 
572
572
  req.send_decode_id_offset = len(decode_ids)
573
573
  read_offsets.append(read_offset)
574
- if self.skip_tokenizer_init:
575
- output_ids.append(req.output_ids[send_token_offset:])
574
+ output_ids.append(req.output_ids[send_token_offset:])
576
575
  req.send_token_offset = len(req.output_ids)
577
576
  skip_special_tokens.append(req.sampling_params.skip_special_tokens)
578
577
  spaces_between_special_tokens.append(
@@ -0,0 +1,37 @@
1
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
2
+ from sglang.srt.server_args import ServerArgs
3
+
4
+
5
+ class SchedulerRecvSkipper:
6
+ @staticmethod
7
+ def maybe_create(server_args: ServerArgs):
8
+ if server_args.scheduler_recv_interval <= 1:
9
+ return None
10
+ return SchedulerRecvSkipper(server_args)
11
+
12
+ def __init__(self, server_args: ServerArgs):
13
+ # Can be supported if needed, but may need e.g. `global_forward_mode`
14
+ assert not server_args.enable_dp_attention
15
+ self._counter = 0
16
+ self._threshold = server_args.scheduler_recv_interval
17
+
18
+ def handle(self, last_forward_mode: ForwardMode):
19
+ should_recv = False
20
+
21
+ last_weight = _WEIGHT_OF_FORWARD_MODE.get(last_forward_mode, _DEFAULT_WEIGHT)
22
+ self._counter += last_weight
23
+
24
+ if self._counter >= self._threshold:
25
+ self._counter = 0
26
+ should_recv = True
27
+
28
+ return should_recv
29
+
30
+
31
+ # All can be tuned if needed
32
+ _DEFAULT_WEIGHT = 1000
33
+ _WEIGHT_OF_FORWARD_MODE = {
34
+ ForwardMode.DECODE: 1,
35
+ ForwardMode.TARGET_VERIFY: 1,
36
+ None: 1,
37
+ }
@@ -78,6 +78,9 @@ class SchedulerUpdateWeightsMixin:
78
78
  if tags is None or len(tags) == 0:
79
79
  tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
80
80
 
81
+ for tag in tags:
82
+ self.offload_tags.add(tag)
83
+
81
84
  if GPU_MEMORY_TYPE_KV_CACHE in tags:
82
85
  self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
83
86
  self.flush_cache()
@@ -97,6 +100,9 @@ class SchedulerUpdateWeightsMixin:
97
100
  if tags is None or len(tags) == 0:
98
101
  tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
99
102
 
103
+ for tag in tags:
104
+ self.offload_tags.remove(tag)
105
+
100
106
  if GPU_MEMORY_TYPE_WEIGHTS in tags:
101
107
  self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
102
108
  torch.distributed.barrier(self.tp_cpu_group)
@@ -21,6 +21,7 @@ and code completion templates, eliminating global state and improving modularity
21
21
  import json
22
22
  import logging
23
23
  import os
24
+ import re
24
25
  from typing import Optional
25
26
 
26
27
  from sglang.srt.code_completion_parser import (
@@ -54,6 +55,7 @@ class TemplateManager:
54
55
  self._chat_template_name: Optional[str] = None
55
56
  self._completion_template_name: Optional[str] = None
56
57
  self._jinja_template_content_format: Optional[str] = "openai"
58
+ self._force_reasoning: bool = False
57
59
 
58
60
  @property
59
61
  def chat_template_name(self) -> Optional[str]:
@@ -70,6 +72,31 @@ class TemplateManager:
70
72
  """Get the detected template content format ('string' or 'openai' or None)."""
71
73
  return self._jinja_template_content_format
72
74
 
75
+ @property
76
+ def force_reasoning(self) -> bool:
77
+ """
78
+ Check if the current chat template enforces reasoning/thinking.
79
+
80
+ Returns:
81
+ True if the template contains reasoning patterns like <think> tags
82
+ """
83
+ return self._force_reasoning
84
+
85
+ def _detect_reasoning_pattern(self, template: str) -> bool:
86
+ """
87
+ Detect if the chat template contains reasoning/thinking patterns.
88
+ """
89
+ if template is None:
90
+ return False
91
+
92
+ force_reasoning_pattern = r"<\|im_start\|>assistant\\n<think>\\n"
93
+ has_reasoning = re.search(force_reasoning_pattern, template) is not None
94
+
95
+ if has_reasoning:
96
+ logger.info("Detected the force reasoning pattern in chat template.")
97
+
98
+ return has_reasoning
99
+
73
100
  def load_chat_template(
74
101
  self, tokenizer_manager, chat_template_arg: Optional[str], model_path: str
75
102
  ) -> None:
@@ -84,26 +111,34 @@ class TemplateManager:
84
111
  if chat_template_arg:
85
112
  self._load_explicit_chat_template(tokenizer_manager, chat_template_arg)
86
113
  else:
87
- # Try HuggingFace template first
88
- hf_template = self._resolve_hf_chat_template(tokenizer_manager)
89
- if hf_template:
90
- self._jinja_template_content_format = (
91
- detect_jinja_template_content_format(hf_template)
92
- )
93
- logger.info(
94
- f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
95
- )
96
- return
97
-
98
- # Fallback to SGLang template guessing
114
+ # Guess chat template from model path
99
115
  self.guess_chat_template_from_model_path(model_path)
100
116
 
101
- # Set default format if no template was found
117
+ # If no pre-defined template was found, fallback to HuggingFace template
102
118
  if self._chat_template_name is None:
103
- self._jinja_template_content_format = "string"
104
- logger.info(
105
- "No chat template found, defaulting to 'string' content format"
106
- )
119
+ # Try HuggingFace template first
120
+ hf_template = self._resolve_hf_chat_template(tokenizer_manager)
121
+ if hf_template:
122
+ # override the chat template
123
+ if tokenizer_manager.tokenizer:
124
+ tokenizer_manager.tokenizer.chat_template = hf_template
125
+ self._jinja_template_content_format = (
126
+ detect_jinja_template_content_format(hf_template)
127
+ )
128
+ logger.info(
129
+ f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
130
+ )
131
+ return
132
+
133
+ # Default to string content format if no template was found
134
+ self._jinja_template_content_format = "string"
135
+ logger.info("No chat template found, defaulting to 'string' content format")
136
+
137
+ # Detect reasoning pattern from chat template
138
+ if tokenizer_manager.tokenizer:
139
+ self._force_reasoning = self._detect_reasoning_pattern(
140
+ tokenizer_manager.tokenizer.chat_template
141
+ )
107
142
 
108
143
  def _load_explicit_chat_template(
109
144
  self, tokenizer_manager, chat_template_arg: str
@@ -257,13 +292,15 @@ class TemplateManager:
257
292
 
258
293
  Returns the chat template string if found, None otherwise.
259
294
  """
260
- tokenizer = tokenizer_manager.tokenizer
261
-
262
- # Try to get AutoTokenizer chat template
263
295
  try:
264
- return tokenizer.get_chat_template()
296
+ if processor := tokenizer_manager.processor:
297
+ if hasattr(processor, "chat_template") and processor.chat_template:
298
+ return processor.chat_template
299
+ if tokenizer := tokenizer_manager.tokenizer:
300
+ if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
301
+ return tokenizer.chat_template
265
302
  except Exception as e:
266
- logger.debug(f"Error getting chat template via get_chat_template(): {e}")
303
+ logger.debug(f"Error getting chat template: {e}")
267
304
 
268
305
  logger.debug("No HuggingFace chat template found")
269
306
  return None