sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ import logging
5
5
  import os
6
6
  import signal
7
7
  import threading
8
+ import time
8
9
  from abc import ABC, abstractmethod
9
10
  from functools import wraps
10
11
  from typing import Any, List, Optional, Tuple
@@ -12,7 +13,8 @@ from typing import Any, List, Optional, Tuple
12
13
  import torch
13
14
 
14
15
  from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
15
- from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
16
+ from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
17
+ from sglang.srt.metrics.collector import StorageMetrics
16
18
 
17
19
  logger = logging.getLogger(__name__)
18
20
 
@@ -112,7 +114,36 @@ def synchronized():
112
114
  return _decorator
113
115
 
114
116
 
117
+ def create_hf3fs_client(
118
+ path: str, size: int, bytes_per_page: int, entries: int, use_mock: bool = False
119
+ ) -> Hf3fsClient:
120
+ """Factory function to create appropriate HF3FS client.
121
+
122
+ Args:
123
+ path: File path for storage
124
+ size: Total size of storage file
125
+ bytes_per_page: Bytes per page
126
+ entries: Number of entries for batch operations
127
+ use_mock: Whether to use mock client instead of real usrbio client
128
+
129
+ Returns:
130
+ """
131
+ if use_mock:
132
+ from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsMockClient
133
+
134
+ logger.info(f"[Rank Using Hf3fsMockClient for testing")
135
+ return Hf3fsMockClient(path, size, bytes_per_page, entries)
136
+ else:
137
+ from sglang.srt.mem_cache.storage.hf3fs.hf3fs_usrbio_client import (
138
+ Hf3fsUsrBioClient,
139
+ )
140
+
141
+ return Hf3fsUsrBioClient(path, size, bytes_per_page, entries)
142
+
143
+
115
144
  class HiCacheHF3FS(HiCacheStorage):
145
+ """HiCache backend that stores KV cache pages in HF3FS files."""
146
+
116
147
  default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
117
148
 
118
149
  def __init__(
@@ -125,18 +156,27 @@ class HiCacheHF3FS(HiCacheStorage):
125
156
  entries: int,
126
157
  dtype: torch.dtype,
127
158
  metadata_client: Hf3fsMetadataInterface,
159
+ is_mla_model: bool = False,
160
+ is_page_first_layout: bool = False,
161
+ use_mock_client: bool = False,
128
162
  ):
129
163
  self.rank = rank
130
164
  self.file_path = file_path
131
165
  self.file_size = file_size
132
166
  self.numjobs = numjobs
133
167
  self.bytes_per_page = bytes_per_page
168
+ self.gb_per_page = bytes_per_page / (1 << 30)
134
169
  self.entries = entries
135
170
  self.dtype = dtype
136
171
  self.metadata_client = metadata_client
137
-
172
+ self.is_mla_model = is_mla_model
173
+ self.is_page_first_layout = is_page_first_layout
138
174
  self.numel = self.bytes_per_page // self.dtype.itemsize
139
175
  self.num_pages = self.file_size // self.bytes_per_page
176
+ self.skip_backup = False
177
+ if self.is_mla_model and self.rank != 0:
178
+ self.skip_backup = True
179
+ self.rank = 0
140
180
 
