sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,12 @@ import torch
30
30
  from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
31
31
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
32
32
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
33
+ from sglang.srt.mem_cache.radix_cache import (
34
+ RadixKey,
35
+ _key_match_page_size1,
36
+ _key_match_paged,
37
+ get_child_key,
38
+ )
33
39
 
34
40
  if TYPE_CHECKING:
35
41
  from sglang.srt.managers.schedule_batch import Req
@@ -47,7 +53,7 @@ class TreeNode:
47
53
  def __init__(self, id: Optional[int] = None):
48
54
  self.children = defaultdict(TreeNode)
49
55
  self.parent: TreeNode = None
50
- self.key: List[int] = None
56
+ self.key: RadixKey = None
51
57
  self.value: Optional[torch.Tensor] = None
52
58
  # swa_tombstone is used to indicate the kv indices have been freed for swa layers
53
59
  self.swa_tombstone = False
@@ -87,27 +93,6 @@ class TreeNode:
87
93
  return self.last_access_time < other.last_access_time
88
94
 
89
95
 
90
- def _key_match_page_size1(key0: List, key1: List):
91
- i = 0
92
- for k0, k1 in zip(key0, key1):
93
- if k0 != k1:
94
- break
95
- i += 1
96
- return i
97
-
98
-
99
- def _key_match_paged(key0: List, key1: List, page_size: int):
100
- min_len = min(len(key0), len(key1))
101
-
102
- i = 0
103
- while i < min_len:
104
- if key0[i : i + page_size] != key1[i : i + page_size]:
105
- break
106
- i += page_size
107
-
108
- return i
109
-
110
-
111
96
  def gen_swa_uuid() -> int:
112
97
  TreeNode.swa_uuid_counter += 1
113
98
  return TreeNode.swa_uuid_counter
@@ -356,10 +341,10 @@ class SWARadixCache(BasePrefixCache):
356
341
 
357
342
  if self.page_size == 1:
358
343
  self.key_match_fn = _key_match_page_size1
359
- self.get_child_key_fn = lambda key: key[0]
344
+ self.get_child_key_fn = get_child_key
360
345
  else:
361
346
  self.key_match_fn = partial(_key_match_paged, page_size=page_size)
362
- self.get_child_key_fn = lambda key: tuple(key[:page_size])
347
+ self.get_child_key_fn = partial(get_child_key, page_size=page_size)
363
348
 
364
349
  self.sliding_window_size = sliding_window_size
365
350
  self.reset()
@@ -380,10 +365,10 @@ class SWARadixCache(BasePrefixCache):
380
365
  self.full_lru_list = LRUList(swa=False)
381
366
  self.swa_lru_list = LRUList(swa=True)
382
367
 
