sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +208 -295
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  64. sglang/srt/layers/moe/topk.py +13 -4
  65. sglang/srt/layers/quantization/__init__.py +111 -7
  66. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  69. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  71. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  72. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  73. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  75. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  80. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  82. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  83. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  87. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  89. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  91. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  93. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  95. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  96. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  97. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  98. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  100. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  102. sglang/srt/layers/quantization/fp8.py +69 -28
  103. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  104. sglang/srt/layers/quantization/gptq.py +416 -0
  105. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  106. sglang/srt/layers/quantization/int8_utils.py +73 -0
  107. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  108. sglang/srt/layers/radix_attention.py +1 -0
  109. sglang/srt/layers/rotary_embedding.py +0 -1
  110. sglang/srt/layers/sampler.py +76 -31
  111. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  112. sglang/srt/lora/lora.py +17 -1
  113. sglang/srt/lora/lora_config.py +5 -0
  114. sglang/srt/lora/lora_manager.py +1 -3
  115. sglang/srt/managers/cache_controller.py +193 -62
  116. sglang/srt/managers/configure_logging.py +2 -1
  117. sglang/srt/managers/data_parallel_controller.py +6 -2
  118. sglang/srt/managers/detokenizer_manager.py +124 -102
  119. sglang/srt/managers/image_processor.py +2 -1
  120. sglang/srt/managers/io_struct.py +143 -6
  121. sglang/srt/managers/schedule_batch.py +238 -197
  122. sglang/srt/managers/schedule_policy.py +29 -29
  123. sglang/srt/managers/scheduler.py +681 -259
  124. sglang/srt/managers/session_controller.py +6 -2
  125. sglang/srt/managers/tokenizer_manager.py +224 -68
  126. sglang/srt/managers/tp_worker.py +15 -4
  127. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  128. sglang/srt/mem_cache/chunk_cache.py +18 -11
  129. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  130. sglang/srt/mem_cache/memory_pool.py +44 -18
  131. sglang/srt/mem_cache/radix_cache.py +58 -47
  132. sglang/srt/metrics/collector.py +94 -36
  133. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  134. sglang/srt/model_executor/forward_batch_info.py +49 -16
  135. sglang/srt/model_executor/model_runner.py +209 -28
  136. sglang/srt/model_loader/loader.py +3 -3
  137. sglang/srt/model_loader/weight_utils.py +36 -14
  138. sglang/srt/models/baichuan.py +31 -6
  139. sglang/srt/models/chatglm.py +39 -7
  140. sglang/srt/models/commandr.py +29 -5
  141. sglang/srt/models/dbrx.py +31 -5
  142. sglang/srt/models/deepseek.py +43 -6
  143. sglang/srt/models/deepseek_nextn.py +32 -19
  144. sglang/srt/models/deepseek_v2.py +265 -29
  145. sglang/srt/models/exaone.py +19 -9
  146. sglang/srt/models/gemma.py +22 -8
  147. sglang/srt/models/gemma2.py +25 -12
  148. sglang/srt/models/gemma2_reward.py +5 -1
  149. sglang/srt/models/gpt2.py +28 -13
  150. sglang/srt/models/gpt_bigcode.py +27 -5
  151. sglang/srt/models/granite.py +21 -9
  152. sglang/srt/models/grok.py +21 -4
  153. sglang/srt/models/internlm2.py +36 -6
  154. sglang/srt/models/internlm2_reward.py +5 -1
  155. sglang/srt/models/llama.py +26 -9
  156. sglang/srt/models/llama_classification.py +5 -1
  157. sglang/srt/models/llama_eagle.py +17 -4
  158. sglang/srt/models/llama_embedding.py +5 -1
  159. sglang/srt/models/llama_reward.py +7 -2
  160. sglang/srt/models/llava.py +19 -3
  161. sglang/srt/models/llavavid.py +10 -1
  162. sglang/srt/models/minicpm.py +26 -2
  163. sglang/srt/models/minicpm3.py +39 -3
  164. sglang/srt/models/minicpmv.py +45 -14
  165. sglang/srt/models/mixtral.py +20 -9
  166. sglang/srt/models/mixtral_quant.py +50 -8
  167. sglang/srt/models/mllama.py +57 -11
  168. sglang/srt/models/olmo.py +34 -6
  169. sglang/srt/models/olmo2.py +34 -13
  170. sglang/srt/models/olmoe.py +26 -4
  171. sglang/srt/models/phi3_small.py +29 -10
  172. sglang/srt/models/qwen.py +26 -3
  173. sglang/srt/models/qwen2.py +26 -4
  174. sglang/srt/models/qwen2_5_vl.py +46 -8
  175. sglang/srt/models/qwen2_eagle.py +17 -5
  176. sglang/srt/models/qwen2_moe.py +44 -6
  177. sglang/srt/models/qwen2_rm.py +78 -0
  178. sglang/srt/models/qwen2_vl.py +39 -8
  179. sglang/srt/models/stablelm.py +32 -5
  180. sglang/srt/models/torch_native_llama.py +5 -2
  181. sglang/srt/models/xverse.py +21 -9
  182. sglang/srt/models/xverse_moe.py +45 -7
  183. sglang/srt/models/yivl.py +2 -1
  184. sglang/srt/openai_api/adapter.py +109 -24
  185. sglang/srt/openai_api/protocol.py +17 -1
  186. sglang/srt/reasoning_parser.py +154 -0
  187. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  188. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  189. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  190. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  191. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  192. sglang/srt/sampling/sampling_batch_info.py +79 -157
  193. sglang/srt/sampling/sampling_params.py +16 -13
  194. sglang/srt/server_args.py +136 -52
  195. sglang/srt/speculative/build_eagle_tree.py +2 -8
  196. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  197. sglang/srt/speculative/eagle_utils.py +92 -58
  198. sglang/srt/speculative/eagle_worker.py +186 -94
  199. sglang/srt/speculative/spec_info.py +1 -13
  200. sglang/srt/utils.py +43 -17
  201. sglang/srt/warmup.py +47 -0
  202. sglang/test/few_shot_gsm8k.py +4 -1
  203. sglang/test/runners.py +389 -126
  204. sglang/test/send_one.py +88 -0
  205. sglang/test/test_block_fp8_ep.py +361 -0
  206. sglang/test/test_programs.py +1 -1
  207. sglang/test/test_utils.py +138 -84
  208. sglang/utils.py +50 -60
  209. sglang/version.py +1 -1
  210. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  211. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
  212. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  213. sglang/bench_latency.py +0 -1
  214. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  215. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  216. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  217. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  218. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  219. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -26,8 +26,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
