sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +220 -378
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +143 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +681 -259
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +224 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +44 -18
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +94 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +208 -28
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +136 -52
  181. sglang/srt/speculative/build_eagle_tree.py +2 -8
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  183. sglang/srt/speculative/eagle_utils.py +92 -58
  184. sglang/srt/speculative/eagle_worker.py +186 -94
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -17,10 +17,11 @@ import faulthandler
17
17
  import logging
18
18
  import os
19
19
  import signal
20
+ import sys
20
21
  import threading
21
22
  import time
22
23
  import warnings
23
- from collections import deque
24
+ from collections import defaultdict, deque
24
25
  from concurrent import futures
25
26
  from dataclasses import dataclass
26
27
  from http import HTTPStatus
@@ -44,17 +45,24 @@ from sglang.srt.managers.io_struct import (
44
45
  BatchTokenIDOut,
45
46
  CloseSessionReqInput,
46
47
  FlushCacheReq,
48
+ GetInternalStateReq,
49
+ GetInternalStateReqOutput,
47
50
  GetWeightsByNameReqInput,
48
51
  GetWeightsByNameReqOutput,
52
+ HealthCheckOutput,
49
53
  InitWeightsUpdateGroupReqInput,
50
54
  InitWeightsUpdateGroupReqOutput,
51
55
  OpenSessionReqInput,
52
56
  OpenSessionReqOutput,
53
57
  ProfileReq,
58
+ ProfileReqOutput,
59
+ ProfileReqType,
54
60
  ReleaseMemoryOccupationReqInput,
55
61
  ReleaseMemoryOccupationReqOutput,
56
62
  ResumeMemoryOccupationReqInput,
57
63
  ResumeMemoryOccupationReqOutput,
64
+ SetInternalStateReq,
65
+ SetInternalStateReqOutput,
58
66
  TokenizedEmbeddingReqInput,
59
67
  TokenizedGenerateReqInput,
60
68
  UpdateWeightFromDiskReqInput,
@@ -82,6 +90,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
82
90
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
83
91
  from sglang.srt.managers.utils import validate_input_length
84
92
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
93
+ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
85
94
  from sglang.srt.mem_cache.radix_cache import RadixCache
86
95
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
87
96
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
@@ -94,6 +103,7 @@ from sglang.srt.utils import (
94
103
  crash_on_warnings,
95
104
  get_bool_env_var,
96
105
  get_zmq_socket,
106
+ pyspy_dump_schedulers,
97
107
  set_gpu_proc_affinity,
98
108
  set_random_seed,
99
109
  suppress_other_loggers,
@@ -103,13 +113,16 @@ from sglang.utils import TypeBasedDispatcher, get_exception_traceback
103
113
  logger = logging.getLogger(__name__)
104
114
 
105
115
  # Test retract decode for debugging purposes
106
- test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
116
+ TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
117
+ RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
107
118
 
108
119
 
109
120
  @dataclass
110
121
  class GenerationBatchResult:
111
122
  logits_output: LogitsProcessorOutput
112
123
  next_token_ids: List[int]
124
+ extend_input_len_per_req: List[int]
125
+ extend_logprob_start_len_per_req: List[int]
113
126
  bid: int
114
127
 
115
128
 
@@ -135,21 +148,28 @@ class Scheduler:
135
148
  self.tp_rank = tp_rank
136
149
  self.tp_size = server_args.tp_size
137
150
  self.schedule_policy = server_args.schedule_policy
138
- self.disable_jump_forward = server_args.disable_jump_forward
139
151
  self.lora_paths = server_args.lora_paths
140
152
  self.max_loras_per_batch = server_args.max_loras_per_batch
141
153
  self.enable_overlap = not server_args.disable_overlap_schedule
142
154
  self.skip_tokenizer_init = server_args.skip_tokenizer_init
143
155
  self.enable_metrics = server_args.enable_metrics
156
+ self.stream_interval = server_args.stream_interval
144
157
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
145
158
  server_args.speculative_algorithm
146
159
  )
160
+ self.gpu_id = gpu_id
161
+ self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
147
162
  self.decode_mem_cache_buf_multiplier = (
148
- self.server_args.speculative_num_draft_tokens
163
+ (
164
+ self.server_args.speculative_num_draft_tokens
165
+ + (
166
+ self.server_args.speculative_eagle_topk
167
+ * self.server_args.speculative_num_draft_tokens
168
+ )
169
+ )
149
170
  if not self.spec_algorithm.is_none()
150
171
  else 1
151
172
  )
152
- self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
153
173
 
154
174
  # Distributed rank info
155
175
  self.dp_size = server_args.dp_size
@@ -228,9 +248,6 @@ class Scheduler:
228
248
  self.enable_overlap = False
229
249
  logger.info("Overlap scheduler is disabled for multimodal models.")
230
250
 
231
- if self.enable_overlap:
232
- self.disable_jump_forward = True
233
-
234
251
  # Launch a tensor parallel worker
235
252
  if self.enable_overlap:
236
253
  TpWorkerClass = TpModelWorkerClient
@@ -245,7 +262,7 @@ class Scheduler:
245
262
  nccl_port=port_args.nccl_port,
246
263
  )
247
264
 
248
- # Launch a worker for speculative decoding if needed
265
+ # Launch a draft worker for speculative decoding
249
266
  if self.spec_algorithm.is_eagle():
250
267
  from sglang.srt.speculative.eagle_worker import EAGLEWorker
251
268
 
@@ -257,8 +274,10 @@ class Scheduler:
257
274
  target_worker=self.tp_worker,
258
275
  dp_rank=dp_rank,
259
276
  )
277
+ self.prefill_only_one_req = True
260
278
  else:
261
279
  self.draft_worker = None
280
+ self.prefill_only_one_req = False
262
281
 
263
282
  # Get token and memory info from the model worker
264
283
  (
@@ -279,6 +298,7 @@ class Scheduler:
279
298
  self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
280
299
  global_server_args_dict.update(worker_global_server_args_dict)
281
300
  set_random_seed(self.random_seed)
301
+
282
302
  # Print debug info
283
303
  logger.info(
284
304
  f"max_total_num_tokens={self.max_total_num_tokens}, "
@@ -289,7 +309,9 @@ class Scheduler:
289
309
  )
290
310
 
291
311
  # Init memory pool and cache
292
- self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
312
+ self.req_to_token_pool, self.token_to_kv_pool_allocator = (
313
+ self.tp_worker.get_memory_pool()
314
+ )
293
315
 
294
316
  if (
295
317
  server_args.chunked_prefill_size is not None
@@ -297,19 +319,26 @@ class Scheduler:
297
319
  ):
298
320
  self.tree_cache = ChunkCache(
299
321
  req_to_token_pool=self.req_to_token_pool,
300
- token_to_kv_pool=self.token_to_kv_pool,
322
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
301
323
  )
302
324
  else:
303
- self.tree_cache = RadixCache(
304
- req_to_token_pool=self.req_to_token_pool,
305
- token_to_kv_pool=self.token_to_kv_pool,
306
- disable=server_args.disable_radix_cache,
307
- )
308
- self.tree_cache_metrics = {"total": 0, "hit": 0}
325
+ if self.enable_hierarchical_cache:
326
+ self.tree_cache = HiRadixCache(
327
+ req_to_token_pool=self.req_to_token_pool,
328
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
329
+ )
330
+ else:
331
+ self.tree_cache = RadixCache(
332
+ req_to_token_pool=self.req_to_token_pool,
333
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
334
+ disable=server_args.disable_radix_cache,
335
+ )
336
+
309
337
  self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
310
338
 
311
339
  # Init running status
312
340
  self.waiting_queue: List[Req] = []
341
+ self.staging_reqs = {}
313
342
  # The running decoding batch for continuous batching
314
343
  self.running_batch: Optional[ScheduleBatch] = None
315
344
  # The current forward batch
@@ -321,12 +350,22 @@ class Scheduler:
321
350
  self.num_generated_tokens = 0
322
351
  self.spec_num_total_accepted_tokens = 0
323
352
  self.spec_num_total_forward_ct = 0
353
+ self.cum_spec_accept_length = 0
354
+ self.cum_spec_accept_count = 0
324
355
  self.last_decode_stats_tic = time.time()
325
- self.stream_interval = server_args.stream_interval
356
+ self.return_health_check_ct = 0
326
357
  self.current_stream = torch.get_device_module(self.device).current_stream()
327
358
  if self.device == "cpu":
328
359
  self.current_stream.synchronize = lambda: None # No-op for CPU
329
360
 
361
+ # For metrics only.
362
+ # The largest prefill length of a single request
363
+ self._largest_prefill_len: int = 0
364
+ # The largest context length (prefill + generation) of a single request
365
+ self._largest_prefill_decode_len: int = 0
366
+ self.last_gen_throughput: float = 0.0
367
+ self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
368
+
330
369
  # Session info
331
370
  self.sessions: Dict[str, Session] = {}
332
371
 
@@ -334,7 +373,7 @@ class Scheduler:
334
373
  self.chunked_prefill_size = server_args.chunked_prefill_size
335
374
  if self.chunked_prefill_size <= 0: # -1 means disable
336
375
  self.chunked_prefill_size = None
337
- self.being_chunked_req = None
376
+ self.chunked_req = None
338
377
  self.is_mixed_chunk = (
339
378
  self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
340
379
  )
@@ -368,7 +407,7 @@ class Scheduler:
368
407
  ) / global_config.default_new_token_ratio_decay_steps
369
408
  self.new_token_ratio = self.init_new_token_ratio
370
409
 
