sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -5,19 +5,16 @@ import logging
5
5
  import os
6
6
  import signal
7
7
  import threading
8
+ import time
8
9
  from abc import ABC, abstractmethod
9
10
  from functools import wraps
10
11
  from typing import Any, List, Optional, Tuple
11
12
 
12
13
  import torch
13
14
 
14
- from sglang.srt.distributed import get_tensor_model_parallel_rank
15
- from sglang.srt.layers.dp_attention import (
16
- get_attention_tp_rank,
17
- is_dp_attention_enabled,
18
- )
19
- from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
20
- from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
15
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
16
+ from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
17
+ from sglang.srt.metrics.collector import StorageMetrics
21
18
 
22
19
  logger = logging.getLogger(__name__)
23
20
 
@@ -117,7 +114,36 @@ def synchronized():
117
114
  return _decorator
118
115
 
119
116
 
117
+ def create_hf3fs_client(
118
+ path: str, size: int, bytes_per_page: int, entries: int, use_mock: bool = False
119
+ ) -> Hf3fsClient:
120
+ """Factory function to create appropriate HF3FS client.
121
+
122
+ Args:
123
+ path: File path for storage
124
+ size: Total size of storage file
125
+ bytes_per_page: Bytes per page
126
+ entries: Number of entries for batch operations
127
+ use_mock: Whether to use mock client instead of real usrbio client
128
+
129
+ Returns:
130
+ """
131
+ if use_mock:
132
+ from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsMockClient
133
+
134
+ logger.info(f"[Rank Using Hf3fsMockClient for testing")
135
+ return Hf3fsMockClient(path, size, bytes_per_page, entries)
136
+ else:
137
+ from sglang.srt.mem_cache.storage.hf3fs.hf3fs_usrbio_client import (
138
+ Hf3fsUsrBioClient,
139
+ )
140
+
141
+ return Hf3fsUsrBioClient(path, size, bytes_per_page, entries)
142
+
143
+
120
144
  class HiCacheHF3FS(HiCacheStorage):
145
+ """HiCache backend that stores KV cache pages in HF3FS files."""
146
+
121
147
  default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
122
148
 
123
149
  def __init__(
@@ -130,18 +156,27 @@ class HiCacheHF3FS(HiCacheStorage):
130
156
  entries: int,
131
157
  dtype: torch.dtype,
132
158
  metadata_client: Hf3fsMetadataInterface,
159
+ is_mla_model: bool = False,
160
+ is_page_first_layout: bool = False,
161
+ use_mock_client: bool = False,
133
162
  ):
134
163
  self.rank = rank
135
164
  self.file_path = file_path
136
165
  self.file_size = file_size
137
166
  self.numjobs = numjobs
138
167
  self.bytes_per_page = bytes_per_page
168
+ self.gb_per_page = bytes_per_page / (1 << 30)
139
169
  self.entries = entries
140
170
  self.dtype = dtype
141
171
  self.metadata_client = metadata_client
142
-
172
+ self.is_mla_model = is_mla_model
173
+ self.is_page_first_layout = is_page_first_layout
143
174
  self.numel = self.bytes_per_page // self.dtype.itemsize
144
175
  self.num_pages = self.file_size // self.bytes_per_page
176
+ self.skip_backup = False
177
+ if self.is_mla_model and self.rank != 0:
178
+ self.skip_backup = True
179
+ self.rank = 0
145
180
 
