sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -12,10 +12,11 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
  """A tensor parallel worker."""
15
+ from __future__ import annotations
15
16
 
16
17
  import logging
17
18
  import threading
18
- from typing import Optional, Tuple, Union
19
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
19
20
 
20
21
  import torch
21
22
 
@@ -29,8 +30,10 @@ from sglang.srt.hf_transformers_utils import (
29
30
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
30
31
  from sglang.srt.managers.io_struct import (
31
32
  GetWeightsByNameReqInput,
33
+ InitWeightsSendGroupForRemoteInstanceReqInput,
32
34
  InitWeightsUpdateGroupReqInput,
33
35
  LoadLoRAAdapterReqInput,
36
+ SendWeightsToRemoteInstanceReqInput,
34
37
  UnloadLoRAAdapterReqInput,
35
38
  UpdateWeightFromDiskReqInput,
36
39
  UpdateWeightsFromDistributedReqInput,
@@ -45,6 +48,9 @@ from sglang.srt.patch_torch import monkey_patch_torch_reductions
45
48
  from sglang.srt.server_args import ServerArgs
46
49
  from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
47
50
 
51
+ if TYPE_CHECKING:
52
+ from sglang.srt.managers.cache_controller import LayerDoneCounter
53
+
48
54
  logger = logging.getLogger(__name__)
49
55
 
50
56
 
@@ -78,7 +84,13 @@ class TpModelWorker:
78
84
  if not is_draft_worker
79
85
  else server_args.speculative_draft_model_path
80
86
  ),
87
+ model_revision=(
88
+ server_args.revision
89
+ if not is_draft_worker
90
+ else server_args.speculative_draft_model_revision
91
+ ),
81
92
  is_draft_model=is_draft_worker,
93
+ tp_rank=tp_rank,
82
94
  )
83
95
 
