sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +302 -414
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +13 -8
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +144 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +773 -334
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +225 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +68 -37
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +102 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +56 -31
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +280 -81
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +135 -60
  181. sglang/srt/speculative/build_eagle_tree.py +8 -9
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
  183. sglang/srt/speculative/eagle_utils.py +92 -57
  184. sglang/srt/speculative/eagle_worker.py +238 -111
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -26,8 +26,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
26
26
 
27
27
  import torch
28
28
 
29
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
29
30
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
30
- from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
31
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
31
32
 
32
33
  if TYPE_CHECKING:
33
34
  from sglang.srt.managers.schedule_batch import Req
@@ -79,11 +80,11 @@ class RadixCache(BasePrefixCache):
79
80
  def __init__(
80
81
  self,
81
82
  req_to_token_pool: ReqToTokenPool,
82
- token_to_kv_pool: BaseTokenToKVPool,
83
+ token_to_kv_pool_allocator: TokenToKVPoolAllocator,
83
84
  disable: bool = False,
84
85
  ):
85
86
  self.req_to_token_pool = req_to_token_pool
86
- self.token_to_kv_pool = token_to_kv_pool
87
+ self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
87
88
  self.disable = disable
88
89
  self.reset()
89
90
 
@@ -111,14 +112,12 @@ class RadixCache(BasePrefixCache):
111
112
  if self.disable:
112
113
  return [], self.root_node
113
114
 
114
- value = []
115
- last_node = [self.root_node]
116
- self._match_prefix_helper(self.root_node, key, value, last_node)
115
+ value, last_node = self._match_prefix_helper(self.root_node, key)
117
116
  if value:
118
117
  value = torch.concat(value)
119
118
  else:
120
119
  value = torch.tensor([], dtype=torch.int32)
121
- return value, last_node[0]
120
+ return value, last_node
122
121
 
123
122
  def insert(self, key: List, value=None):
124
123
  if self.disable:
@@ -139,7 +138,7 @@ class RadixCache(BasePrefixCache):
139
138
  kv_indices = self.req_to_token_pool.req_to_token[
140
139
  req.req_pool_idx, :token_ids_len
141
140
  ]
142
- self.token_to_kv_pool.free(kv_indices)
141
+ self.token_to_kv_pool_allocator.free(kv_indices)
143
142
  self.req_to_token_pool.free(req.req_pool_idx)
144
143
  return
145
144
 
@@ -151,7 +150,9 @@ class RadixCache(BasePrefixCache):
151
150
 
152
151
  # Radix Cache takes one ref in memory pool
153
152
  new_prefix_len = self.insert(token_ids, kv_indices.clone())
154
- self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
153
+ self.token_to_kv_pool_allocator.free(
154
+ kv_indices[len(req.prefix_indices) : new_prefix_len]
155
+ )
155
156
 
156
157
  # Remove req slot release the cache lock
157
158
  self.req_to_token_pool.free(req.req_pool_idx)
@@ -171,7 +172,9 @@ class RadixCache(BasePrefixCache):
171
172
 
172
173
  # Radix Cache takes one ref in memory pool
173
174
  new_prefix_len = self.insert(token_ids, kv_indices.clone())
174
- self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
175
+ self.token_to_kv_pool_allocator.free(
176
+ kv_indices[len(req.prefix_indices) : new_prefix_len]
177
+ )
175
178
 
176
179
  # The prefix indices could be updated, reuse it
177
180
  new_indices, new_last_node = self.match_prefix(token_ids)
@@ -191,7 +194,7 @@ class RadixCache(BasePrefixCache):
191
194
  print(f"#tokens: {self.total_size()}")
192
195
 
193
196
  def total_size(self):
194
- return self._total_size_helper(self.root_node)
197
+ return self._total_size_helper()
195
198
 
196
199
  def evict(self, num_tokens: int, evict_callback: Callable):
197
200
  if self.disable:
@@ -253,24 +256,23 @@ class RadixCache(BasePrefixCache):
253
256
 
254
257
  ##### Internal Helper Functions #####
255
258
 