26
26
 
27
27
  import torch
28
28
 
29
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
29
30
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
30
- from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
31
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
31
32
 
32
33
  if TYPE_CHECKING:
33
34
  from sglang.srt.managers.schedule_batch import Req
@@ -79,11 +80,11 @@ class RadixCache(BasePrefixCache):
79
80
  def __init__(
80
81
  self,
81
82
  req_to_token_pool: ReqToTokenPool,
82
- token_to_kv_pool: BaseTokenToKVPool,
83
+ token_to_kv_pool_allocator: TokenToKVPoolAllocator,
83
84
  disable: bool = False,
84
85
  ):
85
86
  self.req_to_token_pool = req_to_token_pool
86
- self.token_to_kv_pool = token_to_kv_pool
87
+ self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
87
88
  self.disable = disable
88
89
  self.reset()
89
90
 
@@ -111,14 +112,12 @@ class RadixCache(BasePrefixCache):
111
112
  if self.disable:
112
113
  return [], self.root_node
113
114
 
114
- value = []
115
- last_node = [self.root_node]
116
- self._match_prefix_helper(self.root_node, key, value, last_node)
115
+ value, last_node = self._match_prefix_helper(self.root_node, key)
117
116
  if value:
118
117
  value = torch.concat(value)
119
118
  else:
120
119
  value = torch.tensor([], dtype=torch.int32)
121
- return value, last_node[0]
120
+ return value, last_node
122
121
 
123
122
  def insert(self, key: List, value=None):
124
123
  if self.disable:
@@ -139,7 +138,7 @@ class RadixCache(BasePrefixCache):
139
138
  kv_indices = self.req_to_token_pool.req_to_token[
140
139
  req.req_pool_idx, :token_ids_len
141
140
  ]
142
- self.token_to_kv_pool.free(kv_indices)
141
+ self.token_to_kv_pool_allocator.free(kv_indices)
143
142
  self.req_to_token_pool.free(req.req_pool_idx)
