sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -3,22 +3,26 @@ import logging
3
3
  import threading
4
4
  from enum import IntEnum
5
5
  from functools import wraps
6
+ from typing import Optional
6
7
 
7
8
  import psutil
8
9
  import torch
9
10
 
10
11
  from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
11
- from sglang.srt.utils import is_npu
12
+ from sglang.srt.utils import is_npu, is_xpu
12
13
 
13
14
  _is_npu = is_npu()
14
- if not _is_npu:
15
+ _is_xpu = is_xpu()
16
+ if not (_is_npu or _is_xpu):
15
17
  from sgl_kernel.kvcacheio import (
16
18
  transfer_kv_all_layer,
19
+ transfer_kv_all_layer_direct_lf_pf,
17
20
  transfer_kv_all_layer_lf_pf,
18
21
  transfer_kv_all_layer_mla,
19
22
  transfer_kv_all_layer_mla_lf_pf,
20
23
  transfer_kv_direct,
21
24
  transfer_kv_per_layer,
25
+ transfer_kv_per_layer_direct_pf_lf,
22
26
  transfer_kv_per_layer_mla,
23
27
  transfer_kv_per_layer_mla_pf_lf,
24
28
  transfer_kv_per_layer_pf_lf,
@@ -76,6 +80,7 @@ class HostKVCache(abc.ABC):
76
80
  self.size = int(device_pool.size * host_to_device_ratio)
77
81
  # Align the host memory pool size to the page size
78
82
  self.size = self.size - (self.size % self.page_size)
83
+ self.page_num = self.size // self.page_size
79
84
  self.start_layer = device_pool.start_layer
80
85
  self.end_layer = device_pool.end_layer
81
86
 
@@ -168,7 +173,7 @@ class HostKVCache(abc.ABC):
168
173
  return len(self.free_slots)
169
174
 
170
175
  @synchronized()
171
- def alloc(self, need_size: int) -> torch.Tensor:
176
+ def alloc(self, need_size: int) -> Optional[torch.Tensor]:
172
177
  assert (
173
178
  need_size % self.page_size == 0
174
179
  ), "The requested size should be a multiple of the page size."
@@ -315,6 +320,15 @@ class MHATokenToKVPoolHost(HostKVCache):
315
320
  dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
316
321
  elif self.layout == "page_first":
317
322
  dims = (2, self.size, self.layer_num, self.head_num, self.head_dim)
323
+ elif self.layout == "page_first_direct":
324
+ dims = (
325
+ 2,
326
+ self.page_num,
327
+ self.layer_num,
328
+ self.page_size,
329
+ self.head_num,
330
+ self.head_dim,
331
+ )
318
332
  else:
319
333
  raise ValueError(f"Unsupported layout: {self.layout}")
320
334
  self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
@@ -368,19 +382,31 @@ class MHATokenToKVPoolHost(HostKVCache):
368
382
  else:
369
383
  raise ValueError(f"Unsupported layout: {self.layout}")
370
384
  elif io_backend == "direct":
371
- assert (
372
- self.layout == "layer_first"
373
- ), f"Direct IO backend only supports layer_first layout."
374
- transfer_kv_direct(
375
- src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
376
- dst_layers=[
377
- device_pool.k_buffer[layer_id],
378
- device_pool.v_buffer[layer_id],
379
- ],
380
- src_indices=host_indices,
381
- dst_indices=device_indices,
382
- page_size=self.page_size,
383
- )
385
+ if self.layout == "layer_first":
386
+ transfer_kv_direct(
387
+ src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
388
+ dst_layers=[
389
+ device_pool.k_buffer[layer_id],
390
+ device_pool.v_buffer[layer_id],
391
+ ],
392
+ src_indices=host_indices,
393
+ dst_indices=device_indices,
394
+ page_size=self.page_size,
395
+ )
396
+ elif self.layout == "page_first_direct":
397
+ transfer_kv_per_layer_direct_pf_lf(
398
+ src_ptrs=[self.k_buffer, self.v_buffer],
399
+ dst_ptrs=[
400
+ device_pool.k_buffer[layer_id],
401
+ device_pool.v_buffer[layer_id],
402
+ ],
403
+ src_indices=host_indices,
404
+ dst_indices=device_indices,
405
+ layer_id=layer_id,
406
+ page_size=self.page_size,
407
+ )
408
+ else:
409
+ raise ValueError(f"Unsupported layout: {self.layout}")
384
410
  else:
385
411
  raise ValueError(f"Unsupported IO backend: {io_backend}")
386
412
 
@@ -414,16 +440,24 @@ class MHATokenToKVPoolHost(HostKVCache):
414
440
  else:
415
441
  raise ValueError(f"Unsupported layout: {self.layout}")
416
442
  elif io_backend == "direct":
417
- assert (
418
- self.layout == "layer_first"
419
- ), f"Direct IO backend only supports layer_first layout."
420
- transfer_kv_direct(
421
- src_layers=device_pool.k_buffer + device_pool.v_buffer,
422
- dst_layers=self.k_data_refs + self.v_data_refs,
423
- src_indices=device_indices,
424
- dst_indices=host_indices,
425
- page_size=self.page_size,
426
- )
443
+ if self.layout == "layer_first":
444
+ transfer_kv_direct(
445
+ src_layers=device_pool.k_buffer + device_pool.v_buffer,
446
+ dst_layers=self.k_data_refs + self.v_data_refs,
447
+ src_indices=device_indices,
448
+ dst_indices=host_indices,
449
+ page_size=self.page_size,
450
+ )
451
+ elif self.layout == "page_first_direct":
452
+ transfer_kv_all_layer_direct_lf_pf(
453
+ src_ptrs=device_pool.k_buffer + device_pool.v_buffer,
454
+ dst_ptrs=[self.k_buffer, self.v_buffer],
455
+ src_indices=device_indices,
456
+ dst_indices=host_indices,
457
+ page_size=self.page_size,
458
+ )
459
+ else:
460
+ raise ValueError(f"Unsupported layout: {self.layout}")
427
461
  else:
428
462
  raise ValueError(f"Unsupported IO backend: {io_backend}")
429
463
 
@@ -578,6 +612,14 @@ class MLATokenToKVPoolHost(HostKVCache):
578
612
  1,
579
613
  self.kv_lora_rank + self.qk_rope_head_dim,
580
614
  )