256
- def _match_prefix_helper(
257
- self, node: TreeNode, key: List, value, last_node: TreeNode
258
- ):
259
+ def _match_prefix_helper(self, node: TreeNode, key: List):
259
260
  node.last_access_time = time.time()
260
- if len(key) == 0:
261
- return
262
-
263
- if key[0] in node.children.keys():
261
+ value = []
262
+ while len(key) > 0 and key[0] in node.children.keys():
264
263
  child = node.children[key[0]]
264
+ child.last_access_time = time.time()
265
265
  prefix_len = _key_match(child.key, key)
266
266
  if prefix_len < len(child.key):
267
267
  new_node = self._split_node(child.key, child, prefix_len)
268
268
  value.append(new_node.value)
269
- last_node[0] = new_node
269
+ node = new_node
270
+ break
270
271
  else:
271
272
  value.append(child.value)
272
- last_node[0] = child
273
- self._match_prefix_helper(child, key[prefix_len:], value, last_node)
273
+ node = child
274
+ key = key[prefix_len:]
275
+ return value, node
274
276
 
275
277
  def _split_node(self, key, child: TreeNode, split_len: int):
276
278
  # new_node -> child
@@ -291,22 +293,18 @@ class RadixCache(BasePrefixCache):
291
293
  if len(key) == 0:
292
294
  return 0
293
295
 
294
- if key[0] in node.children.keys():
295
- child = node.children[key[0]]
296
- prefix_len = _key_match(child.key, key)
296
+ total_prefix_length = 0
297
+ while len(key) > 0 and key[0] in node.children.keys():
298
+ node = node.children[key[0]]
299
+ node.last_access_time = time.time()
300
+ prefix_len = _key_match(node.key, key)
301
+ total_prefix_length += prefix_len
302
+ key = key[prefix_len:]
303
+ value = value[prefix_len:]
297
304
 
298
- if prefix_len == len(child.key):
299
- if prefix_len == len(key):
300
- return prefix_len
301
- else:
302
- key = key[prefix_len:]
303
- value = value[prefix_len:]
304
- return prefix_len + self._insert_helper(child, key, value)
305
-
306
- new_node = self._split_node(child.key, child, prefix_len)
307
- return prefix_len + self._insert_helper(
308
- new_node, key[prefix_len:], value[prefix_len:]
309
- )
305
+ if prefix_len < len(node.key):
306
+ new_node = self._split_node(node.key, node, prefix_len)
307
+ node = new_node
310
308
 
311
309
  if len(key):
312
310
  new_node = TreeNode()
@@ -315,12 +313,21 @@ class RadixCache(BasePrefixCache):
315
313
  new_node.value = value
316
314
  node.children[key[0]] = new_node
317
315
  self.evictable_size_ += len(value)
318
- return 0
316
+ return total_prefix_length
319
317
 
320
318
  def _print_helper(self, node: TreeNode, indent: int):
321
- for _, child in node.children.items():
322
- print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
323
- self._print_helper(child, indent=indent + 2)
319
+ """Prints the radix tree in a human-readable format."""
320
+ stack = [(node, indent)]
321
+ while stack:
322
+ current_node, current_indent = stack.pop()
323
+ print(
324
+ " " * current_indent,
325
+ len(current_node.key),
326
+ current_node.key[:10],
327
+ f"r={current_node.lock_ref}",
328
+ )
329
+ for _, child in current_node.children.items():
330
+ stack.append((child, current_indent + 2))
324
331
 
325
332
  def _delete_leaf(self, node):
326
333
  for k, v in node.parent.children.items():
@@ -329,13 +336,17 @@ class RadixCache(BasePrefixCache):
329
336
  del node.parent.children[k]
330
337
  self.evictable_size_ -= len(node.key)
331
338
 
332
- def _total_size_helper(self, node: TreeNode):
333
- if node.evicted:
334
- return 0
335
- x = len(node.value)
336
- for child in node.children.values():
337
- x += self._total_size_helper(child)
338
- return x
339
+ def _total_size_helper(self):
340
+ total_size = 0
341
+ stack = [self.root_node]
342
+ while stack:
343
+ current_node = stack.pop()
344
+ total_size += len(current_node.value)
345
+ for child in current_node.children.values():
346
+ if child.evicted:
347
+ continue
348
+ stack.append(child)
349
+ return total_size
339
350
 