144
143
  return
145
144
 
@@ -151,7 +150,9 @@ class RadixCache(BasePrefixCache):
151
150
 
152
151
  # Radix Cache takes one ref in memory pool
153
152
  new_prefix_len = self.insert(token_ids, kv_indices.clone())
154
- self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
153
+ self.token_to_kv_pool_allocator.free(
154
+ kv_indices[len(req.prefix_indices) : new_prefix_len]
155
+ )
155
156
 
156
157
  # Remove req slot release the cache lock
157
158
  self.req_to_token_pool.free(req.req_pool_idx)
@@ -171,7 +172,9 @@ class RadixCache(BasePrefixCache):
171
172
 
172
173
  # Radix Cache takes one ref in memory pool
173
174
  new_prefix_len = self.insert(token_ids, kv_indices.clone())
174
- self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
175
+ self.token_to_kv_pool_allocator.free(
176
+ kv_indices[len(req.prefix_indices) : new_prefix_len]
177
+ )
175
178
 
176
179
  # The prefix indices could be updated, reuse it
177
180
  new_indices, new_last_node = self.match_prefix(token_ids)
@@ -191,7 +194,7 @@ class RadixCache(BasePrefixCache):
191
194
  print(f"#tokens: {self.total_size()}")
192
195
 
193
196
  def total_size(self):
194
- return self._total_size_helper(self.root_node)
197
+ return self._total_size_helper()
195
198
 
196
199
  def evict(self, num_tokens: int, evict_callback: Callable):
197
200
  if self.disable:
@@ -253,24 +256,23 @@ class RadixCache(BasePrefixCache):
253
256
 
254
257
  ##### Internal Helper Functions #####
255
258
 
256
- def _match_prefix_helper(
257
- self, node: TreeNode, key: List, value, last_node: TreeNode
258
- ):
259
+ def _match_prefix_helper(self, node: TreeNode, key: List):
259
260
  node.last_access_time = time.time()
260
- if len(key) == 0:
261
- return
262
-
263
- if key[0] in node.children.keys():
261
+ value = []
262
+ while len(key) > 0 and key[0] in node.children.keys():
264
263
  child = node.children[key[0]]
264
+ child.last_access_time = time.time()
265
265
  prefix_len = _key_match(child.key, key)
266
266
  if prefix_len < len(child.key):
267
267
  new_node = self._split_node(child.key, child, prefix_len)
268
268
  value.append(new_node.value)
269
- last_node[0] = new_node
269
+ node = new_node
270
+ break
270
271
  else:
271
272
  value.append(child.value)
272
- last_node[0] = child
273
- self._match_prefix_helper(child, key[prefix_len:], value, last_node)
273
+ node = child
274
+ key = key[prefix_len:]
275
+ return value, node
274
276
 
275
277
  def _split_node(self, key, child: TreeNode, split_len: int):
276
278
  # new_node -> child
@@ -291,22 +293,18 @@ class RadixCache(BasePrefixCache):
291
293
  if len(key) == 0:
292
294
  return 0
293
295
 
294
- if key[0] in node.children.keys():
295
- child = node.children[key[0]]
296
- prefix_len = _key_match(child.key, key)
296
+ total_prefix_length = 0
297
+ while len(key) > 0 and key[0] in node.children.keys():
298
+ node = node.children[key[0]]
299
+ node.last_access_time = time.time()
300
+ prefix_len = _key_match(node.key, key)
301
+ total_prefix_length += prefix_len
302
+ key = key[prefix_len:]
303
+ value = value[prefix_len:]
297
304
 
298
- if prefix_len == len(child.key):
299
- if prefix_len == len(key):
300
- return prefix_len
301
- else:
302
- key = key[prefix_len:]
303
- value = value[prefix_len:]
304
- return prefix_len + self._insert_helper(child, key, value)
305
-
306
- new_node = self._split_node(child.key, child, prefix_len)
307
- return prefix_len + self._insert_helper(
308
- new_node, key[prefix_len:], value[prefix_len:]
309
- )
305
+ if prefix_len < len(node.key):
306
+ new_node = self._split_node(node.key, node, prefix_len)
307
+ node = new_node
310
308
 
311
309
  if len(key):