84
96
  self.model_runner = ModelRunner(
@@ -137,7 +149,7 @@ class TpModelWorker:
137
149
  assert self.max_running_requests > 0, "max_running_request is zero"
138
150
  self.max_queued_requests = server_args.max_queued_requests
139
151
  assert (
140
- self.max_running_requests > 0
152
+ self.max_queued_requests > 0
141
153
  ), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
142
154
  self.max_req_len = min(
143
155
  self.model_config.context_len - 1,
@@ -162,10 +174,10 @@ class TpModelWorker:
162
174
 
163
175
  self.hicache_layer_transfer_counter = None
164
176
 
165
- def register_hicache_layer_transfer_counter(self, counter):
177
+ def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
166
178
  self.hicache_layer_transfer_counter = counter
167
179
 
168
- def set_hicache_consumer(self, consumer_index):
180
+ def set_hicache_consumer(self, consumer_index: int):
169
181
  if self.hicache_layer_transfer_counter is not None:
170
182
  self.hicache_layer_transfer_counter.set_consumer(consumer_index)
171
183
 
@@ -225,6 +237,9 @@ class TpModelWorker:
225
237
  ) -> Tuple[
226
238
  Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
227
239
  ]:
240
+ # update the consumer index of hicache to the running batch
241
+ self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
242
+
228
243
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
229
244
 
230
245
  pp_proxy_tensors = None
@@ -244,6 +259,15 @@ class TpModelWorker:
244
259
 
245
260
  if skip_sample:
246
261
  next_token_ids = None
262
+ # For prefill-only requests, we still need to compute logprobs even when sampling is skipped
263
+ if (
264
+ model_worker_batch.is_prefill_only
265
+ and model_worker_batch.return_logprob
266
+ ):
267
+ # Compute logprobs without full sampling
268
+ self.model_runner.compute_logprobs_only(
269
+ logits_output, model_worker_batch
270
+ )
247
271
  else:
248
272
  next_token_ids = self.model_runner.sample(
249
273
  logits_output, model_worker_batch
@@ -280,6 +304,31 @@ class TpModelWorker:
280
304
  )
281
305
  return success, message
282
306
 
307
+ def init_weights_send_group_for_remote_instance(
308
+ self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
309
+ ):
310
+ success, message = (
311
+ self.model_runner.init_weights_send_group_for_remote_instance(
312
+ recv_req.master_address,
313
+ recv_req.ports,
314
+ recv_req.group_rank,
315
+ recv_req.world_size,
316
+ recv_req.group_name,
317
+ recv_req.backend,
318
+ )
319
+ )
320
+ return success, message
321
+
322
+ def send_weights_to_remote_instance(
323
+ self, recv_req: SendWeightsToRemoteInstanceReqInput
324
+ ):
325
+ success, message = self.model_runner.send_weights_to_remote_instance(
326
+ recv_req.master_address,
327
+ recv_req.ports,
328
+ recv_req.group_name,
329
+ )
330
+ return success, message
331
+
283
332
  def update_weights_from_distributed(
284
333
  self, recv_req: UpdateWeightsFromDistributedReqInput
285
334
  ):
@@ -12,21 +12,24 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
  """A tensor parallel worker."""
15
+ from __future__ import annotations
15
16
 
16
17
  import dataclasses
17
18
  import logging
18
19
  import signal
19
20
  import threading
20
21
  from queue import Queue
21
- from typing import Optional, Tuple
22
+ from typing import TYPE_CHECKING, List, Optional, Tuple
22
23
 
23
24
  import psutil
24
25
  import torch
25
26
 
26
27
  from sglang.srt.managers.io_struct import (
27
28
  GetWeightsByNameReqInput,
29
+ InitWeightsSendGroupForRemoteInstanceReqInput,
28
30
  InitWeightsUpdateGroupReqInput,
29
31
  LoadLoRAAdapterReqInput,
32
+ SendWeightsToRemoteInstanceReqInput,
30
33
  UnloadLoRAAdapterReqInput,
31
34
  UpdateWeightFromDiskReqInput,
32
35
  UpdateWeightsFromDistributedReqInput,
@@ -38,6 +41,9 @@ from sglang.srt.server_args import ServerArgs
38
41
  from sglang.srt.utils import DynamicGradMode, get_compiler_backend
39
42
  from sglang.utils import get_exception_traceback
40
43
 
44
+ if TYPE_CHECKING:
45
+ from sglang.srt.managers.cache_controller import LayerDoneCounter
46
+
41
47
  logger = logging.getLogger(__name__)
42
48
 
43
49
 
@@ -79,7 +85,7 @@ class TpModelWorkerClient:
79
85
  )
80
86
 
81
87
  # Launch threads
82
- self.input_queue = Queue()
88
+ self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
83
89
  self.output_queue = Queue()
84
90
  self.forward_stream = torch.get_device_module(self.device).Stream()
85
91
  self.forward_thread = threading.Thread(
@@ -93,13 +99,9 @@ class TpModelWorkerClient:
93
99
 
94
100
  self.hicache_layer_transfer_counter = None
95
101
 
96
- def register_hicache_layer_transfer_counter(self, counter):
102
+ def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
97
103
  self.hicache_layer_transfer_counter = counter
98
104
 
99
- def set_hicache_consumer(self, consumer_index):
100
- if self.hicache_layer_transfer_counter is not None:
101
- self.hicache_layer_transfer_counter.set_consumer(consumer_index)
102
-
103
105
  def get_worker_info(self):
104
106
  return self.worker.get_worker_info()
105
107
 
@@ -147,7 +149,7 @@ class TpModelWorkerClient:
147
149
  @DynamicGradMode()
148
150
  def forward_thread_func_(self):
149
151
  batch_pt = 0
150
- batch_lists = [None] * 2
152
+ batch_lists: List = [None] * 2
151
153
 
152
154
  while True:
153
155
  model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
@@ -169,26 +171,31 @@ class TpModelWorkerClient:
169
171
  input_ids = model_worker_batch.input_ids
170
172
  resolve_future_token_ids(input_ids, self.future_token_ids_map)
171
173
 
172
- # update the consumer index of hicache to the running batch
173
- self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
174
174
  # Run forward
175
175
  logits_output, next_token_ids, can_run_cuda_graph = (
176
176
  self.worker.forward_batch_generation(
177
- model_worker_batch, model_worker_batch.launch_done
177
+ model_worker_batch,
178
+ model_worker_batch.launch_done,
179
+ # Skip sampling for prefill-only requests
180
+ skip_sample=model_worker_batch.is_prefill_only,
178
181
  )
179
182
  )
180
183
 
181
184
  # Update the future token ids map
182
185
  bs = len(model_worker_batch.seq_lens)
186
+ if model_worker_batch.is_prefill_only:
187
+ # For prefill-only requests, create dummy token IDs on CPU
188
+ next_token_ids = torch.zeros(bs, dtype=torch.long)
183
189
  self.future_token_ids_map[
184
190
  future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
185
191
  ] = next_token_ids
186
192
 
187
193
  # Copy results to the CPU
188
194
  if model_worker_batch.return_logprob:
189
- logits_output.next_token_logprobs = (
190
- logits_output.next_token_logprobs.to("cpu", non_blocking=True)
191
- )
195
+ if logits_output.next_token_logprobs is not None:
196
+ logits_output.next_token_logprobs = (
197
+ logits_output.next_token_logprobs.to("cpu", non_blocking=True)
198
+ )
192
199
  if logits_output.input_token_logprobs is not None:
193
200
  logits_output.input_token_logprobs = (
194
201
  logits_output.input_token_logprobs.to("cpu", non_blocking=True)
@@ -197,7 +204,9 @@ class TpModelWorkerClient:
197
204
  logits_output.hidden_states = logits_output.hidden_states.to(
198
205
  "cpu", non_blocking=True
199
206
  )
200
- next_token_ids = next_token_ids.to("cpu", non_blocking=True)
207
+ # Only copy to CPU if not already on CPU
208
+ if next_token_ids.device.type != "cpu":
209
+ next_token_ids = next_token_ids.to("cpu", non_blocking=True)
201
210
  copy_done.record()
202
211
 
203
212
  self.output_queue.put(
@@ -221,10 +230,10 @@ class TpModelWorkerClient:
221
230
  logits_output.next_token_logprobs = (
222
231
  logits_output.next_token_logprobs.tolist()
223
232
  )
224
- if logits_output.input_token_logprobs is not None:
225
- logits_output.input_token_logprobs = tuple(
226
- logits_output.input_token_logprobs.tolist()
227
- )
233
+ if logits_output.input_token_logprobs is not None:
234
+ logits_output.input_token_logprobs = tuple(
235
+ logits_output.input_token_logprobs.tolist()
236
+ )
228
237
  next_token_ids = next_token_ids.tolist()
229
238
  return logits_output, next_token_ids, can_run_cuda_graph
230
239
 
@@ -269,6 +278,20 @@ class TpModelWorkerClient:
269
278
  success, message = self.worker.init_weights_update_group(recv_req)
270
279
  return success, message
271
280
 
281
+ def init_weights_send_group_for_remote_instance(
282
+ self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
283
+ ):
284
+ success, message = self.worker.init_weights_send_group_for_remote_instance(
285
+ recv_req
286
+ )
287
+ return success, message
288
+
289
+ def send_weights_to_remote_instance(
290
+ self, recv_req: SendWeightsToRemoteInstanceReqInput
291
+ ):
292
+ success, message = self.worker.send_weights_to_remote_instance(recv_req)
293
+ return success, message
294
+
272
295
  def update_weights_from_distributed(
273
296
  self, recv_req: UpdateWeightsFromDistributedReqInput
274
297
  ):
@@ -103,20 +103,6 @@ class HiCacheStorage(ABC):
103
103
  """
104
104
  pass
105
105
 
106
- @abstractmethod
107
- def delete(self, key: str) -> bool:
108
- """
109
- Delete the entry associated with the given key.
110
- """
111
- pass
112
-
113
- @abstractmethod
114
- def clear(self) -> bool:
115
- """
116
- Clear all entries in the storage.
117
- """
118
- pass
119
-
120
106
  def batch_exists(self, keys: List[str]) -> int:
121
107
  """
122
108
  Check if the keys exist in the storage.
@@ -128,6 +114,9 @@ class HiCacheStorage(ABC):
128
114
  return i
129
115
  return len(keys)
130
116
 
117
+ def get_stats(self):
118
+ return None
119
+
131
120
 
132
121
  class HiCacheFile(HiCacheStorage):
133
122
 
@@ -224,15 +213,6 @@ class HiCacheFile(HiCacheStorage):
224
213
  tensor_path = os.path.join(self.file_path, f"{key}.bin")
225
214
  return os.path.exists(tensor_path)
226
215
 
227
- def delete(self, key: str) -> None:
228
- key = self._get_suffixed_key(key)
229
- tensor_path = os.path.join(self.file_path, f"{key}.bin")
230
- try:
231
- os.remove(tensor_path)
232
- except FileNotFoundError:
233
- logger.warning(f"Key {key} does not exist. Cannot delete.")
234
- return
235
-
236
216
  def clear(self) -> bool:
237
217
  try:
238
218
  for filename in os.listdir(self.file_path):
@@ -20,6 +20,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
20
20
  MLATokenToKVPoolHost,
21
21
  )
22
22
  from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
23
+ from sglang.srt.metrics.collector import StorageMetricsCollector
23
24
 
24
25
  logger = logging.getLogger(__name__)
25
26
 
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
37
38
  hicache_write_policy: str,
38
39
  hicache_io_backend: str,
39
40
  hicache_mem_layout: str,
41
+ enable_metrics: bool,
40
42
  hicache_storage_backend: Optional[str] = None,
41
43
  hicache_storage_prefetch_policy: Optional[str] = "best_effort",
42
44
  model_name: Optional[str] = None,
@@ -73,6 +75,8 @@ class HiRadixCache(RadixCache):
73
75
  self.tp_group = tp_cache_group
74
76
  self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
75
77
  self.enable_storage = hicache_storage_backend is not None
78
+ self.enable_storage_metrics = self.enable_storage and enable_metrics
79
+
76
80
  # todo: customizable storage prefetch threshold and timeout
77
81
  self.prefetch_threshold = 256
78
82
  self.prefetch_timeout = 3 # seconds
@@ -92,6 +96,14 @@ class HiRadixCache(RadixCache):
92
96
  model_name=model_name,
93
97
  storage_backend_extra_config=storage_backend_extra_config,
94
98
  )
99
+ if self.enable_storage_metrics:
100
+ # TODO: support pp
101
+ labels = {
102
+ "storage_backend": hicache_storage_backend,
103
+ "tp_rank": self.cache_controller.tp_rank,
104
+ "dp_rank": self.cache_controller.dp_rank,
105
+ }
106
+ self.metrics_collector = StorageMetricsCollector(labels=labels)
95
107
 
96
108
  # record the nodes with ongoing write through
97
109
  self.ongoing_write_through = {}
@@ -122,11 +134,24 @@ class HiRadixCache(RadixCache):
122
134
  height += 1
123
135
  return height
124
136
 
125
- def clear_storage_backend(self):
137
+ def clear_storage_backend(self) -> bool:
126
138
  if self.enable_storage:
127
- self.cache_controller.storage_backend.clear()
128
- logger.info("Hierarchical cache storage backend cleared successfully!")
129
- return True
139
+ try:
140
+ # Check if the storage backend has a clear method (for nixl backends)
141
+ if hasattr(self.cache_controller.storage_backend, "clear"):
142
+ self.cache_controller.storage_backend.clear()
143
+ logger.info(
144
+ "Hierarchical cache storage backend cleared successfully!"
145
+ )
146
+ return True
147
+ else:
148
+ logger.warning(
149
+ f"Storage backend {type(self.cache_controller.storage_backend).__name__} does not support clear operation."
150
+ )
151
+ return False
152
+ except Exception as e:
153
+ logger.error(f"Failed to clear hierarchical cache storage backend: {e}")
154
+ return False
130
155
  else:
131
156
  logger.warning("Hierarchical cache storage backend is not enabled.")
132
157
  return False
@@ -176,41 +201,57 @@ class HiRadixCache(RadixCache):
176
201
  if write_back:
177
202
  # blocking till all write back complete
178
203
  while len(self.ongoing_write_through) > 0:
179
- ack_id = self.cache_controller.ack_write_queue.get()
180
- del self.ongoing_write_through[ack_id]
204
+ for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
205
+ finish_event.synchronize()
206
+ for ack_id in ack_list:
207
+ del self.ongoing_write_through[ack_id]
208
+ self.cache_controller.ack_write_queue.clear()
209
+ assert len(self.ongoing_write_through) == 0
181
210
  return
182
- queue_size = torch.tensor(
183
- self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
184
- )
211
+
212
+ # NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
213
+ if len(self.ongoing_write_through) == 0:
214
+ return
215
+
216
+ finish_count = 0
217
+ for _, finish_event, ack_list in self.cache_controller.ack_write_queue:
218
+ if not finish_event.query():
219
+ break
220
+ finish_count += 1
221
+ queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu")
185
222
  if self.tp_world_size > 1:
186
- # synchrnoize TP workers to make the same update to radix cache
223
+ # synchronize TP workers to make the same update to radix cache
187
224
  torch.distributed.all_reduce(
188
225
  queue_size,
189
226
  op=torch.distributed.ReduceOp.MIN,
190
227
  group=self.tp_group,
191
228
  )
192
- for _ in range(queue_size.item()):
193
- ack_id = self.cache_controller.ack_write_queue.get()
194
- backuped_node = self.ongoing_write_through[ack_id]
195
- self.dec_lock_ref(backuped_node)
196
- del self.ongoing_write_through[ack_id]
197
- if self.enable_storage:
198
- self.write_backup_storage(backuped_node)
229
+
230
+ finish_count = int(queue_size.item())
231
+ while finish_count > 0:
232
+ _, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
233
+ finish_event.synchronize()
234
+ for ack_id in ack_list:
235
+ backuped_node = self.ongoing_write_through.pop(ack_id)
236
+ self.dec_lock_ref(backuped_node)
237
+ if self.enable_storage:
238
+ self.write_backup_storage(backuped_node)
239
+ finish_count -= 1
199
240
 
200
241
  def loading_check(self):
201
- while not self.cache_controller.ack_load_queue.empty():
202
- try:
203
- ack_id = self.cache_controller.ack_load_queue.get_nowait()
204
- start_node, end_node = self.ongoing_load_back[ack_id]
205
- self.dec_lock_ref(end_node)
206
- while end_node != start_node:
207
- assert end_node.loading
208
- end_node.loading = False
209
- end_node = end_node.parent
210
- # clear the reference
211
- del self.ongoing_load_back[ack_id]
212
- except Exception:
242
+ finish_count = 0
243
+ for _, finish_event, ack_list in self.cache_controller.ack_load_queue:
244
+ if not finish_event.query():
245
+ # the KV cache loading is still ongoing
213
246
  break
247
+ finish_count += 1
248
+ # no need to sync across TP workers as batch forwarding is synced
249
+ for ack_id in ack_list:
250
+ end_node = self.ongoing_load_back.pop(ack_id)
251
+ self.dec_lock_ref(end_node)
252
+
253
+ # ACK until all events are processed
254
+ del self.cache_controller.ack_load_queue[:finish_count]
214
255
 
215
256
  def evictable_size(self):
216
257
  return self.evictable_size_
@@ -335,12 +376,11 @@ class HiRadixCache(RadixCache):
335
376
  # no sufficient GPU memory to load back KV caches
336
377
  return None
337
378
 
338
- self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node)
379
+ self.ongoing_load_back[last_hit_node.id] = last_hit_node
339
380
  offset = 0
340
381
  for node in nodes_to_load:
341
382
  node.value = device_indices[offset : offset + len(node.host_value)]
342
383
  offset += len(node.host_value)
343
- node.loading = True
344
384
  self.evictable_size_ += len(device_indices)
345
385
  self.inc_lock_ref(last_hit_node)
346
386
 
@@ -369,16 +409,22 @@ class HiRadixCache(RadixCache):
369
409
  last_node,
370
410
  )