340
351
  def _collect_leaves(self):
341
352
  ret_list = []
@@ -13,6 +13,7 @@
13
13
  # ==============================================================================
14
14
  """Utilities for Prometheus Metrics Collection."""
15
15
 
16
+ import time
16
17
  from dataclasses import dataclass
17
18
  from typing import Dict, Union
18
19
 
@@ -35,19 +36,20 @@ class SchedulerMetricsCollector:
35
36
  from prometheus_client import Gauge
36
37
 
37
38
  self.labels = labels
39
+ self.last_log_time = time.time()
38
40
 
39
41
  self.num_running_reqs = Gauge(
40
42
  name="sglang:num_running_reqs",
41
43
  documentation="The number of running requests.",
42
44
  labelnames=labels.keys(),
43
- multiprocess_mode="sum",
45
+ multiprocess_mode="mostrecent",
44
46
  )
45
47
 
46
48
  self.num_used_tokens = Gauge(
47
49
  name="sglang:num_used_tokens",
48
50
  documentation="The number of used tokens.",
49
51
  labelnames=labels.keys(),
50
- multiprocess_mode="sum",
52
+ multiprocess_mode="mostrecent",
51
53
  )
52
54
 
53
55
  self.token_usage = Gauge(
@@ -61,14 +63,14 @@ class SchedulerMetricsCollector:
61
63
  name="sglang:gen_throughput",
62
64
  documentation="The generation throughput (token/s).",
63
65
  labelnames=labels.keys(),
64
- multiprocess_mode="sum",
66
+ multiprocess_mode="mostrecent",
65
67
  )
66
68
 
67
69
  self.num_queue_reqs = Gauge(
68
70
  name="sglang:num_queue_reqs",
69
71
  documentation="The number of requests in the waiting queue.",
70
72
  labelnames=labels.keys(),
71
- multiprocess_mode="sum",
73
+ multiprocess_mode="mostrecent",
72
74
  )
73
75
 
74
76
  self.cache_hit_rate = Gauge(
@@ -97,6 +99,7 @@ class SchedulerMetricsCollector:
97
99
  self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
98
100
  self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
99
101
  self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
102
+ self.last_log_time = time.time()
100
103
 
101
104
 
102
105
  class TokenizerMetricsCollector:
@@ -118,6 +121,12 @@ class TokenizerMetricsCollector:
118
121
  labelnames=labels.keys(),
119
122
  )
120
123
 