312
310
  new_node = TreeNode()
@@ -315,12 +313,21 @@ class RadixCache(BasePrefixCache):
315
313
  new_node.value = value
316
314
  node.children[key[0]] = new_node
317
315
  self.evictable_size_ += len(value)
318
- return 0
316
+ return total_prefix_length
319
317
 
320
318
  def _print_helper(self, node: TreeNode, indent: int):
321
- for _, child in node.children.items():
322
- print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
323
- self._print_helper(child, indent=indent + 2)
319
+ """Prints the radix tree in a human-readable format."""
320
+ stack = [(node, indent)]
321
+ while stack:
322
+ current_node, current_indent = stack.pop()
323
+ print(
324
+ " " * current_indent,
325
+ len(current_node.key),
326
+ current_node.key[:10],
327
+ f"r={current_node.lock_ref}",
328
+ )
329
+ for _, child in current_node.children.items():
330
+ stack.append((child, current_indent + 2))
324
331
 
325
332
  def _delete_leaf(self, node):
326
333
  for k, v in node.parent.children.items():
@@ -329,13 +336,17 @@ class RadixCache(BasePrefixCache):
329
336
  del node.parent.children[k]
330
337
  self.evictable_size_ -= len(node.key)
331
338
 
332
- def _total_size_helper(self, node: TreeNode):
333
- if node.evicted:
334
- return 0
335
- x = len(node.value)
336
- for child in node.children.values():
337
- x += self._total_size_helper(child)
338
- return x
339
+ def _total_size_helper(self):
340
+ total_size = 0
341
+ stack = [self.root_node]
342
+ while stack:
343
+ current_node = stack.pop()
344
+ total_size += len(current_node.value)
345
+ for child in current_node.children.values():
346
+ if child.evicted:
347
+ continue
348
+ stack.append(child)
349
+ return total_size
339
350
 
340
351
  def _collect_leaves(self):
341
352
  ret_list = []
@@ -13,6 +13,7 @@
13
13
  # ==============================================================================
14
14
  """Utilities for Prometheus Metrics Collection."""
15
15
 
16
+ import time
16
17
  from dataclasses import dataclass
17
18
  from typing import Dict, Union
18
19
 
@@ -35,19 +36,20 @@ class SchedulerMetricsCollector:
35
36
  from prometheus_client import Gauge
36
37
 
37
38
  self.labels = labels
39
+ self.last_log_time = time.time()
38
40
 
39
41
  self.num_running_reqs = Gauge(
40
42
  name="sglang:num_running_reqs",
41
43
  documentation="The number of running requests.",
42
44
  labelnames=labels.keys(),
43
- multiprocess_mode="sum",
45
+ multiprocess_mode="mostrecent",
44
46
  )
45
47
 
46
48
  self.num_used_tokens = Gauge(
47
49
  name="sglang:num_used_tokens",
48
50
  documentation="The number of used tokens.",
49
51
  labelnames=labels.keys(),
50
- multiprocess_mode="sum",
52
+ multiprocess_mode="mostrecent",
51
53
  )
52
54
 
53
55
  self.token_usage = Gauge(
@@ -61,14 +63,14 @@ class SchedulerMetricsCollector:
61
63
  name="sglang:gen_throughput",
62
64
  documentation="The generation throughput (token/s).",
63
65
  labelnames=labels.keys(),
64
- multiprocess_mode="sum",
66
+ multiprocess_mode="mostrecent",
65
67
  )
66
68
 
67
69
  self.num_queue_reqs = Gauge(
68
70
  name="sglang:num_queue_reqs",
69
71
  documentation="The number of requests in the waiting queue.",
70
72
  labelnames=labels.keys(),
71
- multiprocess_mode="sum",
73
+ multiprocess_mode="mostrecent",
72
74
  )
73
75
 