615
+ elif self.layout == "page_first_direct":
616
+ dims = (
617
+ self.page_num,
618
+ self.layer_num,
619
+ self.page_size,
620
+ 1,
621
+ self.kv_lora_rank + self.qk_rope_head_dim,
622
+ )
581
623
  else:
582
624
  raise ValueError(f"Unsupported layout: {self.layout}")
583
625
  self.token_stride_size = (
@@ -617,16 +659,25 @@ class MLATokenToKVPoolHost(HostKVCache):
617
659
  else:
618
660
  raise ValueError(f"Unsupported layout: {self.layout}")
619
661
  elif io_backend == "direct":
620
- assert (
621
- self.layout == "layer_first"
622
- ), f"Direct IO backend only supports layer_first layout."
623
- transfer_kv_direct(
624
- src_layers=[self.kv_buffer[layer_id]],
625
- dst_layers=[device_pool.kv_buffer[layer_id]],
626
- src_indices=host_indices,
627
- dst_indices=device_indices,
628
- page_size=self.page_size,
629
- )
662
+ if self.layout == "layer_first":
663
+ transfer_kv_direct(
664
+ src_layers=[self.kv_buffer[layer_id]],
665
+ dst_layers=[device_pool.kv_buffer[layer_id]],
666
+ src_indices=host_indices,
667
+ dst_indices=device_indices,
668
+ page_size=self.page_size,
669
+ )
670
+ elif self.layout == "page_first_direct":
671
+ transfer_kv_per_layer_direct_pf_lf(
672
+ src_ptrs=[self.kv_buffer],
673
+ dst_ptrs=[device_pool.kv_buffer[layer_id]],
674
+ src_indices=host_indices,
675
+ dst_indices=device_indices,
676
+ layer_id=layer_id,
677
+ page_size=self.page_size,
678
+ )
679
+ else:
680
+ raise ValueError(f"Unsupported layout: {self.layout}")
630
681
 
631
682
  def backup_from_device_all_layer(
632
683
  self, device_pool, host_indices, device_indices, io_backend
@@ -654,16 +705,24 @@ class MLATokenToKVPoolHost(HostKVCache):
654
705
  else:
655
706
  raise ValueError(f"Unsupported layout: {self.layout}")
656
707
  elif io_backend == "direct":
657
- assert (
658
- self.layout == "layer_first"
659
- ), f"Direct IO backend only supports layer_first layout."
660
- transfer_kv_direct(
661
- src_layers=device_pool.kv_buffer,
662
- dst_layers=self.data_refs,
663
- src_indices=device_indices,
664
- dst_indices=host_indices,
665
- page_size=self.page_size,
666
- )
708
+ if self.layout == "layer_first":
709
+ transfer_kv_direct(
710
+ src_layers=device_pool.kv_buffer,
711
+ dst_layers=self.data_refs,
712
+ src_indices=device_indices,
713
+ dst_indices=host_indices,
714
+ page_size=self.page_size,
715
+ )
716
+ elif self.layout == "page_first_direct":
717
+ transfer_kv_all_layer_direct_lf_pf(
718
+ src_ptrs=device_pool.kv_buffer,
719
+ dst_ptrs=[self.kv_buffer],
720
+ src_indices=device_indices,
721
+ dst_indices=host_indices,
722
+ page_size=self.page_size,
723
+ )
724
+ else:
725
+ raise ValueError(f"Unsupported layout: {self.layout}")
667
726
  else:
668
727
  raise ValueError(f"Unsupported IO backend: {io_backend}")
669
728
 
@@ -53,8 +53,6 @@ class TreeNode:
53
53
  self.last_access_time = time.monotonic()
54
54
 
55
55
  self.hit_count = 0
56
- # indicating the node is loading KV cache from host
57
- self.loading = False
58
56
  # indicating the node is locked to protect from eviction
59
57
  # incremented when the node is referenced by a storage operation
60
58
  self.host_ref_counter = 0
@@ -0,0 +1,164 @@
1
+ import logging
2
+ import os
3
+ import threading
4
+ from abc import ABC, abstractmethod
5
+ from typing import List
6
+
7
+ import torch
8
+
9
+
10
+ class Hf3fsClient(ABC):
11
+ """Abstract interface for HF3FS clients."""
12
+
13
+ @abstractmethod
14
+ def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
15
+ """Initialize the HF3FS client.
16
+
17
+ Args:
18
+ path: File path for storage
19
+ size: Total size of storage file
20
+ bytes_per_page: Bytes per page
21
+ entries: Number of entries for batch operations
22
+ """
23
+ pass
24
+
25
+ @abstractmethod
26
+ def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
27
+ """Batch read from storage."""
28
+ pass
29
+
30
+ @abstractmethod
31
+ def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
32
+ """Batch write to storage."""
33
+ pass
34
+
35
+ @abstractmethod
36
+ def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
37
+ """Validate batch operation parameters."""
38
+ pass
39
+
40
+ @abstractmethod
41
+ def get_size(self) -> int:
42
+ """Get total storage size."""
43
+ pass
44
+
45
+ @abstractmethod
46
+ def close(self) -> None:
47
+ """Close the client and cleanup resources."""
48
+ pass
49
+
50
+ @abstractmethod
51
+ def flush(self) -> None:
52
+ """Flush data to disk."""
53
+ pass
54
+
55
+
56
+ logger = logging.getLogger(__name__)
57
+
58
+
59
+ class Hf3fsMockClient(Hf3fsClient):
60
+ """Mock implementation of Hf3fsClient for CI testing purposes."""
61
+
62
+ def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
63
+ """Initialize mock HF3FS client."""
64
+ self.path = path
65
+ self.size = size
66
+ self.bytes_per_page = bytes_per_page
67
+ self.entries = entries
68
+
69
+ # Create directory if it doesn't exist
70
+ os.makedirs(os.path.dirname(self.path), exist_ok=True)
71
+
72
+ # Create and initialize the file
73
+ self.file = os.open(self.path, os.O_RDWR | os.O_CREAT)
74
+ os.ftruncate(self.file, size)
75
+
76
+ logger.info(
77
+ f"Hf3fsMockClient initialized: path={path}, size={size}, "
78
+ f"bytes_per_page={bytes_per_page}, entries={entries}"
79
+ )
80
+
81
+ def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
82
+ """Batch read from mock storage."""
83
+ self.check(offsets, tensors)
84
+
85
+ results = []
86
+
87
+ for offset, tensor in zip(offsets, tensors):
88
+ size = tensor.numel() * tensor.itemsize
89
+
90
+ try:
91
+ os.lseek(self.file, offset, os.SEEK_SET)
92
+ bytes_read = os.read(self.file, size)
93
+
94
+ if len(bytes_read) == size:
95
+ # Convert bytes to tensor and copy to target
96
+ bytes_tensor = torch.frombuffer(bytes_read, dtype=torch.uint8)
97
+ typed_tensor = bytes_tensor.view(tensor.dtype).view(tensor.shape)
98
+ tensor.copy_(typed_tensor)
99
+ results.append(size)
100
+ else:
101
+ logger.warning(
102
+ f"Short read: expected {size}, got {len(bytes_read)}"
103
+ )
104
+ results.append(len(bytes_read))
105
+
106
+ except Exception as e:
107
+ logger.error(f"Error reading from offset {offset}: {e}")
108
+ results.append(0)
109
+
110
+ return results
111
+
112
+ def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
113
+ """Batch write to mock storage."""
114
+ self.check(offsets, tensors)
115
+
116
+ results = []
117
+
118
+ for offset, tensor in zip(offsets, tensors):
119
+ size = tensor.numel() * tensor.itemsize
120
+
121
+ try:
122
+ # Convert tensor to bytes and write directly to file
123
+ tensor_bytes = tensor.contiguous().view(torch.uint8).flatten()
124
+ data = tensor_bytes.numpy().tobytes()
125
+
126
+ os.lseek(self.file, offset, os.SEEK_SET)
127
+ bytes_written = os.write(self.file, data)
128
+
129
+ if bytes_written == size:
130
+ results.append(size)
131
+ else:
132
+ logger.warning(f"Short write: expected {size}, got {bytes_written}")
133
+ results.append(bytes_written)
134
+
135
+ except Exception as e:
136
+ logger.error(f"Error writing to offset {offset}: {e}")
137
+ results.append(0)
138
+
139
+ return results
140
+
141
+ def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
142
+ """Validate batch operation parameters."""
143
+ pass
144
+
145
+ def get_size(self) -> int:
146
+ """Get total storage size."""
147
+ return self.size
148
+
149
+ def close(self) -> None:
150
+ """Close the mock client and cleanup resources."""
151
+ try:
152
+ if hasattr(self, "file") and self.file >= 0:
153
+ os.close(self.file)
154
+ self.file = -1 # Mark as closed
155
+ logger.info(f"MockHf3fsClient closed: {self.path}")
156
+ except Exception as e:
157
+ logger.error(f"Error closing MockHf3fsClient: {e}")
158
+
159
+ def flush(self) -> None:
160
+ """Flush data to disk."""
161
+ try:
162
+ os.fsync(self.file)
163
+ except Exception as e:
164
+ logger.error(f"Error flushing MockHf3fsClient: {e}")
@@ -9,6 +9,8 @@ from typing import List
9
9
  import torch
10
10
  from torch.utils.cpp_extension import load
11
11
 
12
+ from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
13
+
12
14
  root = Path(__file__).parent.resolve()
13
15
  hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"])
14
16
 
@@ -51,7 +53,9 @@ def wsynchronized():
51
53
  return _decorator
52
54
 
53
55
 
54
- class Hf3fsClient:
56
+ class Hf3fsUsrBioClient(Hf3fsClient):
57
+ """HF3FS client implementation using usrbio."""
58
+
55
59
  def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
56
60
  if not HF3FS_AVAILABLE:
57
61
  raise ImportError(
@@ -5,6 +5,7 @@ import logging
5
5
  import os
6
6
  import signal
7
7
  import threading
8
+ import time
8
9
  from abc import ABC, abstractmethod
9
10
  from functools import wraps
10
11
  from typing import Any, List, Optional, Tuple
@@ -12,7 +13,8 @@ from typing import Any, List, Optional, Tuple
12
13
  import torch
13
14
 
14
15
  from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
15
- from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
16
+ from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
17
+ from sglang.srt.metrics.collector import StorageMetrics
16
18
 
17
19
  logger = logging.getLogger(__name__)
18
20
 
@@ -112,6 +114,33 @@ def synchronized():
112
114
  return _decorator
113
115
 
114
116
 
117
+ def create_hf3fs_client(
118
+ path: str, size: int, bytes_per_page: int, entries: int, use_mock: bool = False
119
+ ) -> Hf3fsClient:
120
+ """Factory function to create appropriate HF3FS client.
121
+
122
+ Args:
123
+ path: File path for storage
124
+ size: Total size of storage file
125
+ bytes_per_page: Bytes per page
126
+ entries: Number of entries for batch operations
127
+ use_mock: Whether to use mock client instead of real usrbio client
128
+
129
+ Returns:
130
+ """
131
+ if use_mock:
132
+ from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsMockClient
133
+
134
+ logger.info(f"[Rank Using Hf3fsMockClient for testing")
135
+ return Hf3fsMockClient(path, size, bytes_per_page, entries)
136
+ else:
137
+ from sglang.srt.mem_cache.storage.hf3fs.hf3fs_usrbio_client import (
138
+ Hf3fsUsrBioClient,
139
+ )
140
+
141
+ return Hf3fsUsrBioClient(path, size, bytes_per_page, entries)
142
+
143
+
115
144
  class HiCacheHF3FS(HiCacheStorage):
116
145
  """HiCache backend that stores KV cache pages in HF3FS files."""
117
146
 
@@ -129,12 +158,14 @@ class HiCacheHF3FS(HiCacheStorage):
129
158
  metadata_client: Hf3fsMetadataInterface,
130
159
  is_mla_model: bool = False,
131
160
  is_page_first_layout: bool = False,
161
+ use_mock_client: bool = False,
132
162
  ):
133
163
  self.rank = rank
134
164
  self.file_path = file_path
135
165
  self.file_size = file_size
136
166
  self.numjobs = numjobs
137
167
  self.bytes_per_page = bytes_per_page
168
+ self.gb_per_page = bytes_per_page / (1 << 30)
138
169
  self.entries = entries
139
170
  self.dtype = dtype
140
171
  self.metadata_client = metadata_client
@@ -156,8 +187,12 @@ class HiCacheHF3FS(HiCacheStorage):
156
187
 
157
188
  self.ac = AtomicCounter(self.numjobs)
158
189
  self.clients = [
159
- Hf3fsClient(
160
- self.file_path, self.file_size, self.bytes_per_page, self.entries
190
+ create_hf3fs_client(
191
+ self.file_path,
192
+ self.file_size,
193
+ self.bytes_per_page,
194
+ self.entries,
195
+ use_mock_client,
161
196
  )
162
197
  for _ in range(numjobs)
163
198
  ]
@@ -174,6 +209,11 @@ class HiCacheHF3FS(HiCacheStorage):
174
209
  signal.signal(signal.SIGTERM, lambda sig, frame: self.close())
175
210
  signal.signal(signal.SIGQUIT, lambda sig, frame: self.close())
176
211
 
212
+ self.prefetch_pgs = []
213
+ self.backup_pgs = []
214
+ self.prefetch_bandwidth = []
215
+ self.backup_bandwidth = []
216
+
177
217
  @staticmethod
178
218
  def from_env_config(
179
219
  bytes_per_page: int,
@@ -194,14 +234,24 @@ class HiCacheHF3FS(HiCacheStorage):
194
234
  Hf3fsLocalMetadataClient,
195
235
  )
196
236
 
237
+ use_mock_client = False
197
238
  if storage_config is not None:
198
239
  rank, is_mla_model, is_page_first_layout = (
199
240
  storage_config.tp_rank,
200
241
  storage_config.is_mla_model,
201
242
  storage_config.is_page_first_layout,
202
243
  )
244
+
245
+ if storage_config.extra_config is not None:
246
+ use_mock_client = storage_config.extra_config.get(
247
+ "use_mock_hf3fs_client", False
248
+ )
203
249
  else:
204
- rank, is_mla_model, is_page_first_layout = 0, False, False
250
+ rank, is_mla_model, is_page_first_layout = (
251
+ 0,
252
+ False,
253
+ False,
254
+ )
205
255
 
206
256
  mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
207
257
 
@@ -220,6 +270,7 @@ class HiCacheHF3FS(HiCacheStorage):
220
270
  dtype=dtype,
221
271
  metadata_client=Hf3fsLocalMetadataClient(),
222
272
  is_page_first_layout=is_page_first_layout,
273
+ use_mock_client=use_mock_client,
223
274
  )
224
275
 
225
276
  try:
@@ -269,6 +320,7 @@ class HiCacheHF3FS(HiCacheStorage):
269
320
  metadata_client=metadata_client,
270
321
  is_mla_model=is_mla_model,
271
322
  is_page_first_layout=is_page_first_layout,
323
+ use_mock_client=use_mock_client,
272
324
  )
273
325
 
274
326
  def get(
@@ -308,6 +360,8 @@ class HiCacheHF3FS(HiCacheStorage):
308
360
  for _ in range(len(batch_indices))
309
361
  ]
310
362
 
363
+ start_time = time.perf_counter()
364
+
311
365
  futures = [
312
366
  self.executor.submit(
313
367
  self.clients[self.ac.next()].batch_read,
@@ -318,6 +372,13 @@ class HiCacheHF3FS(HiCacheStorage):
318
372
  ]
319
373
  read_results = [result for future in futures for result in future.result()]
320
374
 
375
+ end_time = time.perf_counter()
376
+ ionum = len(batch_indices)
377
+ self.prefetch_pgs.append(ionum)
378
+ self.prefetch_bandwidth.append(
379
+ ionum / (end_time - start_time) * self.gb_per_page
380
+ )
381
+
321
382
  results = [None] * len(keys)
322
383
  for batch_index, file_result, read_result in zip(
323
384
  batch_indices, file_results, read_results
@@ -345,6 +406,7 @@ class HiCacheHF3FS(HiCacheStorage):
345
406
  [target_sizes] if target_sizes is not None else None,
346
407
  )
347
408
 
409
+ @synchronized()
348
410
  def batch_set(
349
411
  self,
350
412
  keys: List[str],
@@ -374,6 +436,8 @@ class HiCacheHF3FS(HiCacheStorage):
374
436
  assert value.is_contiguous()
375
437
  file_values.append(value)
376
438
 
439
+ start_time = time.perf_counter()
440
+
377
441
  futures = [
378
442
  self.executor.submit(
379
443
  self.clients[self.ac.next()].batch_write,
@@ -388,6 +452,11 @@ class HiCacheHF3FS(HiCacheStorage):
388
452
  for result in future.result()
389
453
  ]
390
454
 
455
+ end_time = time.perf_counter()
456
+ ionum = len(batch_indices)
457
+ self.backup_pgs.append(ionum)
458
+ self.backup_bandwidth.append(ionum / (end_time - start_time) * self.gb_per_page)
459
+
391
460
  written_keys_to_confirm = []
392
461
  results = [index[0] for index in indices]
393
462
  for batch_index, write_result in zip(batch_indices, write_results):
@@ -439,3 +508,16 @@ class HiCacheHF3FS(HiCacheStorage):
439
508
  except Exception as e:
440
509
  logger.error(f"close HiCacheHF3FS: {e}")
441
510
  logger.info("close HiCacheHF3FS")
511
+
512
+ @synchronized()
513
+ def get_stats(self):
514
+ storage_metrics = StorageMetrics()
515
+ storage_metrics.prefetch_pgs.extend(self.prefetch_pgs)
516
+ storage_metrics.backup_pgs.extend(self.backup_pgs)
517
+ storage_metrics.prefetch_bandwidth.extend(self.prefetch_bandwidth)
518
+ storage_metrics.backup_bandwidth.extend(self.backup_bandwidth)
519
+ self.prefetch_pgs.clear()
520
+ self.backup_pgs.clear()
521
+ self.prefetch_bandwidth.clear()
522
+ self.backup_bandwidth.clear()
523
+ return storage_metrics