sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
21
21
  from __future__ import annotations
22
22
 
23
23
  import logging
24
+ import time
24
25
  from collections import deque
25
26
  from dataclasses import dataclass
26
27
  from http import HTTPStatus
@@ -45,7 +46,7 @@ from sglang.srt.disaggregation.utils import (
45
46
  prepare_abort,
46
47
  )
47
48
  from sglang.srt.layers.dp_attention import get_attention_tp_size
48
- from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
49
+ from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
49
50
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
50
51
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
51
52
  from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
@@ -253,6 +254,7 @@ class DecodePreallocQueue:
253
254
  prefill_dp_rank=req.data_parallel_rank,
254
255
  )
255
256
 
257
+ req.add_latency(RequestStage.DECODE_PREPARE)
256
258
  self.queue.append(
257
259
  DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
258
260
  )
@@ -421,8 +423,13 @@ class DecodePreallocQueue:
421
423
  kv_indices, self.token_to_kv_pool_allocator.page_size
422
424
  )
423
425
  decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
426
+
424
427
  preallocated_reqs.append(decode_req)
425
428
  indices_to_remove.add(i)
429
+ decode_req.req.time_stats.decode_transfer_queue_entry_time = (
430
+ time.perf_counter()
431
+ )
432
+ decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
426
433
 
427
434
  self.queue = [
428
435
  entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
@@ -516,11 +523,19 @@ class DecodePreallocQueue:
516
523
  dtype=torch.int64,
517
524
  device=self.token_to_kv_pool_allocator.device,
518
525
  ),
526
+ prefix_lens_cpu=torch.tensor(
527
+ [0],
528
+ dtype=torch.int64,
529
+ ),
519
530
  seq_lens=torch.tensor(
520
531
  [num_tokens],
521
532
  dtype=torch.int64,
522
533
  device=self.token_to_kv_pool_allocator.device,
523
534
  ),