371
411
 
372
- def ready_to_load_host_cache(self):
373
- producer_index = self.cache_controller.layer_done_counter.next_producer()
374
- self.load_cache_event.set()
375
- return producer_index
412
+ def ready_to_load_host_cache(self) -> int:
413
+ """
414
+ Notify the cache controller to start the KV cache loading.
415
+ Return the consumer index for the schedule batch manager to track.
416
+ """
417
+ return self.cache_controller.start_loading()
376
418
 
377
419
  def check_hicache_events(self):
378
420
  self.writing_check()
379
421
  self.loading_check()
380
422
  if self.enable_storage:
381
423
  self.drain_storage_control_queues()
424
+ if self.enable_storage_metrics:
425
+ self.metrics_collector.log_storage_metrics(
426
+ self.cache_controller.storage_backend.get_stats()
427
+ )
382
428
 
383
429
  def drain_storage_control_queues(self):
384
430
  """
@@ -414,10 +460,13 @@ class HiRadixCache(RadixCache):
414
460
 
415
461
  # process backup acks
416
462
  for _ in range(n_backup):
417
- ack_id = cc.ack_backup_queue.get()
463
+ operation = cc.ack_backup_queue.get()
464
+ ack_id = operation.id
418
465
  entry = self.ongoing_backup.pop(ack_id, None)
419
466
  if entry is not None:
420
467
  entry.release_host()
468
+ if self.enable_storage_metrics:
469
+ self.metrics_collector.log_backuped_tokens(operation.completed_tokens)
421
470
 
422
471
  # release host memory
423
472
  host_indices_list = []
@@ -450,15 +499,22 @@ class HiRadixCache(RadixCache):
450
499
  # unknown prefetch stop policy, just return True
451
500
  return True
452
501
 
502
+ operation_terminated = operation.is_terminated()
453
503
  if self.tp_world_size > 1:
454
- can_terminate = torch.tensor(can_terminate, dtype=torch.int)
504
+ states = torch.tensor(
505
+ [1 - int(can_terminate), int(operation_terminated)],
506
+ dtype=torch.int,
507
+ )
455
508
  torch.distributed.all_reduce(
456
- can_terminate,
457
- op=torch.distributed.ReduceOp.MIN,
509
+ states,
510
+ op=torch.distributed.ReduceOp.MAX,
458
511
  group=self.tp_group,
459
512
  )
460
- can_terminate = bool(can_terminate.item())
461
-
513
+ can_terminate = states[0].item() == 0
514
+ operation_terminated = states[1].item() == 1
515
+ # the operation should be terminated if it is already terminated on any TP worker
516
+ # or it meets the termination condition on all TP workers
517
+ can_terminate = can_terminate or operation_terminated
462
518
  return can_terminate
463
519
 
464
520
  def check_prefetch_progress(self, req_id: str) -> bool:
@@ -485,7 +541,7 @@ class HiRadixCache(RadixCache):
485
541
  logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
486
542
 
487
543
  min_completed_tokens = completed_tokens
488
- if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
544
+ if self.tp_world_size > 1:
489
545
  # synchrnoize TP workers to make the same update to hiradix cache
490
546
  completed_tokens_tensor = torch.tensor(
491
547
  min_completed_tokens, dtype=torch.int
@@ -515,6 +571,11 @@ class HiRadixCache(RadixCache):
515
571
  del self.ongoing_prefetch[req_id]
516
572
  self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
517
573
 
574
+ if self.enable_storage_metrics:
575
+ self.metrics_collector.log_prefetched_tokens(
576
+ min_completed_tokens - matched_length
577
+ )
578
+
518
579
  return True
519
580
 
520
581
  def match_prefix(self, key: List[int], **kwargs):
@@ -658,7 +719,6 @@ class HiRadixCache(RadixCache):
658
719
  new_node.parent = child.parent
659
720
  new_node.lock_ref = child.lock_ref
660
721
  new_node.key = child.key[:split_len]
661
- new_node.loading = child.loading
662
722
  new_node.hit_count = child.hit_count
663
723
 
664
724
  # split value and host value if exists