sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +302 -414
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +13 -8
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +144 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +773 -334
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +225 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +68 -37
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +102 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +56 -31
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +280 -81
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +135 -60
  181. sglang/srt/speculative/build_eagle_tree.py +8 -9
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
  183. sglang/srt/speculative/eagle_utils.py +92 -57
  184. sglang/srt/speculative/eagle_worker.py +238 -111
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -17,10 +17,11 @@ import faulthandler
17
17
  import logging
18
18
  import os
19
19
  import signal
20
+ import sys
20
21
  import threading
21
22
  import time
22
23
  import warnings
23
- from collections import deque
24
+ from collections import defaultdict, deque
24
25
  from concurrent import futures
25
26
  from dataclasses import dataclass
26
27
  from http import HTTPStatus
@@ -44,17 +45,24 @@ from sglang.srt.managers.io_struct import (
44
45
  BatchTokenIDOut,
45
46
  CloseSessionReqInput,
46
47
  FlushCacheReq,
48
+ GetInternalStateReq,
49
+ GetInternalStateReqOutput,
47
50
  GetWeightsByNameReqInput,
48
51
  GetWeightsByNameReqOutput,
52
+ HealthCheckOutput,
49
53
  InitWeightsUpdateGroupReqInput,
50
54
  InitWeightsUpdateGroupReqOutput,
51
55
  OpenSessionReqInput,
52
56
  OpenSessionReqOutput,
53
57
  ProfileReq,
58
+ ProfileReqOutput,
59
+ ProfileReqType,
54
60
  ReleaseMemoryOccupationReqInput,
55
61
  ReleaseMemoryOccupationReqOutput,
56
62
  ResumeMemoryOccupationReqInput,
57
63
  ResumeMemoryOccupationReqOutput,
64
+ SetInternalStateReq,
65
+ SetInternalStateReqOutput,
58
66
  TokenizedEmbeddingReqInput,
59
67
  TokenizedGenerateReqInput,
60
68
  UpdateWeightFromDiskReqInput,
@@ -82,6 +90,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
82
90
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
83
91
  from sglang.srt.managers.utils import validate_input_length
84
92
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
93
+ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
85
94
  from sglang.srt.mem_cache.radix_cache import RadixCache
86
95
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
87
96
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
@@ -94,6 +103,7 @@ from sglang.srt.utils import (
94
103
  crash_on_warnings,
95
104
  get_bool_env_var,
96
105
  get_zmq_socket,
106
+ pyspy_dump_schedulers,
97
107
  set_gpu_proc_affinity,
98
108
  set_random_seed,
99
109
  suppress_other_loggers,
@@ -103,13 +113,16 @@ from sglang.utils import TypeBasedDispatcher, get_exception_traceback
103
113
  logger = logging.getLogger(__name__)
104
114
 
105
115
  # Test retract decode for debugging purposes
106
- test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
116
+ TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
117
+ RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
107
118
 
108
119
 
109
120
  @dataclass
110
121
  class GenerationBatchResult:
111
122
  logits_output: LogitsProcessorOutput
112
123
  next_token_ids: List[int]
124
+ extend_input_len_per_req: List[int]
125
+ extend_logprob_start_len_per_req: List[int]
113
126
  bid: int
114
127
 
115
128
 
@@ -135,20 +148,16 @@ class Scheduler:
135
148
  self.tp_rank = tp_rank
136
149
  self.tp_size = server_args.tp_size
137
150
  self.schedule_policy = server_args.schedule_policy
138
- self.disable_jump_forward = server_args.disable_jump_forward
139
151
  self.lora_paths = server_args.lora_paths
140
152
  self.max_loras_per_batch = server_args.max_loras_per_batch
141
153
  self.enable_overlap = not server_args.disable_overlap_schedule
142
154
  self.skip_tokenizer_init = server_args.skip_tokenizer_init
143
155
  self.enable_metrics = server_args.enable_metrics
156
+ self.stream_interval = server_args.stream_interval
144
157
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
145
158
  server_args.speculative_algorithm
146
159
  )
147
- self.decode_mem_cache_buf_multiplier = (
148
- self.server_args.speculative_num_draft_tokens
149
- if not self.spec_algorithm.is_none()
150
- else 1
151
- )
160
+ self.gpu_id = gpu_id
152
161
  self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
153
162
 
154
163
  # Distributed rank info
@@ -188,49 +197,16 @@ class Scheduler:
188
197
  self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
189
198
 
190
199
  # Init tokenizer
191
- self.model_config = ModelConfig(
192
- server_args.model_path,
193
- trust_remote_code=server_args.trust_remote_code,
194
- revision=server_args.revision,
195
- context_length=server_args.context_length,
196
- model_override_args=server_args.json_model_override_args,
197
- is_embedding=server_args.is_embedding,
198
- dtype=server_args.dtype,
199
- quantization=server_args.quantization,
200
- )
201
- self.is_generation = self.model_config.is_generation
202
-
203
- if server_args.skip_tokenizer_init:
204
- self.tokenizer = self.processor = None
205
- else:
206
- if self.model_config.is_multimodal:
207
- self.processor = get_processor(
208
- server_args.tokenizer_path,
209
- tokenizer_mode=server_args.tokenizer_mode,
210
- trust_remote_code=server_args.trust_remote_code,
211
- revision=server_args.revision,
212
- )
213
- self.tokenizer = self.processor.tokenizer
214
- else:
215
- self.tokenizer = get_tokenizer(
216
- server_args.tokenizer_path,
217
- tokenizer_mode=server_args.tokenizer_mode,
218
- trust_remote_code=server_args.trust_remote_code,
219
- revision=server_args.revision,
220
- )
200
+ self.init_tokenizer()
221
201
 
222
202
  # Check whether overlap can be enabled
223
203
  if not self.is_generation:
224
204
  self.enable_overlap = False
225
205
  logger.info("Overlap scheduler is disabled for embedding models.")
226
-
227
206
  if self.model_config.is_multimodal:
228
207
  self.enable_overlap = False
229
208
  logger.info("Overlap scheduler is disabled for multimodal models.")
230
209
 
231
- if self.enable_overlap:
232
- self.disable_jump_forward = True
233
-
234
210
  # Launch a tensor parallel worker
235
211
  if self.enable_overlap:
236
212
  TpWorkerClass = TpModelWorkerClient
@@ -245,7 +221,7 @@ class Scheduler:
245
221
  nccl_port=port_args.nccl_port,
246
222
  )
247
223
 
248
- # Launch a worker for speculative decoding if needed
224
+ # Launch a draft worker for speculative decoding
249
225
  if self.spec_algorithm.is_eagle():
250
226
  from sglang.srt.speculative.eagle_worker import EAGLEWorker
251
227
 
@@ -279,6 +255,7 @@ class Scheduler:
279
255
  self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
280
256
  global_server_args_dict.update(worker_global_server_args_dict)
281
257
  set_random_seed(self.random_seed)
258
+
282
259
  # Print debug info
283
260
  logger.info(
284
261
  f"max_total_num_tokens={self.max_total_num_tokens}, "
@@ -289,27 +266,11 @@ class Scheduler:
289
266
  )
290
267
 
291
268
  # Init memory pool and cache
292
- self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
293
-
294
- if (
295
- server_args.chunked_prefill_size is not None
296
- and server_args.disable_radix_cache
297
- ):
298
- self.tree_cache = ChunkCache(
299
- req_to_token_pool=self.req_to_token_pool,
300
- token_to_kv_pool=self.token_to_kv_pool,
301
- )
302
- else:
303
- self.tree_cache = RadixCache(
304
- req_to_token_pool=self.req_to_token_pool,
305
- token_to_kv_pool=self.token_to_kv_pool,
306
- disable=server_args.disable_radix_cache,
307
- )
308
- self.tree_cache_metrics = {"total": 0, "hit": 0}
309
- self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
269
+ self.init_memory_pool_and_cache()
310
270
 
311
271
  # Init running status
312
272
  self.waiting_queue: List[Req] = []
273
+ self.staging_reqs = {}
313
274
  # The running decoding batch for continuous batching
314
275
  self.running_batch: Optional[ScheduleBatch] = None
315
276
  # The current forward batch
@@ -319,22 +280,20 @@ class Scheduler:
319
280
  self.forward_ct = 0
320
281
  self.forward_ct_decode = 0
321
282
  self.num_generated_tokens = 0
322
- self.spec_num_total_accepted_tokens = 0
323
- self.spec_num_total_forward_ct = 0
324
283
  self.last_decode_stats_tic = time.time()
325
- self.stream_interval = server_args.stream_interval
284
+ self.return_health_check_ct = 0
326
285
  self.current_stream = torch.get_device_module(self.device).current_stream()
327
286
  if self.device == "cpu":
328
287
  self.current_stream.synchronize = lambda: None # No-op for CPU
329
288
 
330
- # Session info
289
+ # Init session info
331
290
  self.sessions: Dict[str, Session] = {}
332
291
 
333
292
  # Init chunked prefill
334
293
  self.chunked_prefill_size = server_args.chunked_prefill_size
335
294
  if self.chunked_prefill_size <= 0: # -1 means disable
336
295
  self.chunked_prefill_size = None
337
- self.being_chunked_req = None
296
+ self.chunked_req = None
338
297
  self.is_mixed_chunk = (
339
298
  self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
340
299
  )
@@ -348,11 +307,11 @@ class Scheduler:
348
307
  else:
349
308
  self.grammar_backend = None
350
309
 
351
- # Init new token estimation
310
+ # Init schedule policy and new token estimation
311
+ self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
352
312
  assert (
353
313
  server_args.schedule_conservativeness >= 0
354
314
  ), "Invalid schedule_conservativeness"
355
-
356
315
  self.init_new_token_ratio = min(
357
316
  global_config.default_init_new_token_ratio
358
317
  * server_args.schedule_conservativeness,
@@ -368,7 +327,7 @@ class Scheduler:
368
327
  ) / global_config.default_new_token_ratio_decay_steps
369
328
  self.new_token_ratio = self.init_new_token_ratio
370
329
 
371
- # Tells whether the current running batch is full so that we can skip
330
+ # Tell whether the current running batch is full so that we can skip
372
331
  # the check of whether to prefill new requests.
373
332
  # This is an optimization to reduce the overhead of the prefill check.
374
333
  self.batch_is_full = False