535
+ seq_lens_cpu=torch.tensor(
536
+ [num_tokens],
537
+ dtype=torch.int64,
538
+ ),
524
539
  last_loc=torch.tensor(
525
540
  [-1],
526
541
  dtype=torch.int64,
@@ -607,16 +622,23 @@ class DecodeTransferQueue:
607
622
  idx = decode_req.metadata_buffer_index
608
623
  (
609
624
  output_id,
625
+ cached_tokens,
610
626
  output_token_logprobs_val,
611
627
  output_token_logprobs_idx,
612
628
  output_top_logprobs_val,
613
629
  output_top_logprobs_idx,
630
+ output_topk_p,
631
+ output_topk_index,
614
632
  output_hidden_states,
615
633
  ) = self.metadata_buffers.get_buf(idx)
616
634
 
617
635
  decode_req.req.output_ids.append(output_id[0].item())
636
+ decode_req.req.cached_tokens = cached_tokens[0].item()
618
637
  if not self.spec_algorithm.is_none():
638
+ decode_req.req.output_topk_p = output_topk_p
639
+ decode_req.req.output_topk_index = output_topk_index
619
640
  decode_req.req.hidden_states_tensor = output_hidden_states
641
+
620
642
  if decode_req.req.return_logprob:
621
643
  decode_req.req.output_token_logprobs_val.append(
622
644
  output_token_logprobs_val[0].item()
@@ -637,10 +659,17 @@ class DecodeTransferQueue:
637
659
 
638
660
  if hasattr(decode_req.kv_receiver, "clear"):
639
661
  decode_req.kv_receiver.clear()
662
+ decode_req.kv_receiver = None
663
+
664
+ indices_to_remove.add(i)
665
+ decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()
640
666
 
641
667
  # special handling for sampling_params.max_new_tokens == 1
642
668
  if decode_req.req.sampling_params.max_new_tokens == 1:
643
669
  # finish immediately
670
+ decode_req.req.time_stats.forward_entry_time = (
671
+ decode_req.req.time_stats.completion_time
672
+ ) = time.perf_counter()
644
673
  decode_req.req.check_finished()
645
674
  self.scheduler.stream_output(
646
675
  [decode_req.req], decode_req.req.return_logprob
@@ -648,8 +677,6 @@ class DecodeTransferQueue:
648
677
  self.tree_cache.cache_finished_req(decode_req.req)
649
678
  else:
650
679
  transferred_reqs.append(decode_req.req)
651
-
652
- indices_to_remove.add(i)
653
680
  elif poll in [
654
681
  KVPoll.Bootstrapping,
655
682
  KVPoll.WaitingForInput,
@@ -662,6 +689,7 @@ class DecodeTransferQueue:
662
689
  for i in indices_to_remove:
663
690
  idx = self.queue[i].metadata_buffer_index
664
691
  assert idx != -1
692
+ self.queue[i].req.add_latency(RequestStage.DECODE_TRANSFERRED)
665
693
  self.req_to_metadata_buffer_idx_allocator.free(idx)
666
694
 
667
695
  self.queue = [
@@ -704,12 +732,15 @@ class SchedulerDisaggregationDecodeMixin:
704
732
  elif prepare_mlp_sync_flag:
705
733
  batch, _ = self._prepare_idle_batch_and_run(None)
706
734
 
707
- if batch is None and (
735
+ queue_size = (
708
736
  len(self.waiting_queue)
709
737
  + len(self.disagg_decode_transfer_queue.queue)
710
738
  + len(self.disagg_decode_prealloc_queue.queue)
711
- == 0
712
- ):
739
+ )
740
+ if self.server_args.disaggregation_decode_enable_offload_kvcache:
741
+ queue_size += len(self.decode_offload_manager.ongoing_offload)
742
+
743
+ if batch is None and queue_size == 0:
713
744
  self.self_check_during_idle()
714
745
 
715
746
  self.last_batch = batch
@@ -778,12 +809,15 @@ class SchedulerDisaggregationDecodeMixin:
778
809
  )
779
810
  self.process_batch_result(tmp_batch, tmp_result)
780
811
 
781
- if batch is None and (
812
+ queue_size = (
782
813
  len(self.waiting_queue)
783
814
  + len(self.disagg_decode_transfer_queue.queue)
784
815
  + len(self.disagg_decode_prealloc_queue.queue)
785
- == 0
786
- ):
816
+ )
817
+ if self.server_args.disaggregation_decode_enable_offload_kvcache:
818
+ queue_size += len(self.decode_offload_manager.ongoing_offload)
819
+
820
+ if batch is None and queue_size == 0:
787
821
  self.self_check_during_idle()
788
822
 
789
823
  self.last_batch = batch
@@ -853,6 +887,7 @@ class SchedulerDisaggregationDecodeMixin:
853
887
  # we can only add at least `num_not_used_batch` new batch to the running queue
854
888
  if i < num_not_used_batch:
855
889
  can_run_list.append(req)
890
+ req.add_latency(RequestStage.DECODE_WAITING)
856
891
  req.init_next_round_input(self.tree_cache)
857
892
  else:
858
893
  waiting_queue.append(req)
@@ -861,6 +896,9 @@ class SchedulerDisaggregationDecodeMixin:
861
896
  if len(can_run_list) == 0:
862
897
  return None
863
898
 
899
+ for req in can_run_list:
900
+ req.time_stats.forward_entry_time = time.perf_counter()
901
+
864
902
  # construct a schedule batch with those requests and mark as decode
865
903
  new_batch = ScheduleBatch.init_new(
866
904
  can_run_list,
@@ -901,3 +939,6 @@ class SchedulerDisaggregationDecodeMixin:
901
939
  self.disagg_decode_transfer_queue.pop_transferred()
902
940
  ) # the requests which kv has arrived
903
941
  self.waiting_queue.extend(alloc_reqs)
942
+
943
+ if self.server_args.disaggregation_decode_enable_offload_kvcache:
944
+ self.decode_offload_manager.check_offload_progress()
@@ -0,0 +1,185 @@
1
+ import logging
2
+ import threading
3
+ import time
4
+
5
+ import torch
6
+
7
+ from sglang import ServerArgs
8
+ from sglang.srt.managers.cache_controller import HiCacheController
9
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
10
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
11
+ from sglang.srt.mem_cache.memory_pool import (
12
+ MHATokenToKVPool,
13
+ MLATokenToKVPool,
14
+ ReqToTokenPool,
15
+ )
16
+ from sglang.srt.mem_cache.memory_pool_host import (
17
+ MHATokenToKVPoolHost,
18
+ MLATokenToKVPoolHost,
19
+ )
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class DecodeKVCacheOffloadManager:
25
+ """Manage decode-side KV cache offloading lifecycle and operations."""
26
+
27
+ def __init__(
28
+ self,
29
+ req_to_token_pool: ReqToTokenPool,
30
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
31
+ tp_group: torch.distributed.ProcessGroup,
32
+ tree_cache: BasePrefixCache,
33
+ server_args: ServerArgs,
34
+ ) -> None:
35
+ self.req_to_token_pool = req_to_token_pool
36
+ self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
37
+ self.page_size = server_args.page_size
38
+ self.server_args = server_args
39
+ self.request_counter = 0
40
+ self.tree_cache = tree_cache
41
+ kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
42
+ if isinstance(kv_cache, MHATokenToKVPool):
43
+ self.decode_host_mem_pool = MHATokenToKVPoolHost(
44
+ kv_cache,
45
+ server_args.hicache_ratio,
46
+ server_args.hicache_size,
47
+ self.page_size,
48
+ server_args.hicache_mem_layout,
49
+ )
50
+ elif isinstance(kv_cache, MLATokenToKVPool):
51
+ self.decode_host_mem_pool = MLATokenToKVPoolHost(
52
+ kv_cache,
53
+ server_args.hicache_ratio,
54
+ server_args.hicache_size,
55
+ self.page_size,
56
+ server_args.hicache_mem_layout,
57
+ )
58
+ else:
59
+ raise ValueError("Unsupported KV cache type for decode offload")
60
+
61
+ self.tp_group = tp_group
62
+ self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
63
+ self.cache_controller = HiCacheController(
64
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
65
+ mem_pool_host=self.decode_host_mem_pool,
66
+ page_size=self.page_size,
67
+ tp_group=tp_group,
68
+ io_backend=server_args.hicache_io_backend,
69
+ load_cache_event=threading.Event(),
70
+ storage_backend=server_args.hicache_storage_backend,
71
+ model_name=server_args.served_model_name,
72
+ storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
73
+ )
74
+
75
+ self.ongoing_offload = {}
76
+ self.ongoing_backup = {}
77
+ logger.info("Enable offload kv cache for decode side")
78
+
79
+ def offload_kv_cache(self, req) -> bool:
80
+ """Offload a finished request's KV cache to storage."""
81
+
82
+ if self.cache_controller is None or self.decode_host_mem_pool is None:
83
+ return False
84
+
85
+ if req.req_pool_idx == -1:
86
+ return False
87
+
88
+ token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]
89
+ if token_indices.dim() == 0 or token_indices.numel() == 0:
90
+ logger.debug(
91
+ f"Request {req.rid} has invalid token_indices: {token_indices}"
92
+ )
93
+ return False
94
+
95
+ tokens = req.origin_input_ids + req.output_ids
96
+ aligned_len = (len(tokens) // self.page_size) * self.page_size
97
+ if aligned_len == 0:
98
+ return False
99
+
100
+ token_indices = token_indices[:aligned_len]
101
+ tokens = tokens[:aligned_len]
102
+
103
+ # Asynchronously offload KV cache from device to host by cache controller
104
+ self.request_counter += 1
105
+ ack_id = self.request_counter
106
+ host_indices = self.cache_controller.write(
107
+ device_indices=token_indices.long(),
108
+ node_id=ack_id,
109
+ )
110
+ if host_indices is None:
111
+ logger.error(f"Not enough host memory for request {req.rid}")
112
+ return False
113
+
114
+ self.ongoing_offload[ack_id] = (req, host_indices, tokens, time.time())
115
+ return True
116
+
117
+ def check_offload_progress(self):
118
+ """Check the progress of offload from device to host and backup from host to storage."""
119
+ cc = self.cache_controller
120
+
121
+ qsizes = torch.tensor(
122
+ [
123
+ len(cc.ack_write_queue),
124
+ cc.ack_backup_queue.qsize(),
125
+ ],
126
+ dtype=torch.int,
127
+ )
128
+ if self.tp_world_size > 1:
129
+ torch.distributed.all_reduce(
130
+ qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
131
+ )
132
+
133
+ n_write, n_backup = map(int, qsizes.tolist())
134
+ self._check_offload_progress(n_write)
135
+ self._check_backup_progress(n_backup)
136
+
137
+ def _check_offload_progress(self, finish_count):
138
+ """Check the progress of offload from device to host."""
139
+ while finish_count > 0:
140
+ _, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
141
+ finish_event.synchronize()
142
+ for ack_id in ack_list:
143
+ req, host_indices, tokens, start_time = self.ongoing_offload.pop(ack_id)
144
+
145
+ # Release device
146
+ self.tree_cache.cache_finished_req(req)
147
+
148
+ # Trigger async backup from host to storage by cache controller
149
+ self._trigger_backup(req.rid, host_indices, tokens, start_time)
150
+ finish_count -= 1
151
+
152
+ def _check_backup_progress(self, finish_count):
153
+ """Check the progress of backup from host to storage."""
154
+ for _ in range(finish_count):
155
+ storage_operation = self.cache_controller.ack_backup_queue.get()
156
+ ack_id = storage_operation.id
157
+ req_id, host_indices, start_time = self.ongoing_backup.pop(ack_id)
158
+
159
+ # Release host memory
160
+ self.decode_host_mem_pool.free(host_indices)
161
+
162
+ logger.debug(
163
+ f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds."
164
+ )
165
+
166
+ def _trigger_backup(self, req_id, host_indices, tokens, start_time):
167
+ """Trigger async backup from host to storage by cache controller."""
168
+
169
+ # Generate page hashes and write to storage
170
+ page_hashes = self._compute_prefix_hash(tokens)
171
+ ack_id = self.cache_controller.write_storage(
172
+ host_indices,
173
+ tokens,
174
+ hash_value=page_hashes,
175
+ )
176
+ self.ongoing_backup[ack_id] = (req_id, host_indices, start_time)
177
+
178
+ def _compute_prefix_hash(self, tokens):
179
+ last_hash = ""
180
+ page_hashes = []
181
+ for offset in range(0, len(tokens), self.page_size):
182
+ page_tokens = tokens[offset : offset + self.page_size]
183
+ last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)
184
+ page_hashes.append(last_hash)
185
+ return page_hashes
@@ -76,6 +76,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
76
76
  req_pool_indices, dtype=torch.int64, device=self.device
77
77
  )
78
78
  self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
79
+ self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
79
80
  self.orig_seq_lens = torch.tensor(
80
81
  seq_lens, dtype=torch.int32, device=self.device
81
82
  )
@@ -125,31 +126,39 @@ class ScheduleBatchDisaggregationDecodeMixin:
125
126
  req.grammar.finished = req.finished()
126
127
  self.output_ids = torch.tensor(self.output_ids, device=self.device)
127
128
 
128
- # Simulate the eagle run. We add mock data to hidden states for the
129
- # ease of implementation now meaning the first token will have acc rate
130
- # of 0.
131
- if not self.spec_algorithm.is_none():
129
+ # Simulate the eagle run.
130
+ if self.spec_algorithm.is_eagle():
132
131
 
133
132
  b = len(self.reqs)
134
- topk_p = torch.arange(
135
- b * server_args.speculative_eagle_topk,
136
- 0,
137
- -1,
138
- device=self.device,
139
- dtype=torch.float32,
133
+ topk = server_args.speculative_eagle_topk
134
+ topk_p = torch.stack(
135
+ [
136
+ torch.as_tensor(
137
+ req.output_topk_p[:topk],
138
+ device=self.device,
139
+ dtype=torch.float32,
140
+ )
141
+ for req in self.reqs
142
+ ],
143
+ dim=0,
140
144
  )
141
- topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
142
- topk_p /= b * server_args.speculative_eagle_topk
143
- topk_index = torch.arange(
144
- b * server_args.speculative_eagle_topk, device=self.device
145
+ topk_index = torch.stack(
146
+ [
147
+ torch.as_tensor(
148
+ req.output_topk_index[:topk],
149
+ device=self.device,
150
+ dtype=torch.int64,
151
+ )
152
+ for req in self.reqs
153
+ ],
154
+ dim=0,
145
155
  )
146
- topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
147
156
 
148
157
  hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
149
158
  hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
150
159
 
151
160
  # local import to avoid circular import
152
- from sglang.srt.speculative.eagle_utils import EagleDraftInput
161
+ from sglang.srt.speculative.eagle_info import EagleDraftInput
153
162
 
154
163
  spec_info = EagleDraftInput(
155
164
  topk_p=topk_p,