383
- def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
368
+ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
384
369
  """Find the matching prefix from the radix tree.
385
370
  Args:
386
- key: A list of token IDs to find a matching prefix.
371
+ key: A RadixKey contains token IDs to find a matching prefix.
387
372
  Returns:
388
373
  A tuple of a tensor of matching prefix token IDs and
389
374
  the last node that contains the prefix values. Note that
@@ -417,12 +402,12 @@ class SWARadixCache(BasePrefixCache):
417
402
  last_host_node=last_node,
418
403
  )
419
404
 
420
- def insert(self, key: List, value=None, prev_prefix_len: int = 0) -> int:
405
+ def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int:
421
406
  if self.disable:
422
407
  return 0
423
408
 
424
409
  if value is None:
425
- value = [x for x in key]
410
+ value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
426
411
  return self._insert_helper(self.root_node, key, value, prev_prefix_len)
427
412
 
428
413
  def cache_finished_req(self, req: Req) -> None:
@@ -453,7 +438,7 @@ class SWARadixCache(BasePrefixCache):
453
438
  # insert the token_ids and kv_indices into the radix tree
454
439
  # Note: the insert function already frees the overlapped kv_indices
455
440
  new_prefix_len = self.insert(
456
- token_ids[:page_aligned_len],
441
+ RadixKey(token_ids[:page_aligned_len], req.extra_key),
457
442
  page_aligned_kv_indices,
458
443
  len(req.prefix_indices),
459
444
  )
@@ -489,11 +474,15 @@ class SWARadixCache(BasePrefixCache):
489
474
  # Radix Cache takes one ref in memory pool
490
475
  # Note: the insert function already frees the overlapped kv_indices
491
476
  new_prefix_len = self.insert(
492
- page_aligned_token_ids, page_aligned_kv_indices, len(req.prefix_indices)
477
+ RadixKey(page_aligned_token_ids, req.extra_key),
478
+ page_aligned_kv_indices,
479
+ len(req.prefix_indices),
493
480
  )
494
481
 
495
482
  # The prefix indices could be updated, reuse it
496
- new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
483
+ new_indices, new_last_node, _, _ = self.match_prefix(
484
+ RadixKey(page_aligned_token_ids, req.extra_key)
485
+ )
497
486
  assert len(req.prefix_indices) <= len(
498
487
  new_indices
499
488
  ), f"{req.prefix_indices=}, {new_indices=}"
@@ -732,7 +721,9 @@ class SWARadixCache(BasePrefixCache):
732
721
 
733
722
  ##### Internal Helper Functions #####
734
723
 
735
- def _match_prefix_helper(self, key: List) -> Tuple[List[torch.Tensor], TreeNode]:
724
+ def _match_prefix_helper(
725
+ self, key: RadixKey
726
+ ) -> Tuple[List[torch.Tensor], TreeNode]:
736
727
  """
737
728
  SWA prefix matching helper. It factors in the sliding window size such that
738
729
  the matched node is guaranteed to either 1. connected to root without swa tombstone,
@@ -796,7 +787,7 @@ class SWARadixCache(BasePrefixCache):
796
787
 
797
788
  return value[:best_value_len], best_last_node
798
789
 
799
- def _split_node(self, key: List[int], child: TreeNode, split_len: int) -> TreeNode:
790
+ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode:
800
791
  # new_node -> child
801
792
  new_node = TreeNode()
802
793
  new_node.children = {self.get_child_key_fn(key[split_len:]): child}
@@ -831,7 +822,7 @@ class SWARadixCache(BasePrefixCache):
831
822
  return new_node
832
823
 
833
824
  def _insert_helper(
834
- self, node: TreeNode, key: List, value, update_kv_after_len: int
825
+ self, node: TreeNode, key: RadixKey, value, update_kv_after_len: int
835
826
  ) -> int:
836
827
  # Update the last access time from root to leaf, so that
837
828
  # swa will tombstone the node closer to root first
@@ -14,10 +14,10 @@
14
14
  """Utilities for Prometheus Metrics Collection."""
15
15
  import time
16
16
  from dataclasses import dataclass, field
17
- from enum import Enum
18
17
  from typing import Dict, List, Optional, Union
19
18
 
20
- from sglang.srt.metrics.utils import generate_buckets
19
+ from sglang.srt.disaggregation.utils import DisaggregationMode
20
+ from sglang.srt.metrics.utils import exponential_buckets, generate_buckets
21
21
  from sglang.srt.server_args import ServerArgs
22
22
  from sglang.srt.utils import get_bool_env_var
23
23
 
@@ -34,6 +34,7 @@ class TimeStats:
34
34
  Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion
