sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ from enum import Enum, auto
17
17
  from typing import Any, List, Optional
18
18
 
19
19
  from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType
20
- from sglang.srt.poll_based_barrier import PollBasedBarrier
20
+ from sglang.srt.utils.poll_based_barrier import PollBasedBarrier
21
21
 
22
22
  logger = logging.getLogger(__name__)
23
23
 
@@ -12,7 +12,6 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
12
12
  from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
13
13
  from sglang.srt.managers.schedule_policy import PrefillAdder
14
14
  from sglang.srt.managers.scheduler import Req, ScheduleBatch
15
- from sglang.srt.managers.utils import DPBalanceMeta
16
15
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
17
16
  from sglang.srt.utils import get_bool_env_var
18
17
 
@@ -47,8 +46,11 @@ class SchedulerMetricsMixin:
47
46
  self.spec_num_total_forward_ct = 0
48
47
  self.cum_spec_accept_length = 0
49
48
  self.cum_spec_accept_count = 0
50
- self.total_retracted_reqs = 0
49
+ self.kv_transfer_speed_gb_s: float = 0.0
50
+ self.kv_transfer_latency_ms: float = 0.0
51
+
51
52
  self.stats = SchedulerStats()
53
+
52
54
  if self.enable_metrics:
53
55
  engine_type = "unified"
54
56
  labels = {
@@ -61,33 +63,30 @@ class SchedulerMetricsMixin:
61
63
  labels["dp_rank"] = dp_rank
62
64
  self.metrics_collector = SchedulerMetricsCollector(labels=labels)
63
65
 
64
- def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]):
65
- self.balance_meta = dp_balance_meta
66
- if (
67
- self.server_args.enable_dp_attention
68
- and self.server_args.load_balance_method == "minimum_tokens"
69
- ):
70
- assert dp_balance_meta is not None
71
-
72
- self.recv_dp_balance_id_this_term = []
73
-
74
66
  def init_kv_events(self: Scheduler, kv_events_config: Optional[str]):
75
67
  if self.enable_kv_cache_events:
76
68
  self.kv_event_publisher = EventPublisherFactory.create(
77
69
  kv_events_config, self.attn_dp_rank
78
70
  )
79
71
 
72
+ def udpate_spec_metrics(self, bs: int, num_accepted_tokens: int):
73
+ self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
74
+ self.spec_num_total_forward_ct += bs
75
+ self.num_generated_tokens += num_accepted_tokens
76
+
80
77
  def log_prefill_stats(
81
78
  self: Scheduler,
82
79
  adder: PrefillAdder,
83
80
  can_run_list: List[Req],
84
81
  running_bs: int,
82
+ running_bs_offline_batch: int,
85
83
  ):
86
84
  gap_latency = time.perf_counter() - self.last_prefill_stats_tic
87
85
  self.last_prefill_stats_tic = time.perf_counter()
88
86
  self.last_input_throughput = self.last_prefill_tokens / gap_latency
89
87
  self.last_prefill_tokens = adder.log_input_tokens
90
88
 
89
+ # TODO: generalize this for various memory pools
91
90
  if self.is_hybrid:
92
91
  (
93
92
  full_num_used,
@@ -101,51 +100,53 @@ class SchedulerMetricsMixin:
101
100
  ) = self._get_swa_token_info()
102
101
  num_used = max(full_num_used, swa_num_used)
103
102
  token_usage = max(full_token_usage, swa_token_usage)