74
76
  self.cache_hit_rate = Gauge(
@@ -97,6 +99,7 @@ class SchedulerMetricsCollector:
97
99
  self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
98
100
  self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
99
101
  self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
102
+ self.last_log_time = time.time()
100
103
 
101
104
 
102
105
  class TokenizerMetricsCollector:
@@ -130,12 +133,15 @@ class TokenizerMetricsCollector:
130
133
  labelnames=labels.keys(),
131
134
  buckets=[
132
135
  0.1,
133
- 0.25,
136
+ 0.3,
134
137
  0.5,
135
- 0.75,
138
+ 0.7,
139
+ 0.9,
136
140
  1,
137
141
  2,
138
- 5,
142
+ 4,
143
+ 6,
144
+ 8,
139
145
  10,
140
146
  20,
141
147
  40,
@@ -151,24 +157,56 @@ class TokenizerMetricsCollector:
151
157
  documentation="Histogram of time per output token in seconds.",
152
158
  labelnames=labels.keys(),
153
159
  buckets=[
160
+ 0.002,
154
161
  0.005,
155
- 0.01,
162
+ 0.010,
163
+ 0.020,
164
+ 0.030,
165
+ 0.040,
166
+ 0.050,
167
+ 0.060,
168
+ 0.070,
169
+ 0.080,
170
+ 0.090,
171
+ 0.100,
172
+ 0.150,
173
+ 0.200,
174
+ 0.300,
175
+ 0.400,
176
+ 0.600,
177
+ 0.800,
178
+ 1.000,
179
+ 2.000,
180
+ ],
181
+ )
182
+
183
+ self.histogram_inter_token_latency_seconds = Histogram(
184
+ name="sglang:inter_token_latency_seconds",
185
+ documentation="Histogram of inter-token latency in seconds.",
186
+ labelnames=labels.keys(),
187
+ buckets=[
188
+ 0.002,
189
+ 0.004,
190
+ 0.006,
191
+ 0.008,
192
+ 0.010,
156
193
  0.015,
157
- 0.02,
194
+ 0.020,
158
195
  0.025,
159
- 0.03,
160
- 0.04,
161
- 0.05,
196
+ 0.030,
197
+ 0.035,
198
+ 0.040,
199
+ 0.050,
162
200
  0.075,
163
- 0.1,
164
- 0.15,
165
- 0.2,
166
- 0.3,
167
- 0.4,
168
- 0.5,
169
- 0.75,
170
- 1.0,
171
- 2.5,
201
+ 0.100,
202
+ 0.150,
203
+ 0.200,
204
+ 0.300,
205
+ 0.400,
206
+ 0.500,
207
+ 0.750,
208
+ 1.000,
209
+ 2.000,
172
210
  ],
173
211
  )
174
212
 
@@ -178,8 +216,9 @@ class TokenizerMetricsCollector:
178
216
  labelnames=labels.keys(),
179
217
  buckets=[
180
218
  0.1,
181
- 0.25,
182
- 0.5,
219
+ 0.2,
220
+ 0.4,
221
+ 0.8,
183
222
  1,
184
223
  2,
185
224
  5,
@@ -188,28 +227,47 @@ class TokenizerMetricsCollector:
188
227
  40,
189
228
  60,
190
229
  80,
191
- 120,
192
- 160,
230
+ 100,
231
+ 150,
232
+ 200,
233
+ 250,
234
+ 300,
235
+ 350,
236
+ 500,
237
+ 1000,
193
238
  ],
194
239
  )
195
240
 
196
241
  def _log_histogram(self, histogram, data: Union[int, float]) -> None:
197
242
  histogram.labels(**self.labels).observe(data)
198
243
 
199
- def _log_counter(self, counter, data: Union[int, float]) -> None:
200
- # Convenience function for logging to counter.
201
- counter.labels(**self.labels).inc(data)
202
-
203
- def observe_one_finished_request(self, prompt_tokens: int, generation_tokens: int):
244
+ def observe_one_finished_request(
245
+ self,
246
+ prompt_tokens: int,
247
+ generation_tokens: int,
248
+ e2e_latency: float,
249
+ ):
204
250
  self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
205
251
  self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
206
252
  self.num_requests_total.labels(**self.labels).inc(1)
253
+ self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
254
+ if generation_tokens >= 1:
255
+ self.histogram_time_per_output_token.labels(**self.labels).observe(
256
+ e2e_latency / generation_tokens
257
+ )
258
+
259
+ def observe_time_to_first_token(self, value: float):
260
+ self.histogram_time_to_first_token.labels(**self.labels).observe(value)
207
261
 
