sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -12,10 +12,11 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
  """A tensor parallel worker."""
15
+ from __future__ import annotations
15
16
 
16
17
  import logging
17
18
  import threading
18
- from typing import Optional, Tuple, Union
19
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
19
20
 
20
21
  import torch
21
22
 
@@ -45,6 +46,9 @@ from sglang.srt.patch_torch import monkey_patch_torch_reductions
45
46
  from sglang.srt.server_args import ServerArgs
46
47
  from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
47
48
 
49
+ if TYPE_CHECKING:
50
+ from sglang.srt.managers.cache_controller import LayerDoneCounter
51
+
48
52
  logger = logging.getLogger(__name__)
49
53
 
50
54
 
@@ -78,6 +82,11 @@ class TpModelWorker:
78
82
  if not is_draft_worker
79
83
  else server_args.speculative_draft_model_path
80
84
  ),
85
+ model_revision=(
86
+ server_args.revision
87
+ if not is_draft_worker
88
+ else server_args.speculative_draft_model_revision
89
+ ),
81
90
  is_draft_model=is_draft_worker,
82
91
  )
83
92
 
@@ -137,7 +146,7 @@ class TpModelWorker:
137
146
  assert self.max_running_requests > 0, "max_running_request is zero"
138
147
  self.max_queued_requests = server_args.max_queued_requests
139
148
  assert (
140
- self.max_running_requests > 0
149
+ self.max_queued_requests > 0
141
150
  ), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
142
151
  self.max_req_len = min(
143
152
  self.model_config.context_len - 1,
@@ -162,10 +171,10 @@ class TpModelWorker:
162
171
 
163
172
  self.hicache_layer_transfer_counter = None
164
173
 
165
- def register_hicache_layer_transfer_counter(self, counter):
174
+ def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
166
175
  self.hicache_layer_transfer_counter = counter
167
176
 
168
- def set_hicache_consumer(self, consumer_index):
177
+ def set_hicache_consumer(self, consumer_index: int):
169
178
  if self.hicache_layer_transfer_counter is not None:
170
179
  self.hicache_layer_transfer_counter.set_consumer(consumer_index)
171
180
 
@@ -225,6 +234,9 @@ class TpModelWorker:
225
234
  ) -> Tuple[
226
235
  Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
227
236
  ]:
237
+ # update the consumer index of hicache to the running batch
238
+ self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
239
+
228
240
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
229
241
 
230
242
  pp_proxy_tensors = None
@@ -12,13 +12,14 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
  """A tensor parallel worker."""
15
+ from __future__ import annotations
15
16
 
16
17
  import dataclasses
17
18
  import logging
18
19
  import signal
19
20
  import threading
20
21
  from queue import Queue
21
- from typing import Optional, Tuple
22
+ from typing import TYPE_CHECKING, List, Optional, Tuple
22
23
 
23
24
  import psutil
24
25
  import torch
@@ -38,6 +39,9 @@ from sglang.srt.server_args import ServerArgs
38
39
  from sglang.srt.utils import DynamicGradMode, get_compiler_backend
39
40
  from sglang.utils import get_exception_traceback
40
41
 
42
+ if TYPE_CHECKING:
43
+ from sglang.srt.managers.cache_controller import LayerDoneCounter
44
+
41
45
  logger = logging.getLogger(__name__)
42
46
 
43
47
 
@@ -79,7 +83,7 @@ class TpModelWorkerClient:
79
83
  )
80
84
 
81
85
  # Launch threads
82
- self.input_queue = Queue()
86
+ self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
83
87
  self.output_queue = Queue()
84
88
  self.forward_stream = torch.get_device_module(self.device).Stream()