35
35
  """
36
36
 
37
+ disagg_mode: DisaggregationMode = DisaggregationMode.NULL
37
38
  lb_entry_time: float = 0.0
38
39
  wait_queue_entry_time: float = 0.0
39
40
  forward_entry_time: float = 0.0
@@ -43,20 +44,11 @@ class TimeStats:
43
44
  decode_prealloc_queue_entry_time: float = 0.0
44
45
  decode_transfer_queue_entry_time: float = 0.0
45
46
 
46
- class RequestType(Enum):
47
- UNIFIED = "unified"
48
- PREFILL = "prefill"
49
- DECODE = "decode"
50
- INVALID = "invalid"
51
-
52
47
  def get_queueing_time(self) -> float:
53
48
  return self.forward_entry_time - self.wait_queue_entry_time
54
49
 
55
- def __str__(self) -> str:
56
- # if unified
57
- _type = self.get_type()
58
-
59
- if _type == self.RequestType.UNIFIED:
50
+ def convert_to_duration(self) -> str:
51
+ if self.disagg_mode == DisaggregationMode.NULL:
60
52
  queue_duration = self.forward_entry_time - self.wait_queue_entry_time
61
53
  forward_duration = self.completion_time - self.forward_entry_time
62
54
 
@@ -65,30 +57,28 @@ class TimeStats:
65
57
  queue_duration >= 0 and forward_duration >= 0
66
58
  ), f"queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
67
59
 
68
- return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time}"
69
- elif _type == self.RequestType.PREFILL:
60
+ return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time:.3f}"
61
+ elif self.disagg_mode == DisaggregationMode.PREFILL:
70
62
  bootstrap_duration = (
71
63
  self.wait_queue_entry_time - self.prefill_bootstrap_queue_entry_time
72
64
  )
73
-
74
65
  queue_duration = self.forward_entry_time - self.wait_queue_entry_time
75
-
76
66
  forward_duration = self.completion_time - self.forward_entry_time
77
67
 
78
68
  if SGLANG_TEST_REQUEST_TIME_STATS:
79
- assert (
80
- bootstrap_duration >= 0
81
- and queue_duration >= 0
82
- and forward_duration >= 0
83
- ), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
84
- return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time}"
85
- # if decode
86
- elif _type == self.RequestType.DECODE:
69
+ if self.wait_queue_entry_time > 0:
70
+ assert (
71
+ bootstrap_duration >= 0
72
+ and queue_duration >= 0
73
+ and forward_duration >= 0
74
+ ), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
75
+
76
+ return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time:.3f}"
77
+ elif self.disagg_mode == DisaggregationMode.DECODE:
87
78
  prealloc_duration = (
88
79
  self.decode_transfer_queue_entry_time
89
80
  - self.decode_prealloc_queue_entry_time
90
81
  )
91
-
92
82
  transfer_duration = (
93
83
  self.wait_queue_entry_time - self.decode_transfer_queue_entry_time
94
84
  )
@@ -96,42 +86,30 @@ class TimeStats:
96
86
  forward_duration = self.completion_time - self.forward_entry_time
97
87
 
98
88
  if SGLANG_TEST_REQUEST_TIME_STATS:
99
- assert (
100
- prealloc_duration >= 0
101
- and transfer_duration >= 0
102
- and queue_duration >= 0
103
- and forward_duration >= 0
104
- ), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
105
-
106
- return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time}"
89
+ if self.wait_queue_entry_time > 0:
90
+ assert (
91
+ prealloc_duration >= 0
92
+ and transfer_duration >= 0
93
+ and queue_duration >= 0
94
+ and forward_duration >= 0
95
+ ), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0. {self=}"
96
+
97
+ return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time:.3f}"
107
98
  else:
108
- return "Invalid Time Stats"
99
+ return "Unknown Time Stats"
109
100
 
110
101
  def format_duration(self, duration: float) -> str:
111
102
  return f"{duration * 1e3:.2f}ms"
112
103
 
113
- def get_type(self) -> RequestType:
114
- """Determine the type of request based on timestamp values."""
115
- if (
116
- self.prefill_bootstrap_queue_entry_time == 0.0
117
- and self.prefill_transfer_queue_entry_time == 0.0
118
- and self.decode_prealloc_queue_entry_time == 0.0
119
- and self.decode_transfer_queue_entry_time == 0.0
120
- ):
121
- return self.RequestType.UNIFIED
122
- elif (
123
- self.prefill_bootstrap_queue_entry_time > 0.0
124
- and self.prefill_transfer_queue_entry_time > 0.0
125
- ):
126
- return self.RequestType.PREFILL
127
- elif (
128
- self.decode_prealloc_queue_entry_time > 0.0
129
- and self.decode_transfer_queue_entry_time > 0.0
130
- and self.wait_queue_entry_time > 0.0
131
- ):
132
- return self.RequestType.DECODE
104
+ def disagg_mode_str(self) -> str:
105
+ if self.disagg_mode == DisaggregationMode.NULL:
106
+ return "unified"
107
+ elif self.disagg_mode == DisaggregationMode.DECODE:
108
+ return "decode"
109
+ elif self.disagg_mode == DisaggregationMode.PREFILL:
110
+ return "prefill"
133
111
  else:
134
- return self.RequestType.INVALID
112
+ return "unknown"
135
113
 
136
114
 
137
115
  @dataclass
@@ -145,12 +123,15 @@ class SchedulerStats:
145
123
  num_queue_reqs: int = 0
146
124
  num_grammar_queue_reqs: int = 0
147
125
  num_running_reqs_offline_batch: int = 0
148
- avg_request_queue_latency: float = 0.0
149
126
  cache_hit_rate: float = 0.0
150
127
 
151
128
  # Speculative decoding
152
129
  spec_accept_length: float = 0.0
153
130
 
131
+ # Retract
132
+ num_retracted_reqs: int = 0
133
+ num_paused_reqs: int = 0
134
+
154
135
  # PD disaggregation
155
136
  num_prefill_prealloc_queue_reqs: int = 0
156
137
  num_prefill_inflight_queue_reqs: int = 0
@@ -159,11 +140,6 @@ class SchedulerStats:
159
140
  kv_transfer_speed_gb_s: float = 0.0
160
141
  kv_transfer_latency_ms: float = 0.0
161
142
 
162
- # Retract
163
- total_retracted_reqs: int = 0
164
- num_retracted_reqs: int = 0
165
- num_paused_reqs: int = 0
166
-
167
143
  # Utilization
168
144
  utilization: float = 0.0
169
145
  max_running_requests_under_SLO: Optional[int] = None
@@ -230,12 +206,6 @@ class SchedulerMetricsCollector:
230
206
  labelnames=labels.keys(),
231
207
  multiprocess_mode="mostrecent",
232
208
  )
233
- self.avg_request_queue_latency = Gauge(
234
- name="sglang:avg_request_queue_latency",
235
- documentation="The average request queue latency for the last batch of requests in seconds.",
236
- labelnames=labels.keys(),
237
- multiprocess_mode="mostrecent",
238
- )
239
209
  self.cache_hit_rate = Gauge(
240
210
  name="sglang:cache_hit_rate",
241
211
  documentation="The prefix cache hit rate.",
@@ -251,6 +221,18 @@ class SchedulerMetricsCollector:
251
221
  multiprocess_mode="mostrecent",
252
222
  )
253
223
 
224
+ # Retract
225
+ self.num_retracted_reqs = Gauge(
226
+ name="sglang:num_retracted_reqs",
227
+ documentation="The number of retracted requests.",
228
+ labelnames=labels.keys(),
229
+ )
230
+ self.num_paused_reqs = Gauge(
231
+ name="sglang:num_paused_reqs",
232
+ documentation="The number of paused requests by async weight sync.",
233
+ labelnames=labels.keys(),
234
+ )
235
+
254
236
  # PD disaggregation
255
237
  self.num_prefill_prealloc_queue_reqs = Gauge(
256
238
  name="sglang:num_prefill_prealloc_queue_reqs",
@@ -299,24 +281,6 @@ class SchedulerMetricsCollector:
299
281
  multiprocess_mode="mostrecent",
300
282
  )
301
283
 
302
- # Retract
303
- self.total_retracted_reqs = Gauge(
304
- name="sglang:total_retracted_reqs",
305
- documentation="The total number of retracted requests due to kvcache full.",
306
- labelnames=labels.keys(),
307
- multiprocess_mode="mostrecent",
308
- )
309
- self.num_retracted_reqs = Gauge(
310
- name="sglang:num_retracted_reqs",
311
- documentation="The number of retracted requests.",
312
- labelnames=labels.keys(),
313
- )
314
- self.num_paused_reqs = Gauge(
315
- name="sglang:num_paused_reqs",
316
- documentation="The number of paused requests by async weight sync.",
317
- labelnames=labels.keys(),
318
- )
319
-
320
284
  # Utilization
321
285
  self.utilization = Gauge(
322
286
  name="sglang:utilization",
@@ -347,7 +311,7 @@ class SchedulerMetricsCollector:
347
311
 
348
312
  # Additional queueing time histogram
349
313
  self.queue_time = Histogram(
350
- name="sglang:queue_time_s",
314
+ name="sglang:queue_time_seconds",
351
315
  documentation="Histogram of queueing time in seconds.",
352
316
  labelnames=labels.keys(),
353
317
  buckets=[
@@ -513,11 +477,19 @@ class SchedulerMetricsCollector:
513
477
  buckets=tree_traversal_time_buckets,
514
478
  )
515
479
 
480
+ self.per_stage_req_latency_seconds = Histogram(
481
+ name="sglang:per_stage_req_latency_seconds",
482
+ documentation="The latency of each stage of requests.",
483
+ # captures latency in range [1ms - ~1191s]
484
+ buckets=exponential_buckets(start=0.001, width=1.62, length=30),
485
+ labelnames=list(labels.keys()) + ["stage"],
486
+ )
487
+
516
488
  def _log_gauge(self, gauge, data: Union[int, float]) -> None:
517
489
  # Convenience function for logging to gauge.
518
490
  gauge.labels(**self.labels).set(data)
519
491
 
520
- def log_histogram(self, histogram, data: Union[int, float]) -> None:
492
+ def _log_histogram(self, histogram, data: Union[int, float]) -> None:
521
493
  histogram.labels(**self.labels).observe(data)
522
494
 
523
495
  def increment_bootstrap_failed_reqs(self) -> None:
@@ -526,6 +498,13 @@ class SchedulerMetricsCollector:
526
498
  def increment_transfer_failed_reqs(self) -> None:
527
499
  self.num_transfer_failed_reqs.labels(**self.labels).inc(1)
528
500
 
501
+ def observe_per_stage_req_latency(self, stage: str, latency: float) -> None:
502
+ labels_with_stage = {**self.labels, "stage": stage}
503
+ self.per_stage_req_latency_seconds.labels(**labels_with_stage).observe(latency)
504
+
505
+ def observe_queue_time(self, latency: float) -> None:
506
+ self._log_histogram(self.queue_time, latency)
507
+
529
508
  def log_stats(self, stats: SchedulerStats) -> None:
530
509
  self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
531
510
  self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
@@ -538,7 +517,6 @@ class SchedulerMetricsCollector:
538
517
  self.num_running_reqs_offline_batch, stats.num_running_reqs_offline_batch
539
518
  )
540
519
  self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
541
- self._log_gauge(self.avg_request_queue_latency, stats.avg_request_queue_latency)
542
520
 
543
521
  # Speculative decoding
544
522
  self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
@@ -560,7 +538,6 @@ class SchedulerMetricsCollector:
560
538
  self._log_gauge(self.kv_transfer_latency_ms, stats.kv_transfer_latency_ms)
561
539
 
562
540
  # Retract
563
- self._log_gauge(self.total_retracted_reqs, stats.total_retracted_reqs)
564
541
  self._log_gauge(self.num_retracted_reqs, stats.num_retracted_reqs)
565
542
  self._log_gauge(self.num_paused_reqs, stats.num_paused_reqs)
566
543
 
@@ -584,19 +561,19 @@ class SchedulerMetricsCollector:
584
561
  def log_grammar_stats(self, grammar_stats) -> None:
585
562
  # Duck-typed GrammarStats to avoid cross-package dependency
586
563
  if getattr(grammar_stats, "compilation_time", None) is not None:
587
- self.log_histogram(
564
+ self._log_histogram(
588
565
  self.grammar_compilation_time, grammar_stats.compilation_time
589
566
  )
590
567
  if getattr(grammar_stats, "schema_count", None) is not None:
591
- self.log_histogram(self.grammar_schema_count, grammar_stats.schema_count)
568
+ self._log_histogram(self.grammar_schema_count, grammar_stats.schema_count)
592
569
  if getattr(grammar_stats, "ebnf_size", None) is not None:
593
- self.log_histogram(self.grammar_ebnf_size, grammar_stats.ebnf_size)
570
+ self._log_histogram(self.grammar_ebnf_size, grammar_stats.ebnf_size)
594
571
  tree_times = getattr(grammar_stats, "tree_traversal_time", None)
595
572
  if tree_times:
596
573
  max_time = max(tree_times)
597
574
  avg_time = sum(tree_times) / len(tree_times)
598
- self.log_histogram(self.grammar_tree_traversal_time_max, max_time)
599
- self.log_histogram(self.grammar_tree_traversal_time_avg, avg_time)
575
+ self._log_histogram(self.grammar_tree_traversal_time_max, max_time)
576
+ self._log_histogram(self.grammar_tree_traversal_time_avg, avg_time)
600
577
  if getattr(grammar_stats, "is_cache_hit", False):
601
578
  self.num_grammar_cache_hit.labels(**self.labels).inc(1)
602
579
  if getattr(grammar_stats, "is_grammar_aborted", False):
@@ -702,7 +679,7 @@ class TokenizerMetricsCollector:
702
679
  )
703
680
 
704
681
  self.num_aborted_requests_total = Counter(
705
- name="sglang:num_aborted_requests",
682
+ name="sglang:num_aborted_requests_total",
706
683
  documentation="Number of requests aborted.",
707
684
  labelnames=labels.keys(),
708
685
  )
@@ -789,7 +766,7 @@ class TokenizerMetricsCollector:
789
766
  buckets=bucket_time_to_first_token,
790
767
  )
791
768
 
792
- self.histogram_inter_token_latency_seconds = Histogram(
769
+ self.histogram_inter_token_latency = Histogram(
793
770
  name="sglang:inter_token_latency_seconds",
794
771
  documentation="Histogram of inter-token latency in seconds.",
795
772
  labelnames=labels.keys(),
@@ -803,14 +780,6 @@ class TokenizerMetricsCollector:
803
780
  buckets=bucket_e2e_request_latency,
804
781
  )
805
782
 
806
- # Offline batch specific TTFB histogram
807
- self.histogram_time_to_first_token_offline_batch = Histogram(
808
- name="sglang:time_to_first_token_seconds_offline_batch",
809
- documentation="Histogram of time to first token in seconds for offline batch requests.",
810
- labelnames=labels.keys(),
811
- buckets=bucket_time_to_first_token,
812
- )
813
-
814
783
  def observe_one_finished_request(
815
784
  self,
816
785
  labels: Dict[str, str],
@@ -834,26 +803,19 @@ class TokenizerMetricsCollector:
834
803
  float(generation_tokens)
835
804
  )
836
805
 
837
- def observe_time_to_first_token(
838
- self, labels: Dict[str, str], value: float, type: str = ""
839
- ):
840
- if type == "batch":
841
- self.histogram_time_to_first_token_offline_batch.labels(**labels).observe(
842
- value
843
- )
844
- else:
845
- self.histogram_time_to_first_token.labels(**labels).observe(value)
806
+ def observe_time_to_first_token(self, labels: Dict[str, str], value: float):
807
+ self.histogram_time_to_first_token.labels(**labels).observe(value)
846
808
 
847
809
  def check_time_to_first_token_straggler(self, value: float) -> bool:
848
810
  his = self.histogram_time_to_first_token.labels(**self.labels)
849
811
  total_observations = sum(bucket._value for bucket in his._buckets)
850
- if total_observations < 100:
812
+ if total_observations < 1000:
851
813
  return False
852
- p99_threshold = total_observations * 0.99
814
+ p999_threshold = total_observations * 0.999
853
815
  cumulative_count = 0
854
816
  for i, bucket in enumerate(his._buckets):
855
817
  cumulative_count += bucket._value
856
- if cumulative_count > p99_threshold:
818
+ if cumulative_count > p999_threshold:
857
819
  return value >= his._upper_bounds[i]
858
820
  return False
859
821
 
@@ -864,7 +826,7 @@ class TokenizerMetricsCollector:
864
826
 
865
827
  # A faster version of the Histogram::observe which observes multiple values at the same time.
866
828
  # reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639
867
- his = self.histogram_inter_token_latency_seconds.labels(**labels)
829
+ his = self.histogram_inter_token_latency.labels(**labels)
868
830
  his._sum.inc(internval)
869
831
 
870
832
  for i, bound in enumerate(his._upper_bounds):
@@ -872,8 +834,8 @@ class TokenizerMetricsCollector:
872
834
  his._buckets[i].inc(num_new_tokens)
873
835
  break
874
836
 
875
- def observe_one_aborted_request(self):
876
- self.num_aborted_requests_total.labels(**self.labels).inc(1)
837
+ def observe_one_aborted_request(self, labels: Dict[str, str]):
838
+ self.num_aborted_requests_total.labels(**labels).inc(1)
877
839
 
878
840
 
879
841
  @dataclass
@@ -20,6 +20,8 @@ import time
20
20
  from functools import wraps
21
21
  from typing import Any, Callable, List, Optional
22
22
 
23
+ from sglang.srt.metrics.utils import exponential_buckets
24
+
23
25
  enable_metrics = False
24
26
 
25
27
 
@@ -42,13 +44,6 @@ def enable_func_timer():
42
44
  FUNC_LATENCY = None
43
45
 
44
46
 
45
- def exponential_buckets(start: float, width: float, length: int) -> List[float]:
46
- buckets = []
47
- for i in range(length):
48
- buckets.append(start * (width**i))
49
- return buckets
50
-
51
-
52
47
  def time_func_latency(
53
48
  func: Callable = None, name: Optional[str] = None
54
49
  ) -> Callable[..., Any]:
@@ -44,5 +44,12 @@ def generate_buckets(
44
44
  return two_sides_exponential_buckets(float(middle), float(base), int(count))
45
45
  if rule == "default":
46
46
  return sorted(set(default_buckets))
47
- assert rule == "customer"
47
+ assert rule == "custom"
48
48
  return sorted(set([float(x) for x in buckets_rule[1:]]))
49
+
50
+
51
+ def exponential_buckets(start: float, width: float, length: int) -> List[float]:
52
+ buckets = []
53
+ for i in range(length):
54
+ buckets.append(start * (width**i))
55
+ return buckets
@@ -34,7 +34,6 @@ from sglang.srt.model_executor.forward_batch_info import (
34
34
  ForwardMode,
35
35
  PPProxyTensors,
36
36
  )
37
- from sglang.srt.patch_torch import monkey_patch_torch_compile
38
37
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
39
38
  from sglang.srt.utils import (
40
39
  log_info_on_rank0,
@@ -43,6 +42,7 @@ from sglang.srt.utils import (
43
42
  require_mlp_sync,
44
43
  require_mlp_tp_gather,
45
44
  )
45
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_compile
46
46
 
47
47
  logger = logging.getLogger(__name__)
48
48
 
@@ -607,7 +607,7 @@ class CPUGraphRunner:
607
607
  def get_spec_info(self, num_tokens: int):
608
608
  spec_info = None
609
609
  if self.model_runner.spec_algorithm.is_eagle():
610
- from sglang.srt.speculative.eagle_utils import EagleVerifyInput
610
+ from sglang.srt.speculative.eagle_info import EagleVerifyInput
611
611
 
612
612
  if self.model_runner.is_draft_worker:
613
613
  raise RuntimeError("This should not happen.")