@@ -379,41 +338,19 @@ class Scheduler:
379
338
  t.start()
380
339
  self.parent_process = psutil.Process().parent()
381
340
 
341
+ # Init memory saver
382
342
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
383
343
  enable=server_args.enable_memory_saver
384
344
  )
385
345
 
386
346
  # Init profiler
387
- if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
388
- self.profiler = None
389
- else:
390
- self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
391
- logger.info(
392
- "Profiling enabled. Traces will be saved to: %s",
393
- self.torch_profiler_trace_dir,
394
- )
395
- self.profiler = torch.profiler.profile(
396
- activities=[
397
- torch.profiler.ProfilerActivity.CPU,
398
- torch.profiler.ProfilerActivity.CUDA,
399
- ],
400
- with_stack=True,
401
- )
347
+ self.torch_profiler = None
348
+ self.torch_profiler_output_dir: Optional[str] = None
349
+ self.torch_profiler_activities: Optional[List[str]] = None
350
+ self.profiler_target_forward_ct: Optional[int] = None
402
351
 
403
352
  # Init metrics stats
404
- self.stats = SchedulerStats()
405
- if self.enable_metrics:
406
- self.metrics_collector = SchedulerMetricsCollector(
407
- labels={
408
- "model_name": self.server_args.served_model_name,
409
- # TODO: Add lora name/path in the future,
410
- },
411
- )
412
-
413
- # The largest prefill length of a single request
414
- self._largest_prefill_len: int = 0
415
- # The largest context length (prefill + generation) of a single request
416
- self._largest_prefill_decode_len: int = 0
353
+ self.init_metrics()
417
354
 
418
355
  # Init request dispatcher