208
- def observe_time_to_first_token(self, value: Union[float, int]):
209
- self._log_histogram(self.histogram_time_to_first_token, value)
262
+ def observe_inter_token_latency(self, internval: float, num_new_tokens: int):
263
+ adjusted_interval = internval / num_new_tokens
210
264
 
211
- def observe_time_per_output_token(self, value: Union[float, int]):
212
- self._log_histogram(self.histogram_time_per_output_token, value)
265
+ # A faster version of the Histogram::observe which observes multiple values at the same time.
266
+ # reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639
267
+ his = self.histogram_inter_token_latency_seconds.labels(**self.labels)
268
+ his._sum.inc(internval)
213
269
 
214
- def observe_e2e_request_latency(self, value: Union[float, int]):
215
- self._log_histogram(self.histogram_e2e_request_latency, value)
270
+ for i, bound in enumerate(his._upper_bounds):
271
+ if adjusted_interval <= bound:
272
+ his._buckets[i].inc(num_new_tokens)
273
+ break
@@ -109,14 +109,22 @@ def set_torch_compile_config():
109
109
  def get_batch_sizes_to_capture(model_runner: ModelRunner):
110
110
  server_args = model_runner.server_args
111
111
  capture_bs = server_args.cuda_graph_bs
112
+
112
113
  if capture_bs is None:
113
- if server_args.disable_cuda_graph_padding:
114
- capture_bs = list(range(1, 33)) + [64, 128]
114
+ if server_args.speculative_algorithm is None:
115
+ if server_args.disable_cuda_graph_padding:
116
+ capture_bs = list(range(1, 33)) + [64, 96, 128, 160]
117
+ else:
118
+ capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
115
119
  else:
116
- capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
120
+ capture_bs = list(range(1, 33))
121
+
122
+ if is_hip_:
123
+ capture_bs += [i * 8 for i in range(21, 33)]
124
+
117
125
  if max(capture_bs) > model_runner.req_to_token_pool.size:
118
126
  # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
119
- # is very samll. We add more values here to make sure we capture the maximum bs.
127
+ # is very small. We add more values here to make sure we capture the maximum bs.
120
128
  capture_bs = list(
121
129
  sorted(
122
130
  set(
@@ -126,14 +134,13 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
126
134
  )
127
135
  )
128
136
  )
137
+
129
138
  capture_bs = [
130
139
  bs
131
140
  for bs in capture_bs
132
141
  if bs <= model_runner.req_to_token_pool.size
133
142
  and bs <= server_args.cuda_graph_max_bs
134
143
  ]
135
- if is_hip_:
136
- capture_bs += [i * 8 for i in range(21, 33)]
137
144
  compile_bs = (
138
145
  [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
139
146
  if server_args.enable_torch_compile
@@ -173,6 +180,7 @@ class CudaGraphRunner:
173
180
  # Batch sizes to capture
174
181
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
175
182
  self.capture_forward_mode = ForwardMode.DECODE
183
+ self.capture_hidden_mode = CaptureHiddenMode.NULL
176
184
  self.num_tokens_per_bs = 1
177
185
  if model_runner.spec_algorithm.is_eagle():
178
186
  if self.model_runner.is_draft_worker:
@@ -192,6 +200,9 @@ class CudaGraphRunner:
192
200
  )
193
201
  # FIXME(lsyin): leave it here for now, I don't know whether it is necessary
194
202
  self.encoder_len_fill_value = 0
203
+ self.seq_lens_cpu = torch.full(
204
+ (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
205
+ )
195
206
 
196
207
  if self.enable_torch_compile:
197
208
  set_torch_compile_config()
@@ -230,6 +241,9 @@ class CudaGraphRunner:
230
241
  ),
231
242
  dtype=self.model_runner.dtype,
232
243
  )
244
+ self.global_num_tokens_gpu = torch.zeros(
245
+ (self.dp_size,), dtype=torch.int32
246
+ )
233
247
 
234
248
  # Capture
235
249
  try:
@@ -258,9 +272,9 @@ class CudaGraphRunner:
258
272
 
259
273
  def can_run(self, forward_batch: ForwardBatch):
260
274
  if self.enable_dp_attention:
261
- min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
262
- forward_batch.global_num_tokens
263
- )
275
+ min_num_tokens, max_num_tokens = min(
276
+ forward_batch.global_num_tokens_cpu
277
+ ), max(forward_batch.global_num_tokens_cpu)
264
278
  is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