104
- token_msg = (
103
+ token_usage_msg = (
105
104
  f"full token usage: {full_token_usage:.2f}, "
106
105
  f"swa token usage: {swa_token_usage:.2f}, "
107
106
  )
108
107
  else:
109
108
  num_used, token_usage, _, _ = self._get_token_info()
110
- token_msg = f"token usage: {token_usage:.2f}, "
109
+ token_usage_msg = f"token usage: {token_usage:.2f}, "
111
110
 
112
- num_new_seq = len(can_run_list)
113
111
  f = (
114
112
  f"Prefill batch. "
115
- f"#new-seq: {num_new_seq}, "
113
+ f"#new-seq: {len(can_run_list)}, "
116
114
  f"#new-token: {adder.log_input_tokens}, "
117
115
  f"#cached-token: {adder.log_hit_tokens}, "
118
- f"{token_msg}"
116
+ f"{token_usage_msg}"
117
+ f"#running-req: {running_bs}, "
118
+ f"#queue-req: {len(self.waiting_queue)}, "
119
119
  )
120
120
 
121
121
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
122
- f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
123
- f += f"#queue-req: {len(self.waiting_queue)}, "
124
- f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
125
- f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
126
- else:
127
- f += f"#running-req: {running_bs}, "
128
- f += f"#queue-req: {len(self.waiting_queue)}, "
122
+ f += f"#prealloc-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
123
+ f += f"#inflight-req: {len(self.disagg_prefill_inflight_queue)}, "
129
124
 
130
125
  logger.info(f)
131
126
 
132
127
  if self.enable_metrics:
128
+ # Basics
133
129
  total_tokens = adder.log_input_tokens + adder.log_hit_tokens
134
-
135
130
  cache_hit_rate = (
136
131
  adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
137
132
  )
133
+
138
134
  self.stats.num_running_reqs = running_bs
135
+ self.stats.num_running_reqs_offline_batch = running_bs_offline_batch
139
136
  self.stats.num_used_tokens = num_used
140
- self.stats.token_usage = round(token_usage, 2)
137
+ self.stats.token_usage = token_usage
138
+ if self.is_hybrid:
139
+ self.stats.swa_token_usage = swa_token_usage
141
140
  self.stats.num_queue_reqs = len(self.waiting_queue)
141
+ self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
142
142
  self.stats.cache_hit_rate = cache_hit_rate
143
143
 
144
- total_queue_latency = 0
145
- for req in can_run_list:
146
- total_queue_latency += req.queue_time_end - req.queue_time_start
147
- self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
144
+ # Retract
145
+ self.stats.num_retracted_reqs = self.num_retracted_reqs
146
+ self.stats.num_paused_reqs = self.num_paused_reqs
147
+ self.num_retracted_reqs = self.num_paused_reqs = 0
148
148
 
149
+ # PD disaggregation
149
150
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
150
151
  self.stats.num_prefill_prealloc_queue_reqs = len(
151
152
  self.disagg_prefill_bootstrap_queue.queue
@@ -153,7 +154,18 @@ class SchedulerMetricsMixin:
153
154
  self.stats.num_prefill_inflight_queue_reqs = len(
154
155
  self.disagg_prefill_inflight_queue
155
156
  )
157
+ self.stats.kv_transfer_speed_gb_s = self.kv_transfer_speed_gb_s
158
+ self.stats.kv_transfer_latency_ms = self.kv_transfer_latency_ms
159
+ elif self.disaggregation_mode == DisaggregationMode.DECODE:
160
+ self.stats.num_decode_prealloc_queue_reqs = len(
161
+ self.disagg_decode_prealloc_queue.queue
162
+ )
163
+ self.stats.num_decode_transfer_queue_reqs = len(
164
+ self.disagg_decode_transfer_queue.queue
165
+ )
156
166
 
167
+ # Others
168
+ self.calculate_utilization()
157
169
  self.metrics_collector.log_stats(self.stats)
158
170
  self._emit_kv_metrics()
159
171
  self._publish_kv_events()
@@ -166,8 +178,12 @@ class SchedulerMetricsMixin:
166
178
  gap_latency = time.perf_counter() - self.last_decode_stats_tic
167
179
  self.last_decode_stats_tic = time.perf_counter()
168
180
  self.last_gen_throughput = self.num_generated_tokens / gap_latency
181
+
169
182
  self.num_generated_tokens = 0
170
183
  num_running_reqs = len(batch.reqs)
184
+ num_running_reqs_offline_batch = 0
185
+
186
+ # TODO: generalize this for various memory pools
171
187
  if self.is_hybrid:
172
188
  (
173
189
  full_num_used,
@@ -181,7 +197,7 @@ class SchedulerMetricsMixin:
181
197
  ) = self._get_swa_token_info()
182
198
  num_used = max(full_num_used, swa_num_used)
183
199
  token_usage = max(full_token_usage, swa_token_usage)
184
- token_msg = (
200
+ token_usage_msg = (
185
201
  f"#full token: {full_num_used}, "
186
202
  f"full token usage: {full_token_usage:.2f}, "
187
203
  f"#swa token: {swa_num_used}, "
@@ -189,14 +205,14 @@ class SchedulerMetricsMixin:
189
205
  )
190
206
  else:
191
207
  num_used, token_usage, _, _ = self._get_token_info()
192
- token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
208
+ token_usage_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, "
193
209
 
194
210
  if RECORD_STEP_TIME:
195
211
  self.step_time_dict[num_running_reqs].append(
196
212
  gap_latency / self.server_args.decode_log_interval
197
213
  )
198
214
 
199
- msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
215
+ msg = f"Decode batch. #running-req: {num_running_reqs}, {token_usage_msg}"
200
216
 
201
217
  if self.spec_algorithm.is_none():
202
218
  spec_accept_length = 0
@@ -208,41 +224,66 @@ class SchedulerMetricsMixin:
208
224
  self.cum_spec_accept_count += self.spec_num_total_forward_ct
209
225
  self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
210
226
  msg += f"accept len: {spec_accept_length:.2f}, "
227
+ cache_hit_rate = 0.0
211
228
 
212
229
  if self.disaggregation_mode == DisaggregationMode.DECODE:
213
230
  msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
231
+ msg += f"#prealloc-req: {len(self.disagg_decode_prealloc_queue.queue)}, "
232
+ msg += f"#transfer-req: {len(self.disagg_decode_transfer_queue.queue)}, "
214
233
  msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
215
234
 
216
235
  msg += (
217
- f"{'cpu graph' if self.device == 'cpu' else 'cuda graph'}: {can_run_cuda_graph}, "
236
+ f"{'cuda graph' if self.device == 'cuda' else 'cpu graph'}: {can_run_cuda_graph}, "
218
237
  f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
219
238
  f"#queue-req: {len(self.waiting_queue)}, "
220
239
  )
221
240
 
222
241
  logger.info(msg)
223
242
  if self.enable_metrics:
243
+ # Basics
224
244
  self.stats.num_running_reqs = num_running_reqs
245
+ self.stats.num_running_reqs_offline_batch = num_running_reqs_offline_batch
225
246
  self.stats.num_used_tokens = num_used
226
- self.stats.token_usage = round(token_usage, 2)
227
- self.stats.cache_hit_rate = 0.0
247
+ self.stats.token_usage = token_usage
248
+ if self.is_hybrid:
249
+ self.stats.swa_token_usage = swa_token_usage
228
250
  self.stats.gen_throughput = self.last_gen_throughput
229
251
  self.stats.num_queue_reqs = len(self.waiting_queue)
230
252
  self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
253
+ self.stats.cache_hit_rate = cache_hit_rate
231
254
  self.stats.spec_accept_length = spec_accept_length
232
- self.stats.total_retracted_reqs = self.total_retracted_reqs
233
- self.stats.avg_request_queue_latency = 0.0
234
- if self.disaggregation_mode == DisaggregationMode.DECODE:
255
+
256
+ # Retract
257
+ self.stats.num_retracted_reqs = self.num_retracted_reqs
258
+ self.stats.num_paused_reqs = self.num_paused_reqs
259
+ self.num_retracted_reqs = self.num_paused_reqs = 0
260
+
261
+ # PD disaggregation
262
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
263
+ self.stats.num_prefill_prealloc_queue_reqs = len(
264
+ self.disagg_prefill_bootstrap_queue.queue
265
+ )
266
+ self.stats.num_prefill_inflight_queue_reqs = len(
267
+ self.disagg_prefill_inflight_queue
268
+ )
269
+ elif self.disaggregation_mode == DisaggregationMode.DECODE:
235
270
  self.stats.num_decode_prealloc_queue_reqs = len(
236
271
  self.disagg_decode_prealloc_queue.queue
237
272
  )
238
273
  self.stats.num_decode_transfer_queue_reqs = len(
239
274
  self.disagg_decode_transfer_queue.queue
240
275
  )
276
+
277
+ # Others
278
+ self.calculate_utilization()
241
279
  self.metrics_collector.log_stats(self.stats)
242
280
  self._emit_kv_metrics()
243
281
  self._publish_kv_events()
244
282
 
245
283
  def _emit_kv_metrics(self: Scheduler):
284
+ if not self.enable_kv_cache_events:
285
+ return
286
+
246
287
  kv_metrics = KvMetrics()
247
288
  kv_metrics.request_active_slots = self.stats.num_running_reqs
248
289
  kv_metrics.request_total_slots = self.max_running_requests
@@ -259,93 +300,24 @@ class SchedulerMetricsMixin:
259
300
  self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
260
301
 
261
302
  def _publish_kv_events(self: Scheduler):
262
- if self.enable_kv_cache_events:
263
- events = self.tree_cache.take_events()
264
- if events:
265
- batch = KVEventBatch(ts=time.time(), events=events)
266
- self.kv_event_publisher.publish(batch)
303
+ if not self.enable_kv_cache_events:
304
+ return
267
305
 
268
- def maybe_update_dp_balance_data(
269
- self: Scheduler, recv_req: TokenizedGenerateReqInput
270
- ):
271
- if (
272
- self.server_args.enable_dp_attention
273
- and self.server_args.load_balance_method == "minimum_tokens"
274
- ):
275
- self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
276
-
277
- def maybe_handle_dp_balance_data(self: Scheduler):
278
- if (
279
- self.server_args.load_balance_method == "minimum_tokens"
280
- and self.forward_ct % 40 == 0
281
- ):
282
- holding_tokens = self.get_load().num_tokens
283
-
284
- new_recv_dp_balance_id_list, holding_token_list = (
285
- self.gather_dp_balance_info(holding_tokens)
286
- )
306
+ events = self.tree_cache.take_events()
307
+ if events:
308
+ batch = KVEventBatch(ts=time.time(), events=events)
309
+ self.kv_event_publisher.publish(batch)
287
310
 
288
- self.recv_dp_balance_id_this_term.clear()
289
- if self.tp_rank == 0: # only first worker write info
290
- self.write_shared_dp_balance_info(
291
- new_recv_dp_balance_id_list, holding_token_list
292
- )
293
-
294
- def gather_dp_balance_info(
295
- self: Scheduler, holding_tokens_list
296
- ) -> Union[None, List[List[int]]]:
297
- """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
298
- recv_list = self.recv_dp_balance_id_this_term
299
- assert len(recv_list) <= 511, (
300
- "The number of requests received this round is too large. "
301
- "Please increase gather_tensor_size and onfly_info_size."
302
- )
303
- # The maximum size of the tensor used for gathering data from all workers.
304
- gather_tensor_size = 512
305
-
306
- # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
307
- recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
308
- recv_tensor[0] = holding_tokens_list
309
- recv_tensor[1] = len(recv_list) # The first element is the length of the list.
310
- recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32)
311
-
312
- if self.tp_rank == 0:
313
- gathered_list = [
314
- torch.zeros(gather_tensor_size, dtype=torch.int32)
315
- for _ in range(self.balance_meta.num_workers)
316
- ]
311
+ def calculate_utilization(self):
312
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
313
+ self.stats.utilization = -1
317
314
  else:
318
- gathered_list = None
319
-
320
- torch.distributed.gather(recv_tensor, gathered_list, group=self.tp_cpu_group)
321
-
322
- gathered_id_list_per_worker = None
323
- if self.tp_rank == 0:
324
- gathered_id_list_per_worker = []
325
- holding_tokens_list = []
326
- for tensor in gathered_list:
327
- holding_tokens_list.append(tensor[0].item())
328
- list_length = tensor[1].item()
329
- gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist())
330
-
331
- return gathered_id_list_per_worker, holding_tokens_list
332
-
333
- def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens):
334
- meta = self.balance_meta
335
-
336
- with meta.mutex:
337
- onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
338
- assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal"
339
- # 1.Check if the rid received by each worker this round is present in onfly.
340
- # If it is, remove the corresponding onfly item.
341
- worker_id = 0
342
- for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
343
- for new_recv_rid in new_recv_rids:
344
- assert (
345
- new_recv_rid in on_fly_reqs
346
- ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
347
- del on_fly_reqs[new_recv_rid]
348
- worker_id += 1
349
- # 2. Atomically write local_tokens and onfly into shm under the mutex
350
- meta.set_shared_onfly_info(onfly_list)
351
- meta.set_shared_local_tokens(local_tokens)
315
+ if (
316
+ self.stats.max_running_requests_under_SLO is not None
317
+ and self.stats.max_running_requests_under_SLO > 0
318
+ ):
319
+ self.stats.utilization = max(
320
+ self.stats.num_running_reqs
321
+ / self.stats.max_running_requests_under_SLO,
322
+ self.stats.token_usage / 0.9,
323
+ )
@@ -9,7 +9,11 @@ import torch
9
9
 
10
10
  from sglang.srt.disaggregation.utils import DisaggregationMode
11
11
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
12
- from sglang.srt.managers.io_struct import AbortReq, BatchEmbeddingOut, BatchTokenIDOut
12
+ from sglang.srt.managers.io_struct import (
13
+ AbortReq,
14
+ BatchEmbeddingOutput,
15
+ BatchTokenIDOutput,
16
+ )
13
17
  from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
14
18
 
15
19
  if TYPE_CHECKING:
@@ -91,7 +95,7 @@ class SchedulerOutputProcessorMixin:
91
95
 
92
96
  if req.finished():
93
97
  self.tree_cache.cache_finished_req(req)
94
- req.time_stats.completion_time = time.time()
98
+ req.time_stats.completion_time = time.perf_counter()
95
99
  elif not batch.decoding_reqs or req not in batch.decoding_reqs:
96
100
  # This updates radix so others can match
97
101
  self.tree_cache.cache_unfinished_req(req)
@@ -140,7 +144,7 @@ class SchedulerOutputProcessorMixin:
140
144
  logger.error(
141
145
  f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
142
146
  )
143
- self.abort_request(AbortReq(req.rid))
147
+ self.abort_request(AbortReq(rid=req.rid))
144
148
  req.grammar.finished = req.finished()
145
149
  else:
146
150
  # being chunked reqs' prefill is not finished
@@ -173,8 +177,7 @@ class SchedulerOutputProcessorMixin:
173
177
  self.set_next_batch_sampling_info_done(batch)
174
178
 
175
179
  else: # embedding or reward model
176
- embeddings, bid = result.embeddings, result.bid
177
- embeddings = embeddings.tolist()
180
+ embeddings = result.embeddings.tolist()
178
181
 
179
182
  # Check finish conditions
180
183
  for i, req in enumerate(batch.reqs):
@@ -250,8 +253,14 @@ class SchedulerOutputProcessorMixin:
250
253
 
251
254
  req.check_finished()
252
255
  if req.finished():
253
- self.tree_cache.cache_finished_req(req)
254
- req.time_stats.completion_time = time.time()
256
+ if self.server_args.disaggregation_decode_enable_offload_kvcache:
257
+ # Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
258
+ if not self.decode_offload_manager.offload_kv_cache(req):
259
+ self.tree_cache.cache_finished_req(req)
260
+ else:
261
+ self.tree_cache.cache_finished_req(req)
262
+
263
+ req.time_stats.completion_time = time.perf_counter()
255
264
 
256
265
  if req.return_logprob and batch.spec_algorithm.is_none():
257
266
  # speculative worker handles logprob in speculative decoding
@@ -287,7 +296,7 @@ class SchedulerOutputProcessorMixin:
287
296
  logger.error(
288
297
  f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
289
298
  )
290
- self.abort_request(AbortReq(req.rid))
299
+ self.abort_request(AbortReq(rid=req.rid))
291
300
  req.grammar.finished = req.finished()
292
301
 
293
302
  self.set_next_batch_sampling_info_done(batch)
@@ -709,8 +718,7 @@ class SchedulerOutputProcessorMixin:
709
718
  return
710
719
 
711
720
  self.send_to_detokenizer.send_pyobj(
712
- BatchTokenIDOut(
713
- rids,
721
+ BatchTokenIDOutput(
714
722
  finished_reasons,
715
723
  decoded_texts,
716
724
  decode_ids_list,
@@ -736,6 +744,7 @@ class SchedulerOutputProcessorMixin:
736
744
  output_token_ids_logprobs_val,
737
745
  output_token_ids_logprobs_idx,
738
746
  output_hidden_states,
747
+ rids=rids,
739
748
  placeholder_tokens_idx=None,
740
749
  placeholder_tokens_val=None,
741
750
  )
@@ -756,12 +765,12 @@ class SchedulerOutputProcessorMixin:
756
765
  prompt_tokens.append(len(req.origin_input_ids))
757
766
  cached_tokens.append(req.cached_tokens)
758
767
  self.send_to_detokenizer.send_pyobj(
759
- BatchEmbeddingOut(
760
- rids,
768
+ BatchEmbeddingOutput(
761
769
  finished_reasons,
762
770
  embeddings,
763
771
  prompt_tokens,
764
772
  cached_tokens,
773
+ rids=rids,
765
774
  placeholder_tokens_idx=None,
766
775
  placeholder_tokens_val=None,
767
776
  )
@@ -97,7 +97,7 @@ class SchedulerProfilerMixin:
97
97
  def start_profile(
98
98
  self, stage: Optional[ForwardMode] = None
99
99
  ) -> ProfileReqOutput | None:
100
- stage_str = f" for {stage.__str__()}" if stage else ""
100
+ stage_str = f" for {stage.name}" if stage else ""
101
101
  logger.info(
102
102
  f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
103
103
  )
@@ -181,7 +181,7 @@ class SchedulerProfilerMixin:
181
181
  if not Path(self.torch_profiler_output_dir).exists():
182
182
  Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
183
183
 
184
- stage_suffix = f"-{stage.__str__()}" if stage else ""
184
+ stage_suffix = f"-{stage.name}" if stage else ""
185
185
  logger.info("Stop profiling" + stage_suffix + "...")
186
186
  if self.torch_profiler is not None:
187
187
  self.torch_profiler.stop()
@@ -204,7 +204,7 @@ class SchedulerProfilerMixin:
204
204
 
205
205
  torch.distributed.barrier(self.tp_cpu_group)
206
206
  if self.tp_rank == 0:
207
- from sglang.srt.utils import rpd_to_chrome_trace
207
+ from sglang.srt.utils.rpd_utils import rpd_to_chrome_trace
208
208
 
209
209
  rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
210
210
  self.rpd_profiler = None
@@ -247,7 +247,7 @@ class SchedulerProfilerMixin:
247
247
  if self.profiler_decode_ct == 0:
248
248
  if self.profile_in_progress:
249
249
  # force trace flush
250
- self.stop_profile(ForwardMode.EXTEND)
250
+ self.stop_profile(stage=ForwardMode.EXTEND)
251
251
  self.start_profile(batch.forward_mode)
252
252
  self.profiler_decode_ct += 1
253
253
  if self.profiler_decode_ct > self.profiler_target_decode_ct:
@@ -294,6 +294,6 @@ class SchedulerProfilerMixin:
294
294
  recv_req.profile_by_stage,
295
295
  recv_req.profile_id,
296
296
  )
297
- return self.start_profile(True)
297
+ return self.start_profile()
298
298
  else:
299
299
  return self.stop_profile()
@@ -5,6 +5,8 @@ import torch
5
5
 
6
6
  from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
7
7
  from sglang.srt.managers.io_struct import (
8
+ DestroyWeightsUpdateGroupReqInput,
9
+ DestroyWeightsUpdateGroupReqOutput,
8
10
  GetWeightsByNameReqInput,
9
11
  GetWeightsByNameReqOutput,
10
12
  InitWeightsUpdateGroupReqInput,
@@ -41,6 +43,11 @@ class SchedulerUpdateWeightsMixin:
41
43
  success, message = self.tp_worker.init_weights_update_group(recv_req)
42
44
  return InitWeightsUpdateGroupReqOutput(success, message)
43
45
 
46
+ def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
47
+ """Destroy the online model parameter update group."""
48
+ success, message = self.tp_worker.destroy_weights_update_group(recv_req)
49
+ return DestroyWeightsUpdateGroupReqOutput(success, message)
50
+
44
51
  def update_weights_from_distributed(
45
52
  self,
46
53
  recv_req: UpdateWeightsFromDistributedReqInput,