371
- # Tells whether the current running batch is full so that we can skip
410
+ # Tell whether the current running batch is full so that we can skip
372
411
  # the check of whether to prefill new requests.
373
412
  # This is an optimization to reduce the overhead of the prefill check.
374
413
  self.batch_is_full = False
@@ -379,26 +418,16 @@ class Scheduler:
379
418
  t.start()
380
419
  self.parent_process = psutil.Process().parent()
381
420
 
421
+ # Init memory saver
382
422
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
383
423
  enable=server_args.enable_memory_saver
384
424
  )
385
425
 
386
426
  # Init profiler
387
- if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
388
- self.profiler = None
389
- else:
390
- self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
391
- logger.info(
392
- "Profiling enabled. Traces will be saved to: %s",
393
- self.torch_profiler_trace_dir,
394
- )
395
- self.profiler = torch.profiler.profile(
396
- activities=[
397
- torch.profiler.ProfilerActivity.CPU,
398
- torch.profiler.ProfilerActivity.CUDA,
399
- ],
400
- with_stack=True,
401
- )
427
+ self.torch_profiler = None
428
+ self.torch_profiler_output_dir: Optional[str] = None
429
+ self.torch_profiler_activities: Optional[List[str]] = None
430
+ self.profiler_target_forward_ct: Optional[int] = None
402
431
 
403
432
  # Init metrics stats
404
433
  self.stats = SchedulerStats()
@@ -410,11 +439,6 @@ class Scheduler:
410
439
  },
411
440
  )
412
441
 
413
- # The largest prefill length of a single request
414
- self._largest_prefill_len: int = 0
415
- # The largest context length (prefill + generation) of a single request
416
- self._largest_prefill_decode_len: int = 0
417
-
418
442
  # Init request dispatcher