265
279
  (min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
266
280
  if self.disable_padding
@@ -333,6 +347,10 @@ class CudaGraphRunner:
333
347
  gathered_buffer = None
334
348
 
335
349
  spec_info = self.get_spec_info(num_tokens)
350
+ if self.capture_hidden_mode != CaptureHiddenMode.FULL:
351
+ self.capture_hidden_mode = (
352
+ spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
353
+ )
336
354
 
337
355
  forward_batch = ForwardBatch(
338
356
  forward_mode=self.capture_forward_mode,
@@ -348,20 +366,12 @@ class CudaGraphRunner:
348
366
  encoder_lens=encoder_lens,
349
367
  return_logprob=False,
350
368
  positions=positions,
351
- global_num_tokens=global_num_tokens,
369
+ global_num_tokens_cpu=global_num_tokens,
352
370
  gathered_buffer=gathered_buffer,
353
371
  mrope_positions=mrope_positions,
354
372
  spec_algorithm=self.model_runner.spec_algorithm,
355
373
  spec_info=spec_info,
356
- capture_hidden_mode=(
357
- CaptureHiddenMode.FULL
358
- if self.model_runner.server_args.return_hidden_states
359
- else (
360
- spec_info.capture_hidden_mode
361
- if spec_info
362
- else CaptureHiddenMode.NULL
363
- )
364
- ),
374
+ capture_hidden_mode=self.capture_hidden_mode,
365
375
  )
366
376
 
367
377
  # Attention backend
@@ -386,9 +396,6 @@ class CudaGraphRunner:
386
396
 
387
397
  run_once()
388
398
 
389
- torch.cuda.synchronize()
390
- self.model_runner.tp_group.barrier()
391
-
392
399
  torch.cuda.synchronize()
393
400
  self.model_runner.tp_group.barrier()
394
401
 
@@ -402,15 +409,34 @@ class CudaGraphRunner:
402
409
  global_graph_memory_pool = graph.pool()
403
410
  return graph, out
404
411
 
412
+ def recapture_if_needed(self, forward_batch: ForwardBatch):
413
+ # If the capture_hidden_mode changes, we need to recapture the graph
414
+ hidden_mode_from_spec_info = getattr(
415
+ forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
416
+ )
417
+ if (
418
+ forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
419
+ and self.capture_hidden_mode != CaptureHiddenMode.FULL
420
+ ):
421
+ self.capture_hidden_mode = CaptureHiddenMode.FULL
422
+ self.capture()
423
+ elif (
424
+ forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL
425
+ and self.capture_hidden_mode != hidden_mode_from_spec_info
426
+ ):
427
+ self.capture_hidden_mode = hidden_mode_from_spec_info
428
+ self.capture()
429
+
405
430
  def replay(self, forward_batch: ForwardBatch):
406
- assert forward_batch.out_cache_loc is not None
431
+ self.recapture_if_needed(forward_batch)
432
+
407
433
  raw_bs = forward_batch.batch_size
408
434
  raw_num_token = raw_bs * self.num_tokens_per_bs
409
435
 
410
436
  # Pad
411
437
  if self.enable_dp_attention:
412
438
  index = bisect.bisect_left(
413
- self.capture_bs, max(forward_batch.global_num_tokens)
439
+ self.capture_bs, max(forward_batch.global_num_tokens_cpu)
414
440
  )
415
441
  else:
416
442
  index = bisect.bisect_left(self.capture_bs, raw_bs)
@@ -425,6 +451,10 @@ class CudaGraphRunner:
425
451
  self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
426
452
  self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
427
453
  self.positions[:raw_num_token].copy_(forward_batch.positions)
454
+ if forward_batch.decode_seq_lens_cpu is not None:
455
+ if bs != raw_bs:
456
+ self.seq_lens_cpu.fill_(1)
457
+ self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
428
458
 
429
459
  if self.is_encoder_decoder:
430
460
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
@@ -443,6 +473,7 @@ class CudaGraphRunner:
443
473
  self.encoder_lens,
444
474
  forward_batch.forward_mode,
445
475
  forward_batch.spec_info,
476
+ seq_lens_cpu=self.seq_lens_cpu,
446
477
  )
447
478
 
448
479
  # Replay