sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
36
36
  # This can prevent the server from being too conservative.
37
37
  # Note that this only clips the estimation in the scheduler but does not change the stop
38
38
  # condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
39
- CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
39
+ CLIP_MAX_NEW_TOKENS = int(
40
40
  os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
41
41
  )
42
42
 
@@ -305,7 +305,7 @@ class PrefillAdder:
305
305
  [
306
306
  min(
307
307
  (r.sampling_params.max_new_tokens - len(r.output_ids)),
308
- CLIP_MAX_NEW_TOKENS_ESTIMATION,
308
+ CLIP_MAX_NEW_TOKENS,
309
309
  )
310
310
  * self.new_token_ratio
311
311
  for r in running_batch.reqs
@@ -388,7 +388,7 @@ class PrefillAdder:
388
388
  0,
389
389
  req.extend_input_len,
390
390
  (
391
- min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION)
391
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
392
392
  if not truncated
393
393
  else 0
394
394
  ),
@@ -477,7 +477,7 @@ class PrefillAdder:
477
477
  self._update_prefill_budget(
478
478
  0,
479
479
  req.extend_input_len,
480
- min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
480
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
481
481
  )
482
482
  else:
483
483
  if self.rem_chunk_tokens == 0:
@@ -499,7 +499,7 @@ class PrefillAdder:
499
499
  return self.add_one_req_ignore_eos(req, has_chunked_req)
500
500
 
501
501
  total_tokens = req.extend_input_len + min(
502
- req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
502
+ req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
503
503
  )
504
504
 
505
505
  # adjusting the input_tokens based on host_hit_length and page_size
@@ -544,7 +544,7 @@ class PrefillAdder:
544
544
  input_tokens,
545
545
  min(
546
546
  req.sampling_params.max_new_tokens,
547
- CLIP_MAX_NEW_TOKENS_ESTIMATION,
547
+ CLIP_MAX_NEW_TOKENS,
548
548
  ),
549
549
  )
550
550
  else:
@@ -120,6 +120,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
120
120
  SchedulerOutputProcessorMixin,
121
121
  )
122
122
  from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
123
+ from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
123
124
  from sglang.srt.managers.scheduler_update_weights_mixin import (
124
125
  SchedulerUpdateWeightsMixin,
125
126
  )
@@ -129,6 +130,7 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
129
130
  from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
130
131
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
131
132
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
133
+ from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
132
134
  from sglang.srt.mem_cache.radix_cache import RadixCache
133
135
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
134
136
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
@@ -472,8 +474,10 @@ class Scheduler(
472
474
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
473
475
  enable=server_args.enable_memory_saver
474
476
  )
477
+ self.offload_tags = set()
475
478
  self.init_profier()
476
479
 