85
89
  self.forward_thread = threading.Thread(
@@ -93,13 +97,9 @@ class TpModelWorkerClient:
93
97
 
94
98
  self.hicache_layer_transfer_counter = None
95
99
 
96
- def register_hicache_layer_transfer_counter(self, counter):
100
+ def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
97
101
  self.hicache_layer_transfer_counter = counter
98
102
 
99
- def set_hicache_consumer(self, consumer_index):
100
- if self.hicache_layer_transfer_counter is not None:
101
- self.hicache_layer_transfer_counter.set_consumer(consumer_index)
102
-
103
103
  def get_worker_info(self):
104
104
  return self.worker.get_worker_info()
105
105
 
@@ -147,7 +147,7 @@ class TpModelWorkerClient:
147
147
  @DynamicGradMode()
148
148
  def forward_thread_func_(self):
149
149
  batch_pt = 0
150
- batch_lists = [None] * 2
150
+ batch_lists: List = [None] * 2
151
151
 
152
152
  while True:
153
153
  model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
@@ -169,8 +169,6 @@ class TpModelWorkerClient:
169
169
  input_ids = model_worker_batch.input_ids
170
170
  resolve_future_token_ids(input_ids, self.future_token_ids_map)
171
171
 
172
- # update the consumer index of hicache to the running batch
173
- self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
174
172
  # Run forward
175
173
  logits_output, next_token_ids, can_run_cuda_graph = (
176
174
  self.worker.forward_batch_generation(
@@ -283,7 +283,7 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
283
283
  self.swa_attn_allocator.clear()
284
284
  self.full_attn_allocator.clear()
285
285
  self.full_to_swa_index_mapping.fill_(0)
286
- self.is_in_free_group = False
286
+ self.is_not_in_free_group = True
287
287
  self.free_group = []
288
288
 
289
289
 
@@ -47,7 +47,7 @@ class ChunkCache(BasePrefixCache):
47
47
  self.req_to_token_pool.free(req.req_pool_idx)
48
48
  self.token_to_kv_pool_allocator.free(kv_indices)
49
49
 
50
- def cache_unfinished_req(self, req: Req):
50
+ def cache_unfinished_req(self, req: Req, chunked=False):
51
51
  kv_indices = self.req_to_token_pool.req_to_token[
52
52
  req.req_pool_idx, : len(req.fill_ids)
53
53
  ]
@@ -2,6 +2,7 @@ import hashlib
2
2
  import logging
3
3
  import os
4
4
  from abc import ABC, abstractmethod
5
+ from dataclasses import dataclass
5
6
  from typing import Any, List, Optional
6
7
 
7
8
  import torch
@@ -9,17 +10,6 @@ import torch
9
10
  logger = logging.getLogger(__name__)
10
11
 
11
12
 
12
- from sglang.srt.distributed import (
13
- get_tensor_model_parallel_rank,
14
- get_tensor_model_parallel_world_size,
15
- )
16
- from sglang.srt.layers.dp_attention import (
17
- get_attention_tp_rank,
18
- get_attention_tp_size,
19
- is_dp_attention_enabled,
20
- )
21
-
22
-
23
13
  def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
24
14
  hasher = hashlib.sha256()
25
15
 
@@ -32,6 +22,16 @@ def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
32
22
  return hasher.hexdigest()
33
23
 
34
24
 
25
+ @dataclass
26
+ class HiCacheStorageConfig:
27
+ tp_rank: int
28
+ tp_size: int
29
+ is_mla_model: bool
30
+ is_page_first_layout: bool
31
+ model_name: Optional[str]
32
+ extra_config: Optional[dict] = None
33
+
34
+
35
35
  class HiCacheStorage(ABC):
36
36
  """
37
37
  HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
@@ -60,7 +60,7 @@ class HiCacheStorage(ABC):
60
60
  keys: List[str],
61
61
  target_locations: Optional[Any] = None,
62
62
  target_sizes: Optional[Any] = None,
63
- ) -> List[torch.Tensor | None]:
63
+ ) -> List[torch.Tensor | None] | int:
64
64
  """
65
65
  Retrieve values for multiple keys.
66
66
  Returns a list of tensors or None for each key.
@@ -96,32 +96,53 @@ class HiCacheStorage(ABC):
96
96
  pass
97
97
 
98
98
  @abstractmethod
99
- def exists(self, key: str) -> bool | dict:
99
+ def exists(self, key: str) -> bool:
100
100
  """
101
101
  Check if the key exists in the storage.
102
102
  Returns True if the key exists, False otherwise.
103
103
  """
104
104
  pass
105
105
 
106
+ def batch_exists(self, keys: List[str]) -> int:
107
+ """
108
+ Check if the keys exist in the storage.
109
+ return the number of consecutive existing keys from the start.
110
+ Can be overridden by subclasses for more efficient implementation.
111
+ """
112
+ for i in range(len(keys)):
113
+ if not self.exists(keys[i]):
114
+ return i
115
+ return len(keys)
116
+
117
+ def get_stats(self):
118
+ return None
119
+
106
120
 
107
121
  class HiCacheFile(HiCacheStorage):
108
122
 
109
- def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False):
123
+ def __init__(
124
+ self, storage_config: HiCacheStorageConfig, file_path: str = "/tmp/hicache"
125
+ ):
110
126
  self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
111
- if is_dp_attention_enabled():
112
- tp_rank = get_attention_tp_rank()
113
- tp_size = get_attention_tp_size()
127
+
128
+ tp_rank, tp_size, model_name, is_mla_model = (
129
+ storage_config.tp_rank,
130
+ storage_config.tp_size,
131
+ storage_config.model_name,
132
+ storage_config.is_mla_model,
133
+ )
134
+ model_name = "-".join(model_name.split("/")) if model_name else ""
135
+ if is_mla_model:
136
+ self.config_suffix = f"_{model_name}"
114
137
  else:
115
- tp_rank = get_tensor_model_parallel_rank()
116
- tp_size = get_tensor_model_parallel_world_size()
138
+ self.config_suffix = f"_{model_name}_{tp_rank}_{tp_size}"
117
139
 
118
- self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
119
140
  if not os.path.exists(self.file_path) and tp_rank == 0:
120
141
  os.makedirs(self.file_path)
121
142
  logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
122
143
 
123
144
  def _get_suffixed_key(self, key: str) -> str:
124
- return key + self.tp_suffix
145
+ return key + self.config_suffix
125
146
 
126
147
  def get(
127
148
  self,
@@ -132,13 +153,11 @@ class HiCacheFile(HiCacheStorage):
132
153
  key = self._get_suffixed_key(key)
133
154
  tensor_path = os.path.join(self.file_path, f"{key}.bin")
134
155
  try:
135
- # Load directly into target_location's memory buffer
136
- with open(tensor_path, "rb") as f:
137
- target_location.set_(
138
- torch.frombuffer(f.read(), dtype=target_location.dtype)
139
- .reshape(target_location.shape)
140
- .untyped_storage()
141
- )
156
+ expected = target_location.numel() * target_location.element_size()
157
+ with open(tensor_path, "rb", buffering=0) as f:
158
+ buf = memoryview(target_location.view(torch.uint8).contiguous().numpy())
159
+ if f.readinto(buf) != expected:
160
+ raise IOError(f"Short read for {key}")
142
161
  return target_location
143
162
  except FileNotFoundError:
144
163
  logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
@@ -164,11 +183,12 @@ class HiCacheFile(HiCacheStorage):
164
183
  target_location: Optional[Any] = None,
165
184
  target_sizes: Optional[Any] = None,
166
185
  ) -> bool:
167
- key = self._get_suffixed_key(key)
168
- tensor_path = os.path.join(self.file_path, f"{key}.bin")
169
186
  if self.exists(key):
170
187
  logger.debug(f"Key {key} already exists. Skipped.")
171
188
  return True
189
+
190
+ key = self._get_suffixed_key(key)
191
+ tensor_path = os.path.join(self.file_path, f"{key}.bin")
172
192
  try:
173
193
  value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
174
194
  return True
@@ -193,21 +213,14 @@ class HiCacheFile(HiCacheStorage):
193
213
  tensor_path = os.path.join(self.file_path, f"{key}.bin")
194
214
  return os.path.exists(tensor_path)
195
215
 
196
- def delete(self, key: str) -> None:
197
- key = self._get_suffixed_key(key)
198
- tensor_path = os.path.join(self.file_path, f"{key}.bin")
199
- try:
200
- os.remove(tensor_path)
201
- except FileNotFoundError:
202
- logger.warning(f"Key {key} does not exist. Cannot delete.")
203
- return
204
-
205
- def clear(self) -> None:
216
+ def clear(self) -> bool:
206
217
  try:
207
218
  for filename in os.listdir(self.file_path):
208
219
  file_path = os.path.join(self.file_path, filename)
209
220
  if os.path.isfile(file_path):
210
221
  os.remove(file_path)
211
222
  logger.info("Cleared all entries in HiCacheFile storage.")
223
+ return True
212
224
  except Exception as e:
213
225
  logger.error(f"Failed to clear HiCacheFile storage: {e}")
226
+ return False