sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -3,16 +3,17 @@ import logging
3
3
  import threading
4
4
  from enum import IntEnum
5
5
  from functools import wraps
6
+ from typing import Optional
6
7
 
7
8
  import psutil
8
9
  import torch
9
10
 
10
- from sglang.srt.distributed import get_tensor_model_parallel_rank
11
11
  from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
12
- from sglang.srt.utils import is_npu
12
+ from sglang.srt.utils import is_npu, is_xpu
13
13
 
14
14
  _is_npu = is_npu()
15
- if not _is_npu:
15
+ _is_xpu = is_xpu()
16
+ if not (_is_npu or _is_xpu):
16
17
  from sgl_kernel.kvcacheio import (
17
18
  transfer_kv_all_layer,
18
19
  transfer_kv_all_layer_lf_pf,
@@ -169,7 +170,7 @@ class HostKVCache(abc.ABC):
169
170
  return len(self.free_slots)
170
171
 
171
172
  @synchronized()
172
- def alloc(self, need_size: int) -> torch.Tensor:
173
+ def alloc(self, need_size: int) -> Optional[torch.Tensor]:
173
174
  assert (
174
175
  need_size % self.page_size == 0
175
176
  ), "The requested size should be a multiple of the page size."
@@ -464,10 +465,11 @@ class MHATokenToKVPoolHost(HostKVCache):
464
465
  else:
465
466
  raise ValueError(f"Unsupported layout: {self.layout}")
466
467
 
467
- def get_buffer_meta(self, keys, indices):
468
+ def get_buffer_meta(self, keys, indices, local_rank):
468
469
  ptr_list = []
469
470
  key_list = []
470
471
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
472
+ indices = indices.tolist()
471
473
  v_offset = (
472
474
  self.layer_num
473
475
  * self.size
@@ -488,8 +490,8 @@ class MHATokenToKVPoolHost(HostKVCache):
488
490
  ptr_list.append(k_ptr)
489
491
  ptr_list.append(v_ptr)
490
492
  key_ = keys[index // self.page_size]
491
- key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_k")
492
- key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_v")
493
+ key_list.append(f"{key_}_{local_rank}_k")
494
+ key_list.append(f"{key_}_{local_rank}_v")
493
495
  element_size = (
494
496
  self.layer_num
495
497
  * self.dtype.itemsize
@@ -500,20 +502,23 @@ class MHATokenToKVPoolHost(HostKVCache):
500
502
  element_size_list = [element_size] * len(key_list)
501
503
  return key_list, ptr_list, element_size_list
502
504
 
503
- def get_buffer_with_hash(self, keys, indices):
505
+ def get_buffer_with_hash(self, keys, indices=None):
504
506
  assert self.layout == "page_first"
505
- assert len(keys) == (len(indices) // self.page_size)
507
+ assert indices is None or (len(keys) == (len(indices) // self.page_size))
506
508
 
507
509
  key_list = []
508
510
  buf_list = []
509
511
 
510
- for key, i in zip(keys, range(0, len(indices), self.page_size)):
512
+ for i in range(len(keys)):
513
+ key = keys[i]
511
514
  key_list.append(f"{key}-k")
512
- buf_list.append(self.k_buffer[i : i + self.page_size])
513
515
  key_list.append(f"{key}-v")
514
- buf_list.append(self.v_buffer[i : i + self.page_size])
516
+ if indices is not None:
517
+ index = indices[i * self.page_size]
518
+ buf_list.append(self.k_buffer[index : index + self.page_size])
519
+ buf_list.append(self.v_buffer[index : index + self.page_size])
515
520
 
516
- return key_list, buf_list
521
+ return key_list, buf_list, 2
517
522
 
518
523
 
519
524
  class MLATokenToKVPoolHost(HostKVCache):
@@ -703,10 +708,11 @@ class MLATokenToKVPoolHost(HostKVCache):
703
708
  else:
704
709
  raise ValueError(f"Unsupported layout: {self.layout}")
705
710
 
706
- def get_buffer_meta(self, keys, indices):
711
+ def get_buffer_meta(self, keys, indices, local_rank):
707
712
  ptr_list = []
708
713
  key_list = []
709
714
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
715
+ indices = indices.tolist()
710
716
  for index in range(0, len(indices), self.page_size):
711
717
  k_ptr = (
712
718
  kv_buffer_data_ptr
@@ -727,13 +733,15 @@ class MLATokenToKVPoolHost(HostKVCache):
727
733
  element_size_list = [element_size] * len(key_list)
728
734
  return key_list, ptr_list, element_size_list
729
735
 
730
- def get_buffer_with_hash(self, keys, indices):
736
+ def get_buffer_with_hash(self, keys, indices=None):
731
737
  assert self.layout == "page_first"
732
- assert len(keys) == (len(indices) // self.page_size)
738
+ assert indices is None or (len(keys) == (len(indices) // self.page_size))
733
739
 
734
740
  buf_list = []
735
741
 
736
- for i in range(0, len(indices), self.page_size):
737
- buf_list.append(self.kv_buffer[i : i + self.page_size])
742
+ if indices is not None:
743
+ for i in range(len(keys)):
744
+ index = indices[i * self.page_size]
745
+ buf_list.append(self.kv_buffer[index : index + self.page_size])
738
746
 
739
- return keys, buf_list
747
+ return keys, buf_list, 1
@@ -53,8 +53,6 @@ class TreeNode:
53
53
  self.last_access_time = time.monotonic()
54
54
 
55
55
  self.hit_count = 0
56
- # indicating the node is loading KV cache from host
57
- self.loading = False
58
56
  # indicating the node is locked to protect from eviction
59
57
  # incremented when the node is referenced by a storage operation
60
58
  self.host_ref_counter = 0
@@ -62,7 +60,6 @@ class TreeNode:
62
60
  self.host_value: Optional[torch.Tensor] = None
63
61
  # store hash values of each pages
64
62
  self.hash_value: Optional[List[str]] = None
65
- self.backuped_storage = False
66
63
 
67
64
  self.id = TreeNode.counter if id is None else id
68
65
  TreeNode.counter += 1
@@ -152,6 +149,7 @@ class RadixCache(BasePrefixCache):
152
149
  self.root_node = TreeNode()
153
150
  self.root_node.key = []
154
151
  self.root_node.value = []
152
+ self.root_node.host_value = []
155
153
  self.root_node.lock_ref = 1
156
154
  self.evictable_size_ = 0
157
155
  self.protected_size_ = 0
@@ -194,7 +192,7 @@ class RadixCache(BasePrefixCache):
194
192
  last_host_node=last_node,
195
193
  )
196
194
 
197
- def insert(self, key: List, value=None):
195
+ def insert(self, key: List, value=None, chunked=False):
198
196
  if self.disable:
199
197
  return 0
200
198
 
@@ -239,7 +237,7 @@ class RadixCache(BasePrefixCache):
239
237
  self.req_to_token_pool.free(req.req_pool_idx)
240
238
  self.dec_lock_ref(req.last_node)
241
239
 
242
- def cache_unfinished_req(self, req: Req):
240
+ def cache_unfinished_req(self, req: Req, chunked=False):
243
241
  """Cache request when it is unfinished."""
244
242
  if self.disable:
245
243
  return
@@ -260,7 +258,9 @@ class RadixCache(BasePrefixCache):
260
258
  page_aligned_token_ids = token_ids[:page_aligned_len]
261
259
 
262
260
  # Radix Cache takes one ref in memory pool
263
- new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices)
261
+ new_prefix_len = self.insert(
262
+ page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked
263
+ )
264
264
  self.token_to_kv_pool_allocator.free(
265
265
  kv_indices[len(req.prefix_indices) : new_prefix_len]
266
266
  )
@@ -181,7 +181,7 @@ class RadixCacheCpp(BasePrefixCache):
181
181
  self.dec_lock_ref(req.last_node)
182
182
  self.req_to_token_pool.free(req.req_pool_idx)
183
183
 
184
- def cache_unfinished_req(self, req: Req):
184
+ def cache_unfinished_req(self, req: Req, chunked=False):
185
185
  """Cache request when it is unfinished."""
186
186
  assert req.req_pool_idx is not None
187
187
  token_ids = req.fill_ids
@@ -0,0 +1,164 @@
1
+ import logging
2
+ import os
3
+ import threading
4
+ from abc import ABC, abstractmethod
5
+ from typing import List
6
+
7
+ import torch
8
+
9
+
10
+ class Hf3fsClient(ABC):
11
+ """Abstract interface for HF3FS clients."""
12
+
13
+ @abstractmethod
14
+ def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
15
+ """Initialize the HF3FS client.
16
+
17
+ Args:
18
+ path: File path for storage
19
+ size: Total size of storage file
20
+ bytes_per_page: Bytes per page
21
+ entries: Number of entries for batch operations
22
+ """
23
+ pass
24
+
25
+ @abstractmethod
26
+ def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
27
+ """Batch read from storage."""
28
+ pass
29
+
30
+ @abstractmethod
31
+ def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
32
+ """Batch write to storage."""
33
+ pass
34
+
35
+ @abstractmethod
36
+ def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
37
+ """Validate batch operation parameters."""
38
+ pass
39
+
40
+ @abstractmethod
41
+ def get_size(self) -> int:
42
+ """Get total storage size."""
43
+ pass
44
+
45
+ @abstractmethod
46
+ def close(self) -> None:
47
+ """Close the client and cleanup resources."""
48
+ pass
49
+
50
+ @abstractmethod
51
+ def flush(self) -> None:
52
+ """Flush data to disk."""
53
+ pass
54
+
55
+
56
+ logger = logging.getLogger(__name__)
57
+
58
+
59
+ class Hf3fsMockClient(Hf3fsClient):
60
+ """Mock implementation of Hf3fsClient for CI testing purposes."""
61
+
62
+ def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
63
+ """Initialize mock HF3FS client."""
64
+ self.path = path
65
+ self.size = size
66
+ self.bytes_per_page = bytes_per_page
67
+ self.entries = entries
68
+
69
+ # Create directory if it doesn't exist
70
+ os.makedirs(os.path.dirname(self.path), exist_ok=True)
71
+
72
+ # Create and initialize the file
73
+ self.file = os.open(self.path, os.O_RDWR | os.O_CREAT)
74
+ os.ftruncate(self.file, size)
75
+
76
+ logger.info(
77
+ f"Hf3fsMockClient initialized: path={path}, size={size}, "
78
+ f"bytes_per_page={bytes_per_page}, entries={entries}"
79
+ )
80
+
81
+ def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
82
+ """Batch read from mock storage."""
83
+ self.check(offsets, tensors)
84
+
85
+ results = []
86
+
87
+ for offset, tensor in zip(offsets, tensors):
88
+ size = tensor.numel() * tensor.itemsize
89
+
90
+ try:
91
+ os.lseek(self.file, offset, os.SEEK_SET)
92
+ bytes_read = os.read(self.file, size)
93
+
94
+ if len(bytes_read) == size:
95
+ # Convert bytes to tensor and copy to target
96
+ bytes_tensor = torch.frombuffer(bytes_read, dtype=torch.uint8)
97
+ typed_tensor = bytes_tensor.view(tensor.dtype).view(tensor.shape)
98
+ tensor.copy_(typed_tensor)
99
+ results.append(size)
100
+ else:
101
+ logger.warning(
102
+ f"Short read: expected {size}, got {len(bytes_read)}"
103
+ )
104
+ results.append(len(bytes_read))
105
+
106
+ except Exception as e:
107
+ logger.error(f"Error reading from offset {offset}: {e}")
108
+ results.append(0)
109
+
110
+ return results
111
+
112
+ def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
113
+ """Batch write to mock storage."""
114
+ self.check(offsets, tensors)
115
+
116
+ results = []
117
+
118
+ for offset, tensor in zip(offsets, tensors):
119
+ size = tensor.numel() * tensor.itemsize
120
+
121
+ try:
122
+ # Convert tensor to bytes and write directly to file
123
+ tensor_bytes = tensor.contiguous().view(torch.uint8).flatten()
124
+ data = tensor_bytes.numpy().tobytes()
125
+
126
+ os.lseek(self.file, offset, os.SEEK_SET)
127
+ bytes_written = os.write(self.file, data)
128
+
129
+ if bytes_written == size:
130
+ results.append(size)
131
+ else:
132
+ logger.warning(f"Short write: expected {size}, got {bytes_written}")
133
+ results.append(bytes_written)
134
+
135
+ except Exception as e:
136
+ logger.error(f"Error writing to offset {offset}: {e}")
137
+ results.append(0)
138
+
139
+ return results
140
+
141
+ def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
142
+ """Validate batch operation parameters."""
143
+ pass
144
+
145
+ def get_size(self) -> int:
146
+ """Get total storage size."""
147
+ return self.size
148
+
149
+ def close(self) -> None:
150
+ """Close the mock client and cleanup resources."""
151
+ try:
152
+ if hasattr(self, "file") and self.file >= 0:
153
+ os.close(self.file)
154
+ self.file = -1 # Mark as closed
155
+ logger.info(f"MockHf3fsClient closed: {self.path}")
156
+ except Exception as e:
157
+ logger.error(f"Error closing MockHf3fsClient: {e}")
158
+
159
+ def flush(self) -> None:
160
+ """Flush data to disk."""
161
+ try:
162
+ os.fsync(self.file)
163
+ except Exception as e:
164
+ logger.error(f"Error flushing MockHf3fsClient: {e}")
@@ -9,6 +9,8 @@ from typing import List
9
9
  import torch
10
10
  from torch.utils.cpp_extension import load
11
11
 
12
+ from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
13
+
12
14
  root = Path(__file__).parent.resolve()
13
15
  hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"])
14
16
 
@@ -51,7 +53,9 @@ def wsynchronized():
51
53
  return _decorator
52
54
 
53
55
 
54
- class Hf3fsClient:
56
+ class Hf3fsUsrBioClient(Hf3fsClient):
57
+ """HF3FS client implementation using usrbio."""
58
+
55
59
  def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
56
60
  if not HF3FS_AVAILABLE:
57
61
  raise ImportError(
@@ -4,10 +4,12 @@ import json
4
4
  import logging
5
5
  import threading
6
6
  from pathlib import Path
7
- from typing import Dict, List, Optional, Tuple
7
+ from typing import Dict, List, Optional, OrderedDict, Tuple
8
8
 
9
+ import orjson
9
10
  import requests
10
- from fastapi import FastAPI, HTTPException, Request, status
11
+ from fastapi import FastAPI, HTTPException, Request, Response
12
+ from fastapi.responses import ORJSONResponse
11
13
  from requests.adapters import HTTPAdapter
12
14
  from urllib3.util.retry import Retry
13
15
 
@@ -24,10 +26,10 @@ class RankMetadata:
24
26
  """Holds all metadata for a single rank."""
25
27
 
26
28
  def __init__(self, num_pages: int):
27
- self.lock = threading.RLock()
29
+ self.lock = threading.Lock()
28
30
  self.num_pages = num_pages
29
31
  self.free_pages: List[int] = list(range(num_pages))
30
- self.key_to_index: Dict[str, int] = {}
32
+ self.key_to_index: OrderedDict[str, int] = OrderedDict()
31
33
  # Todo: Support multi files for HF3FS
32
34
 
33
35
  def exists_keys(self, keys: List[str]) -> List[bool]:
@@ -46,16 +48,18 @@ class RankMetadata:
46
48
  for i, (key, prefix_key) in enumerate(keys):
47
49
  if key in self.key_to_index:
48
50
  results[i] = (True, self.key_to_index[key])
51
+ self.key_to_index.move_to_end(key)
49
52
  else:
50
53
  new_keys_to_process.append((i, key, prefix_key))
51
54
 
52
55
  # Todo: Implementing data eviction logic after HiCache supports prefix information pass-through
53
56
  for i, key, prefix_key in new_keys_to_process:
54
57
  if len(self.free_pages) > 0:
55
- page_idx = self.free_pages.pop()
56
- results[i] = (False, page_idx)
58
+ page_index = self.free_pages.pop()
57
59
  else:
58
- results[i] = (False, -1)
60
+ page_index = self.key_to_index.popitem(last=False)[1]
61
+
62
+ results[i] = (False, page_index)
59
63
 
60
64
  return results
61
65
 
@@ -68,6 +72,7 @@ class RankMetadata:
68
72
  with self.lock:
69
73
  for key, page_index in written_keys_to_confirm:
70
74
  self.key_to_index[key] = page_index
75
+ self.key_to_index.move_to_end(key)
71
76
 
72
77
  for page_index in pages_to_release:
73
78
  if page_index not in self.free_pages:
@@ -94,7 +99,14 @@ class RankMetadata:
94
99
  def get_page_indices(self, keys: List[str]) -> List[Optional[int]]:
95
100
  """Get page indices for keys."""
96
101
  with self.lock:
97
- return [self.key_to_index.get(key) for key in keys]
102
+ results = []
103
+ for key in keys:
104
+ if key in self.key_to_index:
105
+ results.append(self.key_to_index[key])
106
+ self.key_to_index.move_to_end(key)
107
+ else:
108
+ results.append(None)
109
+ return results
98
110
 
99
111
 
100
112
  class GlobalMetadataState:
@@ -182,7 +194,8 @@ class Hf3fsMetadataServer:
182
194
 
183
195
  def __init__(self, persistence_path: Optional[str] = None, save_interval: int = 60):
184
196
  self.state = GlobalMetadataState(persistence_path, save_interval)
185
- self.app = FastAPI()
197
+ self.app = FastAPI(default_response_class=ORJSONResponse)
198
+
186
199
  self._setup_routes()
187
200
 
188
201
  def _setup_routes(self):
@@ -199,17 +212,25 @@ class Hf3fsMetadataServer:
199
212
 
200
213
  def get_rank_metadata(self, rank: int) -> RankMetadata:
201
214
  """Get rank metadata with proper error handling."""
202
- with self.state.global_lock:
203
- if rank not in self.state.ranks:
204
- raise HTTPException(
205
- status_code=404,
206
- detail=f"Rank {rank} not initialized. Please call /{{rank}}/initialize first.",
207
- )
208
- return self.state.ranks[rank]
215
+ if rank not in self.state.ranks:
216
+ raise HTTPException(
217
+ status_code=404,
218
+ detail=f"Rank {rank} not initialized. Please call /{rank}/initialize first.",
219
+ )
220
+ return self.state.ranks[rank]
221
+
222
+ async def _read_json(self, request: Request) -> dict:
223
+ """Parse request JSON using orjson if available."""
224
+ body = await request.body()
225
+ return orjson.loads(body)
226
+
227
+ def _json_response(self, content: dict):
228
+ """Return ORJSONResponse when available to bypass jsonable_encoder."""
229
+ return ORJSONResponse(content)
209
230
 
210
231
  async def initialize(self, rank: int, request: Request):
211
232
  """Initialize a rank with specified number of pages."""
212
- data = await request.json()
233
+ data = await self._read_json(request)
213
234
  num_pages = data["num_pages"]
214
235
  with self.state.global_lock:
215
236
  if rank in self.state.ranks:
@@ -223,57 +244,55 @@ class Hf3fsMetadataServer:
223
244
  else:
224
245
  logging.info(f"Initializing new Rank {rank} with {num_pages} pages.")
225
246
  self.state.ranks[rank] = RankMetadata(num_pages)
226
- return {"message": f"Rank {rank} is ready."}
247
+ return Response(status_code=204)
227
248
 
228
249
  async def exists(self, rank: int, request: Request):
229
250
  """Check if keys exist in metadata."""
230
- data = await request.json()
251
+ data = await self._read_json(request)
231
252
  keys = data["keys"]
232
253
  metadata = self.get_rank_metadata(rank)
233
254
  results = metadata.exists_keys(keys)
234
- return {"exists": results}
255
+ return self._json_response({"exists": results})
235
256
 
236
257
  async def reserve_and_allocate_page_indices(self, rank: int, request: Request):
237
258
  """Reserve and allocate page indices for keys."""
238
- data = await request.json()
259
+ data = await self._read_json(request)
239
260
  metadata = self.get_rank_metadata(rank)
240
261
  keys = data["keys"]
241
262
  results = metadata.reserve_and_allocate_page_indices(keys)
242
- return {"indices": results}
263
+ return self._json_response({"indices": results})
243
264
 
244
265
  async def confirm_write(self, rank: int, request: Request):
245
266
  """Confirm write operations and release pages."""
246
- data = await request.json()
267
+ data = await self._read_json(request)
247
268
  metadata = self.get_rank_metadata(rank)
248
269
  success_written_keys = data.get("written_keys_to_confirm", [])
249
270
  released_pages = data.get("pages_to_release", [])
250
271
 
251
272
  metadata.confirm_write(success_written_keys, released_pages)
252
273
 
253
- return {
254
- "message": f"Rank {rank}: Write confirmed for {len(success_written_keys)} keys. {len(released_pages)} pages released."
255
- }
274
+ return Response(status_code=204)
256
275
 
257
276
  async def delete_keys(self, rank: int, request: Request):
258
277
  """Delete keys from metadata."""
259
- data = await request.json()
278
+ data = await self._read_json(request)
260
279
  metadata = self.get_rank_metadata(rank)
261
280
  count = metadata.delete_keys(data["keys"])
262
- return {"message": f"Rank {rank}: {count} keys deleted."}
281
+ return Response(status_code=204)
263
282
 
264
283
  async def clear(self, rank: int):
265
284
  """Clear all metadata for a rank."""
266
285
  metadata = self.get_rank_metadata(rank)
267
286
  metadata.clear_all()
268
- return {"message": f"Rank {rank}: Metadata cleared."}
287
+ return Response(status_code=204)
269
288
 
270
289
  async def get_page_indices(self, rank: int, request: Request):
271
290
  """Get page indices for keys."""
272
- data = await request.json()
291
+ data = await self._read_json(request)
273
292
  metadata = self.get_rank_metadata(rank)
274
293
  keys = data["keys"]
275
294
  results = metadata.get_page_indices(keys)
276
- return {"indices": results}
295
+ return self._json_response({"indices": results})
277
296
 
278
297
  def run(self, host: str = "0.0.0.0", port: int = 18000):
279
298
  """Run the metadata server."""
@@ -309,14 +328,22 @@ class Hf3fsGlobalMetadataClient(Hf3fsMetadataInterface):
309
328
  status_forcelist=[500, 502, 503, 504],
310
329
  allowed_methods=["GET", "POST"],
311
330
  )
312
- adapter = HTTPAdapter(max_retries=retry_strategy)
331
+ adapter = HTTPAdapter(
332
+ max_retries=retry_strategy, pool_connections=256, pool_maxsize=256
333
+ )
313
334
  self._session.mount("http://", adapter)
314
335
 
315
336
  def _post(self, endpoint: str, json_data: dict) -> dict:
316
337
  try:
317
- response = self._session.post(f"{self.base_url}/{endpoint}", json=json_data)
338
+ url = f"{self.base_url}/{endpoint}"
339
+ headers = {"Content-Type": "application/json"}
340
+ payload = orjson.dumps(json_data) # type: ignore[union-attr]
341
+ response = self._session.post(url, data=payload, headers=headers)
318
342
  response.raise_for_status()
319
- return response.json()
343
+
344
+ if response.status_code == 204 or not response.content:
345
+ return {}
346
+ return orjson.loads(response.content) # type: ignore[union-attr]
320
347
  except requests.exceptions.RequestException as e:
321
348
  logging.error(f"Failed to POST to {endpoint} after retries: {e}")
322
349
  raise RuntimeError(f"Failed to connect to metadata server: {e}") from e