141
181
  logger.info(
142
182
  f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
@@ -147,8 +187,12 @@ class HiCacheHF3FS(HiCacheStorage):
147
187
 
148
188
  self.ac = AtomicCounter(self.numjobs)
149
189
  self.clients = [
150
- Hf3fsClient(
151
- self.file_path, self.file_size, self.bytes_per_page, self.entries
190
+ create_hf3fs_client(
191
+ self.file_path,
192
+ self.file_size,
193
+ self.bytes_per_page,
194
+ self.entries,
195
+ use_mock_client,
152
196
  )
153
197
  for _ in range(numjobs)
154
198
  ]
@@ -165,21 +209,57 @@ class HiCacheHF3FS(HiCacheStorage):
165
209
  signal.signal(signal.SIGTERM, lambda sig, frame: self.close())
166
210
  signal.signal(signal.SIGQUIT, lambda sig, frame: self.close())
167
211
 
212
+ self.prefetch_pgs = []
213
+ self.backup_pgs = []
214
+ self.prefetch_bandwidth = []
215
+ self.backup_bandwidth = []
216
+
168
217
  @staticmethod
169
218
  def from_env_config(
170
219
  bytes_per_page: int,
171
220
  dtype: torch.dtype,
172
221
  storage_config: HiCacheStorageConfig = None,
173
222
  ) -> "HiCacheHF3FS":
223
+ """Create a HiCacheHF3FS instance from environment configuration.
224
+
225
+ Environment:
226
+ - Uses env var stored in `HiCacheHF3FS.default_env_var` to locate a JSON config.
227
+ - Falls back to a local single-machine config when the env var is not set.
228
+
229
+ Raises:
230
+ ValueError: If MLA Model is requested without global metadata server or required keys are missing.
231
+ """
174
232
  from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
175
233
  Hf3fsGlobalMetadataClient,
176
234
  Hf3fsLocalMetadataClient,
177
235
  )
178
236
 
179
- rank = storage_config.tp_rank if storage_config is not None else 0
237
+ use_mock_client = False
238
+ if storage_config is not None:
239
+ rank, is_mla_model, is_page_first_layout = (
240
+ storage_config.tp_rank,
241
+ storage_config.is_mla_model,
242
+ storage_config.is_page_first_layout,
243
+ )
244
+
245
+ if storage_config.extra_config is not None:
246
+ use_mock_client = storage_config.extra_config.get(
247
+ "use_mock_hf3fs_client", False
248
+ )
249
+ else:
250
+ rank, is_mla_model, is_page_first_layout = (
251
+ 0,
252
+ False,
253
+ False,
254
+ )
255
+
256
+ mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
180
257
 
181
258
  config_path = os.getenv(HiCacheHF3FS.default_env_var)
182
259
  if not config_path:
260
+ if is_mla_model:
261
+ raise ValueError(mla_unsupported_msg)
262
+
183
263
  return HiCacheHF3FS(
184
264
  rank=rank,
185
265
  file_path=f"/data/hicache.{rank}.bin",
@@ -189,6 +269,8 @@ class HiCacheHF3FS(HiCacheStorage):
189
269
  entries=8,
190
270
  dtype=dtype,
191
271
  metadata_client=Hf3fsLocalMetadataClient(),
272
+ is_page_first_layout=is_page_first_layout,
273
+ use_mock_client=use_mock_client,
192
274
  )
193
275
 
194
276
  try:
@@ -209,26 +291,36 @@ class HiCacheHF3FS(HiCacheStorage):
209
291
  raise ValueError(f"Missing required keys in config: {missing_keys}")
210
292
 
211
293
  # Choose metadata client based on configuration
212
- if "metadata_server_url" in config and config["metadata_server_url"]:
294
+ if config.get("metadata_server_url"):
213
295
  # Use global metadata client to connect to metadata server
214
296
  metadata_server_url = config["metadata_server_url"]
215
297
  metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
298
+
216
299
  logger.info(
217
300
  f"Using global metadata client with server url: {metadata_server_url}"
218
301
  )
219
302
  else:
303
+ # Enable MLA optimization only when using the global metadata client
304
+ if is_mla_model:
305
+ raise ValueError(mla_unsupported_msg)
306
+
220
307
  # Use local metadata client for single-machine deployment
221
308
  metadata_client = Hf3fsLocalMetadataClient()
222
309
 
310
+ rank_for_path = 0 if is_mla_model else rank
223
311
  return HiCacheHF3FS(
224
312
  rank=rank,
225
- file_path=f"{config['file_path_prefix']}.{rank}.bin",
313
+ # Let all ranks use the same file path for MLA model
314
+ file_path=f"{config['file_path_prefix']}.{rank_for_path}.bin",
226
315
  file_size=int(config["file_size"]),
227
316
  numjobs=int(config["numjobs"]),
228
317
  bytes_per_page=bytes_per_page,
229
318
  entries=int(config["entries"]),
230
319
  dtype=dtype,
231
320
  metadata_client=metadata_client,
321
+ is_mla_model=is_mla_model,
322
+ is_page_first_layout=is_page_first_layout,
323
+ use_mock_client=use_mock_client,
232
324
  )
233
325
 
234
326
  def get(
@@ -268,6 +360,8 @@ class HiCacheHF3FS(HiCacheStorage):
268
360
  for _ in range(len(batch_indices))
269
361
  ]
270
362
 
363
+ start_time = time.perf_counter()
364
+
271
365
  futures = [
272
366
  self.executor.submit(
273
367
  self.clients[self.ac.next()].batch_read,
@@ -278,6 +372,13 @@ class HiCacheHF3FS(HiCacheStorage):
278
372
  ]
279
373
  read_results = [result for future in futures for result in future.result()]
280
374
 
375
+ end_time = time.perf_counter()
376
+ ionum = len(batch_indices)
377
+ self.prefetch_pgs.append(ionum)
378
+ self.prefetch_bandwidth.append(
379
+ ionum / (end_time - start_time) * self.gb_per_page
380
+ )
381
+
281
382
  results = [None] * len(keys)
282
383
  for batch_index, file_result, read_result in zip(
283
384
  batch_indices, file_results, read_results
@@ -305,6 +406,7 @@ class HiCacheHF3FS(HiCacheStorage):
305
406
  [target_sizes] if target_sizes is not None else None,
306
407
  )
307
408
 
409
+ @synchronized()
308
410
  def batch_set(
309
411
  self,
310
412
  keys: List[str],
@@ -312,6 +414,10 @@ class HiCacheHF3FS(HiCacheStorage):
312
414
  target_locations: Optional[Any] = None,
313
415
  target_sizes: Optional[Any] = None,
314
416
  ) -> bool:
417
+ # In MLA backend, only one rank needs to backup the KV cache
418
+ if self.skip_backup:
419
+ return True
420
+
315
421
  # Todo: Add prefix block's hash key
316
422
  key_with_prefix = [(key, "") for key in keys]
317
423
  indices = self.metadata_client.reserve_and_allocate_page_indices(
@@ -330,6 +436,8 @@ class HiCacheHF3FS(HiCacheStorage):
330
436
  assert value.is_contiguous()
331
437
  file_values.append(value)
332
438
 
439
+ start_time = time.perf_counter()
440
+
333
441
  futures = [
334
442
  self.executor.submit(
335
443
  self.clients[self.ac.next()].batch_write,
@@ -344,6 +452,11 @@ class HiCacheHF3FS(HiCacheStorage):
344
452
  for result in future.result()
345
453
  ]
346
454
 
455
+ end_time = time.perf_counter()
456
+ ionum = len(batch_indices)
457
+ self.backup_pgs.append(ionum)
458
+ self.backup_bandwidth.append(ionum / (end_time - start_time) * self.gb_per_page)
459
+
347
460
  written_keys_to_confirm = []
348
461
  results = [index[0] for index in indices]
349
462
  for batch_index, write_result in zip(batch_indices, write_results):
@@ -363,18 +476,29 @@ class HiCacheHF3FS(HiCacheStorage):
363
476
 
364
477
  return all(results)
365
478
 
366
- @synchronized()
367
479
  def delete(self, key: str) -> None:
368
480
  self.metadata_client.delete_keys(self.rank, [key])
369
481
 
370
- @synchronized()
371
482
  def exists(self, key: str) -> bool:
372
483
  result = self.metadata_client.exists(self.rank, [key])
373
484
  return result[0] if result else False
374
485
 
375
- @synchronized()
376
- def clear(self) -> None:
377
- self.metadata_client.clear(self.rank)
486
+ def batch_exists(self, keys: List[str]) -> int:
487
+ results = self.metadata_client.exists(self.rank, keys)
488
+ for i in range(len(keys)):
489
+ if not results[i]:
490
+ return i
491
+
492
+ return len(keys)
493
+
494
+ def clear(self) -> bool:
495
+ try:
496
+ self.metadata_client.clear(self.rank)
497
+ logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
498
+ return True
499
+ except Exception as e:
500
+ logger.error(f"Failed to clear HiCacheHF3FS: {e}")
501
+ return False
378
502
 
379
503
  def close(self) -> None:
380
504
  try:
@@ -384,3 +508,16 @@ class HiCacheHF3FS(HiCacheStorage):
384
508
  except Exception as e:
385
509
  logger.error(f"close HiCacheHF3FS: {e}")
386
510
  logger.info("close HiCacheHF3FS")
511
+
512
+ @synchronized()
513
+ def get_stats(self):
514
+ storage_metrics = StorageMetrics()
515
+ storage_metrics.prefetch_pgs.extend(self.prefetch_pgs)
516
+ storage_metrics.backup_pgs.extend(self.backup_pgs)
517
+ storage_metrics.prefetch_bandwidth.extend(self.prefetch_bandwidth)
518
+ storage_metrics.backup_bandwidth.extend(self.backup_bandwidth)
519
+ self.prefetch_pgs.clear()
520
+ self.backup_pgs.clear()
521
+ self.prefetch_bandwidth.clear()
522
+ self.backup_bandwidth.clear()
523
+ return storage_metrics
@@ -0,0 +1,280 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import threading
5
+ from typing import TYPE_CHECKING, List, Optional
6
+
7
+ import torch
8
+
9
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
10
+ from sglang.srt.mem_cache.base_prefix_cache import MatchResult
11
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
12
+ from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
13
+
14
+ try:
15
+ from lmcache.integration.sglang.sglang_adapter import (
16
+ LMCacheLayerwiseConnector,
17
+ LoadMetadata,
18
+ StoreMetadata,
19
+ )
20
+ except ImportError as e:
21
+ raise RuntimeError(
22
+ "LMCache is not installed. Please install it by running `pip install lmcache`"
23
+ ) from e
24
+
25
+ if TYPE_CHECKING:
26
+ from sglang.srt.configs.model_config import ModelConfig
27
+ from sglang.srt.managers.schedule_batch import Req
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class LayerTransferCounter:
33
+ """Minimal adapter that lets the memory pool notify LMCache per-layer.
34
+
35
+ The KV pool calls `wait_until(layer_id)` after finishing a layer, which we
36
+ translate into a `load_kv_layerwise(layer_id)` call on the LMCache connector
37
+ within the provided CUDA stream.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ num_layers: int,
43
+ load_stream: torch.cuda.Stream,
44
+ lmc_connector: LMCacheLayerwiseConnector,
45
+ printable: bool = False,
46
+ ):
47
+ self.num_layers = num_layers
48
+ self.load_stream = load_stream
49
+ self.lmc_connector = lmc_connector
50
+
51
+ def wait_until(self, layer_id: int):
52
+ # Ensure ordering of the async loads wrt compute stream(s).
53
+ self.load_stream.synchronize()
54
+ with self.load_stream:
55
+ self.lmc_connector.load_kv_layerwise(layer_id)
56
+
57
+
58
+ class LMCRadixCache(RadixCache):
59
+ """RadixCache + LMCache IO.
60
+
61
+ This subclass adds:
62
+ - LMCache connector setup (device/host buffers, TP rank/size)
63
+ - Two CUDA streams for async load/store
64
+ - Layer-wise transfer executor wiring to the KV cache
65
+ - Overridden `match_prefix` to fetch missing prefix chunks from LMCache
66
+ - Extended cache_finalization paths to store back into LMCache
67
+ - Eviction barrier that respects any in-flight host->device stores
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ req_to_token_pool: ReqToTokenPool,
73
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
74
+ page_size: int,
75
+ disable: bool = False,
76
+ enable_kv_cache_events: bool = False,
77
+ model_config: Optional["ModelConfig"] = None,
78
+ tp_size: int = 1,
79
+ rank: int = 0,
80
+ tp_group: Optional[torch.distributed.ProcessGroup] = None,
81
+ ):
82
+ super().__init__(
83
+ req_to_token_pool=req_to_token_pool,
84
+ token_to_kv_pool_allocator=token_to_kv_pool_allocator,
85
+ page_size=page_size,
86
+ disable=disable,
87
+ enable_kv_cache_events=enable_kv_cache_events,
88
+ )
89
+
90
+ kvcache = self.token_to_kv_pool_allocator.get_kvcache()
91
+ self.lmcache_connector = LMCacheLayerwiseConnector(
92
+ sgl_config=model_config,
93
+ tp_size=tp_size,
94
+ rank=rank,
95
+ # NOTE: The original implementation accessed private buffers via
96
+ # `_kvcache.k_buffer` / `.v_buffer`. We prefer public accessors when
97
+ # available; fall back to private fields if needed.
98
+ k_pool=getattr(
99
+ kvcache,
100
+ "k_buffer",
101
+ getattr(self.token_to_kv_pool_allocator._kvcache, "k_buffer"),
102
+ ),
103
+ v_pool=getattr(
104
+ kvcache,
105
+ "v_buffer",
106
+ getattr(self.token_to_kv_pool_allocator._kvcache, "v_buffer"),
107
+ ),
108
+ tp_group=tp_group,
109
+ )
110
+
111
+ self.load_stream = torch.cuda.Stream()
112
+ self.store_stream = torch.cuda.Stream()
113
+
114
+ self.layer_done_executor = LayerTransferCounter(
115
+ num_layers=(
116
+ model_config.num_hidden_layers if model_config is not None else 0
117
+ ),
118
+ load_stream=self.load_stream,
119
+ lmc_connector=self.lmcache_connector,
120
+ )
121
+ kvcache.register_layer_transfer_counter(self.layer_done_executor)
122
+
123
+ self._in_flight_nodes: list[TreeNode] = []
124
+ self._node_lock = threading.Lock()
125
+
126
+ def reset(self): # type: ignore[override]
127
+ super().reset()
128
+ if hasattr(self, "_in_flight_nodes"):
129
+ with self._node_lock:
130
+ self._in_flight_nodes.clear()
131
+
132
+ def match_prefix(self, key: List[int], **kwargs) -> MatchResult: # type: ignore[override]
133
+ """Match cached prefix; if there's a tail miss, prefetch from LMCache.
134
+
135
+ Reuses the base matching logic to obtain (value, last_node). If there
136
+ remains a *page-aligned* uncached suffix and there is room (or after
137
+ eviction), we allocate token slots and trigger an async LMCache load
138
+ into those slots, then materialize a new child node for the retrieved
139
+ chunk.
140
+ """
141
+ if self.disable or not key:
142
+ return super().match_prefix(key, **kwargs)
143
+
144
+ if self.page_size != 1:
145
+ aligned_len = len(key) // self.page_size * self.page_size
146
+ key = key[:aligned_len]
147
+
148
+ base_res = super().match_prefix(key, **kwargs)
149
+ value: torch.Tensor = base_res.device_indices
150
+ last_node: TreeNode = base_res.last_device_node
151
+
152
+ if value.numel() == len(key):
153
+ return base_res
154
+
155
+ uncached_len = len(key) - value.numel()
156
+ if uncached_len == 0:
157
+ return base_res
158
+
159
+ chunk_size = self.lmcache_connector.chunk_size()
160
+ prefix_pad = value.numel() % chunk_size
161
+
162
+ if self.token_to_kv_pool_allocator.available_size() < uncached_len:
163
+ self.evict(uncached_len)
164
+
165
+ token_slots = self.token_to_kv_pool_allocator.alloc(uncached_len)
166
+ if token_slots is None:
167
+ return base_res
168
+
169
+ slot_mapping = torch.cat(
170
+ [
171
+ torch.full((value.numel(),), -1, dtype=torch.int64, device=self.device),
172
+ token_slots.detach().clone().to(torch.int64).to(self.device),
173
+ ]
174
+ )
175
+
176
+ with torch.cuda.stream(self.load_stream):
177
+ num_retrieved = self.lmcache_connector.start_load_kv(
178
+ LoadMetadata(
179
+ token_ids=key, # full page-aligned key
180
+ slot_mapping=slot_mapping,
181
+ offset=value.numel() - prefix_pad, # LMCache offset convention
182
+ )
183
+ )
184
+ logger.debug("num_retrieved_tokens: %s", num_retrieved)
185
+
186
+ if num_retrieved > 0:
187
+ self.token_to_kv_pool_allocator.free(
188
+ token_slots[(num_retrieved - prefix_pad) :]
189
+ )
190
+ else:
191
+ self.token_to_kv_pool_allocator.free(token_slots)
192
+
193
+ if num_retrieved > 0:
194
+ fetched = num_retrieved - prefix_pad
195
+ new_node = TreeNode()
196
+ start = value.numel()
197
+ end = start + fetched
198
+ new_node.key = key[start:end]
199
+ new_node.value = token_slots[:fetched]
200
+ new_node.parent = last_node
201
+ last_node.children[self.get_child_key_fn(new_node.key)] = new_node
202
+ last_node = new_node
203
+
204
+ value = torch.cat([value, token_slots[:fetched]])
205
+ self.evictable_size_ += fetched
206
+
207
+ self._record_store_event(new_node.parent)
208
+ self._record_store_event(new_node)
209
+
210
+ return MatchResult(
211
+ device_indices=value,
212
+ last_device_node=last_node,
213
+ last_host_node=last_node,
214
+ )
215
+
216
+ return base_res
217
+
218
+ def cache_finished_req(self, req: "Req") -> None: # type: ignore[override]
219
+ """On request completion, insert device KV into radix and store to LMCache."""
220
+
221
+ super().cache_finished_req(req)
222
+
223
+ token_ids = (req.origin_input_ids + req.output_ids)[:-1]
224
+ kv_indices = self.req_to_token_pool.req_to_token[
225
+ req.req_pool_idx, : len(token_ids)
226
+ ]
227
+
228
+ _, new_last_node, _, _ = self.match_prefix(token_ids)
229
+ assert new_last_node is not None
230
+
231
+ self.inc_lock_ref(new_last_node)
232
+ store_md = StoreMetadata(
233
+ last_node=new_last_node,
234
+ token_ids=token_ids,
235
+ kv_indices=kv_indices,
236
+ offset=0,
237
+ )
238
+ with torch.cuda.stream(self.store_stream):
239
+ self.lmcache_connector.store_kv(store_md)
240
+ with self._node_lock:
241
+ self._in_flight_nodes.append(new_last_node)
242
+
243
+ def evict(self, num_tokens: int) -> None: # type: ignore[override]
244
+ """Before base eviction, wait for any outstanding stores and release locks."""
245
+ if self.disable:
246
+ return
247
+
248
+ self.store_stream.synchronize()
249
+ with self._node_lock:
250
+ for node in self._in_flight_nodes:
251
+ self.dec_lock_ref(node)
252
+ self._in_flight_nodes.clear()
253
+
254
+ super().evict(num_tokens)
255
+
256
+ def pretty_print(self): # type: ignore[override]
257
+ super().pretty_print()
258
+ try:
259
+ logger.debug(
260
+ "evictable=%d protected=%d", self.evictable_size_, self.protected_size_
261
+ )
262
+ except Exception: # pragma: no cover
263
+ pass
264
+
265
+
266
+ if __name__ == "__main__":
267
+ cache = LMCRadixCache(
268
+ req_to_token_pool=None,
269
+ token_to_kv_pool_allocator=None,
270
+ page_size=1,
271
+ disable=False,
272
+ enable_kv_cache_events=False,
273
+ model_config=None,
274
+ tp_size=1,
275
+ rank=0,
276
+ tp_group=None,
277
+ )
278
+ cache.insert([1, 2, 3], torch.tensor([10, 11, 12], dtype=torch.int64))
279
+ cache.insert([1, 2, 3, 4], torch.tensor([10, 11, 12, 13], dtype=torch.int64))
280
+ cache.pretty_print()