124
+ self.cached_tokens_total = Counter(
125
+ name="sglang:cached_tokens_total",
126
+ documentation="Number of cached prompt tokens.",
127
+ labelnames=labels.keys(),
128
+ )
129
+
121
130
  self.num_requests_total = Counter(
122
131
  name="sglang:num_requests_total",
123
132
  documentation="Number of requests processed.",
@@ -130,12 +139,15 @@ class TokenizerMetricsCollector:
130
139
  labelnames=labels.keys(),
131
140
  buckets=[
132
141
  0.1,
133
- 0.25,
142
+ 0.3,
134
143
  0.5,
135
- 0.75,
144
+ 0.7,
145
+ 0.9,
136
146
  1,
137
147
  2,
138
- 5,
148
+ 4,
149
+ 6,
150
+ 8,
139
151
  10,
140
152
  20,
141
153
  40,
@@ -151,24 +163,56 @@ class TokenizerMetricsCollector:
151
163
  documentation="Histogram of time per output token in seconds.",
152
164
  labelnames=labels.keys(),
153
165
  buckets=[
166
+ 0.002,
154
167
  0.005,
155
- 0.01,
168
+ 0.010,
169
+ 0.020,
170
+ 0.030,
171
+ 0.040,
172
+ 0.050,
173
+ 0.060,
174
+ 0.070,
175
+ 0.080,
176
+ 0.090,
177
+ 0.100,
178
+ 0.150,
179
+ 0.200,
180
+ 0.300,
181
+ 0.400,
182
+ 0.600,
183
+ 0.800,
184
+ 1.000,
185
+ 2.000,
186
+ ],
187
+ )
188
+
189
+ self.histogram_inter_token_latency_seconds = Histogram(
190
+ name="sglang:inter_token_latency_seconds",
191
+ documentation="Histogram of inter-token latency in seconds.",
192
+ labelnames=labels.keys(),
193
+ buckets=[
194
+ 0.002,
195
+ 0.004,
196
+ 0.006,
197
+ 0.008,
198
+ 0.010,
156
199
  0.015,
157
- 0.02,
200
+ 0.020,
158
201
  0.025,
159
- 0.03,
160
- 0.04,
161
- 0.05,
202
+ 0.030,
203
+ 0.035,
204
+ 0.040,
205
+ 0.050,
162
206
  0.075,
163
- 0.1,
164
- 0.15,
165
- 0.2,
166
- 0.3,
167
- 0.4,
168
- 0.5,
169
- 0.75,
170
- 1.0,
171
- 2.5,
207
+ 0.100,
208
+ 0.150,
209
+ 0.200,
210
+ 0.300,
211
+ 0.400,
212
+ 0.500,
213
+ 0.750,
214
+ 1.000,
215
+ 2.000,
172
216
  ],
173
217
  )
174
218
 
@@ -178,8 +222,9 @@ class TokenizerMetricsCollector:
178
222
  labelnames=labels.keys(),
179
223
  buckets=[
180
224
  0.1,
181
- 0.25,
182
- 0.5,
225
+ 0.2,
226
+ 0.4,
227
+ 0.8,
183
228
  1,
184
229
  2,
185
230
  5,
@@ -188,28 +233,49 @@ class TokenizerMetricsCollector:
188
233
  40,
189
234
  60,
190
235
  80,
191
- 120,
192
- 160,
236
+ 100,
237
+ 150,
238
+ 200,
239
+ 250,
240
+ 300,
241
+ 350,
242
+ 500,
243
+ 1000,
193
244
  ],
194
245
  )
195
246
 
196
247
  def _log_histogram(self, histogram, data: Union[int, float]) -> None:
197
248
  histogram.labels(**self.labels).observe(data)
198
249
 
199
- def _log_counter(self, counter, data: Union[int, float]) -> None:
200
- # Convenience function for logging to counter.
201
- counter.labels(**self.labels).inc(data)
202
-
203
- def observe_one_finished_request(self, prompt_tokens: int, generation_tokens: int):
250
+ def observe_one_finished_request(
251
+ self,
252
+ prompt_tokens: int,
253
+ generation_tokens: int,
254
+ cached_tokens: int,
255
+ e2e_latency: float,
256
+ ):
204
257
  self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
205
258
  self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
259
+ self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
206
260
  self.num_requests_total.labels(**self.labels).inc(1)
261
+ self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
262
+ if generation_tokens >= 1:
263
+ self.histogram_time_per_output_token.labels(**self.labels).observe(
264
+ e2e_latency / generation_tokens
265
+ )
266
+
267
+ def observe_time_to_first_token(self, value: float):
268
+ self.histogram_time_to_first_token.labels(**self.labels).observe(value)
207
269
 
208
- def observe_time_to_first_token(self, value: Union[float, int]):
209
- self._log_histogram(self.histogram_time_to_first_token, value)
270
+ def observe_inter_token_latency(self, internval: float, num_new_tokens: int):
271
+ adjusted_interval = internval / num_new_tokens
210
272
 
211
- def observe_time_per_output_token(self, value: Union[float, int]):
212
- self._log_histogram(self.histogram_time_per_output_token, value)
273
+ # A faster version of the Histogram::observe which observes multiple values at the same time.
274
+ # reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639
275
+ his = self.histogram_inter_token_latency_seconds.labels(**self.labels)
276
+ his._sum.inc(internval)
213
277
 