419
443
  self._request_dispatcher = TypeBasedDispatcher(
420
444
  [
@@ -422,6 +446,8 @@ class Scheduler:
422
446
  (TokenizedEmbeddingReqInput, self.handle_embedding_request),
423
447
  (FlushCacheReq, self.flush_cache_wrapped),
424
448
  (AbortReq, self.abort_request),
449
+ (OpenSessionReqInput, self.open_session),
450
+ (CloseSessionReqInput, self.close_session),
425
451
  (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
426
452
  (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
427
453
  (
@@ -430,22 +456,15 @@ class Scheduler:
430
456
  ),
431
457
  (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
432
458
  (GetWeightsByNameReqInput, self.get_weights_by_name),
459
+ (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
460
+ (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
433
461
  (ProfileReq, self.profile),
434
- (OpenSessionReqInput, self.open_session),
435
- (CloseSessionReqInput, self.close_session),
436
- (
437
- ReleaseMemoryOccupationReqInput,
438
- lambda _: self.release_memory_occupation(),
439
- ),
440
- (
441
- ResumeMemoryOccupationReqInput,
442
- lambda _: self.resume_memory_occupation(),
443
- ),
462
+ (GetInternalStateReq, self.get_internal_state),
444
463
  ]
445
464
  )
446
465
 
447
466
  def watchdog_thread(self):
448
- """A watch dog thread that will try to kill the server itself if one batch takes too long."""
467
+ """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
449
468
  self.watchdog_last_forward_ct = 0
450
469
  self.watchdog_last_time = time.time()
451
470
 
@@ -460,7 +479,18 @@ class Scheduler:
460
479
  self.watchdog_last_forward_ct = self.forward_ct
461
480
  self.watchdog_last_time = current
462
481
  time.sleep(self.watchdog_timeout // 2)
463
- # Wait sometimes so that the parent process can print the error.
482
+
483
+ # Print batch size and memory pool info to check whether there are de-sync issues.
484
+ logger.error(
485
+ f"{self.cur_batch.batch_size()=}, "
486
+ f"{self.cur_batch.reqs=}, "
487
+ f"{self.token_to_kv_pool.available_size()=}, "
488
+ f"{self.tree_cache.evictable_size()=}, "
489
+ )
490
+ # Wait for some time so that the parent process can print the error.
491
+ pyspy_dump_schedulers()
492
+ print(file=sys.stderr, flush=True)
493
+ print(file=sys.stdout, flush=True)
464
494
  time.sleep(5)
465
495
  self.parent_process.send_signal(signal.SIGQUIT)
466
496
 
@@ -577,6 +607,13 @@ class Scheduler:
577
607
 
578
608
  def process_input_requests(self, recv_reqs: List):
579
609
  for recv_req in recv_reqs:
610
+ # If it is a health check generation request and there are running requests, ignore it.
611
+ if is_health_check_generate_req(recv_req) and (
612
+ self.chunked_req is not None or self.running_batch is not None
613
+ ):
614
+ self.return_health_check_ct += 1
615
+ continue
616
+
580
617
  output = self._request_dispatcher(recv_req)
581
618
  if output is not None:
582
619
  self.send_to_tokenizer.send_pyobj(output)
@@ -591,7 +628,6 @@ class Scheduler:
591
628
  or recv_req.session_params.id is None
592
629
  or recv_req.session_params.id not in self.sessions
593
630
  ):
594
-
595
631
  if recv_req.input_embeds is not None:
596
632
  # Generate fake input_ids based on the length of input_embeds
597
633
  seq_length = len(recv_req.input_embeds)
@@ -618,10 +654,12 @@ class Scheduler:
618
654
  recv_req.sampling_params,
619
655
  return_logprob=recv_req.return_logprob,
620
656
  top_logprobs_num=recv_req.top_logprobs_num,
657
+ token_ids_logprob=recv_req.token_ids_logprob,
621
658
  stream=recv_req.stream,
622
659
  lora_path=recv_req.lora_path,
623
660
  input_embeds=recv_req.input_embeds,
624
661
  custom_logit_processor=custom_logit_processor,
662
+ return_hidden_states=recv_req.return_hidden_states,
625
663
  eos_token_ids=self.model_config.hf_eos_token_id,
626
664
  )
627
665
  req.tokenizer = self.tokenizer
@@ -633,14 +671,14 @@ class Scheduler:
633
671
  req.finished_reason = FINISH_ABORT(
634
672
  f"Invalid request: session id {recv_req.session_params.id} does not exist"
635
673
  )
636
- self.waiting_queue.append(req)
674
+ self._add_request_to_queue(req)
637
675
  return
638
676
  else:
639
677
  # Create a new request from a previous session
640
678
  session = self.sessions[recv_req.session_params.id]
641
679
  req = session.create_req(recv_req, self.tokenizer)
642
680
  if isinstance(req.finished_reason, FINISH_ABORT):
643
- self.waiting_queue.append(req)
681
+ self._add_request_to_queue(req)
644
682
  return
645
683
 
646
684
  # Handle multimodal inputs
@@ -664,7 +702,7 @@ class Scheduler:
664
702
  req.finished_reason = FINISH_ABORT(
665
703
  error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
666
704
  )
667
- self.waiting_queue.append(req)
705
+ self._add_request_to_queue(req)
668
706
  return
669
707
 
670
708
  # Validate prompts length
@@ -674,16 +712,28 @@ class Scheduler:
674
712
  self.server_args.allow_auto_truncate,
675
713
  )
676
714
  if error_msg:
677
- self.waiting_queue.append(req)
715
+ req.origin_input_ids = [0]
716
+ req.sampling_params.max_new_tokens = 0
717
+ self._add_request_to_queue(req)
678
718
  return
679
719
 
680
720
  # Copy more attributes
681
- if recv_req.logprob_start_len == -1:
721
+ if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
682
722
  # By default, only return the logprobs for output tokens
683
723
  req.logprob_start_len = len(req.origin_input_ids) - 1
684
724
  else:
685
725
  req.logprob_start_len = recv_req.logprob_start_len
686
726
 
727
+ if req.logprob_start_len >= len(req.origin_input_ids):
728
+ req.finished_reason = FINISH_ABORT(
729
+ f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
730
+ HTTPStatus.BAD_REQUEST,
731
+ "BadRequestError",
732
+ )
733
+ req.logprob_start_len = len(req.origin_input_ids) - 1
734
+ self._add_request_to_queue(req)
735
+ return
736
+
687
737
  req.sampling_params.max_new_tokens = min(
688
738
  (
689
739
  req.sampling_params.max_new_tokens
@@ -699,6 +749,7 @@ class Scheduler:
699
749
  req.sampling_params.json_schema is not None
700
750
  or req.sampling_params.regex is not None
701
751
  or req.sampling_params.ebnf is not None
752
+ or req.sampling_params.structural_tag is not None
702
753
  ):
703
754
  assert self.grammar_backend is not None
704
755
  if req.sampling_params.json_schema is not None:
@@ -707,6 +758,8 @@ class Scheduler:
707
758
  key = ("regex", req.sampling_params.regex)
708
759
  elif req.sampling_params.ebnf is not None:
709
760
  key = ("ebnf", req.sampling_params.ebnf)
761
+ elif req.sampling_params.structural_tag:
762
+ key = ("structural_tag", req.sampling_params.structural_tag)
710
763
 
711
764
  req.grammar = self.grammar_backend.get_cached_value(key)
712
765
  if not req.grammar:
@@ -716,7 +769,13 @@ class Scheduler:
716
769
  if add_to_grammar_queue:
717
770
  self.grammar_queue.append(req)
718
771
  else:
719
- self.waiting_queue.append(req)
772
+ self._add_request_to_queue(req)
773
+
774
+ def _add_request_to_queue(self, req: Req):
775
+ self.waiting_queue.append(req)
776
+
777
+ def _extend_requests_to_queue(self, reqs: List[Req]):
778
+ self.waiting_queue.extend(reqs)
720
779
 
721
780
  def handle_embedding_request(
722
781
  self,
@@ -737,61 +796,64 @@ class Scheduler:
737
796
  self.server_args.allow_auto_truncate,
738
797
  )
739
798
  if error_msg:
740
- self.waiting_queue.append(req)
799
+ self._add_request_to_queue(req)
741
800
  return
742
801
 
743
802
  # Copy more attributes
744
803
  req.logprob_start_len = len(req.origin_input_ids) - 1
745
- self.waiting_queue.append(req)
804
+ self._add_request_to_queue(req)
746
805
 
747
806
  def log_prefill_stats(
748
807
  self,
749
808
  adder: PrefillAdder,
750
809
  can_run_list: List[Req],
751
- running_bs: ScheduleBatch,
752
- has_being_chunked: bool,
810
+ running_bs: int,
753
811
  ):
754
- self.tree_cache_metrics["total"] += (
755
- adder.log_input_tokens + adder.log_hit_tokens
756
- ) / 10**9
757
- self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
758
- tree_cache_hit_rate = (
759
- self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
760
- )
761
-
762
812
  num_used = self.max_total_num_tokens - (
763
- self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
813
+ self.token_to_kv_pool_allocator.available_size()
814
+ + self.tree_cache.evictable_size()
815
+ )
816
+ self._largest_prefill_len = max(
817
+ self._largest_prefill_len, adder.log_input_tokens
764
818
  )
765
819
 
766
- logger.info(
820
+ f = (
767
821
  f"Prefill batch. "
768
822
  f"#new-seq: {len(can_run_list)}, "
769
823
  f"#new-token: {adder.log_input_tokens}, "
770
824
  f"#cached-token: {adder.log_hit_tokens}, "
771
- f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
772
825
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
773
826
  f"#running-req: {running_bs}, "
774
- f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
827
+ f"#queue-req: {len(self.waiting_queue)}, "
775
828
  )
829
+ logger.info(f)
776
830
 
777
831
  if self.enable_metrics:
832
+ cache_hit_rate = adder.log_hit_tokens / (
833
+ adder.log_input_tokens + adder.log_hit_tokens
834
+ )
778
835
  self.stats.num_running_reqs = running_bs
779
836
  self.stats.num_used_tokens = num_used
780
837
  self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
781
- self.stats.num_queue_reqs = len(self.waiting_queue) + has_being_chunked
782
- self.stats.cache_hit_rate = tree_cache_hit_rate
838
+ self.stats.num_queue_reqs = len(self.waiting_queue)
839
+ self.stats.cache_hit_rate = cache_hit_rate
783
840
  self.metrics_collector.log_stats(self.stats)
784
841
 
785
842
  def log_decode_stats(self):
786
- num_used = self.max_total_num_tokens - (
787
- self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
788
- )
789
- gen_throughput = self.num_generated_tokens / (
790
- time.time() - self.last_decode_stats_tic
791
- )
792
- self.num_generated_tokens = 0
843
+ gap_latency = time.time() - self.last_decode_stats_tic
793
844
  self.last_decode_stats_tic = time.time()
845
+ self.last_gen_throughput = self.num_generated_tokens / gap_latency
846
+ self.num_generated_tokens = 0
794
847
  num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
848
+ num_used = self.max_total_num_tokens - (
849
+ self.token_to_kv_pool_allocator.available_size()
850
+ + self.tree_cache.evictable_size()
851
+ )
852
+
853
+ if RECORD_STEP_TIME:
854
+ self.step_time_dict[num_running_reqs].append(
855
+ gap_latency / self.server_args.decode_log_interval
856
+ )
795
857
 
796
858
  if self.spec_algorithm.is_none():
797
859
  msg = (
@@ -799,14 +861,17 @@ class Scheduler:
799
861
  f"#running-req: {num_running_reqs}, "
800
862
  f"#token: {num_used}, "
801
863
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
802
- f"gen throughput (token/s): {gen_throughput:.2f}, "
803
- f"#queue-req: {len(self.waiting_queue)}"
864
+ f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
865
+ f"largest-len: {self._largest_prefill_decode_len}, "
866
+ f"#queue-req: {len(self.waiting_queue)}, "
804
867
  )
805
868
  spec_accept_length = 0
806
869
  else:
807
870
  spec_accept_length = (
808
871
  self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
809
872
  )
873
+ self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
874
+ self.cum_spec_accept_count += self.spec_num_total_forward_ct
810
875
  self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
811
876
  msg = (
812
877
  f"Decode batch. "
@@ -814,8 +879,9 @@ class Scheduler:
814
879
  f"#token: {num_used}, "
815
880
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
816
881
  f"accept len: {spec_accept_length:.2f}, "
817
- f"gen throughput (token/s): {gen_throughput:.2f}, "
818
- f"#queue-req: {len(self.waiting_queue)}"
882
+ f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
883
+ f"largest-len: {self._largest_prefill_decode_len}, "
884
+ f"#queue-req: {len(self.waiting_queue)}, "
819
885
  )
820
886
 
821
887
  logger.info(msg)
@@ -823,14 +889,16 @@ class Scheduler:
823
889
  self.stats.num_running_reqs = num_running_reqs
824
890
  self.stats.num_used_tokens = num_used
825
891
  self.stats.token_usage = num_used / self.max_total_num_tokens
826
- self.stats.gen_throughput = gen_throughput
892
+ self.stats.cache_hit_rate = 0.0
893
+ self.stats.gen_throughput = self.last_gen_throughput
827
894
  self.stats.num_queue_reqs = len(self.waiting_queue)
828
895
  self.stats.spec_accept_length = spec_accept_length
829
896
  self.metrics_collector.log_stats(self.stats)
830
897
 
831
898
  def check_memory(self):
832
899
  available_size = (
833
- self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
900
+ self.token_to_kv_pool_allocator.available_size()
901
+ + self.tree_cache.evictable_size()
834
902
  )
835
903
  protected_size = self.tree_cache.protected_size()
836
904
  memory_leak = available_size != (
@@ -857,21 +925,42 @@ class Scheduler:
857
925
  if crash_on_warnings():
858
926
  raise ValueError(msg)
859
927
 
928
+ if (
929
+ self.enable_metrics
930
+ and self.attn_tp_rank == 0
931
+ and time.time() > self.metrics_collector.last_log_time + 30
932
+ ):
933
+ # During idle time, also collect metrics every 30 seconds.
934
+ num_used = self.max_total_num_tokens - (
935
+ self.token_to_kv_pool.available_size()
936
+ + self.tree_cache.evictable_size()
937
+ )
938
+ num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
939
+ self.stats.num_running_reqs = num_running_reqs
940
+ self.stats.num_used_tokens = num_used
941
+ self.stats.token_usage = num_used / self.max_total_num_tokens
942
+ self.stats.gen_throughput = 0
943
+ self.stats.num_queue_reqs = len(self.waiting_queue)
944
+ self.metrics_collector.log_stats(self.stats)
945
+
860
946
  def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
861
947
  # Merge the prefill batch into the running batch
862
948
  if self.last_batch and self.last_batch.forward_mode.is_extend():
863
- if self.being_chunked_req:
864
- # Move the chunked request out of the batch
865
- self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
866
- self.tree_cache.cache_unfinished_req(self.being_chunked_req)
867
- # being chunked request keeps its rid but will get a new req_pool_idx
868
- self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
949
+ if self.chunked_req:
950
+ # Move the chunked request out of the batch so that we can merge
951
+ # only finished requests to running_batch.
952
+ self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
953
+ self.tree_cache.cache_unfinished_req(self.chunked_req)
954
+ # chunked request keeps its rid but will get a new req_pool_idx
955
+ self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
869
956
  self.batch_is_full = False
870
957
 
958
+ self.last_batch.filter_batch()
871
959
  if not self.last_batch.is_empty():
872
960
  if self.running_batch is None:
873
961
  self.running_batch = self.last_batch
874
962
  else:
963
+ # merge running_batch with prefill batch
875
964
  self.running_batch.merge_batch(self.last_batch)
876
965
 
877
966
  new_batch = self.get_new_batch_prefill()
@@ -900,7 +989,7 @@ class Scheduler:
900
989
  # Handle the cases where prefill is not allowed
901
990
  if (
902
991
  self.batch_is_full or len(self.waiting_queue) == 0
903
- ) and self.being_chunked_req is None:
992
+ ) and self.chunked_req is None:
904
993
  return None
905
994
 
906
995
  running_bs = len(self.running_batch.reqs) if self.running_batch else 0
@@ -914,7 +1003,7 @@ class Scheduler:
914
1003
  # Prefill policy
915
1004
  adder = PrefillAdder(
916
1005
  self.tree_cache,
917
- self.token_to_kv_pool,
1006
+ self.token_to_kv_pool_allocator,
918
1007
  self.running_batch,
919
1008
  self.new_token_ratio,
920
1009
  self.max_prefill_tokens,
@@ -922,10 +1011,10 @@ class Scheduler:
922
1011
  running_bs if self.is_mixed_chunk else 0,
923
1012
  )
924
1013
 
925
- has_being_chunked = self.being_chunked_req is not None
926
- if has_being_chunked:
927
- self.being_chunked_req.init_next_round_input()
928
- self.being_chunked_req = adder.add_being_chunked_req(self.being_chunked_req)
1014
+ is_chunked = self.chunked_req is not None
1015
+ if is_chunked:
1016
+ self.chunked_req.init_next_round_input()
1017
+ self.chunked_req = adder.add_chunked_req(self.chunked_req)
929
1018
 
930
1019
  if self.lora_paths:
931
1020
  lora_set = (
@@ -933,7 +1022,6 @@ class Scheduler:
933
1022
  if self.running_batch is not None
934
1023
  else set([])
935
1024
  )
936
-
937
1025
  # Get requests from the waiting queue to a new prefill batch
938
1026
  for req in self.waiting_queue:
939
1027
  if (
@@ -953,7 +1041,31 @@ class Scheduler:
953
1041
  break
954
1042
 
955
1043
  req.init_next_round_input(None if prefix_computed else self.tree_cache)
956
- res = adder.add_one_req(req)
1044
+
1045
+ if self.enable_hierarchical_cache and req.last_node is not None:
1046
+ if req.last_node.evicted:
1047
+ # loading KV cache for the request
1048
+ req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
1049
+ req.last_node,
1050
+ req.prefix_indices,
1051
+ adder.rem_total_tokens,
1052
+ )
1053
+ if req.last_node.loading:
1054
+ # to prevent frequent cache invalidation
1055
+ if req.rid in self.staging_reqs:
1056
+ self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
1057
+ self.tree_cache.inc_lock_ref(req.last_node)
1058
+ self.staging_reqs[req.rid] = req.last_node
1059
+ continue
1060
+ elif req.last_node.loading:
1061
+ if not self.tree_cache.loading_complete(req.last_node):
1062
+ continue
1063
+
1064
+ if req.rid in self.staging_reqs:
1065
+ self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
1066
+ del self.staging_reqs[req.rid]
1067
+
1068
+ res = adder.add_one_req(req, self.chunked_req)
957
1069
  if res != AddReqResult.CONTINUE:
958
1070
  if res == AddReqResult.NO_TOKEN:
959
1071
  if self.enable_hierarchical_cache:
@@ -965,39 +1077,38 @@ class Scheduler:
965
1077
  else:
966
1078
  self.batch_is_full = True
967
1079
  break
968
- if self.server_args.prefill_only_one_req:
1080
+ if self.prefill_only_one_req:
969
1081
  break
970
1082
 
971
1083
  # Update waiting queue
972
- can_run_list = adder.can_run_list
1084
+ can_run_list: List[Req] = adder.can_run_list
973
1085
  if len(can_run_list) == 0:
974
1086
  return None
975
1087
  self.waiting_queue = [
976
1088
  x for x in self.waiting_queue if x not in set(can_run_list)
977
1089
  ]
978
1090
 
979
- if adder.new_being_chunked_req is not None:
980
- assert self.being_chunked_req is None
981
- self.being_chunked_req = adder.new_being_chunked_req
1091
+ if adder.new_chunked_req is not None:
1092
+ assert self.chunked_req is None
1093
+ self.chunked_req = adder.new_chunked_req
982
1094
 
983
- if self.being_chunked_req:
984
- self.being_chunked_req.is_being_chunked += 1
1095
+ if self.chunked_req:
1096
+ self.chunked_req.is_chunked += 1
985
1097
 
986
1098
  # Print stats
987
1099
  if self.attn_tp_rank == 0:
988
- self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
1100
+ self.log_prefill_stats(adder, can_run_list, running_bs)
989
1101
 
990
1102
  # Create a new batch
991
1103
  new_batch = ScheduleBatch.init_new(
992
1104
  can_run_list,
993
1105
  self.req_to_token_pool,
994
- self.token_to_kv_pool,
1106
+ self.token_to_kv_pool_allocator,
995
1107
  self.tree_cache,
996
1108
  self.model_config,
997
1109
  self.enable_overlap,
998
1110
  self.spec_algorithm,
999
1111
  self.server_args.enable_custom_logit_processor,
1000
- self.server_args.return_hidden_states,
1001
1112
  )
1002
1113
  new_batch.prepare_for_extend()
1003
1114
 
@@ -1021,8 +1132,6 @@ class Scheduler:
1021
1132
 
1022
1133
  def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1023
1134
  """Update the current running decoding batch."""
1024
- global test_retract
1025
-
1026
1135
  initial_bs = batch.batch_size()
1027
1136
 
1028
1137
  batch.filter_batch()
@@ -1032,35 +1141,25 @@ class Scheduler:
1032
1141
 
1033
1142
  # Check if decode out of memory
1034
1143
  if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1035
- test_retract and batch.batch_size() > 10
1144
+ TEST_RETRACT and batch.batch_size() > 10
1036
1145
  ):
1037
1146
  old_ratio = self.new_token_ratio
1038
1147
 
1039
- retracted_reqs, new_token_ratio = batch.retract_decode()
1148
+ retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
1040
1149
  self.new_token_ratio = new_token_ratio
1041
- if self.draft_worker:
1042
- self.draft_worker.finish_request(retracted_reqs)
1043
1150
 
1044
1151
  logger.info(
1045
1152
  "Decode out of memory happened. "
1046
1153
  f"#retracted_reqs: {len(retracted_reqs)}, "
1047
1154
  f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
1048
1155
  )
1049
- self.waiting_queue.extend(retracted_reqs)
1156
+ self._extend_requests_to_queue(retracted_reqs)
1050
1157
  else:
1051
1158
  self.new_token_ratio = max(
1052
1159
  self.new_token_ratio - self.new_token_ratio_decay,
1053
1160
  self.min_new_token_ratio,
1054
1161
  )
1055
1162
 
1056
- # Check for jump-forward
1057
- if not self.disable_jump_forward and batch.has_grammar:
1058
- jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
1059
- self.waiting_queue.extend(jump_forward_reqs)
1060
- if batch.is_empty():
1061
- self.batch_is_full = False
1062
- return None
1063
-
1064
1163
  if batch.batch_size() < initial_bs:
1065
1164
  self.batch_is_full = False
1066
1165
 
@@ -1074,17 +1173,25 @@ class Scheduler:
1074
1173
  """Run a batch."""
1075
1174
  self.forward_ct += 1
1076
1175
 
1176
+ # Check profiler
1177
+ if (
1178
+ self.profiler_target_forward_ct
1179
+ and self.profiler_target_forward_ct <= self.forward_ct
1180
+ ):
1181
+ self.stop_profile()
1182
+
1077
1183
  if self.is_generation:
1078
1184
  if self.spec_algorithm.is_none():
1079
1185
  model_worker_batch = batch.get_model_worker_batch()
1080
1186
  logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
1081
1187
  model_worker_batch
1082
1188
  )
1189
+ bid = model_worker_batch.bid
1083
1190
  else:
1084
1191
  (
1085
1192
  logits_output,
1086
1193
  next_token_ids,
1087
- model_worker_batch,
1194
+ bid,
1088
1195
  num_accepted_tokens,
1089
1196
  ) = self.draft_worker.forward_batch_speculative_generation(batch)
1090
1197
  self.spec_num_total_accepted_tokens += (
@@ -1093,11 +1200,24 @@ class Scheduler:
1093
1200
  self.spec_num_total_forward_ct += batch.batch_size()
1094
1201
  self.num_generated_tokens += num_accepted_tokens
1095
1202
  batch.output_ids = next_token_ids
1203
+ # These 2 values are needed for processing the output, but the values can be
1204
+ # modified by overlap schedule. So we have to copy them here so that
1205
+ # we can use the correct values in output processing.
1206
+ if batch.return_logprob:
1207
+ extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
1208
+ extend_logprob_start_len_per_req = [
1209
+ req.extend_logprob_start_len for req in batch.reqs
1210
+ ]
1211
+ else:
1212
+ extend_input_len_per_req = None
1213
+ extend_logprob_start_len_per_req = None
1096
1214
 
1097
1215
  ret = GenerationBatchResult(
1098
1216
  logits_output=logits_output,
1099
1217
  next_token_ids=next_token_ids,
1100
- bid=model_worker_batch.bid,
1218
+ extend_input_len_per_req=extend_input_len_per_req,
1219
+ extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1220
+ bid=bid,
1101
1221
  )
1102
1222
  else: # embedding or reward model
1103
1223
  model_worker_batch = batch.get_model_worker_batch()
@@ -1113,6 +1233,7 @@ class Scheduler:
1113
1233
  result: Union[GenerationBatchResult, EmbeddingBatchResult],
1114
1234
  ):
1115
1235
  if batch.forward_mode.is_decode():
1236
+ assert isinstance(result, GenerationBatchResult)
1116
1237
  self.process_batch_result_decode(batch, result)
1117
1238
  if batch.is_empty():
1118
1239
  self.running_batch = None
@@ -1121,11 +1242,22 @@ class Scheduler:
1121
1242
  elif batch.forward_mode.is_idle():
1122
1243
  if self.enable_overlap:
1123
1244
  self.tp_worker.resolve_batch_result(result.bid)
1245
+ if batch.next_batch_sampling_info:
1246
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
1247
+ self.current_stream.synchronize()
1248
+ batch.next_batch_sampling_info.sampling_info_done.set()
1124
1249
  elif batch.forward_mode.is_dummy_first():
1125
1250
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1126
1251
  self.current_stream.synchronize()
1127
1252
  batch.next_batch_sampling_info.sampling_info_done.set()
1128
1253
 
1254
+ if self.return_health_check_ct:
1255
+ # Return some signal for the health check.
1256
+ # This is used to prevent the health check signal being blocked by long context prefill.
1257
+ # However, one minor issue is that this code path does not check the status of detokenizer manager.
1258
+ self.return_health_check_ct -= 1
1259
+ self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
1260
+
1129
1261
  def process_batch_result_prefill(
1130
1262
  self,
1131
1263
  batch: ScheduleBatch,
@@ -1137,10 +1269,14 @@ class Scheduler:
1137
1269
  (
1138
1270
  logits_output,
1139
1271
  next_token_ids,
1272
+ extend_input_len_per_req,
1273
+ extend_logprob_start_len_per_req,
1140
1274
  bid,
1141
1275
  ) = (
1142
1276
  result.logits_output,
1143
1277
  result.next_token_ids,
1278
+ result.extend_input_len_per_req,
1279
+ result.extend_logprob_start_len_per_req,
1144
1280
  result.bid,
1145
1281
  )
1146
1282
 
@@ -1150,12 +1286,14 @@ class Scheduler:
1150
1286
  # Move next_token_ids and logprobs to cpu
1151
1287
  next_token_ids = next_token_ids.tolist()
1152
1288
  if batch.return_logprob:
1153
- logits_output.next_token_logprobs = (
1154
- logits_output.next_token_logprobs.tolist()
1155
- )
1156
- logits_output.input_token_logprobs = (
1157
- logits_output.input_token_logprobs.tolist()
1158
- )
1289
+ if logits_output.next_token_logprobs is not None:
1290
+ logits_output.next_token_logprobs = (
1291
+ logits_output.next_token_logprobs.tolist()
1292
+ )
1293
+ if logits_output.input_token_logprobs is not None:
1294
+ logits_output.input_token_logprobs = tuple(
1295
+ logits_output.input_token_logprobs.tolist()
1296
+ )
1159
1297
 
1160
1298
  hidden_state_offset = 0
1161
1299
 
@@ -1168,25 +1306,38 @@ class Scheduler:
1168
1306
  if self.is_mixed_chunk and self.enable_overlap and req.finished():
1169
1307
  # Free the one delayed token for the mixed decode batch
1170
1308
  j = len(batch.out_cache_loc) - len(batch.reqs) + i
1171
- self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
1309
+ self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
1172
1310
  continue
1173
1311
 
1174
- if req.is_being_chunked <= 0:
1312
+ if req.is_chunked <= 0:
1313
+ # req output_ids are set here
1175
1314
  req.output_ids.append(next_token_id)
1176
1315
  req.check_finished()
1177
1316
 
1178
1317
  if req.finished():
1179
1318
  self.tree_cache.cache_finished_req(req)
1180
1319
  elif not batch.decoding_reqs or req not in batch.decoding_reqs:
1320
+ # This updates radix so others can match
1181
1321
  self.tree_cache.cache_unfinished_req(req)
1182
1322
 
1183
1323
  if req.return_logprob:
1184
- logprob_pt += self.add_logprob_return_values(
1185
- i, req, logprob_pt, next_token_ids, logits_output
1324
+ assert extend_logprob_start_len_per_req is not None
1325
+ assert extend_input_len_per_req is not None
1326
+ extend_logprob_start_len = extend_logprob_start_len_per_req[i]
1327
+ extend_input_len = extend_input_len_per_req[i]
1328
+ num_input_logprobs = extend_input_len - extend_logprob_start_len
1329
+ self.add_logprob_return_values(
1330
+ i,
1331
+ req,
1332
+ logprob_pt,
1333
+ next_token_ids,
1334
+ num_input_logprobs,
1335
+ logits_output,
1186
1336
  )
1337
+ logprob_pt += num_input_logprobs
1187
1338
 
1188
1339
  if (
1189
- self.server_args.return_hidden_states
1340
+ req.return_hidden_states
1190
1341
  and logits_output.hidden_states is not None
1191
1342
  ):
1192
1343
  req.hidden_states.append(
@@ -1205,12 +1356,31 @@ class Scheduler:
1205
1356
  req.grammar.finished = req.finished()
1206
1357
  else:
1207
1358
  # being chunked reqs' prefill is not finished
1208
- req.is_being_chunked -= 1
1359
+ req.is_chunked -= 1
1209
1360
  # There is only at most one request being currently chunked.
1210
1361
  # Because this request does not finish prefill,
1211
1362
  # we don't want to stream the request currently being chunked.
1212
1363
  skip_stream_req = req
1213
1364
 
1365
+ # Incrementally update input logprobs.
1366
+ if req.return_logprob:
1367
+ extend_logprob_start_len = extend_logprob_start_len_per_req[i]
1368
+ extend_input_len = extend_input_len_per_req[i]
1369
+ if extend_logprob_start_len < extend_input_len:
1370
+ # Update input logprobs.
1371
+ num_input_logprobs = (
1372
+ extend_input_len - extend_logprob_start_len
1373
+ )
1374
+ self.add_input_logprob_return_values(
1375
+ i,
1376
+ req,
1377
+ logits_output,
1378
+ logprob_pt,
1379
+ num_input_logprobs,
1380
+ last_prefill_chunk=False,
1381
+ )
1382
+ logprob_pt += num_input_logprobs
1383
+
1214
1384
  if batch.next_batch_sampling_info:
1215
1385
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1216
1386
  self.current_stream.synchronize()
@@ -1226,7 +1396,7 @@ class Scheduler:
1226
1396
  continue
1227
1397
 
1228
1398
  req.embedding = embeddings[i]
1229
- if req.is_being_chunked <= 0:
1399
+ if req.is_chunked <= 0:
1230
1400
  # Dummy output token for embedding models
1231
1401
  req.output_ids.append(0)
1232
1402
  req.check_finished()
@@ -1237,7 +1407,7 @@ class Scheduler:
1237
1407
  self.tree_cache.cache_unfinished_req(req)
1238
1408
  else:
1239
1409
  # being chunked reqs' prefill is not finished
1240
- req.is_being_chunked -= 1
1410
+ req.is_chunked -= 1
1241
1411
 
1242
1412
  self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1243
1413
 
@@ -1254,23 +1424,27 @@ class Scheduler:
1254
1424
  self.num_generated_tokens += len(batch.reqs)
1255
1425
 
1256
1426
  if self.enable_overlap:
1427
+ assert batch.spec_algorithm.is_none()
1257
1428
  logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1258
1429
  next_token_logprobs = logits_output.next_token_logprobs
1259
- else:
1430
+ elif batch.spec_algorithm.is_none():
1431
+ # spec decoding handles output logprobs inside verify process.
1260
1432
  next_token_ids = next_token_ids.tolist()
1261
1433
  if batch.return_logprob:
1262
1434
  next_token_logprobs = logits_output.next_token_logprobs.tolist()
1263
1435
 
1264
- self.token_to_kv_pool.free_group_begin()
1436
+ self.token_to_kv_pool_allocator.free_group_begin()
1265
1437
 
1266
1438
  # Check finish condition
1439
+ # NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
1440
+ # We should ignore using next_token_ids for spec decoding cases.
1267
1441
  for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1268
1442
  if req.is_retracted:
1269
1443
  continue
1270
1444
 
1271
1445
  if self.enable_overlap and req.finished():
1272
1446
  # Free the one delayed token
1273
- self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
1447
+ self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
1274
1448
  continue
1275
1449
 
1276
1450
  if batch.spec_algorithm.is_none():
@@ -1278,11 +1452,11 @@ class Scheduler:
1278
1452
  req.output_ids.append(next_token_id)
1279
1453
 
1280
1454
  req.check_finished()
1281
-
1282
1455
  if req.finished():
1283
1456
  self.tree_cache.cache_finished_req(req)
1284
1457
 
1285
- if req.return_logprob:
1458
+ if req.return_logprob and batch.spec_algorithm.is_none():
1459
+ # speculative worker handles logprob in speculative decoding
1286
1460
  req.output_token_logprobs_val.append(next_token_logprobs[i])
1287
1461
  req.output_token_logprobs_idx.append(next_token_id)
1288
1462
  if req.top_logprobs_num > 0:
@@ -1292,14 +1466,18 @@ class Scheduler:
1292
1466
  req.output_top_logprobs_idx.append(
1293
1467
  logits_output.next_token_top_logprobs_idx[i]
1294
1468
  )
1469
+ if req.token_ids_logprob is not None:
1470
+ req.output_token_ids_logprobs_val.append(
1471
+ logits_output.next_token_token_ids_logprobs_val[i]
1472
+ )
1473
+ req.output_token_ids_logprobs_idx.append(
1474
+ logits_output.next_token_token_ids_logprobs_idx[i]
1475
+ )
1295
1476
 
1296
- if (
1297
- self.server_args.return_hidden_states
1298
- and logits_output.hidden_states is not None
1299
- ):
1477
+ if req.return_hidden_states and logits_output.hidden_states is not None:
1300
1478
  req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
1301
1479
 
1302
- if req.grammar is not None:
1480
+ if req.grammar is not None and batch.spec_algorithm.is_none():
1303
1481
  req.grammar.accept_token(next_token_id)
1304
1482
  req.grammar.finished = req.finished()
1305
1483
 
@@ -1307,10 +1485,9 @@ class Scheduler:
1307
1485
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1308
1486
  self.current_stream.synchronize()
1309
1487
  batch.next_batch_sampling_info.sampling_info_done.set()
1310
-
1311
1488
  self.stream_output(batch.reqs, batch.return_logprob)
1312
1489
 
1313
- self.token_to_kv_pool.free_group_end()
1490
+ self.token_to_kv_pool_allocator.free_group_end()
1314
1491
 
1315
1492
  self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
1316
1493
  if (
@@ -1319,86 +1496,167 @@ class Scheduler:
1319
1496
  ):
1320
1497
  self.log_decode_stats()
1321
1498
 
1322
- def add_logprob_return_values(
1499
+ def add_input_logprob_return_values(
1323
1500
  self,
1324
1501
  i: int,
1325
1502
  req: Req,
1326
- pt: int,
1327
- next_token_ids: List[int],
1328
1503
  output: LogitsProcessorOutput,
1504
+ logprob_pt: int,
1505
+ num_input_logprobs: int,
1506
+ last_prefill_chunk: bool, # If True, it means prefill is finished.
1329
1507
  ):
1330
- """Attach logprobs to the return values."""
1331
- req.output_token_logprobs_val.append(output.next_token_logprobs[i])
1332
- req.output_token_logprobs_idx.append(next_token_ids[i])
1508
+ """Incrementally add input logprobs to `req`.
1509
+
1510
+ Args:
1511
+ i: The request index in a batch.
1512
+ req: The request. Input logprobs inside req are modified as a
1513
+ consequence of the API
1514
+ fill_ids: The prefill ids processed.
1515
+ output: Logit processor output that's used to compute input logprobs
1516
+ last_prefill_chunk: True if it is the last prefill (when chunked).
1517
+ Some of input logprob operation should only happen at the last
1518
+ prefill (e.g., computing input token logprobs).
1519
+ """
1520
+ assert output.input_token_logprobs is not None
1521
+ if req.input_token_logprobs is None:
1522
+ req.input_token_logprobs = []
1523
+ if req.temp_input_top_logprobs_val is None:
1524
+ req.temp_input_top_logprobs_val = []
1525
+ if req.temp_input_top_logprobs_idx is None:
1526
+ req.temp_input_top_logprobs_idx = []
1527
+ if req.temp_input_token_ids_logprobs_val is None:
1528
+ req.temp_input_token_ids_logprobs_val = []
1529
+ if req.temp_input_token_ids_logprobs_idx is None:
1530
+ req.temp_input_token_ids_logprobs_idx = []
1531
+
1532
+ if req.input_token_logprobs_val is not None:
1533
+ # The input logprob has been already computed. It only happens
1534
+ # upon retract.
1535
+ if req.top_logprobs_num > 0:
1536
+ assert req.input_token_logprobs_val is not None
1537
+ return
1333
1538
 
1334
- # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
1335
- num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
1539
+ # Important for the performance.
1540
+ assert isinstance(output.input_token_logprobs, tuple)
1541
+ input_token_logprobs: Tuple[int] = output.input_token_logprobs
1542
+ input_token_logprobs = input_token_logprobs[
1543
+ logprob_pt : logprob_pt + num_input_logprobs
1544
+ ]
1545
+ req.input_token_logprobs.extend(input_token_logprobs)
1336
1546
 
1337
- if req.input_token_logprobs_val is None:
1338
- input_token_logprobs_val = output.input_token_logprobs[
1339
- pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
1340
- ]
1547
+ if req.top_logprobs_num > 0:
1548
+ req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
1549
+ req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
1341
1550
 
1342
- input_token_logprobs_idx = req.fill_ids[
1343
- len(req.fill_ids)
1344
- - num_input_logprobs
1345
- + 1 : len(req.fill_ids)
1346
- - req.last_update_decode_tokens
1347
- ]
1551
+ if req.token_ids_logprob is not None:
1552
+ req.temp_input_token_ids_logprobs_val.append(
1553
+ output.input_token_ids_logprobs_val[i]
1554
+ )
1555
+ req.temp_input_token_ids_logprobs_idx.append(
1556
+ output.input_token_ids_logprobs_idx[i]
1557
+ )
1558
+
1559
+ if last_prefill_chunk:
1560
+ input_token_logprobs = req.input_token_logprobs
1561
+ req.input_token_logprobs = None
1562
+ assert req.input_token_logprobs_val is None
1563
+ assert req.input_token_logprobs_idx is None
1564
+ assert req.input_top_logprobs_val is None
1565
+ assert req.input_top_logprobs_idx is None
1566
+
1567
+ # Compute input_token_logprobs_val
1568
+ # Always pad the first one with None.
1569
+ req.input_token_logprobs_val = [None]
1570
+ req.input_token_logprobs_val.extend(input_token_logprobs)
1571
+ # The last input logprob is for sampling, so just pop it out.
1572
+ req.input_token_logprobs_val.pop()
1573
+
1574
+ # Compute input_token_logprobs_idx
1575
+ input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
1348
1576
  # Clip the padded hash values from image tokens.
1349
1577
  # Otherwise, it will lead to detokenization errors.
1350
1578
  input_token_logprobs_idx = [
1351
1579
  x if x < self.model_config.vocab_size - 1 else 0
1352
1580
  for x in input_token_logprobs_idx
1353
1581
  ]
1582
+ req.input_token_logprobs_idx = input_token_logprobs_idx
1354
1583
 
1355
- if (
1356
- req.logprob_start_len == 0
1357
- ): # The first token does not have logprob, pad it.
1358
- input_token_logprobs_val = [None] + input_token_logprobs_val
1359
- input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx
1584
+ if req.top_logprobs_num > 0:
1585
+ req.input_top_logprobs_val = [None]
1586
+ req.input_top_logprobs_idx = [None]
1587
+ assert len(req.temp_input_token_ids_logprobs_val) == len(
1588
+ req.temp_input_token_ids_logprobs_idx
1589
+ )
1590
+ for val, idx in zip(
1591
+ req.temp_input_top_logprobs_val, req.temp_input_top_logprobs_idx
1592
+ ):
1593
+ req.input_top_logprobs_val.extend(val)
1594
+ req.input_top_logprobs_idx.extend(idx)
1595
+
1596
+ # Last token is a sample token.
1597
+ req.input_top_logprobs_val.pop()
1598
+ req.input_top_logprobs_idx.pop()
1599
+ req.temp_input_top_logprobs_idx = None
1600
+ req.temp_input_top_logprobs_val = None
1601
+
1602
+ if req.token_ids_logprob is not None:
1603
+ req.input_token_ids_logprobs_val = [None]
1604
+ req.input_token_ids_logprobs_idx = [None]
1605
+
1606
+ for val, idx in zip(
1607
+ req.temp_input_token_ids_logprobs_val,
1608
+ req.temp_input_token_ids_logprobs_idx,
1609
+ strict=True,
1610
+ ):
1611
+ req.input_token_ids_logprobs_val.extend(val)
1612
+ req.input_token_ids_logprobs_idx.extend(idx)
1360
1613
 
1361
- req.input_token_logprobs_val = input_token_logprobs_val
1362
- req.input_token_logprobs_idx = input_token_logprobs_idx
1614
+ # Last token is a sample token.
1615
+ req.input_token_ids_logprobs_val.pop()
1616
+ req.input_token_ids_logprobs_idx.pop()
1617
+ req.temp_input_token_ids_logprobs_idx = None
1618
+ req.temp_input_token_ids_logprobs_val = None
1363
1619
 
1364
- if req.last_update_decode_tokens != 0:
1365
- # Some decode tokens are re-computed in an extend batch
1366
- req.output_token_logprobs_val.extend(
1367
- output.input_token_logprobs[
1368
- pt
1369
- + num_input_logprobs
1370
- - 1
1371
- - req.last_update_decode_tokens : pt
1372
- + num_input_logprobs
1373
- - 1
1374
- ],
1375
- )
1376
- req.output_token_logprobs_idx.extend(
1377
- req.fill_ids[
1378
- len(req.fill_ids)
1379
- - req.last_update_decode_tokens : len(req.fill_ids)
1380
- ]
1381
- )
1620
+ if req.return_logprob:
1621
+ relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
1622
+ assert len(req.input_token_logprobs_val) == relevant_tokens_len
1623
+ assert len(req.input_token_logprobs_idx) == relevant_tokens_len
1624
+ if req.top_logprobs_num > 0:
1625
+ assert len(req.input_top_logprobs_val) == relevant_tokens_len
1626
+ assert len(req.input_top_logprobs_idx) == relevant_tokens_len
1627
+ if req.token_ids_logprob is not None:
1628
+ assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
1629
+ assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
1382
1630
 
1383
- if req.top_logprobs_num > 0:
1384
- if req.input_top_logprobs_val is None:
1385
- req.input_top_logprobs_val = output.input_top_logprobs_val[i]
1386
- req.input_top_logprobs_idx = output.input_top_logprobs_idx[i]
1387
- if req.logprob_start_len == 0:
1388
- req.input_top_logprobs_val = [None] + req.input_top_logprobs_val
1389
- req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx
1390
-
1391
- if req.last_update_decode_tokens != 0:
1392
- req.output_top_logprobs_val.extend(
1393
- output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
1394
- )
1395
- req.output_top_logprobs_idx.extend(
1396
- output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
1397
- )
1631
+ def add_logprob_return_values(
1632
+ self,
1633
+ i: int,
1634
+ req: Req,
1635
+ pt: int,
1636
+ next_token_ids: List[int],
1637
+ num_input_logprobs: int,
1638
+ output: LogitsProcessorOutput,
1639
+ ):
1640
+ """Attach logprobs to the return values."""
1641
+ req.output_token_logprobs_val.append(output.next_token_logprobs[i])
1642
+ req.output_token_logprobs_idx.append(next_token_ids[i])
1398
1643
 
1644
+ self.add_input_logprob_return_values(
1645
+ i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
1646
+ )
1647
+
1648
+ if req.top_logprobs_num > 0:
1399
1649
  req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
1400
1650
  req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
1401
1651
 
1652
+ if req.token_ids_logprob is not None:
1653
+ req.output_token_ids_logprobs_val.append(
1654
+ output.next_token_token_ids_logprobs_val[i]
1655
+ )
1656
+ req.output_token_ids_logprobs_idx.append(
1657
+ output.next_token_token_ids_logprobs_idx[i]
1658
+ )
1659
+
1402
1660
  return num_input_logprobs
1403
1661
 
1404
1662
  def stream_output(
@@ -1409,7 +1667,6 @@ class Scheduler:
1409
1667
  finished_reasons: List[BaseFinishReason] = []
1410
1668
 
1411
1669
  if self.is_generation:
1412
- vids = []
1413
1670
  decoded_texts = []
1414
1671
  decode_ids_list = []
1415
1672
  read_offsets = []
@@ -1422,7 +1679,7 @@ class Scheduler:
1422
1679
  completion_tokens = []
1423
1680
  cached_tokens = []
1424
1681
  spec_verify_ct = []
1425
- hidden_states = []
1682
+ output_hidden_states = None
1426
1683
 
1427
1684
  if return_logprob:
1428
1685
  input_token_logprobs_val = []
@@ -1433,33 +1690,46 @@ class Scheduler:
1433
1690
  input_top_logprobs_idx = []
1434
1691
  output_top_logprobs_val = []
1435
1692
  output_top_logprobs_idx = []
1693
+ input_token_ids_logprobs_val = []
1694
+ input_token_ids_logprobs_idx = []
1695
+ output_token_ids_logprobs_val = []
1696
+ output_token_ids_logprobs_idx = []
1436
1697
  else:
1437
1698
  input_token_logprobs_val = input_token_logprobs_idx = (
1438
1699
  output_token_logprobs_val
1439
1700
  ) = output_token_logprobs_idx = input_top_logprobs_val = (
1440
1701
  input_top_logprobs_idx
1441
- ) = output_top_logprobs_val = output_top_logprobs_idx = None
1702
+ ) = output_top_logprobs_val = output_top_logprobs_idx = (
1703
+ input_token_ids_logprobs_val
1704
+ ) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
1705
+ output_token_ids_logprobs_idx
1706
+ ) = None
1442
1707
 
1443
1708
  for req in reqs:
1444
1709
  if req is skip_req:
1445
1710
  continue
1446
1711
 
1447
- # TODO(lianmin): revisit this for overlap + retract + stream
1712
+ # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
1713
+ if self.model_config.is_multimodal_gen and req.to_abort:
1714
+ continue
1715
+
1448
1716
  if (
1449
1717
  req.finished()
1450
1718
  # If stream, follow the given stream_interval
1451
1719
  or (req.stream and len(req.output_ids) % self.stream_interval == 0)
1452
1720
  # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
1453
- or (not req.stream and len(req.output_ids) % 50 == 0)
1721
+ # TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
1722
+ # always increase one-by-one.
1723
+ or (
1724
+ not req.stream
1725
+ and len(req.output_ids) % 50 == 0
1726
+ and not self.model_config.is_multimodal_gen
1727
+ )
1454
1728
  ):
1455
- if self.draft_worker and req.finished():
1456
- self.draft_worker.finish_request(req)
1457
-
1458
1729
  rids.append(req.rid)
1459
1730
  finished_reasons.append(
1460
1731
  req.finished_reason.to_json() if req.finished_reason else None
1461
1732
  )
1462
- vids.append(req.vid)
1463
1733
  decoded_texts.append(req.decoded_text)
1464
1734
  decode_ids, read_offset = req.init_incremental_detokenize()
1465
1735
  decode_ids_list.append(decode_ids)
@@ -1488,16 +1758,32 @@ class Scheduler:
1488
1758
  input_top_logprobs_idx.append(req.input_top_logprobs_idx)
1489
1759
  output_top_logprobs_val.append(req.output_top_logprobs_val)
1490
1760
  output_top_logprobs_idx.append(req.output_top_logprobs_idx)
1761
+ input_token_ids_logprobs_val.append(
1762
+ req.input_token_ids_logprobs_val
1763
+ )
1764
+ input_token_ids_logprobs_idx.append(
1765
+ req.input_token_ids_logprobs_idx
1766
+ )
1767
+ output_token_ids_logprobs_val.append(
1768
+ req.output_token_ids_logprobs_val
1769
+ )
1770
+ output_token_ids_logprobs_idx.append(
1771
+ req.output_token_ids_logprobs_idx
1772
+ )
1491
1773
 
1492
- hidden_states.append(req.hidden_states)
1774
+ if req.return_hidden_states:
1775
+ if output_hidden_states is None:
1776
+ output_hidden_states = []
1777
+ output_hidden_states.append(req.hidden_states)
1493
1778
 
1494
1779
  # Send to detokenizer
1495
1780
  if rids:
1781
+ if self.model_config.is_multimodal_gen:
1782
+ raise NotImplementedError()
1496
1783
  self.send_to_detokenizer.send_pyobj(
1497
1784
  BatchTokenIDOut(
1498
1785
  rids,
1499
1786
  finished_reasons,
1500
- vids,
1501
1787
  decoded_texts,
1502
1788
  decode_ids_list,
1503
1789
  read_offsets,
@@ -1517,7 +1803,11 @@ class Scheduler:
1517
1803
  input_top_logprobs_idx,
1518
1804
  output_top_logprobs_val,
1519
1805
  output_top_logprobs_idx,
1520
- hidden_states,
1806
+ input_token_ids_logprobs_val,
1807
+ input_token_ids_logprobs_idx,
1808
+ output_token_ids_logprobs_val,
1809
+ output_token_ids_logprobs_idx,
1810
+ output_hidden_states,
1521
1811
  )
1522
1812
  )
1523
1813
  else: # embedding or reward model
@@ -1575,13 +1865,12 @@ class Scheduler:
1575
1865
  idle_batch = ScheduleBatch.init_new(
1576
1866
  [],
1577
1867
  self.req_to_token_pool,
1578
- self.token_to_kv_pool,
1868
+ self.token_to_kv_pool_allocator,
1579
1869
  self.tree_cache,
1580
1870
  self.model_config,
1581
1871
  self.enable_overlap,
1582
1872
  self.spec_algorithm,
1583
1873
  self.server_args.enable_custom_logit_processor,
1584
- self.server_args.return_hidden_states,
1585
1874
  )
1586
1875
  idle_batch.prepare_for_idle()
1587
1876
  return idle_batch
@@ -1596,18 +1885,25 @@ class Scheduler:
1596
1885
  except futures._base.TimeoutError:
1597
1886
  break
1598
1887
 
1599
- if self.tp_size > 1:
1888
+ if self.server_args.enable_dp_attention:
1889
+ tp_size = self.attn_tp_size
1890
+ tp_group = self.attn_tp_cpu_group
1891
+ else:
1892
+ tp_size = self.tp_size
1893
+ tp_group = self.tp_cpu_group
1894
+
1895
+ if tp_size > 1:
1600
1896
  # Sync across TP ranks to make sure they have the same number of ready requests
1601
1897
  tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
1602
1898
  torch.distributed.all_reduce(
1603
- tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
1899
+ tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
1604
1900
  )
1605
1901
  num_ready_reqs_max = tensor.item()
1606
1902
  for i in range(num_ready_reqs, num_ready_reqs_max):
1607
1903
  self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
1608
1904
  num_ready_reqs = num_ready_reqs_max
1609
1905
 
1610
- self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
1906
+ self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1611
1907
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
1612
1908
 
1613
1909
  def flush_cache_wrapped(self, recv_req: FlushCacheReq):
@@ -1618,21 +1914,25 @@ class Scheduler:
1618
1914
  if len(self.waiting_queue) == 0 and (
1619
1915
  self.running_batch is None or len(self.running_batch.reqs) == 0
1620
1916
  ):
1917
+ self.cur_batch = None
1918
+ self.last_batch = None
1621
1919
  self.tree_cache.reset()
1622
1920
  self.tree_cache_metrics = {"total": 0, "hit": 0}
1623
1921
  if self.grammar_backend:
1624
1922
  self.grammar_backend.reset()
1625
1923
  self.req_to_token_pool.clear()
1626
- self.token_to_kv_pool.clear()
1924
+ self.token_to_kv_pool_allocator.clear()
1627
1925
 
1628
1926
  if not self.spec_algorithm.is_none():
1629
1927
  self.draft_worker.model_runner.req_to_token_pool.clear()
1630
- self.draft_worker.model_runner.token_to_kv_pool.clear()
1928
+ self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
1631
1929
 
1632
1930
  self.num_generated_tokens = 0
1633
1931
  self.forward_ct_decode = 0
1634
1932
  self.spec_num_total_accepted_tokens = 0
1635
1933
  self.spec_num_total_forward_ct = 0
1934
+ self.cum_spec_accept_length = 0
1935
+ self.cum_spec_accept_count = 0
1636
1936
  torch.cuda.empty_cache()
1637
1937
  logger.info("Cache flushed successfully!")
1638
1938
  if_success = True
@@ -1645,6 +1945,49 @@ class Scheduler:
1645
1945
  if_success = False
1646
1946
  return if_success
1647
1947
 
1948
+ def get_internal_state(self, recv_req: GetInternalStateReq):
1949
+ ret = dict(global_server_args_dict)
1950
+ ret["last_gen_throughput"] = self.last_gen_throughput
1951
+ if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
1952
+ ret["avg_spec_accept_length"] = (
1953
+ self.cum_spec_accept_length / self.cum_spec_accept_count
1954
+ )
1955
+
1956
+ if RECORD_STEP_TIME:
1957
+ ret["step_time_dict"] = self.step_time_dict
1958
+ return GetInternalStateReqOutput(
1959
+ internal_state=ret,
1960
+ )
1961
+
1962
+ def set_internal_state(self, recv_req: SetInternalStateReq):
1963
+ server_args_dict = recv_req.server_args
1964
+ args_allow_update = set(
1965
+ [
1966
+ "speculative_accept_threshold_single",
1967
+ "speculative_accept_threshold_acc",
1968
+ ]
1969
+ )
1970
+ if_success = True
1971
+ for k, v in server_args_dict.items():
1972
+ if k not in args_allow_update:
1973
+ logging.warning(f"Updating {k} is not supported.")
1974
+ if_success = False
1975
+ break
1976
+ if if_success:
1977
+ if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
1978
+ avg_spec_accept_length = (
1979
+ self.cum_spec_accept_length / self.cum_spec_accept_count
1980
+ )
1981
+ logger.info(f"{avg_spec_accept_length=}")
1982
+ self.cum_spec_accept_length = self.cum_spec_accept_count = 0
1983
+ for k, v in server_args_dict.items():
1984
+ global_server_args_dict[k] = v
1985
+ logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
1986
+ return SetInternalStateReqOutput(
1987
+ updated=True,
1988
+ server_args=global_server_args_dict,
1989
+ )
1990
+
1648
1991
  def abort_request(self, recv_req: AbortReq):
1649
1992
  # Delete requests in the waiting queue
1650
1993
  to_del = None
@@ -1674,7 +2017,7 @@ class Scheduler:
1674
2017
  assert flash_cache_success, "Cache flush failed after updating weights"
1675
2018
  else:
1676
2019
  logger.error(message)
1677
- return UpdateWeightFromDiskReqOutput(success, message)
2020
+ return UpdateWeightFromDiskReqOutput(success, message, 0)
1678
2021
 
1679
2022
  def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
1680
2023
  """Initialize the online model parameter update group."""
@@ -1699,8 +2042,9 @@ class Scheduler:
1699
2042
  success, message = self.tp_worker.update_weights_from_tensor(recv_req)
1700
2043
  # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
1701
2044
  if success:
1702
- flash_cache_success = self.flush_cache()
1703
- assert flash_cache_success, "Cache flush failed after updating weights"
2045
+ if recv_req.flush_cache:
2046
+ flash_cache_success = self.flush_cache()
2047
+ assert flash_cache_success, "Cache flush failed after updating weights"
1704
2048
  else:
1705
2049
  logger.error(message)
1706
2050
  return UpdateWeightsFromTensorReqOutput(success, message)
@@ -1709,7 +2053,7 @@ class Scheduler:
1709
2053
  parameter = self.tp_worker.get_weights_by_name(recv_req)
1710
2054
  return GetWeightsByNameReqOutput(parameter)
1711
2055
 
1712
- def release_memory_occupation(self):
2056
+ def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
1713
2057
  self.stashed_model_static_state = _export_static_state(
1714
2058
  self.tp_worker.worker.model_runner.model
1715
2059
  )
@@ -1717,7 +2061,7 @@ class Scheduler:
1717
2061
  self.flush_cache()
1718
2062
  return ReleaseMemoryOccupationReqOutput()
1719
2063
 
1720
- def resume_memory_occupation(self):
2064
+ def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
1721
2065
  self.memory_saver_adapter.resume()
1722
2066
  _import_static_state(
1723
2067
  self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
@@ -1726,24 +2070,96 @@ class Scheduler:
1726
2070
  return ResumeMemoryOccupationReqOutput()
1727
2071
 
1728
2072
  def profile(self, recv_req: ProfileReq):
1729
- if recv_req == ProfileReq.START_PROFILE:
1730
- self.start_profile()
2073
+ if recv_req.type == ProfileReqType.START_PROFILE:
2074
+ return self.start_profile(
2075
+ recv_req.output_dir, recv_req.num_steps, recv_req.activities
2076
+ )
1731
2077
  else:
1732
- self.stop_profile()
2078
+ return self.stop_profile()
2079
+
2080
+ def start_profile(
2081
+ self,
2082
+ output_dir: Optional[str],
2083
+ num_steps: Optional[int],
2084
+ activities: Optional[List[str]],
2085
+ ) -> None:
2086
+ if self.torch_profiler_activities:
2087
+ return ProfileReqOutput(
2088
+ success=False,
2089
+ message="Profiling is already in progress. Call /stop_profile first.",
2090
+ )
2091
+
2092
+ if output_dir is None:
2093
+ output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
2094
+ if activities is None:
2095
+ activities = ["CPU", "GPU"]
1733
2096
 
1734
- def start_profile(self) -> None:
1735
- if self.profiler is None:
1736
- raise RuntimeError("Profiler is not enabled.")
1737
- self.profiler.start()
2097
+ self.torch_profiler_output_dir = output_dir
2098
+ self.torch_profiler_activities = activities
2099
+ logger.info(
2100
+ "Profiling starts. Traces will be saved to: %s",
2101
+ self.torch_profiler_output_dir,
2102
+ )
2103
+
2104
+ activity_map = {
2105
+ "CPU": torch.profiler.ProfilerActivity.CPU,
2106
+ "GPU": torch.profiler.ProfilerActivity.CUDA,
2107
+ }
2108
+ torchprof_activities = [
2109
+ activity_map[a] for a in activities if a in activity_map
2110
+ ]
2111
+
2112
+ if torchprof_activities:
2113
+ self.torch_profiler = torch.profiler.profile(
2114
+ activities=torchprof_activities,
2115
+ with_stack=True,
2116
+ )
2117
+ self.torch_profiler.start()
2118
+
2119
+ if "MEM" in activities:
2120
+ torch.cuda.memory._record_memory_history(max_entries=100000)
2121
+
2122
+ if num_steps:
2123
+ self.profiler_target_forward_ct = self.forward_ct + num_steps
2124
+ # The caller will be notified when reaching profiler_target_forward_ct
2125
+ else:
2126
+ self.profiler_target_forward_ct = None
2127
+ return ProfileReqOutput(success=True, message="Succeeded")
1738
2128
 
1739
2129
  def stop_profile(self) -> None:
1740
- if self.profiler is None:
1741
- raise RuntimeError("Profiler is not enabled.")
1742
- self.profiler.stop()
1743
- self.profiler.export_chrome_trace(
1744
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
2130
+ if self.torch_profiler_activities is None:
2131
+ return
2132
+
2133
+ logger.info("Stop profiling...")
2134
+ if self.torch_profiler is not None:
2135
+ self.torch_profiler.stop()
2136
+ self.torch_profiler.export_chrome_trace(
2137
+ os.path.join(
2138
+ self.torch_profiler_output_dir,
2139
+ str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
2140
+ )
2141
+ )
2142
+
2143
+ if "MEM" in self.torch_profiler_activities:
2144
+ memory_profile_path = os.path.join(
2145
+ self.torch_profiler_trace_dir,
2146
+ str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
2147
+ )
2148
+ torch.cuda.memory._dump_snapshot(memory_profile_path)
2149
+ torch.cuda.memory._record_memory_history(enabled=None)
2150
+
2151
+ logger.info(
2152
+ "Profiling done. Traces are saved to: %s",
2153
+ self.torch_profiler_output_dir,
1745
2154
  )
1746
- logger.info("Profiler is done")
2155
+ self.torch_profiler = None
2156
+ self.torch_profiler_output_dir = None
2157
+ self.torch_profiler_activities = None
2158
+
2159
+ if self.profiler_target_forward_ct:
2160
+ self.send_to_tokenizer.send_pyobj(
2161
+ ProfileReqOutput(success=True, message="Succeeded.")
2162
+ )
1747
2163
 
1748
2164
  def open_session(self, recv_req: OpenSessionReqInput):
1749
2165
  # handle error
@@ -1752,7 +2168,7 @@ class Scheduler:
1752
2168
  logger.warning(f"session id {session_id} already exist, cannot open.")
1753
2169
  return OpenSessionReqOutput(session_id, False)
1754
2170
  elif session_id is None:
1755
- logger.warning(f"session id is None, cannot open.")
2171
+ logger.warning("session id is None, cannot open.")
1756
2172
  return OpenSessionReqOutput(session_id, False)
1757
2173
  else:
1758
2174
  self.sessions[session_id] = Session(
@@ -1769,6 +2185,10 @@ class Scheduler:
1769
2185
  del self.sessions[session_id]
1770
2186
 
1771
2187
 
2188
+ def is_health_check_generate_req(recv_req):
2189
+ return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2190
+
2191
+
1772
2192
  def _export_static_state(model):
1773
2193
  return dict(
1774
2194
  buffers=[
@@ -1791,26 +2211,28 @@ def run_scheduler_process(
1791
2211
  dp_rank: Optional[int],
1792
2212
  pipe_writer,
1793
2213
  ):
1794
- setproctitle.setproctitle("sglang::scheduler")
2214
+ # Config the process
2215
+ # kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2`
2216
+ setproctitle.setproctitle(f"sglang::scheduler_{dp_rank}")
1795
2217
  faulthandler.enable()
2218
+ parent_process = psutil.Process().parent()
1796
2219
 
1797
2220
  # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
1798
2221
  if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
1799
2222
  dp_rank = int(os.environ["SGLANG_DP_RANK"])
1800
2223
 
1801
- # Configue the logger
2224
+ # Configure the logger
1802
2225
  if dp_rank is None:
1803
- configure_logger(server_args, prefix=f" TP{tp_rank}")
2226
+ prefix = f" TP{tp_rank}"
1804
2227
  else:
1805
- configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
2228
+ prefix = f" DP{dp_rank} TP{tp_rank}"
2229
+ configure_logger(server_args, prefix=prefix)
1806
2230
  suppress_other_loggers()
1807
2231
 
1808
2232
  # Set cpu affinity to this gpu process
1809
2233
  if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
1810
2234
  set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
1811
2235
 
1812
- parent_process = psutil.Process().parent()
1813
-
1814
2236
  # Create a scheduler and run the event loop
1815
2237
  try:
1816
2238
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)