480
+ self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
477
481
  self.input_blocker = (
478
482
  SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
479
483
  if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
@@ -608,14 +612,10 @@ class Scheduler(
608
612
  hicache_ratio=server_args.hicache_ratio,
609
613
  hicache_size=server_args.hicache_size,
610
614
  hicache_write_policy=server_args.hicache_write_policy,
611
- hicache_io_backend=(
612
- "direct"
613
- if server_args.attention_backend
614
- == "fa3" # hot fix for incompatibility
615
- else server_args.hicache_io_backend
616
- ),
615
+ hicache_io_backend=server_args.hicache_io_backend,
617
616
  hicache_mem_layout=server_args.hicache_mem_layout,
618
617
  hicache_storage_backend=server_args.hicache_storage_backend,
618
+ hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
619
619
  )
620
620
  self.tp_worker.register_hicache_layer_transfer_counter(
621
621
  self.tree_cache.cache_controller.layer_done_counter
@@ -631,7 +631,19 @@ class Scheduler(
631
631
  page_size=self.page_size,
632
632
  disable=server_args.disable_radix_cache,
633
633
  )
634
-
634
+ elif self.enable_lora:
635
+ assert (
636
+ not self.enable_hierarchical_cache
637
+ ), "LoRA radix cache doesn't support hierarchical cache"
638
+ assert (
639
+ self.schedule_policy == "fcfs"
640
+ ), "LoRA radix cache only supports FCFS policy"
641
+ self.tree_cache = LoRARadixCache(
642
+ req_to_token_pool=self.req_to_token_pool,
643
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
644
+ page_size=self.page_size,
645
+ disable=server_args.disable_radix_cache,
646
+ )
635
647
  else:
636
648
  self.tree_cache = RadixCache(
637
649
  req_to_token_pool=self.req_to_token_pool,
@@ -946,6 +958,14 @@ class Scheduler(
946
958
 
947
959
  def recv_requests(self) -> List[Req]:
948
960
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
961
+
962
+ if self.recv_skipper is not None:
963
+ last_forward_mode = (
964
+ self.last_batch.forward_mode if self.last_batch is not None else None
965
+ )
966
+ if not self.recv_skipper.handle(last_forward_mode):
967
+ return []
968
+
949
969
  if self.pp_rank == 0:
950
970
  if self.attn_tp_rank == 0:
951
971
  recv_reqs = []
@@ -1029,7 +1049,9 @@ class Scheduler(
1029
1049
  for recv_req in recv_reqs:
1030
1050
  # If it is a health check generation request and there are running requests, ignore it.
1031
1051
  if is_health_check_generate_req(recv_req) and (
1032
- self.chunked_req is not None or not self.running_batch.is_empty()
1052
+ self.chunked_req is not None
1053
+ or not self.running_batch.is_empty()
1054
+ or len(self.offload_tags) > 0
1033
1055
  ):
1034
1056
  self.return_health_check_ct += 1
1035
1057
  continue
@@ -1090,7 +1112,7 @@ class Scheduler(
1090
1112
  top_logprobs_num=recv_req.top_logprobs_num,
1091
1113
  token_ids_logprob=recv_req.token_ids_logprob,
1092
1114
  stream=recv_req.stream,
1093
- lora_path=recv_req.lora_path,
1115
+ lora_id=recv_req.lora_id,
1094
1116
  input_embeds=recv_req.input_embeds,
1095
1117
  custom_logit_processor=recv_req.custom_logit_processor,
1096
1118
  return_hidden_states=recv_req.return_hidden_states,
@@ -1534,18 +1556,15 @@ class Scheduler(
1534
1556
  self.chunked_req = adder.add_chunked_req(self.chunked_req)
1535
1557
 
1536
1558
  if self.enable_lora:
1537
- lora_set = set([req.lora_path for req in self.running_batch.reqs])
1559
+ lora_set = set([req.lora_id for req in self.running_batch.reqs])
1538
1560
 
1539
1561
  # Get requests from the waiting queue to a new prefill batch
1540
1562
  for req in self.waiting_queue:
1541
- if (
1542
- self.enable_lora
1543
- and len(
1544
- lora_set
1545
- | set([req.lora_path for req in adder.can_run_list])
1546
- | set([req.lora_path])
1547
- )
1548
- > self.max_loras_per_batch
1563
+
1564
+ if self.enable_lora and not self.tp_worker.can_run_lora_batch(
1565
+ lora_set
1566
+ | set([req.lora_id for req in adder.can_run_list])
1567
+ | set([req.lora_id])
1549
1568
  ):
1550
1569
  self.running_batch.batch_is_full = True
1551
1570
  break
@@ -1562,7 +1581,10 @@ class Scheduler(
1562
1581
  break
1563
1582
 
1564
1583
  if self.enable_hicache_storage:
1565
- self.tree_cache.check_prefetch_progress(req.rid)
1584
+ prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
1585
+ if not prefetch_done:
1586
+ # skip staging requests that are ongoing prefetch
1587
+ continue
1566
1588
 
1567
1589
  req.init_next_round_input(self.tree_cache)
1568
1590
  res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
@@ -571,8 +571,7 @@ class SchedulerOutputProcessorMixin:
571
571
 
572
572
  req.send_decode_id_offset = len(decode_ids)
573
573
  read_offsets.append(read_offset)
574
- if self.skip_tokenizer_init:
575
- output_ids.append(req.output_ids[send_token_offset:])
574
+ output_ids.append(req.output_ids[send_token_offset:])
576
575
  req.send_token_offset = len(req.output_ids)
577
576
  skip_special_tokens.append(req.sampling_params.skip_special_tokens)
578
577
  spaces_between_special_tokens.append(
@@ -8,6 +8,18 @@ import torch
8
8
 
9
9
  from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
10
10
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
11
+ from sglang.srt.utils import is_npu
12
+
13
+ _is_npu = is_npu()
14
+ if _is_npu:
15
+ import torch_npu
16
+
17
+ patches = [
18
+ ["profiler.profile", torch_npu.profiler.profile],
19
+ ["profiler.ProfilerActivity.CUDA", torch_npu.profiler.ProfilerActivity.NPU],
20
+ ["profiler.ProfilerActivity.CPU", torch_npu.profiler.ProfilerActivity.CPU],
21
+ ]
22
+ torch_npu._apply_patches(patches)
11
23
 
12
24
  logger = logging.getLogger(__name__)
13
25
 
@@ -136,6 +148,13 @@ class SchedulerProfilerMixin:
136
148
  activities=torchprof_activities,
137
149
  with_stack=with_stack if with_stack is not None else True,
138
150
  record_shapes=record_shapes if record_shapes is not None else False,
151
+ on_trace_ready=(
152
+ None
153
+ if not _is_npu
154
+ else torch_npu.profiler.tensorboard_trace_handler(
155
+ self.torch_profiler_output_dir
156
+ )
157
+ ),
139
158
  )
140
159
  self.torch_profiler.start()
141
160
  self.profile_in_progress = True
@@ -166,15 +185,16 @@ class SchedulerProfilerMixin:
166
185
  logger.info("Stop profiling" + stage_suffix + "...")
167
186
  if self.torch_profiler is not None:
168
187
  self.torch_profiler.stop()
169
- self.torch_profiler.export_chrome_trace(
170
- os.path.join(
171
- self.torch_profiler_output_dir,
172
- self.profile_id
173
- + f"-TP-{self.tp_rank}"
174
- + stage_suffix
175
- + ".trace.json.gz",
188
+ if not _is_npu:
189
+ self.torch_profiler.export_chrome_trace(
190
+ os.path.join(
191
+ self.torch_profiler_output_dir,
192
+ self.profile_id
193
+ + f"-TP-{self.tp_rank}"
194
+ + stage_suffix
195
+ + ".trace.json.gz",
196
+ )
176
197
  )
177
- )
178
198
  torch.distributed.barrier(self.tp_cpu_group)
179
199
 
180
200
  if self.rpd_profiler is not None:
@@ -0,0 +1,37 @@
1
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
2
+ from sglang.srt.server_args import ServerArgs
3
+
4
+
5
+ class SchedulerRecvSkipper:
6
+ @staticmethod
7
+ def maybe_create(server_args: ServerArgs):
8
+ if server_args.scheduler_recv_interval <= 1:
9
+ return None
10
+ return SchedulerRecvSkipper(server_args)
11
+
12
+ def __init__(self, server_args: ServerArgs):
13
+ # Can be supported if needed, but may need e.g. `global_forward_mode`
14
+ assert not server_args.enable_dp_attention
15
+ self._counter = 0
16
+ self._threshold = server_args.scheduler_recv_interval
17
+
18
+ def handle(self, last_forward_mode: ForwardMode):
19
+ should_recv = False
20
+
21
+ last_weight = _WEIGHT_OF_FORWARD_MODE.get(last_forward_mode, _DEFAULT_WEIGHT)
22
+ self._counter += last_weight
23
+
24
+ if self._counter >= self._threshold:
25
+ self._counter = 0
26
+ should_recv = True
27
+
28
+ return should_recv
29
+
30
+
31
+ # All can be tuned if needed
32
+ _DEFAULT_WEIGHT = 1000
33
+ _WEIGHT_OF_FORWARD_MODE = {
34
+ ForwardMode.DECODE: 1,
35
+ ForwardMode.TARGET_VERIFY: 1,
36
+ None: 1,
37
+ }
@@ -78,6 +78,9 @@ class SchedulerUpdateWeightsMixin:
78
78
  if tags is None or len(tags) == 0:
79
79
  tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
80
80
 
81
+ for tag in tags:
82
+ self.offload_tags.add(tag)
83
+
81
84
  if GPU_MEMORY_TYPE_KV_CACHE in tags:
82
85
  self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
83
86
  self.flush_cache()
@@ -97,6 +100,9 @@ class SchedulerUpdateWeightsMixin:
97
100
  if tags is None or len(tags) == 0:
98
101
  tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
99
102
 
103
+ for tag in tags:
104
+ self.offload_tags.remove(tag)
105
+
100
106
  if GPU_MEMORY_TYPE_WEIGHTS in tags:
101
107
  self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
102
108
  torch.distributed.barrier(self.tp_cpu_group)
@@ -21,6 +21,7 @@ and code completion templates, eliminating global state and improving modularity
21
21
  import json
22
22
  import logging
23
23
  import os
24
+ import re
24
25
  from typing import Optional
25
26
 
26
27
  from sglang.srt.code_completion_parser import (
@@ -54,6 +55,7 @@ class TemplateManager:
54
55
  self._chat_template_name: Optional[str] = None
55
56
  self._completion_template_name: Optional[str] = None
56
57
  self._jinja_template_content_format: Optional[str] = "openai"
58
+ self._force_reasoning: bool = False
57
59
 
58
60
  @property
59
61
  def chat_template_name(self) -> Optional[str]:
@@ -70,6 +72,31 @@ class TemplateManager:
70
72
  """Get the detected template content format ('string' or 'openai' or None)."""
71
73
  return self._jinja_template_content_format
72
74
 
75
+ @property
76
+ def force_reasoning(self) -> bool:
77
+ """
78
+ Check if the current chat template enforces reasoning/thinking.
79
+
80
+ Returns:
81
+ True if the template contains reasoning patterns like <think> tags
82
+ """
83
+ return self._force_reasoning
84
+
85
+ def _detect_reasoning_pattern(self, template: str) -> bool:
86
+ """
87
+ Detect if the chat template contains reasoning/thinking patterns.
88
+ """
89
+ if template is None:
90
+ return False
91
+
92
+ force_reasoning_pattern = r"<\|im_start\|>assistant\\n<think>\\n"
93
+ has_reasoning = re.search(force_reasoning_pattern, template) is not None
94
+
95
+ if has_reasoning:
96
+ logger.info("Detected the force reasoning pattern in chat template.")
97
+
98
+ return has_reasoning
99
+
73
100
  def load_chat_template(
74
101
  self, tokenizer_manager, chat_template_arg: Optional[str], model_path: str
75
102
  ) -> None:
@@ -93,7 +120,8 @@ class TemplateManager:
93
120
  hf_template = self._resolve_hf_chat_template(tokenizer_manager)
94
121
  if hf_template:
95
122
  # override the chat template
96
- tokenizer_manager.tokenizer.chat_template = hf_template
123
+ if tokenizer_manager.tokenizer:
124
+ tokenizer_manager.tokenizer.chat_template = hf_template
97
125
  self._jinja_template_content_format = (
98
126
  detect_jinja_template_content_format(hf_template)
99
127
  )
@@ -106,6 +134,12 @@ class TemplateManager:
106
134
  self._jinja_template_content_format = "string"
107
135
  logger.info("No chat template found, defaulting to 'string' content format")
108
136
 
137
+ # Detect reasoning pattern from chat template
138
+ if tokenizer_manager.tokenizer:
139
+ self._force_reasoning = self._detect_reasoning_pattern(
140
+ tokenizer_manager.tokenizer.chat_template
141
+ )
142
+
109
143
  def _load_explicit_chat_template(
110
144
  self, tokenizer_manager, chat_template_arg: str
111
145
  ) -> None:
@@ -269,10 +269,9 @@ class TokenizerManager:
269
269
  self.asyncio_tasks = set()
270
270
 
271
271
  # Health check
272
- self.health_check_failed = False
272
+ self.server_status = ServerStatus.Starting
273
273
  self.gracefully_exit = False
274
274
  self.last_receive_tstamp = 0
275
- self.server_status = ServerStatus.Starting
276
275
 
277
276
  # Dumping
278
277
  self.dump_requests_folder = "" # By default do not dump
@@ -291,8 +290,8 @@ class TokenizerManager:
291
290
  self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
292
291
  None
293
292
  )
294
- self._is_updating = False
295
- self._is_updating_cond = asyncio.Condition()
293
+ self.is_pause = False
294
+ self.is_pause_cond = asyncio.Condition()
296
295
 
297
296
  # LoRA
298
297
  # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
@@ -476,16 +475,20 @@ class TokenizerManager:
476
475
  self.auto_create_handle_loop()
477
476
  obj.normalize_batch_and_arguments()
478
477
 
479
- async with self._is_updating_cond:
480
- await self._is_updating_cond.wait_for(lambda: not self._is_updating)
481
-
482
478
  if self.log_requests:
483
479
  max_length, skip_names, _ = self.log_request_metadata
484
480
  logger.info(
485
481
  f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
486
482
  )
487
483
 
484
+ async with self.is_pause_cond:
485
+ await self.is_pause_cond.wait_for(lambda: not self.is_pause)
486
+
488
487
  async with self.model_update_lock.reader_lock:
488
+ if self.server_args.enable_lora and obj.lora_path:
489
+ # Look up the LoRA ID from the registry and start tracking ongoing LoRA requests.
490
+ obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
491
+
489
492
  if obj.is_single:
490
493
  tokenized_obj = await self._tokenize_one_request(obj)
491
494
  state = self._send_one_request(obj, tokenized_obj, created_time)
@@ -553,11 +556,6 @@ class TokenizerManager:
553
556
  else:
554
557
  mm_inputs = None
555
558
 
556
- if self.server_args.enable_lora and obj.lora_path:
557
- # Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
558
- # `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
559
- obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
560
-
561
559
  self._validate_one_request(obj, input_ids)
562
560
  return self._create_tokenized_object(
563
561
  obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
@@ -665,7 +663,7 @@ class TokenizerManager:
665
663
  bootstrap_host=obj.bootstrap_host,
666
664
  bootstrap_port=obj.bootstrap_port,
667
665
  bootstrap_room=obj.bootstrap_room,
668
- lora_path=obj.lora_path,
666
+ lora_id=obj.lora_id,
669
667
  input_embeds=input_embeds,
670
668
  session_params=session_params,
671
669
  custom_logit_processor=obj.custom_logit_processor,
@@ -750,7 +748,11 @@ class TokenizerManager:
750
748
  try:
751
749
  await asyncio.wait_for(state.event.wait(), timeout=4)
752
750
  except asyncio.TimeoutError:
753
- if request is not None and await request.is_disconnected():
751
+ if (
752
+ request is not None
753
+ and not obj.background
754
+ and await request.is_disconnected()
755
+ ):
754
756
  # Abort the request for disconnected requests (non-streaming, waiting queue)
755
757
  self.abort_request(obj.rid)
756
758
  # Use exception to kill the whole call stack and asyncio task
@@ -771,10 +773,6 @@ class TokenizerManager:
771
773
  msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
772
774
  logger.info(msg)
773
775
 
774
- # Mark ongoing LoRA request as finished.
775
- if self.server_args.enable_lora and obj.lora_path:
776
- await self.lora_registry.release(obj.lora_path)
777
-
778
776
  # Check if this was an abort/error created by scheduler
779
777
  if isinstance(out["meta_info"].get("finish_reason"), dict):
780
778
  finish_reason = out["meta_info"]["finish_reason"]
@@ -793,6 +791,11 @@ class TokenizerManager:
793
791
  # Delete the key to prevent resending abort request to the scheduler and
794
792
  # to ensure aborted request state is cleaned up.
795
793
  del self.rid_to_state[state.obj.rid]
794
+
795
+ # Mark ongoing LoRA request as finished.
796
+ if self.server_args.enable_lora and state.obj.lora_path:
797
+ await self.lora_registry.release(state.obj.lora_id)
798
+
796
799
  raise fastapi.HTTPException(
797
800
  status_code=finish_reason["status_code"],
798
801
  detail=finish_reason["message"],
@@ -805,7 +808,11 @@ class TokenizerManager:
805
808
  if obj.stream:
806
809
  yield out
807
810
  else:
808
- if request is not None and await request.is_disconnected():
811
+ if (
812
+ request is not None
813
+ and not obj.background
814
+ and await request.is_disconnected()
815
+ ):
809
816
  # Abort the request for disconnected requests (non-streaming, running)
810
817
  self.abort_request(obj.rid)
811
818
  # Use exception to kill the whole call stack and asyncio task
@@ -974,14 +981,14 @@ class TokenizerManager:
974
981
  await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
975
982
 
976
983
  async def pause_generation(self):
977
- async with self._is_updating_cond:
978
- self._is_updating = True
984
+ async with self.is_pause_cond:
985
+ self.is_pause = True
979
986
  self.abort_request(abort_all=True)
980
987
 
981
988
  async def continue_generation(self):
982
- async with self._is_updating_cond:
983
- self._is_updating = False
984
- self._is_updating_cond.notify_all()
989
+ async with self.is_pause_cond:
990
+ self.is_pause = False
991
+ self.is_pause_cond.notify_all()
985
992
 
986
993
  async def update_weights_from_disk(
987
994
  self,
@@ -1121,6 +1128,7 @@ class TokenizerManager:
1121
1128
  new_adapter = LoRARef(
1122
1129
  lora_name=obj.lora_name,
1123
1130
  lora_path=obj.lora_path,
1131
+ pinned=obj.pinned,
1124
1132
  )
1125
1133
 
1126
1134
  # Trigger the actual loading operation at the backend processes.
@@ -1178,7 +1186,7 @@ class TokenizerManager:
1178
1186
 
1179
1187
  return result
1180
1188
  except ValueError as e:
1181
- return UnloadLoRAAdapterReqOutput(success=False, rror_message=str(e))
1189
+ return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
1182
1190
 
1183
1191
  async def get_weights_by_name(
1184
1192
  self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
@@ -1465,7 +1473,7 @@ class TokenizerManager:
1465
1473
  while True:
1466
1474
  remain_num_req = len(self.rid_to_state)
1467
1475
 
1468
- if self.health_check_failed:
1476
+ if self.server_status == ServerStatus.UnHealthy:
1469
1477
  # if health check failed, we should exit immediately
1470
1478
  logger.error(
1471
1479
  "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
@@ -1548,8 +1556,17 @@ class TokenizerManager:
1548
1556
 
1549
1557
  if isinstance(recv_obj, BatchStrOut):
1550
1558
  state.text += recv_obj.output_strs[i]
1559
+ if state.obj.stream:
1560
+ state.output_ids.extend(recv_obj.output_ids[i])
1561
+ output_token_ids = state.output_ids[state.last_output_offset :]
1562
+ state.last_output_offset = len(state.output_ids)
1563
+ else:
1564
+ state.output_ids.extend(recv_obj.output_ids[i])
1565
+ output_token_ids = state.output_ids.copy()
1566
+
1551
1567
  out_dict = {
1552
1568
  "text": state.text,
1569
+ "output_ids": output_token_ids,
1553
1570
  "meta_info": meta_info,
1554
1571
  }
1555
1572
  elif isinstance(recv_obj, BatchTokenIDOut):
@@ -1582,6 +1599,10 @@ class TokenizerManager:
1582
1599
  meta_info["e2e_latency"] = state.finished_time - state.created_time
1583
1600
  del self.rid_to_state[rid]
1584
1601
 
1602
+ # Mark ongoing LoRA request as finished.
1603
+ if self.server_args.enable_lora and state.obj.lora_path:
1604
+ asyncio.create_task(self.lora_registry.release(state.obj.lora_id))
1605
+
1585
1606
  state.out_list.append(out_dict)
1586
1607
  state.event.set()
1587
1608
 
@@ -1947,10 +1968,6 @@ class ServerStatus(Enum):
1947
1968
  Up = "Up"
1948
1969
  Starting = "Starting"
1949
1970
  UnHealthy = "UnHealthy"
1950
- Crashed = "Crashed"
1951
-
1952
- def is_healthy(self) -> bool:
1953
- return self == ServerStatus.Up
1954
1971
 
1955
1972
 
1956
1973
  def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
@@ -311,3 +311,6 @@ class TpModelWorker:
311
311
  def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
312
312
  result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
313
313
  return result
314
+
315
+ def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
316
+ return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
@@ -288,6 +288,9 @@ class TpModelWorkerClient:
288
288
  def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
289
289
  return self.worker.unload_lora_adapter(recv_req)
290
290
 
291
+ def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
292
+ return self.worker.can_run_lora_batch(lora_ids)
293
+
291
294
  def __delete__(self):
292
295
  self.input_queue.put((None, None))
293
296
  self.copy_queue.put((None, None, None))