419
356
  self._request_dispatcher = TypeBasedDispatcher(
@@ -422,6 +359,8 @@ class Scheduler:
422
359
  (TokenizedEmbeddingReqInput, self.handle_embedding_request),
423
360
  (FlushCacheReq, self.flush_cache_wrapped),
424
361
  (AbortReq, self.abort_request),
362
+ (OpenSessionReqInput, self.open_session),
363
+ (CloseSessionReqInput, self.close_session),
425
364
  (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
426
365
  (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
427
366
  (
@@ -430,39 +369,108 @@ class Scheduler:
430
369
  ),
431
370
  (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
432
371
  (GetWeightsByNameReqInput, self.get_weights_by_name),
372
+ (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
373
+ (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
433
374
  (ProfileReq, self.profile),
434
- (OpenSessionReqInput, self.open_session),
435
- (CloseSessionReqInput, self.close_session),
436
- (
437
- ReleaseMemoryOccupationReqInput,
438
- lambda _: self.release_memory_occupation(),
439
- ),
440
- (
441
- ResumeMemoryOccupationReqInput,
442
- lambda _: self.resume_memory_occupation(),
443
- ),
375
+ (GetInternalStateReq, self.get_internal_state),
376
+ (SetInternalStateReq, self.set_internal_state),
444
377
  ]
445
378
  )
446
379
 
447
- def watchdog_thread(self):
448
- """A watch dog thread that will try to kill the server itself if one batch takes too long."""
449
- self.watchdog_last_forward_ct = 0
450
- self.watchdog_last_time = time.time()
380
+ def init_tokenizer(self):
381
+ server_args = self.server_args
451
382
 
452
- while True:
453
- current = time.time()
454
- if self.cur_batch is not None:
455
- if self.watchdog_last_forward_ct == self.forward_ct:
456
- if current > self.watchdog_last_time + self.watchdog_timeout:
457
- logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
458
- break
459
- else:
460
- self.watchdog_last_forward_ct = self.forward_ct
461
- self.watchdog_last_time = current
462
- time.sleep(self.watchdog_timeout // 2)
463
- # Wait sometimes so that the parent process can print the error.
464
- time.sleep(5)
465
- self.parent_process.send_signal(signal.SIGQUIT)
383
+ self.model_config = ModelConfig(
384
+ server_args.model_path,
385
+ trust_remote_code=server_args.trust_remote_code,
386
+ revision=server_args.revision,
387
+ context_length=server_args.context_length,
388
+ model_override_args=server_args.json_model_override_args,
389
+ is_embedding=server_args.is_embedding,
390
+ dtype=server_args.dtype,
391
+ quantization=server_args.quantization,
392
+ )
393
+ self.is_generation = self.model_config.is_generation
394
+
395
+ if server_args.skip_tokenizer_init:
396
+ self.tokenizer = self.processor = None
397
+ else:
398
+ if self.model_config.is_multimodal:
399
+ self.processor = get_processor(
400
+ server_args.tokenizer_path,
401
+ tokenizer_mode=server_args.tokenizer_mode,
402
+ trust_remote_code=server_args.trust_remote_code,
403
+ revision=server_args.revision,
404
+ )
405
+ self.tokenizer = self.processor.tokenizer
406
+ else:
407
+ self.tokenizer = get_tokenizer(
408
+ server_args.tokenizer_path,
409
+ tokenizer_mode=server_args.tokenizer_mode,
410
+ trust_remote_code=server_args.trust_remote_code,
411
+ revision=server_args.revision,
412
+ )
413
+
414
+ def init_memory_pool_and_cache(self):
415
+ server_args = self.server_args
416
+
417
+ self.req_to_token_pool, self.token_to_kv_pool_allocator = (
418
+ self.tp_worker.get_memory_pool()
419
+ )
420
+
421
+ if (
422
+ server_args.chunked_prefill_size is not None
423
+ and server_args.disable_radix_cache
424
+ ):
425
+ self.tree_cache = ChunkCache(
426
+ req_to_token_pool=self.req_to_token_pool,
427
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
428
+ )
429
+ else:
430
+ if self.enable_hierarchical_cache:
431
+ self.tree_cache = HiRadixCache(
432
+ req_to_token_pool=self.req_to_token_pool,
433
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
434
+ )
435
+ else:
436
+ self.tree_cache = RadixCache(
437
+ req_to_token_pool=self.req_to_token_pool,
438
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
439
+ disable=server_args.disable_radix_cache,
440
+ )
441
+
442
+ self.decode_mem_cache_buf_multiplier = (
443
+ 1
444
+ if self.spec_algorithm.is_none()
445
+ else (
446
+ server_args.speculative_num_draft_tokens
447
+ + (
448
+ server_args.speculative_eagle_topk
449
+ * server_args.speculative_num_steps
450
+ )
451
+ )
452
+ )
453
+
454
+ def init_metrics(self):
455
+ # The largest prefill length of a single request
456
+ self._largest_prefill_len: int = 0
457
+ # The largest context length (prefill + generation) of a single request
458
+ self._largest_prefill_decode_len: int = 0
459
+ self.last_gen_throughput: float = 0.0
460
+ self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
461
+ self.spec_num_total_accepted_tokens = 0
462
+ self.spec_num_total_forward_ct = 0
463
+ self.cum_spec_accept_length = 0
464
+ self.cum_spec_accept_count = 0
465
+ self.stats = SchedulerStats()
466
+ if self.enable_metrics:
467
+ engine_type = "unified"
468
+ self.metrics_collector = SchedulerMetricsCollector(
469
+ labels={
470
+ "model_name": self.server_args.served_model_name,
471
+ "engine_type": engine_type,
472
+ },
473
+ )
466
474
 
467
475
  @torch.no_grad()
468
476
  def event_loop_normal(self):
@@ -577,6 +585,13 @@ class Scheduler:
577
585
 
578
586
  def process_input_requests(self, recv_reqs: List):
579
587
  for recv_req in recv_reqs:
588
+ # If it is a health check generation request and there are running requests, ignore it.
589
+ if is_health_check_generate_req(recv_req) and (
590
+ self.chunked_req is not None or self.running_batch is not None
591
+ ):
592
+ self.return_health_check_ct += 1
593
+ continue
594
+
580
595
  output = self._request_dispatcher(recv_req)
581
596
  if output is not None:
582
597
  self.send_to_tokenizer.send_pyobj(output)
@@ -591,7 +606,6 @@ class Scheduler:
591
606
  or recv_req.session_params.id is None
592
607
  or recv_req.session_params.id not in self.sessions
593
608
  ):
594
-
595
609
  if recv_req.input_embeds is not None:
596
610
  # Generate fake input_ids based on the length of input_embeds
597
611
  seq_length = len(recv_req.input_embeds)
@@ -618,10 +632,12 @@ class Scheduler:
618
632
  recv_req.sampling_params,
619
633
  return_logprob=recv_req.return_logprob,
620
634
  top_logprobs_num=recv_req.top_logprobs_num,
635
+ token_ids_logprob=recv_req.token_ids_logprob,
621
636
  stream=recv_req.stream,
622
637
  lora_path=recv_req.lora_path,
623
638
  input_embeds=recv_req.input_embeds,
624
639
  custom_logit_processor=custom_logit_processor,
640
+ return_hidden_states=recv_req.return_hidden_states,
625
641
  eos_token_ids=self.model_config.hf_eos_token_id,
626
642
  )
627
643
  req.tokenizer = self.tokenizer
@@ -633,14 +649,14 @@ class Scheduler:
633
649
  req.finished_reason = FINISH_ABORT(
634
650
  f"Invalid request: session id {recv_req.session_params.id} does not exist"
635
651
  )
636
- self.waiting_queue.append(req)
652
+ self._add_request_to_queue(req)
637
653
  return
638
654
  else:
639
655
  # Create a new request from a previous session
640
656
  session = self.sessions[recv_req.session_params.id]
641
657
  req = session.create_req(recv_req, self.tokenizer)
642
658
  if isinstance(req.finished_reason, FINISH_ABORT):
643
- self.waiting_queue.append(req)
659
+ self._add_request_to_queue(req)
644
660
  return
645
661
 
646
662
  # Handle multimodal inputs
@@ -664,7 +680,7 @@ class Scheduler:
664
680
  req.finished_reason = FINISH_ABORT(
665
681
  error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
666
682
  )
667
- self.waiting_queue.append(req)
683
+ self._add_request_to_queue(req)
668
684
  return
669
685
 
670
686
  # Validate prompts length
@@ -674,16 +690,28 @@ class Scheduler:
674
690
  self.server_args.allow_auto_truncate,
675
691
  )
676
692
  if error_msg:
677
- self.waiting_queue.append(req)
693
+ req.origin_input_ids = [0]
694
+ req.sampling_params.max_new_tokens = 0
695
+ self._add_request_to_queue(req)
678
696
  return
679
697
 
680
698
  # Copy more attributes
681
- if recv_req.logprob_start_len == -1:
699
+ if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
682
700
  # By default, only return the logprobs for output tokens
683
701
  req.logprob_start_len = len(req.origin_input_ids) - 1
684
702
  else:
685
703
  req.logprob_start_len = recv_req.logprob_start_len
686
704
 
705
+ if req.logprob_start_len >= len(req.origin_input_ids):
706
+ req.finished_reason = FINISH_ABORT(
707
+ f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
708
+ HTTPStatus.BAD_REQUEST,
709
+ "BadRequestError",
710
+ )
711
+ req.logprob_start_len = len(req.origin_input_ids) - 1
712
+ self._add_request_to_queue(req)
713
+ return
714
+
687
715
  req.sampling_params.max_new_tokens = min(
688
716
  (
689
717
  req.sampling_params.max_new_tokens
@@ -699,6 +727,7 @@ class Scheduler:
699
727
  req.sampling_params.json_schema is not None
700
728
  or req.sampling_params.regex is not None
701
729
  or req.sampling_params.ebnf is not None
730
+ or req.sampling_params.structural_tag is not None
702
731
  ):
703
732
  assert self.grammar_backend is not None
704
733
  if req.sampling_params.json_schema is not None:
@@ -707,6 +736,8 @@ class Scheduler:
707
736
  key = ("regex", req.sampling_params.regex)
708
737
  elif req.sampling_params.ebnf is not None:
709
738
  key = ("ebnf", req.sampling_params.ebnf)
739
+ elif req.sampling_params.structural_tag:
740
+ key = ("structural_tag", req.sampling_params.structural_tag)
710
741
 
711
742
  req.grammar = self.grammar_backend.get_cached_value(key)
712
743
  if not req.grammar:
@@ -716,7 +747,13 @@ class Scheduler:
716
747
  if add_to_grammar_queue:
717
748
  self.grammar_queue.append(req)
718
749
  else:
719
- self.waiting_queue.append(req)
750
+ self._add_request_to_queue(req)
751
+
752
+ def _add_request_to_queue(self, req: Req):
753
+ self.waiting_queue.append(req)
754
+
755
+ def _extend_requests_to_queue(self, reqs: List[Req]):
756
+ self.waiting_queue.extend(reqs)
720
757
 
721
758
  def handle_embedding_request(
722
759
  self,
@@ -737,61 +774,64 @@ class Scheduler:
737
774
  self.server_args.allow_auto_truncate,
738
775
  )
739
776
  if error_msg:
740
- self.waiting_queue.append(req)
777
+ self._add_request_to_queue(req)
741
778
  return
742
779
 
743
780
  # Copy more attributes
744
781
  req.logprob_start_len = len(req.origin_input_ids) - 1
745
- self.waiting_queue.append(req)
782
+ self._add_request_to_queue(req)
746
783
 
747
784
  def log_prefill_stats(
748
785
  self,
749
786
  adder: PrefillAdder,
750
787
  can_run_list: List[Req],
751
- running_bs: ScheduleBatch,
752
- has_being_chunked: bool,
788
+ running_bs: int,
753
789
  ):
754
- self.tree_cache_metrics["total"] += (
755
- adder.log_input_tokens + adder.log_hit_tokens
756
- ) / 10**9
757
- self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
758
- tree_cache_hit_rate = (
759
- self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
760
- )
761
-
762
790
  num_used = self.max_total_num_tokens - (
763
- self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
791
+ self.token_to_kv_pool_allocator.available_size()
792
+ + self.tree_cache.evictable_size()
793
+ )
794
+ self._largest_prefill_len = max(
795
+ self._largest_prefill_len, adder.log_input_tokens
764
796
  )
765
797
 
766
- logger.info(
798
+ f = (
767
799
  f"Prefill batch. "
768
800
  f"#new-seq: {len(can_run_list)}, "
769
801
  f"#new-token: {adder.log_input_tokens}, "
770
802
  f"#cached-token: {adder.log_hit_tokens}, "
771
- f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
772
803
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
773
804
  f"#running-req: {running_bs}, "
774
- f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
805
+ f"#queue-req: {len(self.waiting_queue)}, "
775
806
  )
807
+ logger.info(f)
776
808
 
777
809
  if self.enable_metrics:
810
+ cache_hit_rate = adder.log_hit_tokens / (
811
+ adder.log_input_tokens + adder.log_hit_tokens
812
+ )
778
813
  self.stats.num_running_reqs = running_bs
779
814
  self.stats.num_used_tokens = num_used
780
815
  self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
781
- self.stats.num_queue_reqs = len(self.waiting_queue) + has_being_chunked
782
- self.stats.cache_hit_rate = tree_cache_hit_rate
816
+ self.stats.num_queue_reqs = len(self.waiting_queue)
817
+ self.stats.cache_hit_rate = cache_hit_rate
783
818
  self.metrics_collector.log_stats(self.stats)
784
819
 
785
820
  def log_decode_stats(self):
786
- num_used = self.max_total_num_tokens - (
787
- self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
788
- )
789
- gen_throughput = self.num_generated_tokens / (
790
- time.time() - self.last_decode_stats_tic
791
- )
792
- self.num_generated_tokens = 0
821
+ gap_latency = time.time() - self.last_decode_stats_tic
793
822
  self.last_decode_stats_tic = time.time()
823
+ self.last_gen_throughput = self.num_generated_tokens / gap_latency
824
+ self.num_generated_tokens = 0
794
825
  num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
826
+ num_used = self.max_total_num_tokens - (
827
+ self.token_to_kv_pool_allocator.available_size()
828
+ + self.tree_cache.evictable_size()
829
+ )
830
+
831
+ if RECORD_STEP_TIME:
832
+ self.step_time_dict[num_running_reqs].append(
833
+ gap_latency / self.server_args.decode_log_interval
834
+ )
795
835
 
796
836
  if self.spec_algorithm.is_none():
797
837
  msg = (
@@ -799,14 +839,17 @@ class Scheduler:
799
839
  f"#running-req: {num_running_reqs}, "
800
840
  f"#token: {num_used}, "
801
841
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
802
- f"gen throughput (token/s): {gen_throughput:.2f}, "
803
- f"#queue-req: {len(self.waiting_queue)}"
842
+ f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
843
+ f"largest-len: {self._largest_prefill_decode_len}, "
844
+ f"#queue-req: {len(self.waiting_queue)}, "
804
845
  )
805
846
  spec_accept_length = 0
806
847
  else:
807
848
  spec_accept_length = (
808
849
  self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
809
850
  )
851
+ self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
852
+ self.cum_spec_accept_count += self.spec_num_total_forward_ct
810
853
  self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
811
854
  msg = (
812
855
  f"Decode batch. "
@@ -814,8 +857,9 @@ class Scheduler:
814
857
  f"#token: {num_used}, "
815
858
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
816
859
  f"accept len: {spec_accept_length:.2f}, "
817
- f"gen throughput (token/s): {gen_throughput:.2f}, "
818
- f"#queue-req: {len(self.waiting_queue)}"
860
+ f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
861
+ f"largest-len: {self._largest_prefill_decode_len}, "
862
+ f"#queue-req: {len(self.waiting_queue)}, "
819
863
  )
820
864
 
821
865
  logger.info(msg)
@@ -823,14 +867,16 @@ class Scheduler:
823
867
  self.stats.num_running_reqs = num_running_reqs
824
868
  self.stats.num_used_tokens = num_used
825
869
  self.stats.token_usage = num_used / self.max_total_num_tokens
826
- self.stats.gen_throughput = gen_throughput
870
+ self.stats.cache_hit_rate = 0.0
871
+ self.stats.gen_throughput = self.last_gen_throughput
827
872
  self.stats.num_queue_reqs = len(self.waiting_queue)
828
873
  self.stats.spec_accept_length = spec_accept_length
829
874
  self.metrics_collector.log_stats(self.stats)
830
875
 
831
876
  def check_memory(self):
832
877
  available_size = (
833
- self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
878
+ self.token_to_kv_pool_allocator.available_size()
879
+ + self.tree_cache.evictable_size()
834
880
  )
835
881
  protected_size = self.tree_cache.protected_size()
836
882
  memory_leak = available_size != (
@@ -857,21 +903,42 @@ class Scheduler:
857
903
  if crash_on_warnings():
858
904
  raise ValueError(msg)
859
905
 
906
+ if (
907
+ self.enable_metrics
908
+ and self.attn_tp_rank == 0
909
+ and time.time() > self.metrics_collector.last_log_time + 30
910
+ ):
911
+ # During idle time, also collect metrics every 30 seconds.
912
+ num_used = self.max_total_num_tokens - (
913
+ self.token_to_kv_pool_allocator.available_size()
914
+ + self.tree_cache.evictable_size()
915
+ )
916
+ num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
917
+ self.stats.num_running_reqs = num_running_reqs
918
+ self.stats.num_used_tokens = num_used
919
+ self.stats.token_usage = num_used / self.max_total_num_tokens
920
+ self.stats.gen_throughput = 0
921
+ self.stats.num_queue_reqs = len(self.waiting_queue)
922
+ self.metrics_collector.log_stats(self.stats)
923
+
860
924
  def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
861
925
  # Merge the prefill batch into the running batch
862
926
  if self.last_batch and self.last_batch.forward_mode.is_extend():
863
- if self.being_chunked_req:
864
- # Move the chunked request out of the batch
865
- self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
866
- self.tree_cache.cache_unfinished_req(self.being_chunked_req)
867
- # being chunked request keeps its rid but will get a new req_pool_idx
868
- self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
927
+ if self.chunked_req:
928
+ # Move the chunked request out of the batch so that we can merge
929
+ # only finished requests to running_batch.
930
+ self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
931
+ self.tree_cache.cache_unfinished_req(self.chunked_req)
932
+ # chunked request keeps its rid but will get a new req_pool_idx
933
+ self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
869
934
  self.batch_is_full = False
870
935
 
936
+ self.last_batch.filter_batch()
871
937
  if not self.last_batch.is_empty():
872
938
  if self.running_batch is None:
873
939
  self.running_batch = self.last_batch
874
940
  else:
941
+ # merge running_batch with prefill batch
875
942
  self.running_batch.merge_batch(self.last_batch)
876
943
 
877
944
  new_batch = self.get_new_batch_prefill()
@@ -900,7 +967,7 @@ class Scheduler:
900
967
  # Handle the cases where prefill is not allowed
901
968
  if (
902
969
  self.batch_is_full or len(self.waiting_queue) == 0
903
- ) and self.being_chunked_req is None:
970
+ ) and self.chunked_req is None:
904
971
  return None
905
972
 
906
973
  running_bs = len(self.running_batch.reqs) if self.running_batch else 0
@@ -914,7 +981,7 @@ class Scheduler:
914
981
  # Prefill policy
915
982
  adder = PrefillAdder(
916
983
  self.tree_cache,
917
- self.token_to_kv_pool,
984
+ self.token_to_kv_pool_allocator,
918
985
  self.running_batch,
919
986
  self.new_token_ratio,
920
987
  self.max_prefill_tokens,
@@ -922,10 +989,10 @@ class Scheduler:
922
989
  running_bs if self.is_mixed_chunk else 0,
923
990
  )
924
991
 
925
- has_being_chunked = self.being_chunked_req is not None
926
- if has_being_chunked:
927
- self.being_chunked_req.init_next_round_input()
928
- self.being_chunked_req = adder.add_being_chunked_req(self.being_chunked_req)
992
+ is_chunked = self.chunked_req is not None
993
+ if is_chunked:
994
+ self.chunked_req.init_next_round_input()
995
+ self.chunked_req = adder.add_chunked_req(self.chunked_req)
929
996
 
930
997
  if self.lora_paths:
931
998
  lora_set = (
@@ -933,7 +1000,6 @@ class Scheduler:
933
1000
  if self.running_batch is not None
934
1001
  else set([])
935
1002
  )
936
-
937
1003
  # Get requests from the waiting queue to a new prefill batch
938
1004
  for req in self.waiting_queue:
939
1005
  if (
@@ -953,7 +1019,31 @@ class Scheduler:
953
1019
  break
954
1020
 
955
1021
  req.init_next_round_input(None if prefix_computed else self.tree_cache)
956
- res = adder.add_one_req(req)
1022
+
1023
+ if self.enable_hierarchical_cache and req.last_node is not None:
1024
+ if req.last_node.evicted:
1025
+ # loading KV cache for the request
1026
+ req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
1027
+ req.last_node,
1028
+ req.prefix_indices,
1029
+ adder.rem_total_tokens,
1030
+ )
1031
+ if req.last_node.loading:
1032
+ # to prevent frequent cache invalidation
1033
+ if req.rid in self.staging_reqs:
1034
+ self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
1035
+ self.tree_cache.inc_lock_ref(req.last_node)
1036
+ self.staging_reqs[req.rid] = req.last_node
1037
+ continue
1038
+ elif req.last_node.loading:
1039
+ if not self.tree_cache.loading_complete(req.last_node):
1040
+ continue
1041
+
1042
+ if req.rid in self.staging_reqs:
1043
+ self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
1044
+ del self.staging_reqs[req.rid]
1045
+
1046
+ res = adder.add_one_req(req, self.chunked_req)
957
1047
  if res != AddReqResult.CONTINUE:
958
1048
  if res == AddReqResult.NO_TOKEN:
959
1049
  if self.enable_hierarchical_cache:
@@ -965,39 +1055,36 @@ class Scheduler:
965
1055
  else:
966
1056
  self.batch_is_full = True
967
1057
  break
968
- if self.server_args.prefill_only_one_req:
969
- break
970
1058
 
971
1059
  # Update waiting queue
972
- can_run_list = adder.can_run_list
1060
+ can_run_list: List[Req] = adder.can_run_list
973
1061
  if len(can_run_list) == 0:
974
1062
  return None
975
1063
  self.waiting_queue = [
976
1064
  x for x in self.waiting_queue if x not in set(can_run_list)
977
1065
  ]
978
1066
 
979
- if adder.new_being_chunked_req is not None:
980
- assert self.being_chunked_req is None
981
- self.being_chunked_req = adder.new_being_chunked_req
1067
+ if adder.new_chunked_req is not None:
1068
+ assert self.chunked_req is None
1069
+ self.chunked_req = adder.new_chunked_req
982
1070
 
983
- if self.being_chunked_req:
984
- self.being_chunked_req.is_being_chunked += 1
1071
+ if self.chunked_req:
1072
+ self.chunked_req.is_chunked += 1
985
1073
 
986
1074
  # Print stats
987
1075
  if self.attn_tp_rank == 0:
988
- self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
1076
+ self.log_prefill_stats(adder, can_run_list, running_bs)
989
1077
 
990
1078
  # Create a new batch
991
1079
  new_batch = ScheduleBatch.init_new(
992
1080
  can_run_list,
993
1081
  self.req_to_token_pool,
994
- self.token_to_kv_pool,
1082
+ self.token_to_kv_pool_allocator,
995
1083
  self.tree_cache,
996
1084
  self.model_config,
997
1085
  self.enable_overlap,
998
1086
  self.spec_algorithm,
999
1087
  self.server_args.enable_custom_logit_processor,
1000
- self.server_args.return_hidden_states,
1001
1088
  )
1002
1089
  new_batch.prepare_for_extend()
1003
1090
 
@@ -1021,8 +1108,6 @@ class Scheduler:
1021
1108
 
1022
1109
  def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1023
1110
  """Update the current running decoding batch."""
1024
- global test_retract
1025
-
1026
1111
  initial_bs = batch.batch_size()
1027
1112
 
1028
1113
  batch.filter_batch()
@@ -1032,35 +1117,25 @@ class Scheduler:
1032
1117
 
1033
1118
  # Check if decode out of memory
1034
1119
  if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1035
- test_retract and batch.batch_size() > 10
1120
+ TEST_RETRACT and batch.batch_size() > 10
1036
1121
  ):
1037
1122
  old_ratio = self.new_token_ratio
1038
1123
 
1039
- retracted_reqs, new_token_ratio = batch.retract_decode()
1124
+ retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
1040
1125
  self.new_token_ratio = new_token_ratio
1041
- if self.draft_worker:
1042
- self.draft_worker.finish_request(retracted_reqs)
1043
1126
 
1044
1127
  logger.info(
1045
1128
  "Decode out of memory happened. "
1046
1129
  f"#retracted_reqs: {len(retracted_reqs)}, "
1047
1130
  f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
1048
1131
  )
1049
- self.waiting_queue.extend(retracted_reqs)
1132
+ self._extend_requests_to_queue(retracted_reqs)
1050
1133
  else:
1051
1134
  self.new_token_ratio = max(
1052
1135
  self.new_token_ratio - self.new_token_ratio_decay,
1053
1136
  self.min_new_token_ratio,
1054
1137
  )
1055
1138
 
1056
- # Check for jump-forward
1057
- if not self.disable_jump_forward and batch.has_grammar:
1058
- jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
1059
- self.waiting_queue.extend(jump_forward_reqs)
1060
- if batch.is_empty():
1061
- self.batch_is_full = False
1062
- return None
1063
-
1064
1139
  if batch.batch_size() < initial_bs:
1065
1140
  self.batch_is_full = False
1066
1141
 
@@ -1074,17 +1149,26 @@ class Scheduler:
1074
1149
  """Run a batch."""
1075
1150
  self.forward_ct += 1
1076
1151
 
1152
+ # Check profiler
1153
+ if (
1154
+ self.profiler_target_forward_ct
1155
+ and self.profiler_target_forward_ct <= self.forward_ct
1156
+ ):
1157
+ self.stop_profile()
1158
+
1159
+ # Run forward
1077
1160
  if self.is_generation:
1078
1161
  if self.spec_algorithm.is_none():
1079
1162
  model_worker_batch = batch.get_model_worker_batch()
1080
1163
  logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
1081
1164
  model_worker_batch
1082
1165
  )
1166
+ bid = model_worker_batch.bid
1083
1167
  else:
1084
1168
  (
1085
1169
  logits_output,
1086
1170
  next_token_ids,
1087
- model_worker_batch,
1171
+ bid,
1088
1172
  num_accepted_tokens,
1089
1173
  ) = self.draft_worker.forward_batch_speculative_generation(batch)
1090
1174
  self.spec_num_total_accepted_tokens += (
@@ -1094,10 +1178,24 @@ class Scheduler:
1094
1178
  self.num_generated_tokens += num_accepted_tokens
1095
1179
  batch.output_ids = next_token_ids
1096
1180
 
1181
+ # These 2 values are needed for processing the output, but the values can be
1182
+ # modified by overlap schedule. So we have to copy them here so that
1183
+ # we can use the correct values in output processing.
1184
+ if batch.return_logprob:
1185
+ extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
1186
+ extend_logprob_start_len_per_req = [
1187
+ req.extend_logprob_start_len for req in batch.reqs
1188
+ ]
1189
+ else:
1190
+ extend_input_len_per_req = None
1191
+ extend_logprob_start_len_per_req = None
1192
+
1097
1193
  ret = GenerationBatchResult(
1098
1194
  logits_output=logits_output,
1099
1195
  next_token_ids=next_token_ids,
1100
- bid=model_worker_batch.bid,
1196
+ extend_input_len_per_req=extend_input_len_per_req,
1197
+ extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1198
+ bid=bid,
1101
1199
  )
1102
1200
  else: # embedding or reward model
1103
1201
  model_worker_batch = batch.get_model_worker_batch()
@@ -1121,11 +1219,22 @@ class Scheduler:
1121
1219
  elif batch.forward_mode.is_idle():
1122
1220
  if self.enable_overlap:
1123
1221
  self.tp_worker.resolve_batch_result(result.bid)
1222
+ if batch.next_batch_sampling_info:
1223
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
1224
+ self.current_stream.synchronize()
1225
+ batch.next_batch_sampling_info.sampling_info_done.set()
1124
1226
  elif batch.forward_mode.is_dummy_first():
1125
1227
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1126
1228
  self.current_stream.synchronize()
1127
1229
  batch.next_batch_sampling_info.sampling_info_done.set()
1128
1230
 
1231
+ if self.return_health_check_ct:
1232
+ # Return some signal for the health check.
1233
+ # This is used to prevent the health check signal being blocked by long context prefill.
1234
+ # However, one minor issue is that this code path does not check the status of detokenizer manager.
1235
+ self.return_health_check_ct -= 1
1236
+ self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
1237
+
1129
1238
  def process_batch_result_prefill(
1130
1239
  self,
1131
1240
  batch: ScheduleBatch,
@@ -1137,10 +1246,14 @@ class Scheduler:
1137
1246
  (
1138
1247
  logits_output,
1139
1248
  next_token_ids,
1249
+ extend_input_len_per_req,
1250
+ extend_logprob_start_len_per_req,
1140
1251
  bid,
1141
1252
  ) = (
1142
1253
  result.logits_output,
1143
1254
  result.next_token_ids,
1255
+ result.extend_input_len_per_req,
1256
+ result.extend_logprob_start_len_per_req,
1144
1257
  result.bid,
1145
1258
  )
1146
1259
 
@@ -1150,12 +1263,14 @@ class Scheduler:
1150
1263
  # Move next_token_ids and logprobs to cpu
1151
1264
  next_token_ids = next_token_ids.tolist()
1152
1265
  if batch.return_logprob:
1153
- logits_output.next_token_logprobs = (
1154
- logits_output.next_token_logprobs.tolist()
1155
- )
1156
- logits_output.input_token_logprobs = (
1157
- logits_output.input_token_logprobs.tolist()
1158
- )
1266
+ if logits_output.next_token_logprobs is not None:
1267
+ logits_output.next_token_logprobs = (
1268
+ logits_output.next_token_logprobs.tolist()
1269
+ )
1270
+ if logits_output.input_token_logprobs is not None:
1271
+ logits_output.input_token_logprobs = tuple(
1272
+ logits_output.input_token_logprobs.tolist()
1273
+ )
1159
1274
 
1160
1275
  hidden_state_offset = 0
1161
1276
 
@@ -1168,25 +1283,38 @@ class Scheduler:
1168
1283
  if self.is_mixed_chunk and self.enable_overlap and req.finished():
1169
1284
  # Free the one delayed token for the mixed decode batch
1170
1285
  j = len(batch.out_cache_loc) - len(batch.reqs) + i
1171
- self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
1286
+ self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
1172
1287
  continue
1173
1288
 
1174
- if req.is_being_chunked <= 0:
1289
+ if req.is_chunked <= 0:
1290
+ # req output_ids are set here
1175
1291
  req.output_ids.append(next_token_id)
1176
1292
  req.check_finished()
1177
1293
 
1178
1294
  if req.finished():
1179
1295
  self.tree_cache.cache_finished_req(req)
1180
1296
  elif not batch.decoding_reqs or req not in batch.decoding_reqs:
1297
+ # This updates radix so others can match
1181
1298
  self.tree_cache.cache_unfinished_req(req)
1182
1299
 
1183
1300
  if req.return_logprob:
1184
- logprob_pt += self.add_logprob_return_values(
1185
- i, req, logprob_pt, next_token_ids, logits_output
1301
+ assert extend_logprob_start_len_per_req is not None
1302
+ assert extend_input_len_per_req is not None
1303
+ extend_logprob_start_len = extend_logprob_start_len_per_req[i]
1304
+ extend_input_len = extend_input_len_per_req[i]
1305
+ num_input_logprobs = extend_input_len - extend_logprob_start_len
1306
+ self.add_logprob_return_values(
1307
+ i,
1308
+ req,
1309
+ logprob_pt,
1310
+ next_token_ids,
1311
+ num_input_logprobs,
1312
+ logits_output,
1186
1313
  )
1314
+ logprob_pt += num_input_logprobs
1187
1315
 
1188
1316
  if (
1189
- self.server_args.return_hidden_states
1317
+ req.return_hidden_states
1190
1318
  and logits_output.hidden_states is not None
1191
1319
  ):
1192
1320
  req.hidden_states.append(
@@ -1205,12 +1333,31 @@ class Scheduler:
1205
1333
  req.grammar.finished = req.finished()
1206
1334
  else:
1207
1335
  # being chunked reqs' prefill is not finished
1208
- req.is_being_chunked -= 1
1336
+ req.is_chunked -= 1
1209
1337
  # There is only at most one request being currently chunked.
1210
1338
  # Because this request does not finish prefill,
1211
1339
  # we don't want to stream the request currently being chunked.
1212
1340
  skip_stream_req = req
1213
1341
 
1342
+ # Incrementally update input logprobs.
1343
+ if req.return_logprob:
1344
+ extend_logprob_start_len = extend_logprob_start_len_per_req[i]
1345
+ extend_input_len = extend_input_len_per_req[i]
1346
+ if extend_logprob_start_len < extend_input_len:
1347
+ # Update input logprobs.
1348
+ num_input_logprobs = (
1349
+ extend_input_len - extend_logprob_start_len
1350
+ )
1351
+ self.add_input_logprob_return_values(
1352
+ i,
1353
+ req,
1354
+ logits_output,
1355
+ logprob_pt,
1356
+ num_input_logprobs,
1357
+ last_prefill_chunk=False,
1358
+ )
1359
+ logprob_pt += num_input_logprobs
1360
+
1214
1361
  if batch.next_batch_sampling_info:
1215
1362
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1216
1363
  self.current_stream.synchronize()
@@ -1226,7 +1373,7 @@ class Scheduler:
1226
1373
  continue
1227
1374
 
1228
1375
  req.embedding = embeddings[i]
1229
- if req.is_being_chunked <= 0:
1376
+ if req.is_chunked <= 0:
1230
1377
  # Dummy output token for embedding models
1231
1378
  req.output_ids.append(0)
1232
1379
  req.check_finished()
@@ -1237,7 +1384,7 @@ class Scheduler:
1237
1384
  self.tree_cache.cache_unfinished_req(req)
1238
1385
  else:
1239
1386
  # being chunked reqs' prefill is not finished
1240
- req.is_being_chunked -= 1
1387
+ req.is_chunked -= 1
1241
1388
 
1242
1389
  self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1243
1390
 
@@ -1254,23 +1401,27 @@ class Scheduler:
1254
1401
  self.num_generated_tokens += len(batch.reqs)
1255
1402
 
1256
1403
  if self.enable_overlap:
1404
+ assert batch.spec_algorithm.is_none()
1257
1405
  logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1258
1406
  next_token_logprobs = logits_output.next_token_logprobs
1259
- else:
1407
+ elif batch.spec_algorithm.is_none():
1408
+ # spec decoding handles output logprobs inside verify process.
1260
1409
  next_token_ids = next_token_ids.tolist()
1261
1410
  if batch.return_logprob:
1262
1411
  next_token_logprobs = logits_output.next_token_logprobs.tolist()
1263
1412
 
1264
- self.token_to_kv_pool.free_group_begin()
1413
+ self.token_to_kv_pool_allocator.free_group_begin()
1265
1414
 
1266
1415
  # Check finish condition
1416
+ # NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
1417
+ # We should ignore using next_token_ids for spec decoding cases.
1267
1418
  for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1268
1419
  if req.is_retracted:
1269
1420
  continue
1270
1421
 
1271
1422
  if self.enable_overlap and req.finished():
1272
1423
  # Free the one delayed token
1273
- self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
1424
+ self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
1274
1425
  continue
1275
1426
 
1276
1427
  if batch.spec_algorithm.is_none():
@@ -1278,11 +1429,11 @@ class Scheduler:
1278
1429
  req.output_ids.append(next_token_id)
1279
1430
 
1280
1431
  req.check_finished()
1281
-
1282
1432
  if req.finished():
1283
1433
  self.tree_cache.cache_finished_req(req)
1284
1434
 
1285
- if req.return_logprob:
1435
+ if req.return_logprob and batch.spec_algorithm.is_none():
1436
+ # speculative worker handles logprob in speculative decoding
1286
1437
  req.output_token_logprobs_val.append(next_token_logprobs[i])
1287
1438
  req.output_token_logprobs_idx.append(next_token_id)
1288
1439
  if req.top_logprobs_num > 0:
@@ -1292,14 +1443,18 @@ class Scheduler:
1292
1443
  req.output_top_logprobs_idx.append(
1293
1444
  logits_output.next_token_top_logprobs_idx[i]
1294
1445
  )
1446
+ if req.token_ids_logprob is not None:
1447
+ req.output_token_ids_logprobs_val.append(
1448
+ logits_output.next_token_token_ids_logprobs_val[i]
1449
+ )
1450
+ req.output_token_ids_logprobs_idx.append(
1451
+ logits_output.next_token_token_ids_logprobs_idx[i]
1452
+ )
1295
1453
 
1296
- if (
1297
- self.server_args.return_hidden_states
1298
- and logits_output.hidden_states is not None
1299
- ):
1454
+ if req.return_hidden_states and logits_output.hidden_states is not None:
1300
1455
  req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
1301
1456
 
1302
- if req.grammar is not None:
1457
+ if req.grammar is not None and batch.spec_algorithm.is_none():
1303
1458
  req.grammar.accept_token(next_token_id)
1304
1459
  req.grammar.finished = req.finished()
1305
1460
 
@@ -1310,7 +1465,7 @@ class Scheduler:
1310
1465
 
1311
1466
  self.stream_output(batch.reqs, batch.return_logprob)
1312
1467
 
1313
- self.token_to_kv_pool.free_group_end()
1468
+ self.token_to_kv_pool_allocator.free_group_end()
1314
1469
 
1315
1470
  self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
1316
1471
  if (
@@ -1319,86 +1474,169 @@ class Scheduler:
1319
1474
  ):
1320
1475
  self.log_decode_stats()
1321
1476
 
1322
- def add_logprob_return_values(
1477
+ def add_input_logprob_return_values(
1323
1478
  self,
1324
1479
  i: int,
1325
1480
  req: Req,
1326
- pt: int,
1327
- next_token_ids: List[int],
1328
1481
  output: LogitsProcessorOutput,
1482
+ logprob_pt: int,
1483
+ num_input_logprobs: int,
1484
+ last_prefill_chunk: bool, # If True, it means prefill is finished.
1329
1485
  ):
1330
- """Attach logprobs to the return values."""
1331
- req.output_token_logprobs_val.append(output.next_token_logprobs[i])
1332
- req.output_token_logprobs_idx.append(next_token_ids[i])
1486
+ """Incrementally add input logprobs to `req`.
1487
+
1488
+ Args:
1489
+ i: The request index in a batch.
1490
+ req: The request. Input logprobs inside req are modified as a
1491
+ consequence of the API
1492
+ fill_ids: The prefill ids processed.
1493
+ output: Logit processor output that's used to compute input logprobs
1494
+ last_prefill_chunk: True if it is the last prefill (when chunked).
1495
+ Some of input logprob operation should only happen at the last
1496
+ prefill (e.g., computing input token logprobs).
1497
+ """
1498
+ assert output.input_token_logprobs is not None
1499
+ if req.input_token_logprobs is None:
1500
+ req.input_token_logprobs = []
1501
+ if req.temp_input_top_logprobs_val is None:
1502
+ req.temp_input_top_logprobs_val = []
1503
+ if req.temp_input_top_logprobs_idx is None:
1504
+ req.temp_input_top_logprobs_idx = []
1505
+ if req.temp_input_token_ids_logprobs_val is None:
1506
+ req.temp_input_token_ids_logprobs_val = []
1507
+ if req.temp_input_token_ids_logprobs_idx is None:
1508
+ req.temp_input_token_ids_logprobs_idx = []
1509
+
1510
+ if req.input_token_logprobs_val is not None:
1511
+ # The input logprob has been already computed. It only happens
1512
+ # upon retract.
1513
+ if req.top_logprobs_num > 0:
1514
+ assert req.input_token_logprobs_val is not None
1515
+ return
1333
1516
 
1334
- # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
1335
- num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
1517
+ # Important for the performance.
1518
+ assert isinstance(output.input_token_logprobs, tuple)
1519
+ input_token_logprobs: Tuple[int] = output.input_token_logprobs
1520
+ input_token_logprobs = input_token_logprobs[
1521
+ logprob_pt : logprob_pt + num_input_logprobs
1522
+ ]
1523
+ req.input_token_logprobs.extend(input_token_logprobs)
1336
1524
 
1337
- if req.input_token_logprobs_val is None:
1338
- input_token_logprobs_val = output.input_token_logprobs[
1339
- pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
1340
- ]
1525
+ if req.top_logprobs_num > 0:
1526
+ req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
1527
+ req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
1341
1528
 
1342
- input_token_logprobs_idx = req.fill_ids[
1343
- len(req.fill_ids)
1344
- - num_input_logprobs
1345
- + 1 : len(req.fill_ids)
1346
- - req.last_update_decode_tokens
1347
- ]
1529
+ if req.token_ids_logprob is not None:
1530
+ req.temp_input_token_ids_logprobs_val.append(
1531
+ output.input_token_ids_logprobs_val[i]
1532
+ )
1533
+ req.temp_input_token_ids_logprobs_idx.append(
1534
+ output.input_token_ids_logprobs_idx[i]
1535
+ )
1536
+
1537
+ if last_prefill_chunk:
1538
+ input_token_logprobs = req.input_token_logprobs
1539
+ req.input_token_logprobs = None
1540
+ assert req.input_token_logprobs_val is None
1541
+ assert req.input_token_logprobs_idx is None
1542
+ assert req.input_top_logprobs_val is None
1543
+ assert req.input_top_logprobs_idx is None
1544
+
1545
+ # Compute input_token_logprobs_val
1546
+ # Always pad the first one with None.
1547
+ req.input_token_logprobs_val = [None]
1548
+ req.input_token_logprobs_val.extend(input_token_logprobs)
1549
+ # The last input logprob is for sampling, so just pop it out.
1550
+ req.input_token_logprobs_val.pop()
1551
+
1552
+ # Compute input_token_logprobs_idx
1553
+ input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
1348
1554
  # Clip the padded hash values from image tokens.
1349
1555
  # Otherwise, it will lead to detokenization errors.
1350
1556
  input_token_logprobs_idx = [
1351
1557
  x if x < self.model_config.vocab_size - 1 else 0
1352
1558
  for x in input_token_logprobs_idx
1353
1559
  ]
1560
+ req.input_token_logprobs_idx = input_token_logprobs_idx
1354
1561
 
1355
- if (
1356
- req.logprob_start_len == 0
1357
- ): # The first token does not have logprob, pad it.
1358
- input_token_logprobs_val = [None] + input_token_logprobs_val
1359
- input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx
1562
+ if req.top_logprobs_num > 0:
1563
+ req.input_top_logprobs_val = [None]
1564
+ req.input_top_logprobs_idx = [None]
1565
+ assert len(req.temp_input_token_ids_logprobs_val) == len(
1566
+ req.temp_input_token_ids_logprobs_idx
1567
+ )
1568
+ for val, idx in zip(
1569
+ req.temp_input_top_logprobs_val,
1570
+ req.temp_input_top_logprobs_idx,
1571
+ strict=True,
1572
+ ):
1573
+ req.input_top_logprobs_val.extend(val)
1574
+ req.input_top_logprobs_idx.extend(idx)
1575
+
1576
+ # Last token is a sample token.
1577
+ req.input_top_logprobs_val.pop()
1578
+ req.input_top_logprobs_idx.pop()
1579
+ req.temp_input_top_logprobs_idx = None
1580
+ req.temp_input_top_logprobs_val = None
1581
+
1582
+ if req.token_ids_logprob is not None:
1583
+ req.input_token_ids_logprobs_val = [None]
1584
+ req.input_token_ids_logprobs_idx = [None]
1585
+
1586
+ for val, idx in zip(
1587
+ req.temp_input_token_ids_logprobs_val,
1588
+ req.temp_input_token_ids_logprobs_idx,
1589
+ strict=True,
1590
+ ):
1591
+ req.input_token_ids_logprobs_val.extend(val)
1592
+ req.input_token_ids_logprobs_idx.extend(idx)
1360
1593
 
1361
- req.input_token_logprobs_val = input_token_logprobs_val
1362
- req.input_token_logprobs_idx = input_token_logprobs_idx
1594
+ # Last token is a sample token.
1595
+ req.input_token_ids_logprobs_val.pop()
1596
+ req.input_token_ids_logprobs_idx.pop()
1597
+ req.temp_input_token_ids_logprobs_idx = None
1598
+ req.temp_input_token_ids_logprobs_val = None
1363
1599
 
1364
- if req.last_update_decode_tokens != 0:
1365
- # Some decode tokens are re-computed in an extend batch
1366
- req.output_token_logprobs_val.extend(
1367
- output.input_token_logprobs[
1368
- pt
1369
- + num_input_logprobs
1370
- - 1
1371
- - req.last_update_decode_tokens : pt
1372
- + num_input_logprobs
1373
- - 1
1374
- ],
1375
- )
1376
- req.output_token_logprobs_idx.extend(
1377
- req.fill_ids[
1378
- len(req.fill_ids)
1379
- - req.last_update_decode_tokens : len(req.fill_ids)
1380
- ]
1381
- )
1600
+ if req.return_logprob:
1601
+ relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
1602
+ assert len(req.input_token_logprobs_val) == relevant_tokens_len
1603
+ assert len(req.input_token_logprobs_idx) == relevant_tokens_len
1604
+ if req.top_logprobs_num > 0:
1605
+ assert len(req.input_top_logprobs_val) == relevant_tokens_len
1606
+ assert len(req.input_top_logprobs_idx) == relevant_tokens_len
1607
+ if req.token_ids_logprob is not None:
1608
+ assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
1609
+ assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
1382
1610
 
1383
- if req.top_logprobs_num > 0:
1384
- if req.input_top_logprobs_val is None:
1385
- req.input_top_logprobs_val = output.input_top_logprobs_val[i]
1386
- req.input_top_logprobs_idx = output.input_top_logprobs_idx[i]
1387
- if req.logprob_start_len == 0:
1388
- req.input_top_logprobs_val = [None] + req.input_top_logprobs_val
1389
- req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx
1390
-
1391
- if req.last_update_decode_tokens != 0:
1392
- req.output_top_logprobs_val.extend(
1393
- output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
1394
- )
1395
- req.output_top_logprobs_idx.extend(
1396
- output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
1397
- )
1611
+ def add_logprob_return_values(
1612
+ self,
1613
+ i: int,
1614
+ req: Req,
1615
+ pt: int,
1616
+ next_token_ids: List[int],
1617
+ num_input_logprobs: int,
1618
+ output: LogitsProcessorOutput,
1619
+ ):
1620
+ """Attach logprobs to the return values."""
1621
+ req.output_token_logprobs_val.append(output.next_token_logprobs[i])
1622
+ req.output_token_logprobs_idx.append(next_token_ids[i])
1623
+
1624
+ self.add_input_logprob_return_values(
1625
+ i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
1626
+ )
1398
1627
 
1628
+ if req.top_logprobs_num > 0:
1399
1629
  req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
1400
1630
  req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
1401
1631
 
1632
+ if req.token_ids_logprob is not None:
1633
+ req.output_token_ids_logprobs_val.append(
1634
+ output.next_token_token_ids_logprobs_val[i]
1635
+ )
1636
+ req.output_token_ids_logprobs_idx.append(
1637
+ output.next_token_token_ids_logprobs_idx[i]
1638
+ )
1639
+
1402
1640
  return num_input_logprobs
1403
1641
 
1404
1642
  def stream_output(
@@ -1409,7 +1647,6 @@ class Scheduler:
1409
1647
  finished_reasons: List[BaseFinishReason] = []
1410
1648
 
1411
1649
  if self.is_generation:
1412
- vids = []
1413
1650
  decoded_texts = []
1414
1651
  decode_ids_list = []
1415
1652
  read_offsets = []
@@ -1422,7 +1659,7 @@ class Scheduler:
1422
1659
  completion_tokens = []
1423
1660
  cached_tokens = []
1424
1661
  spec_verify_ct = []
1425
- hidden_states = []
1662
+ output_hidden_states = None
1426
1663
 
1427
1664
  if return_logprob:
1428
1665
  input_token_logprobs_val = []
@@ -1433,33 +1670,46 @@ class Scheduler:
1433
1670
  input_top_logprobs_idx = []
1434
1671
  output_top_logprobs_val = []
1435
1672
  output_top_logprobs_idx = []
1673
+ input_token_ids_logprobs_val = []
1674
+ input_token_ids_logprobs_idx = []
1675
+ output_token_ids_logprobs_val = []
1676
+ output_token_ids_logprobs_idx = []
1436
1677
  else:
1437
1678
  input_token_logprobs_val = input_token_logprobs_idx = (
1438
1679
  output_token_logprobs_val
1439
1680
  ) = output_token_logprobs_idx = input_top_logprobs_val = (
1440
1681
  input_top_logprobs_idx
1441
- ) = output_top_logprobs_val = output_top_logprobs_idx = None
1682
+ ) = output_top_logprobs_val = output_top_logprobs_idx = (
1683
+ input_token_ids_logprobs_val
1684
+ ) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
1685
+ output_token_ids_logprobs_idx
1686
+ ) = None
1442
1687
 
1443
1688
  for req in reqs:
1444
1689
  if req is skip_req:
1445
1690
  continue
1446
1691
 
1447
- # TODO(lianmin): revisit this for overlap + retract + stream
1692
+ # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
1693
+ if self.model_config.is_multimodal_gen and req.to_abort:
1694
+ continue
1695
+
1448
1696
  if (
1449
1697
  req.finished()
1450
1698
  # If stream, follow the given stream_interval
1451
1699
  or (req.stream and len(req.output_ids) % self.stream_interval == 0)
1452
1700
  # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
1453
- or (not req.stream and len(req.output_ids) % 50 == 0)
1701
+ # TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
1702
+ # always increase one-by-one.
1703
+ or (
1704
+ not req.stream
1705
+ and len(req.output_ids) % 50 == 0
1706
+ and not self.model_config.is_multimodal_gen
1707
+ )
1454
1708
  ):
1455
- if self.draft_worker and req.finished():
1456
- self.draft_worker.finish_request(req)
1457
-
1458
1709
  rids.append(req.rid)
1459
1710
  finished_reasons.append(
1460
1711
  req.finished_reason.to_json() if req.finished_reason else None
1461
1712
  )
1462
- vids.append(req.vid)
1463
1713
  decoded_texts.append(req.decoded_text)
1464
1714
  decode_ids, read_offset = req.init_incremental_detokenize()
1465
1715
  decode_ids_list.append(decode_ids)
@@ -1488,16 +1738,32 @@ class Scheduler:
1488
1738
  input_top_logprobs_idx.append(req.input_top_logprobs_idx)
1489
1739
  output_top_logprobs_val.append(req.output_top_logprobs_val)
1490
1740
  output_top_logprobs_idx.append(req.output_top_logprobs_idx)
1741
+ input_token_ids_logprobs_val.append(
1742
+ req.input_token_ids_logprobs_val
1743
+ )
1744
+ input_token_ids_logprobs_idx.append(
1745
+ req.input_token_ids_logprobs_idx
1746
+ )
1747
+ output_token_ids_logprobs_val.append(
1748
+ req.output_token_ids_logprobs_val
1749
+ )
1750
+ output_token_ids_logprobs_idx.append(
1751
+ req.output_token_ids_logprobs_idx
1752
+ )
1491
1753
 
1492
- hidden_states.append(req.hidden_states)
1754
+ if req.return_hidden_states:
1755
+ if output_hidden_states is None:
1756
+ output_hidden_states = []
1757
+ output_hidden_states.append(req.hidden_states)
1493
1758
 
1494
1759
  # Send to detokenizer
1495
1760
  if rids:
1761
+ if self.model_config.is_multimodal_gen:
1762
+ raise NotImplementedError()
1496
1763
  self.send_to_detokenizer.send_pyobj(
1497
1764
  BatchTokenIDOut(
1498
1765
  rids,
1499
1766
  finished_reasons,
1500
- vids,
1501
1767
  decoded_texts,
1502
1768
  decode_ids_list,
1503
1769
  read_offsets,
@@ -1517,20 +1783,28 @@ class Scheduler:
1517
1783
  input_top_logprobs_idx,
1518
1784
  output_top_logprobs_val,
1519
1785
  output_top_logprobs_idx,
1520
- hidden_states,
1786
+ input_token_ids_logprobs_val,
1787
+ input_token_ids_logprobs_idx,
1788
+ output_token_ids_logprobs_val,
1789
+ output_token_ids_logprobs_idx,
1790
+ output_hidden_states,
1521
1791
  )
1522
1792
  )
1523
1793
  else: # embedding or reward model
1524
1794
  embeddings = []
1525
1795
  prompt_tokens = []
1796
+ cached_tokens = []
1526
1797
  for req in reqs:
1527
1798
  if req.finished():
1528
1799
  rids.append(req.rid)
1529
1800
  finished_reasons.append(req.finished_reason.to_json())
1530
1801
  embeddings.append(req.embedding)
1531
1802
  prompt_tokens.append(len(req.origin_input_ids))
1803
+ cached_tokens.append(req.cached_tokens)
1532
1804
  self.send_to_detokenizer.send_pyobj(
1533
- BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
1805
+ BatchEmbeddingOut(
1806
+ rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
1807
+ )
1534
1808
  )
1535
1809
 
1536
1810
  def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
@@ -1575,13 +1849,12 @@ class Scheduler:
1575
1849
  idle_batch = ScheduleBatch.init_new(
1576
1850
  [],
1577
1851
  self.req_to_token_pool,
1578
- self.token_to_kv_pool,
1852
+ self.token_to_kv_pool_allocator,
1579
1853
  self.tree_cache,
1580
1854
  self.model_config,
1581
1855
  self.enable_overlap,
1582
1856
  self.spec_algorithm,
1583
1857
  self.server_args.enable_custom_logit_processor,
1584
- self.server_args.return_hidden_states,
1585
1858
  )
1586
1859
  idle_batch.prepare_for_idle()
1587
1860
  return idle_batch
@@ -1596,20 +1869,58 @@ class Scheduler:
1596
1869
  except futures._base.TimeoutError:
1597
1870
  break
1598
1871
 
1599
- if self.tp_size > 1:
1872
+ if self.server_args.enable_dp_attention:
1873
+ tp_size = self.attn_tp_size
1874
+ tp_group = self.attn_tp_cpu_group
1875
+ else:
1876
+ tp_size = self.tp_size
1877
+ tp_group = self.tp_cpu_group
1878
+
1879
+ if tp_size > 1:
1600
1880
  # Sync across TP ranks to make sure they have the same number of ready requests
1601
1881
  tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
1602
1882
  torch.distributed.all_reduce(
1603
- tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
1883
+ tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
1604
1884
  )
1605
1885
  num_ready_reqs_max = tensor.item()
1606
1886
  for i in range(num_ready_reqs, num_ready_reqs_max):
1607
1887
  self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
1608
1888
  num_ready_reqs = num_ready_reqs_max
1609
1889
 
1610
- self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
1890
+ self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1611
1891
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
1612
1892
 
1893
+ def watchdog_thread(self):
1894
+ """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
1895
+ self.watchdog_last_forward_ct = 0
1896
+ self.watchdog_last_time = time.time()
1897
+
1898
+ while True:
1899
+ current = time.time()
1900
+ if self.cur_batch is not None:
1901
+ if self.watchdog_last_forward_ct == self.forward_ct:
1902
+ if current > self.watchdog_last_time + self.watchdog_timeout:
1903
+ logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
1904
+ break
1905
+ else:
1906
+ self.watchdog_last_forward_ct = self.forward_ct
1907
+ self.watchdog_last_time = current
1908
+ time.sleep(self.watchdog_timeout // 2)
1909
+
1910
+ # Print batch size and memory pool info to check whether there are de-sync issues.
1911
+ logger.error(
1912
+ f"{self.cur_batch.batch_size()=}, "
1913
+ f"{self.cur_batch.reqs=}, "
1914
+ f"{self.token_to_kv_pool_allocator.available_size()=}, "
1915
+ f"{self.tree_cache.evictable_size()=}, "
1916
+ )
1917
+ # Wait for some time so that the parent process can print the error.
1918
+ pyspy_dump_schedulers()
1919
+ print(file=sys.stderr, flush=True)
1920
+ print(file=sys.stdout, flush=True)
1921
+ time.sleep(5)
1922
+ self.parent_process.send_signal(signal.SIGQUIT)
1923
+
1613
1924
  def flush_cache_wrapped(self, recv_req: FlushCacheReq):
1614
1925
  self.flush_cache()
1615
1926
 
@@ -1618,21 +1929,24 @@ class Scheduler:
1618
1929
  if len(self.waiting_queue) == 0 and (
1619
1930
  self.running_batch is None or len(self.running_batch.reqs) == 0
1620
1931
  ):
1932
+ self.cur_batch = None
1933
+ self.last_batch = None
1621
1934
  self.tree_cache.reset()
1622
- self.tree_cache_metrics = {"total": 0, "hit": 0}
1623
1935
  if self.grammar_backend:
1624
1936
  self.grammar_backend.reset()
1625
1937
  self.req_to_token_pool.clear()
1626
- self.token_to_kv_pool.clear()
1938
+ self.token_to_kv_pool_allocator.clear()
1627
1939
 
1628
1940
  if not self.spec_algorithm.is_none():
1629
1941
  self.draft_worker.model_runner.req_to_token_pool.clear()
1630
- self.draft_worker.model_runner.token_to_kv_pool.clear()
1942
+ self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
1631
1943
 
1632
1944
  self.num_generated_tokens = 0
1633
1945
  self.forward_ct_decode = 0
1634
1946
  self.spec_num_total_accepted_tokens = 0
1635
1947
  self.spec_num_total_forward_ct = 0
1948
+ self.cum_spec_accept_length = 0
1949
+ self.cum_spec_accept_count = 0
1636
1950
  torch.cuda.empty_cache()
1637
1951
  logger.info("Cache flushed successfully!")
1638
1952
  if_success = True
@@ -1645,6 +1959,49 @@ class Scheduler:
1645
1959
  if_success = False
1646
1960
  return if_success
1647
1961
 
1962
+ def get_internal_state(self, recv_req: GetInternalStateReq):
1963
+ ret = dict(global_server_args_dict)
1964
+ ret["last_gen_throughput"] = self.last_gen_throughput
1965
+ if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
1966
+ ret["avg_spec_accept_length"] = (
1967
+ self.cum_spec_accept_length / self.cum_spec_accept_count
1968
+ )
1969
+
1970
+ if RECORD_STEP_TIME:
1971
+ ret["step_time_dict"] = self.step_time_dict
1972
+ return GetInternalStateReqOutput(
1973
+ internal_state=ret,
1974
+ )
1975
+
1976
+ def set_internal_state(self, recv_req: SetInternalStateReq):
1977
+ server_args_dict = recv_req.server_args
1978
+ args_allow_update = set(
1979
+ [
1980
+ "speculative_accept_threshold_single",
1981
+ "speculative_accept_threshold_acc",
1982
+ ]
1983
+ )
1984
+ if_success = True
1985
+ for k, v in server_args_dict.items():
1986
+ if k not in args_allow_update:
1987
+ logging.warning(f"Updating {k} is not supported.")
1988
+ if_success = False
1989
+ break
1990
+ if if_success:
1991
+ if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
1992
+ avg_spec_accept_length = (
1993
+ self.cum_spec_accept_length / self.cum_spec_accept_count
1994
+ )
1995
+ logger.info(f"{avg_spec_accept_length=}")
1996
+ self.cum_spec_accept_length = self.cum_spec_accept_count = 0
1997
+ for k, v in server_args_dict.items():
1998
+ global_server_args_dict[k] = v
1999
+ logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
2000
+ return SetInternalStateReqOutput(
2001
+ updated=True,
2002
+ server_args=global_server_args_dict,
2003
+ )
2004
+
1648
2005
  def abort_request(self, recv_req: AbortReq):
1649
2006
  # Delete requests in the waiting queue
1650
2007
  to_del = None
@@ -1666,6 +2023,9 @@ class Scheduler:
1666
2023
  req.to_abort = True
1667
2024
  break
1668
2025
 
2026
+ def _pause_engine(self) -> Tuple[List[Req], int]:
2027
+ raise NotImplementedError()
2028
+
1669
2029
  def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
1670
2030
  """In-place update of the weights from disk."""
1671
2031
  success, message = self.tp_worker.update_weights_from_disk(recv_req)
@@ -1674,7 +2034,7 @@ class Scheduler:
1674
2034
  assert flash_cache_success, "Cache flush failed after updating weights"
1675
2035
  else:
1676
2036
  logger.error(message)
1677
- return UpdateWeightFromDiskReqOutput(success, message)
2037
+ return UpdateWeightFromDiskReqOutput(success, message, 0)
1678
2038
 
1679
2039
  def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
1680
2040
  """Initialize the online model parameter update group."""
@@ -1699,8 +2059,9 @@ class Scheduler:
1699
2059
  success, message = self.tp_worker.update_weights_from_tensor(recv_req)
1700
2060
  # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
1701
2061
  if success:
1702
- flash_cache_success = self.flush_cache()
1703
- assert flash_cache_success, "Cache flush failed after updating weights"
2062
+ if recv_req.flush_cache:
2063
+ flash_cache_success = self.flush_cache()
2064
+ assert flash_cache_success, "Cache flush failed after updating weights"
1704
2065
  else:
1705
2066
  logger.error(message)
1706
2067
  return UpdateWeightsFromTensorReqOutput(success, message)
@@ -1709,7 +2070,7 @@ class Scheduler:
1709
2070
  parameter = self.tp_worker.get_weights_by_name(recv_req)
1710
2071
  return GetWeightsByNameReqOutput(parameter)
1711
2072
 
1712
- def release_memory_occupation(self):
2073
+ def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
1713
2074
  self.stashed_model_static_state = _export_static_state(
1714
2075
  self.tp_worker.worker.model_runner.model
1715
2076
  )
@@ -1717,7 +2078,7 @@ class Scheduler:
1717
2078
  self.flush_cache()
1718
2079
  return ReleaseMemoryOccupationReqOutput()
1719
2080
 
1720
- def resume_memory_occupation(self):
2081
+ def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
1721
2082
  self.memory_saver_adapter.resume()
1722
2083
  _import_static_state(
1723
2084
  self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
@@ -1726,24 +2087,96 @@ class Scheduler:
1726
2087
  return ResumeMemoryOccupationReqOutput()
1727
2088
 
1728
2089
  def profile(self, recv_req: ProfileReq):
1729
- if recv_req == ProfileReq.START_PROFILE:
1730
- self.start_profile()
2090
+ if recv_req.type == ProfileReqType.START_PROFILE:
2091
+ return self.start_profile(
2092
+ recv_req.output_dir, recv_req.num_steps, recv_req.activities
2093
+ )
1731
2094
  else:
1732
- self.stop_profile()
2095
+ return self.stop_profile()
2096
+
2097
+ def start_profile(
2098
+ self,
2099
+ output_dir: Optional[str],
2100
+ num_steps: Optional[int],
2101
+ activities: Optional[List[str]],
2102
+ ) -> None:
2103
+ if self.torch_profiler_activities:
2104
+ return ProfileReqOutput(
2105
+ success=False,
2106
+ message="Profiling is already in progress. Call /stop_profile first.",
2107
+ )
2108
+
2109
+ if output_dir is None:
2110
+ output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
2111
+ if activities is None:
2112
+ activities = ["CPU", "GPU"]
2113
+
2114
+ self.torch_profiler_output_dir = output_dir
2115
+ self.torch_profiler_activities = activities
2116
+ logger.info(
2117
+ "Profiling starts. Traces will be saved to: %s",
2118
+ self.torch_profiler_output_dir,
2119
+ )
2120
+
2121
+ activity_map = {
2122
+ "CPU": torch.profiler.ProfilerActivity.CPU,
2123
+ "GPU": torch.profiler.ProfilerActivity.CUDA,
2124
+ }
2125
+ torchprof_activities = [
2126
+ activity_map[a] for a in activities if a in activity_map
2127
+ ]
2128
+
2129
+ if torchprof_activities:
2130
+ self.torch_profiler = torch.profiler.profile(
2131
+ activities=torchprof_activities,
2132
+ with_stack=True,
2133
+ )
2134
+ self.torch_profiler.start()
2135
+
2136
+ if "MEM" in activities:
2137
+ torch.cuda.memory._record_memory_history(max_entries=100000)
1733
2138
 
1734
- def start_profile(self) -> None:
1735
- if self.profiler is None:
1736
- raise RuntimeError("Profiler is not enabled.")
1737
- self.profiler.start()
2139
+ if num_steps:
2140
+ self.profiler_target_forward_ct = self.forward_ct + num_steps
2141
+ # The caller will be notified when reaching profiler_target_forward_ct
2142
+ else:
2143
+ self.profiler_target_forward_ct = None
2144
+ return ProfileReqOutput(success=True, message="Succeeded")
1738
2145
 
1739
2146
  def stop_profile(self) -> None:
1740
- if self.profiler is None:
1741
- raise RuntimeError("Profiler is not enabled.")
1742
- self.profiler.stop()
1743
- self.profiler.export_chrome_trace(
1744
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
2147
+ if self.torch_profiler_activities is None:
2148
+ return
2149
+
2150
+ logger.info("Stop profiling...")
2151
+ if self.torch_profiler is not None:
2152
+ self.torch_profiler.stop()
2153
+ self.torch_profiler.export_chrome_trace(
2154
+ os.path.join(
2155
+ self.torch_profiler_output_dir,
2156
+ str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
2157
+ )
2158
+ )
2159
+
2160
+ if "MEM" in self.torch_profiler_activities:
2161
+ memory_profile_path = os.path.join(
2162
+ self.torch_profiler_trace_dir,
2163
+ str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
2164
+ )
2165
+ torch.cuda.memory._dump_snapshot(memory_profile_path)
2166
+ torch.cuda.memory._record_memory_history(enabled=None)
2167
+
2168
+ logger.info(
2169
+ "Profiling done. Traces are saved to: %s",
2170
+ self.torch_profiler_output_dir,
1745
2171
  )
1746
- logger.info("Profiler is done")
2172
+ self.torch_profiler = None
2173
+ self.torch_profiler_output_dir = None
2174
+ self.torch_profiler_activities = None
2175
+
2176
+ if self.profiler_target_forward_ct:
2177
+ self.send_to_tokenizer.send_pyobj(
2178
+ ProfileReqOutput(success=True, message="Succeeded.")
2179
+ )
1747
2180
 
1748
2181
  def open_session(self, recv_req: OpenSessionReqInput):
1749
2182
  # handle error
@@ -1752,7 +2185,7 @@ class Scheduler:
1752
2185
  logger.warning(f"session id {session_id} already exist, cannot open.")
1753
2186
  return OpenSessionReqOutput(session_id, False)
1754
2187
  elif session_id is None:
1755
- logger.warning(f"session id is None, cannot open.")
2188
+ logger.warning("session id is None, cannot open.")
1756
2189
  return OpenSessionReqOutput(session_id, False)
1757
2190
  else:
1758
2191
  self.sessions[session_id] = Session(
@@ -1769,6 +2202,10 @@ class Scheduler:
1769
2202
  del self.sessions[session_id]
1770
2203
 
1771
2204
 
2205
+ def is_health_check_generate_req(recv_req):
2206
+ return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2207
+
2208
+
1772
2209
  def _export_static_state(model):
1773
2210
  return dict(
1774
2211
  buffers=[
@@ -1791,26 +2228,28 @@ def run_scheduler_process(
1791
2228
  dp_rank: Optional[int],
1792
2229
  pipe_writer,
1793
2230
  ):
1794
- setproctitle.setproctitle("sglang::scheduler")
2231
+ # Config the process
2232
+ # kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2`
2233
+ setproctitle.setproctitle(f"sglang::scheduler_{dp_rank}")
1795
2234
  faulthandler.enable()
2235
+ parent_process = psutil.Process().parent()
1796
2236
 
1797
2237
  # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
1798
2238
  if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
1799
2239
  dp_rank = int(os.environ["SGLANG_DP_RANK"])
1800
2240
 
1801
- # Configue the logger
2241
+ # Configure the logger
1802
2242
  if dp_rank is None:
1803
- configure_logger(server_args, prefix=f" TP{tp_rank}")
2243
+ prefix = f" TP{tp_rank}"
1804
2244
  else:
1805
- configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
2245
+ prefix = f" DP{dp_rank} TP{tp_rank}"
2246
+ configure_logger(server_args, prefix=prefix)
1806
2247
  suppress_other_loggers()
1807
2248
 
1808
2249
  # Set cpu affinity to this gpu process
1809
2250
  if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
1810
2251
  set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
1811
2252
 
1812
- parent_process = psutil.Process().parent()
1813
-
1814
2253
  # Create a scheduler and run the event loop
1815
2254
  try:
1816
2255
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)