146
181
  logger.info(
147
182
  f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
@@ -152,8 +187,12 @@ class HiCacheHF3FS(HiCacheStorage):
152
187
 
153
188
  self.ac = AtomicCounter(self.numjobs)
154
189
  self.clients = [
155
- Hf3fsClient(
156
- self.file_path, self.file_size, self.bytes_per_page, self.entries
190
+ create_hf3fs_client(
191
+ self.file_path,
192
+ self.file_size,
193
+ self.bytes_per_page,
194
+ self.entries,
195
+ use_mock_client,
157
196
  )
158
197
  for _ in range(numjobs)
159
198
  ]
@@ -170,24 +209,57 @@ class HiCacheHF3FS(HiCacheStorage):
170
209
  signal.signal(signal.SIGTERM, lambda sig, frame: self.close())
171
210
  signal.signal(signal.SIGQUIT, lambda sig, frame: self.close())
172
211
 
212
+ self.prefetch_pgs = []
213
+ self.backup_pgs = []
214
+ self.prefetch_bandwidth = []
215
+ self.backup_bandwidth = []
216
+
173
217
  @staticmethod
174
218
  def from_env_config(
175
- bytes_per_page: int, dtype: torch.dtype, rank: int = None
219
+ bytes_per_page: int,
220
+ dtype: torch.dtype,
221
+ storage_config: HiCacheStorageConfig = None,
176
222
  ) -> "HiCacheHF3FS":
223
+ """Create a HiCacheHF3FS instance from environment configuration.
224
+
225
+ Environment:
226
+ - Uses env var stored in `HiCacheHF3FS.default_env_var` to locate a JSON config.
227
+ - Falls back to a local single-machine config when the env var is not set.
228
+
229
+ Raises:
230
+ ValueError: If MLA Model is requested without global metadata server or required keys are missing.
231
+ """
177
232
  from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
178
233
  Hf3fsGlobalMetadataClient,
179
234
  Hf3fsLocalMetadataClient,
180
235
  )
181
236
 
182
- if rank is None:
183
- rank = (
184
- get_attention_tp_rank()
185
- if is_dp_attention_enabled()
186
- else get_tensor_model_parallel_rank()
237
+ use_mock_client = False
238
+ if storage_config is not None:
239
+ rank, is_mla_model, is_page_first_layout = (
240
+ storage_config.tp_rank,
241
+ storage_config.is_mla_model,
242
+ storage_config.is_page_first_layout,
187
243
  )
188
244
 
245
+ if storage_config.extra_config is not None:
246
+ use_mock_client = storage_config.extra_config.get(
247
+ "use_mock_hf3fs_client", False
248
+ )
249
+ else:
250
+ rank, is_mla_model, is_page_first_layout = (
251
+ 0,
252
+ False,
253
+ False,
254
+ )
255
+
256
+ mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
257
+
189
258
  config_path = os.getenv(HiCacheHF3FS.default_env_var)
190
259
  if not config_path:
260
+ if is_mla_model:
261
+ raise ValueError(mla_unsupported_msg)
262
+
191
263
  return HiCacheHF3FS(
192
264
  rank=rank,
193
265
  file_path=f"/data/hicache.{rank}.bin",
@@ -197,6 +269,8 @@ class HiCacheHF3FS(HiCacheStorage):
197
269
  entries=8,
198
270
  dtype=dtype,
199
271
  metadata_client=Hf3fsLocalMetadataClient(),
272
+ is_page_first_layout=is_page_first_layout,
273
+ use_mock_client=use_mock_client,
200
274
  )
201
275
 
202
276
  try:
@@ -217,26 +291,36 @@ class HiCacheHF3FS(HiCacheStorage):
217
291
  raise ValueError(f"Missing required keys in config: {missing_keys}")
218
292
 
219
293
  # Choose metadata client based on configuration
220
- if "metadata_server_url" in config and config["metadata_server_url"]:
294
+ if config.get("metadata_server_url"):
221
295
  # Use global metadata client to connect to metadata server
222
296
  metadata_server_url = config["metadata_server_url"]
223
297
  metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
298
+
224
299
  logger.info(
225
300
  f"Using global metadata client with server url: {metadata_server_url}"
226
301
  )
227
302
  else:
303
+ # Enable MLA optimization only when using the global metadata client
304
+ if is_mla_model:
305
+ raise ValueError(mla_unsupported_msg)
306
+
228
307
  # Use local metadata client for single-machine deployment
229
308
  metadata_client = Hf3fsLocalMetadataClient()
230
309
 
310
+ rank_for_path = 0 if is_mla_model else rank
231
311
  return HiCacheHF3FS(
232
312
  rank=rank,
233
- file_path=f"{config['file_path_prefix']}.{rank}.bin",
313
+ # Let all ranks use the same file path for MLA model
314
+ file_path=f"{config['file_path_prefix']}.{rank_for_path}.bin",
234
315
  file_size=int(config["file_size"]),
235
316
  numjobs=int(config["numjobs"]),
236
317
  bytes_per_page=bytes_per_page,
237
318
  entries=int(config["entries"]),
238
319
  dtype=dtype,
239
320
  metadata_client=metadata_client,
321
+ is_mla_model=is_mla_model,
322
+ is_page_first_layout=is_page_first_layout,
323
+ use_mock_client=use_mock_client,
240
324
  )
241
325
 
242
326
  def get(
@@ -276,6 +360,8 @@ class HiCacheHF3FS(HiCacheStorage):
276
360
  for _ in range(len(batch_indices))
277
361
  ]
278
362
 
363
+ start_time = time.perf_counter()
364
+
279
365
  futures = [
280
366
  self.executor.submit(
281
367
  self.clients[self.ac.next()].batch_read,
@@ -286,6 +372,13 @@ class HiCacheHF3FS(HiCacheStorage):
286
372
  ]
287
373
  read_results = [result for future in futures for result in future.result()]
288
374
 
375
+ end_time = time.perf_counter()
376
+ ionum = len(batch_indices)
377
+ self.prefetch_pgs.append(ionum)
378
+ self.prefetch_bandwidth.append(
379
+ ionum / (end_time - start_time) * self.gb_per_page
380
+ )
381
+
289
382
  results = [None] * len(keys)
290
383
  for batch_index, file_result, read_result in zip(
291
384
  batch_indices, file_results, read_results
@@ -313,6 +406,7 @@ class HiCacheHF3FS(HiCacheStorage):
313
406
  [target_sizes] if target_sizes is not None else None,
314
407
  )
315
408
 
409
+ @synchronized()
316
410
  def batch_set(
317
411
  self,
318
412
  keys: List[str],
@@ -320,6 +414,10 @@ class HiCacheHF3FS(HiCacheStorage):
320
414
  target_locations: Optional[Any] = None,
321
415
  target_sizes: Optional[Any] = None,
322
416
  ) -> bool:
417
+ # In MLA backend, only one rank needs to backup the KV cache
418
+ if self.skip_backup:
419
+ return True
420
+
323
421
  # Todo: Add prefix block's hash key
324
422
  key_with_prefix = [(key, "") for key in keys]
325
423
  indices = self.metadata_client.reserve_and_allocate_page_indices(
@@ -338,6 +436,8 @@ class HiCacheHF3FS(HiCacheStorage):
338
436
  assert value.is_contiguous()
339
437
  file_values.append(value)
340
438
 
439
+ start_time = time.perf_counter()
440
+
341
441
  futures = [
342
442
  self.executor.submit(
343
443
  self.clients[self.ac.next()].batch_write,
@@ -352,6 +452,11 @@ class HiCacheHF3FS(HiCacheStorage):
352
452
  for result in future.result()
353
453
  ]
354
454
 
455
+ end_time = time.perf_counter()
456
+ ionum = len(batch_indices)
457
+ self.backup_pgs.append(ionum)
458
+ self.backup_bandwidth.append(ionum / (end_time - start_time) * self.gb_per_page)
459
+
355
460
  written_keys_to_confirm = []
356
461
  results = [index[0] for index in indices]
357
462
  for batch_index, write_result in zip(batch_indices, write_results):
@@ -371,18 +476,29 @@ class HiCacheHF3FS(HiCacheStorage):
371
476
 
372
477
  return all(results)
373
478
 
374
- @synchronized()
375
479
  def delete(self, key: str) -> None:
376
480
  self.metadata_client.delete_keys(self.rank, [key])
377
481
 
378
- @synchronized()
379
482
  def exists(self, key: str) -> bool:
380
483
  result = self.metadata_client.exists(self.rank, [key])
381
484
  return result[0] if result else False
382
485
 
383
- @synchronized()
384
- def clear(self) -> None:
385
- self.metadata_client.clear(self.rank)
486
+ def batch_exists(self, keys: List[str]) -> int:
487
+ results = self.metadata_client.exists(self.rank, keys)
488
+ for i in range(len(keys)):
489
+ if not results[i]:
490
+ return i
491
+
492
+ return len(keys)
493
+
494
+ def clear(self) -> bool:
495
+ try:
496
+ self.metadata_client.clear(self.rank)
497
+ logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
498
+ return True
499
+ except Exception as e:
500
+ logger.error(f"Failed to clear HiCacheHF3FS: {e}")
501
+ return False
386
502
 
387
503
  def close(self) -> None:
388
504
  try:
@@ -392,3 +508,16 @@ class HiCacheHF3FS(HiCacheStorage):
392
508
  except Exception as e:
393
509
  logger.error(f"close HiCacheHF3FS: {e}")
394
510
  logger.info("close HiCacheHF3FS")
511
+
512
+ @synchronized()
513
+ def get_stats(self):
514
+ storage_metrics = StorageMetrics()
515
+ storage_metrics.prefetch_pgs.extend(self.prefetch_pgs)
516
+ storage_metrics.backup_pgs.extend(self.backup_pgs)
517
+ storage_metrics.prefetch_bandwidth.extend(self.prefetch_bandwidth)
518
+ storage_metrics.backup_bandwidth.extend(self.backup_bandwidth)
519
+ self.prefetch_pgs.clear()
520
+ self.backup_pgs.clear()
521
+ self.prefetch_bandwidth.clear()
522
+ self.backup_bandwidth.clear()
523
+ return storage_metrics
@@ -0,0 +1,280 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import threading
5
+ from typing import TYPE_CHECKING, List, Optional
6
+
7
+ import torch
8
+
9
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
10
+ from sglang.srt.mem_cache.base_prefix_cache import MatchResult
11
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
12
+ from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
13
+
14
+ try:
15
+ from lmcache.integration.sglang.sglang_adapter import (
16
+ LMCacheLayerwiseConnector,
17
+ LoadMetadata,
18
+ StoreMetadata,
19
+ )
20
+ except ImportError as e:
21
+ raise RuntimeError(
22
+ "LMCache is not installed. Please install it by running `pip install lmcache`"
23
+ ) from e
24
+
25
+ if TYPE_CHECKING:
26
+ from sglang.srt.configs.model_config import ModelConfig
27
+ from sglang.srt.managers.schedule_batch import Req
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class LayerTransferCounter:
33
+ """Minimal adapter that lets the memory pool notify LMCache per-layer.
34
+
35
+ The KV pool calls `wait_until(layer_id)` after finishing a layer, which we
36
+ translate into a `load_kv_layerwise(layer_id)` call on the LMCache connector
37
+ within the provided CUDA stream.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ num_layers: int,
43
+ load_stream: torch.cuda.Stream,
44
+ lmc_connector: LMCacheLayerwiseConnector,
45
+ printable: bool = False,
46
+ ):
47
+ self.num_layers = num_layers
48
+ self.load_stream = load_stream
49
+ self.lmc_connector = lmc_connector
50
+
51
+ def wait_until(self, layer_id: int):
52
+ # Ensure ordering of the async loads wrt compute stream(s).
53
+ self.load_stream.synchronize()
54
+ with self.load_stream:
55
+ self.lmc_connector.load_kv_layerwise(layer_id)
56
+
57
+
58
+ class LMCRadixCache(RadixCache):
59
+ """RadixCache + LMCache IO.
60
+
61
+ This subclass adds:
62
+ - LMCache connector setup (device/host buffers, TP rank/size)
63
+ - Two CUDA streams for async load/store
64
+ - Layer-wise transfer executor wiring to the KV cache
65
+ - Overridden `match_prefix` to fetch missing prefix chunks from LMCache
66
+ - Extended cache_finalization paths to store back into LMCache
67
+ - Eviction barrier that respects any in-flight host->device stores
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ req_to_token_pool: ReqToTokenPool,
73
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
74
+ page_size: int,
75
+ disable: bool = False,
76
+ enable_kv_cache_events: bool = False,
77
+ model_config: Optional["ModelConfig"] = None,
78
+ tp_size: int = 1,
79
+ rank: int = 0,
80
+ tp_group: Optional[torch.distributed.ProcessGroup] = None,
81
+ ):
82
+ super().__init__(
83
+ req_to_token_pool=req_to_token_pool,
84
+ token_to_kv_pool_allocator=token_to_kv_pool_allocator,
85
+ page_size=page_size,
86
+ disable=disable,
87
+ enable_kv_cache_events=enable_kv_cache_events,
88
+ )
89
+
90
+ kvcache = self.token_to_kv_pool_allocator.get_kvcache()
91
+ self.lmcache_connector = LMCacheLayerwiseConnector(
92
+ sgl_config=model_config,
93
+ tp_size=tp_size,
94
+ rank=rank,
95
+ # NOTE: The original implementation accessed private buffers via
96
+ # `_kvcache.k_buffer` / `.v_buffer`. We prefer public accessors when
97
+ # available; fall back to private fields if needed.
98
+ k_pool=getattr(
99
+ kvcache,
100
+ "k_buffer",
101
+ getattr(self.token_to_kv_pool_allocator._kvcache, "k_buffer"),
102
+ ),
103
+ v_pool=getattr(
104
+ kvcache,
105
+ "v_buffer",
106
+ getattr(self.token_to_kv_pool_allocator._kvcache, "v_buffer"),
107
+ ),
108
+ tp_group=tp_group,
109
+ )
110
+
111
+ self.load_stream = torch.cuda.Stream()
112
+ self.store_stream = torch.cuda.Stream()
113
+
114
+ self.layer_done_executor = LayerTransferCounter(
115
+ num_layers=(
116
+ model_config.num_hidden_layers if model_config is not None else 0
117
+ ),
118
+ load_stream=self.load_stream,
119
+ lmc_connector=self.lmcache_connector,
120
+ )
121
+ kvcache.register_layer_transfer_counter(self.layer_done_executor)
122
+
123
+ self._in_flight_nodes: list[TreeNode] = []
124
+ self._node_lock = threading.Lock()
125
+
126
+ def reset(self): # type: ignore[override]
127
+ super().reset()
128
+ if hasattr(self, "_in_flight_nodes"):
129
+ with self._node_lock:
130
+ self._in_flight_nodes.clear()
131
+
132
+ def match_prefix(self, key: List[int], **kwargs) -> MatchResult: # type: ignore[override]
133
+ """Match cached prefix; if there's a tail miss, prefetch from LMCache.
134
+
135
+ Reuses the base matching logic to obtain (value, last_node). If there
136
+ remains a *page-aligned* uncached suffix and there is room (or after
137
+ eviction), we allocate token slots and trigger an async LMCache load
138
+ into those slots, then materialize a new child node for the retrieved
139
+ chunk.
140
+ """
141
+ if self.disable or not key:
142
+ return super().match_prefix(key, **kwargs)
143
+
144
+ if self.page_size != 1:
145
+ aligned_len = len(key) // self.page_size * self.page_size
146
+ key = key[:aligned_len]
147
+
148
+ base_res = super().match_prefix(key, **kwargs)
149
+ value: torch.Tensor = base_res.device_indices
150
+ last_node: TreeNode = base_res.last_device_node
151
+
152
+ if value.numel() == len(key):
153
+ return base_res
154
+
155
+ uncached_len = len(key) - value.numel()
156
+ if uncached_len == 0:
157
+ return base_res
158
+
159
+ chunk_size = self.lmcache_connector.chunk_size()
160
+ prefix_pad = value.numel() % chunk_size
161
+
162
+ if self.token_to_kv_pool_allocator.available_size() < uncached_len:
163
+ self.evict(uncached_len)
164
+
165
+ token_slots = self.token_to_kv_pool_allocator.alloc(uncached_len)
166
+ if token_slots is None:
167
+ return base_res
168
+
169
+ slot_mapping = torch.cat(
170
+ [
171
+ torch.full((value.numel(),), -1, dtype=torch.int64, device=self.device),
172
+ token_slots.detach().clone().to(torch.int64).to(self.device),
173
+ ]
174
+ )
175
+
176
+ with torch.cuda.stream(self.load_stream):
177
+ num_retrieved = self.lmcache_connector.start_load_kv(
178
+ LoadMetadata(
179
+ token_ids=key, # full page-aligned key
180
+ slot_mapping=slot_mapping,
181
+ offset=value.numel() - prefix_pad, # LMCache offset convention
182
+ )
183
+ )
184
+ logger.debug("num_retrieved_tokens: %s", num_retrieved)
185
+
186
+ if num_retrieved > 0:
187
+ self.token_to_kv_pool_allocator.free(
188
+ token_slots[(num_retrieved - prefix_pad) :]
189
+ )
190
+ else:
191
+ self.token_to_kv_pool_allocator.free(token_slots)
192
+
193
+ if num_retrieved > 0:
194
+ fetched = num_retrieved - prefix_pad
195
+ new_node = TreeNode()
196
+ start = value.numel()
197
+ end = start + fetched
198
+ new_node.key = key[start:end]
199
+ new_node.value = token_slots[:fetched]
200
+ new_node.parent = last_node
201
+ last_node.children[self.get_child_key_fn(new_node.key)] = new_node
202
+ last_node = new_node
203
+
204
+ value = torch.cat([value, token_slots[:fetched]])
205
+ self.evictable_size_ += fetched
206
+
207
+ self._record_store_event(new_node.parent)
208
+ self._record_store_event(new_node)
209
+
210
+ return MatchResult(
211
+ device_indices=value,
212
+ last_device_node=last_node,
213
+ last_host_node=last_node,
214
+ )
215
+
216
+ return base_res
217
+
218
+ def cache_finished_req(self, req: "Req") -> None: # type: ignore[override]
219
+ """On request completion, insert device KV into radix and store to LMCache."""
220
+
221
+ super().cache_finished_req(req)
222
+
223
+ token_ids = (req.origin_input_ids + req.output_ids)[:-1]
224
+ kv_indices = self.req_to_token_pool.req_to_token[
225
+ req.req_pool_idx, : len(token_ids)
226
+ ]
227
+
228
+ _, new_last_node, _, _ = self.match_prefix(token_ids)
229
+ assert new_last_node is not None
230
+
231
+ self.inc_lock_ref(new_last_node)
232
+ store_md = StoreMetadata(
233
+ last_node=new_last_node,
234
+ token_ids=token_ids,
235
+ kv_indices=kv_indices,
236
+ offset=0,
237
+ )
238
+ with torch.cuda.stream(self.store_stream):
239
+ self.lmcache_connector.store_kv(store_md)
240
+ with self._node_lock:
241
+ self._in_flight_nodes.append(new_last_node)
242
+
243
+ def evict(self, num_tokens: int) -> None: # type: ignore[override]
244
+ """Before base eviction, wait for any outstanding stores and release locks."""
245
+ if self.disable:
246
+ return
247
+
248
+ self.store_stream.synchronize()
249
+ with self._node_lock:
250
+ for node in self._in_flight_nodes:
251
+ self.dec_lock_ref(node)
252
+ self._in_flight_nodes.clear()
253
+
254
+ super().evict(num_tokens)
255
+
256
+ def pretty_print(self): # type: ignore[override]
257
+ super().pretty_print()
258
+ try:
259
+ logger.debug(
260
+ "evictable=%d protected=%d", self.evictable_size_, self.protected_size_
261
+ )
262
+ except Exception: # pragma: no cover
263
+ pass
264
+
265
+
266
+ if __name__ == "__main__":
267
+ cache = LMCRadixCache(
268
+ req_to_token_pool=None,
269
+ token_to_kv_pool_allocator=None,
270
+ page_size=1,
271
+ disable=False,
272
+ enable_kv_cache_events=False,
273
+ model_config=None,
274
+ tp_size=1,
275
+ rank=0,
276
+ tp_group=None,
277
+ )
278
+ cache.insert([1, 2, 3], torch.tensor([10, 11, 12], dtype=torch.int64))
279
+ cache.insert([1, 2, 3, 4], torch.tensor([10, 11, 12, 13], dtype=torch.int64))
280
+ cache.pretty_print()