214
- def observe_e2e_request_latency(self, value: Union[float, int]):
215
- self._log_histogram(self.histogram_e2e_request_latency, value)
278
+ for i, bound in enumerate(his._upper_bounds):
279
+ if adjusted_interval <= bound:
280
+ his._buckets[i].inc(num_new_tokens)
281
+ break
@@ -109,14 +109,22 @@ def set_torch_compile_config():
109
109
  def get_batch_sizes_to_capture(model_runner: ModelRunner):
110
110
  server_args = model_runner.server_args
111
111
  capture_bs = server_args.cuda_graph_bs
112
+
112
113
  if capture_bs is None:
113
- if server_args.disable_cuda_graph_padding:
114
- capture_bs = list(range(1, 33)) + [64, 128]
114
+ if server_args.speculative_algorithm is None:
115
+ if server_args.disable_cuda_graph_padding:
116
+ capture_bs = list(range(1, 33)) + [64, 96, 128, 160]
117
+ else:
118
+ capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
115
119
  else:
116
- capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
120
+ capture_bs = list(range(1, 33))
121
+
122
+ if is_hip_:
123
+ capture_bs += [i * 8 for i in range(21, 33)]
124
+
117
125
  if max(capture_bs) > model_runner.req_to_token_pool.size:
118
126
  # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
119
- # is very samll. We add more values here to make sure we capture the maximum bs.
127
+ # is very small. We add more values here to make sure we capture the maximum bs.
120
128
  capture_bs = list(
121
129
  sorted(
122
130
  set(
@@ -126,14 +134,13 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
126
134
  )
127
135
  )
128
136
  )
137
+
129
138
  capture_bs = [
130
139
  bs
131
140
  for bs in capture_bs
132
141
  if bs <= model_runner.req_to_token_pool.size
133
142
  and bs <= server_args.cuda_graph_max_bs
134
143
  ]
135
- if is_hip_:
136
- capture_bs += [i * 8 for i in range(21, 33)]
137
144
  compile_bs = (
138
145
  [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
139
146
  if server_args.enable_torch_compile
@@ -173,6 +180,7 @@ class CudaGraphRunner:
173
180
  # Batch sizes to capture
174
181
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
175
182
  self.capture_forward_mode = ForwardMode.DECODE
183
+ self.capture_hidden_mode = CaptureHiddenMode.NULL
176
184
  self.num_tokens_per_bs = 1
177
185
  if model_runner.spec_algorithm.is_eagle():
178
186
  if self.model_runner.is_draft_worker:
@@ -192,6 +200,9 @@ class CudaGraphRunner:
192
200
  )
193
201
  # FIXME(lsyin): leave it here for now, I don't know whether it is necessary
194
202
  self.encoder_len_fill_value = 0
203
+ self.seq_lens_cpu = torch.full(
204
+ (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
205
+ )
195
206
 
196
207
  if self.enable_torch_compile:
197
208
  set_torch_compile_config()
@@ -230,6 +241,9 @@ class CudaGraphRunner:
230
241
  ),
231
242
  dtype=self.model_runner.dtype,
232
243
  )
244
+ self.global_num_tokens_gpu = torch.zeros(
245
+ (self.dp_size,), dtype=torch.int32
246
+ )
233
247
 
234
248
  # Capture
235
249
  try:
@@ -258,9 +272,9 @@ class CudaGraphRunner:
258
272
 
259
273
  def can_run(self, forward_batch: ForwardBatch):
260
274
  if self.enable_dp_attention:
261
- min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
262
- forward_batch.global_num_tokens
263
- )
275
+ min_num_tokens, max_num_tokens = min(
276
+ forward_batch.global_num_tokens_cpu
277
+ ), max(forward_batch.global_num_tokens_cpu)
264
278
  is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
265
279
  (min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
266
280
  if self.disable_padding
@@ -333,6 +347,10 @@ class CudaGraphRunner:
333
347
  gathered_buffer = None
334
348
 
335
349
  spec_info = self.get_spec_info(num_tokens)
350
+ if self.capture_hidden_mode != CaptureHiddenMode.FULL:
351
+ self.capture_hidden_mode = (
352
+ spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
353
+ )
336
354
 
337
355
  forward_batch = ForwardBatch(
338
356
  forward_mode=self.capture_forward_mode,
@@ -348,20 +366,12 @@ class CudaGraphRunner:
348
366
  encoder_lens=encoder_lens,
349
367
  return_logprob=False,
350
368
  positions=positions,
351
- global_num_tokens=global_num_tokens,
369
+ global_num_tokens_cpu=global_num_tokens,
352
370
  gathered_buffer=gathered_buffer,
353
371
  mrope_positions=mrope_positions,
354
372
  spec_algorithm=self.model_runner.spec_algorithm,
355
373
  spec_info=spec_info,
356
- capture_hidden_mode=(
357
- CaptureHiddenMode.FULL
358
- if self.model_runner.server_args.return_hidden_states
359
- else (
360
- spec_info.capture_hidden_mode
361
- if spec_info
362
- else CaptureHiddenMode.NULL
363
- )
364
- ),
374
+ capture_hidden_mode=self.capture_hidden_mode,
365
375
  )
366
376
 
367
377
  # Attention backend
@@ -386,31 +396,41 @@ class CudaGraphRunner:
386
396
 
387
397
  run_once()
388
398
 
389
- torch.cuda.synchronize()
390
- self.model_runner.tp_group.barrier()
391
-
392
- torch.cuda.synchronize()
393
- self.model_runner.tp_group.barrier()
394
-
395
399
  global global_graph_memory_pool
396
400
  with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
397
401
  out = run_once()
398
402
 
399
- torch.cuda.synchronize()
400
- self.model_runner.tp_group.barrier()
401
-
402
403
  global_graph_memory_pool = graph.pool()
403
404
  return graph, out
404
405
 
405
- def replay(self, forward_batch: ForwardBatch):
406
- assert forward_batch.out_cache_loc is not None
406
+ def recapture_if_needed(self, forward_batch: ForwardBatch):
407
+ # If the capture_hidden_mode changes, we need to recapture the graph
408
+ hidden_mode_from_spec_info = getattr(
409
+ forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
410
+ )
411
+ if (
412
+ forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
413
+ and self.capture_hidden_mode != CaptureHiddenMode.FULL
414
+ ):
415
+ self.capture_hidden_mode = CaptureHiddenMode.FULL
416
+ self.capture()
417
+ elif (
418
+ forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL
419
+ and self.capture_hidden_mode != hidden_mode_from_spec_info
420
+ ):
421
+ self.capture_hidden_mode = hidden_mode_from_spec_info
422
+ self.capture()
423
+
424
+ def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
425
+ self.recapture_if_needed(forward_batch)
426
+
407
427
  raw_bs = forward_batch.batch_size
408
428
  raw_num_token = raw_bs * self.num_tokens_per_bs
409
429
 
410
430
  # Pad
411
431
  if self.enable_dp_attention:
412
432
  index = bisect.bisect_left(
413
- self.capture_bs, max(forward_batch.global_num_tokens)
433
+ self.capture_bs, max(forward_batch.global_num_tokens_cpu)
414
434
  )
415
435
  else:
416
436
  index = bisect.bisect_left(self.capture_bs, raw_bs)
@@ -425,6 +445,10 @@ class CudaGraphRunner:
425
445
  self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
426
446
  self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
427
447
  self.positions[:raw_num_token].copy_(forward_batch.positions)
448
+ if forward_batch.decode_seq_lens_cpu is not None:
449
+ if bs != raw_bs:
450
+ self.seq_lens_cpu.fill_(1)
451
+ self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
428
452
 
429
453
  if self.is_encoder_decoder:
430
454
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
@@ -443,6 +467,7 @@ class CudaGraphRunner:
443
467
  self.encoder_lens,
444
468
  forward_batch.forward_mode,
445
469
  forward_batch.spec_info,
470
+ seq_lens_cpu=self.seq_lens_cpu,
446
471
  )
447
472
 
448
473
  # Replay