sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (265) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/lang/interpreter.py +1 -1
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/device_config.py +3 -1
  6. sglang/srt/configs/dots_vlm.py +139 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/load_config.py +1 -0
  9. sglang/srt/configs/model_config.py +50 -6
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +8 -1
  12. sglang/srt/connector/remote_instance.py +82 -0
  13. sglang/srt/constrained/base_grammar_backend.py +48 -12
  14. sglang/srt/constrained/llguidance_backend.py +0 -1
  15. sglang/srt/constrained/outlines_backend.py +0 -1
  16. sglang/srt/constrained/xgrammar_backend.py +28 -9
  17. sglang/srt/custom_op.py +11 -1
  18. sglang/srt/debug_utils/dump_comparator.py +81 -44
  19. sglang/srt/debug_utils/dump_loader.py +97 -0
  20. sglang/srt/debug_utils/dumper.py +11 -3
  21. sglang/srt/debug_utils/text_comparator.py +73 -11
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +21 -10
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  26. sglang/srt/disaggregation/fake/conn.py +1 -1
  27. sglang/srt/disaggregation/mini_lb.py +6 -445
  28. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  29. sglang/srt/disaggregation/nixl/conn.py +180 -16
  30. sglang/srt/disaggregation/prefill.py +5 -3
  31. sglang/srt/disaggregation/utils.py +5 -50
  32. sglang/srt/distributed/parallel_state.py +67 -43
  33. sglang/srt/entrypoints/engine.py +38 -17
  34. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  35. sglang/srt/entrypoints/grpc_server.py +680 -0
  36. sglang/srt/entrypoints/http_server.py +88 -53
  37. sglang/srt/entrypoints/openai/protocol.py +7 -4
  38. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  39. sglang/srt/entrypoints/openai/serving_chat.py +39 -19
  40. sglang/srt/entrypoints/openai/serving_completions.py +15 -4
  41. sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
  42. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  43. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  44. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  45. sglang/srt/eplb/eplb_manager.py +2 -2
  46. sglang/srt/eplb/expert_distribution.py +26 -13
  47. sglang/srt/eplb/expert_location.py +8 -3
  48. sglang/srt/eplb/expert_location_updater.py +1 -1
  49. sglang/srt/function_call/base_format_detector.py +3 -6
  50. sglang/srt/function_call/ebnf_composer.py +11 -9
  51. sglang/srt/function_call/function_call_parser.py +6 -0
  52. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  53. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  54. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  55. sglang/srt/grpc/__init__.py +1 -0
  56. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  57. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  58. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  59. sglang/srt/hf_transformers_utils.py +4 -0
  60. sglang/srt/layers/activation.py +142 -9
  61. sglang/srt/layers/attention/aiter_backend.py +93 -68
  62. sglang/srt/layers/attention/ascend_backend.py +11 -4
  63. sglang/srt/layers/attention/fla/chunk.py +242 -0
  64. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  65. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  66. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  67. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  68. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  69. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  70. sglang/srt/layers/attention/fla/index.py +37 -0
  71. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  72. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  73. sglang/srt/layers/attention/fla/op.py +66 -0
  74. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  75. sglang/srt/layers/attention/fla/utils.py +331 -0
  76. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  77. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  78. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  79. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  80. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  81. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  82. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  83. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  84. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  85. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  86. sglang/srt/layers/attention/triton_backend.py +18 -1
  87. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  88. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  89. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  90. sglang/srt/layers/communicator.py +45 -7
  91. sglang/srt/layers/dp_attention.py +30 -1
  92. sglang/srt/layers/layernorm.py +32 -15
  93. sglang/srt/layers/linear.py +34 -3
  94. sglang/srt/layers/logits_processor.py +29 -10
  95. sglang/srt/layers/moe/__init__.py +2 -1
  96. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  97. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  98. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  99. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  100. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  107. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  108. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  113. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  114. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  115. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  116. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  117. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  118. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  119. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  120. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  121. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  122. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  123. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  124. sglang/srt/layers/moe/topk.py +30 -9
  125. sglang/srt/layers/moe/utils.py +12 -7
  126. sglang/srt/layers/quantization/awq.py +19 -7
  127. sglang/srt/layers/quantization/base_config.py +11 -6
  128. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  129. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  130. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  131. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  132. sglang/srt/layers/quantization/fp8.py +76 -47
  133. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  134. sglang/srt/layers/quantization/gptq.py +25 -17
  135. sglang/srt/layers/quantization/modelopt_quant.py +182 -49
  136. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  137. sglang/srt/layers/quantization/mxfp4.py +68 -41
  138. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  139. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  140. sglang/srt/layers/quantization/quark/utils.py +97 -0
  141. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  142. sglang/srt/layers/quantization/unquant.py +135 -47
  143. sglang/srt/layers/quantization/w4afp8.py +30 -17
  144. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  145. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  146. sglang/srt/layers/rocm_linear_utils.py +44 -0
  147. sglang/srt/layers/rotary_embedding.py +0 -18
  148. sglang/srt/layers/sampler.py +162 -18
  149. sglang/srt/lora/backend/base_backend.py +50 -8
  150. sglang/srt/lora/backend/triton_backend.py +90 -2
  151. sglang/srt/lora/layers.py +32 -0
  152. sglang/srt/lora/lora.py +4 -1
  153. sglang/srt/lora/lora_manager.py +35 -112
  154. sglang/srt/lora/mem_pool.py +24 -10
  155. sglang/srt/lora/utils.py +18 -9
  156. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  157. sglang/srt/managers/cache_controller.py +200 -199
  158. sglang/srt/managers/data_parallel_controller.py +105 -35
  159. sglang/srt/managers/detokenizer_manager.py +8 -4
  160. sglang/srt/managers/disagg_service.py +46 -0
  161. sglang/srt/managers/io_struct.py +199 -12
  162. sglang/srt/managers/mm_utils.py +1 -0
  163. sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
  164. sglang/srt/managers/schedule_batch.py +77 -56
  165. sglang/srt/managers/schedule_policy.py +4 -3
  166. sglang/srt/managers/scheduler.py +191 -139
  167. sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
  168. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  169. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  170. sglang/srt/managers/template_manager.py +3 -3
  171. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  172. sglang/srt/managers/tokenizer_manager.py +260 -519
  173. sglang/srt/managers/tp_worker.py +53 -4
  174. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  175. sglang/srt/mem_cache/allocator.py +1 -1
  176. sglang/srt/mem_cache/hicache_storage.py +18 -33
  177. sglang/srt/mem_cache/hiradix_cache.py +108 -48
  178. sglang/srt/mem_cache/memory_pool.py +347 -48
  179. sglang/srt/mem_cache/memory_pool_host.py +121 -57
  180. sglang/srt/mem_cache/radix_cache.py +0 -2
  181. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  182. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  183. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
  184. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  185. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  186. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
  187. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  188. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  189. sglang/srt/metrics/collector.py +502 -77
  190. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  191. sglang/srt/metrics/utils.py +48 -0
  192. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  193. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  194. sglang/srt/model_executor/forward_batch_info.py +75 -19
  195. sglang/srt/model_executor/model_runner.py +357 -30
  196. sglang/srt/model_loader/__init__.py +9 -3
  197. sglang/srt/model_loader/loader.py +128 -4
  198. sglang/srt/model_loader/weight_utils.py +2 -1
  199. sglang/srt/models/apertus.py +686 -0
  200. sglang/srt/models/bailing_moe.py +798 -218
  201. sglang/srt/models/bailing_moe_nextn.py +168 -0
  202. sglang/srt/models/deepseek_v2.py +346 -48
  203. sglang/srt/models/dots_vlm.py +174 -0
  204. sglang/srt/models/dots_vlm_vit.py +337 -0
  205. sglang/srt/models/ernie4.py +1 -1
  206. sglang/srt/models/gemma3n_mm.py +1 -1
  207. sglang/srt/models/glm4_moe.py +11 -2
  208. sglang/srt/models/glm4v.py +4 -2
  209. sglang/srt/models/glm4v_moe.py +3 -0
  210. sglang/srt/models/gpt_oss.py +1 -1
  211. sglang/srt/models/internvl.py +28 -0
  212. sglang/srt/models/llama4.py +9 -0
  213. sglang/srt/models/llama_eagle3.py +13 -0
  214. sglang/srt/models/longcat_flash.py +2 -2
  215. sglang/srt/models/minicpmv.py +165 -3
  216. sglang/srt/models/mllama4.py +25 -0
  217. sglang/srt/models/opt.py +637 -0
  218. sglang/srt/models/qwen2.py +7 -0
  219. sglang/srt/models/qwen2_5_vl.py +27 -3
  220. sglang/srt/models/qwen2_moe.py +60 -13
  221. sglang/srt/models/qwen3.py +8 -2
  222. sglang/srt/models/qwen3_moe.py +40 -9
  223. sglang/srt/models/qwen3_next.py +1042 -0
  224. sglang/srt/models/qwen3_next_mtp.py +112 -0
  225. sglang/srt/models/step3_vl.py +1 -1
  226. sglang/srt/models/torch_native_llama.py +1 -1
  227. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  228. sglang/srt/multimodal/processors/glm4v.py +9 -9
  229. sglang/srt/multimodal/processors/internvl.py +141 -129
  230. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  231. sglang/srt/offloader.py +27 -3
  232. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  233. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  234. sglang/srt/sampling/sampling_batch_info.py +18 -15
  235. sglang/srt/server_args.py +355 -37
  236. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  237. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  238. sglang/srt/speculative/eagle_utils.py +0 -2
  239. sglang/srt/speculative/eagle_worker.py +197 -112
  240. sglang/srt/speculative/spec_info.py +5 -0
  241. sglang/srt/speculative/standalone_worker.py +109 -0
  242. sglang/srt/tracing/trace.py +552 -0
  243. sglang/srt/utils.py +46 -3
  244. sglang/srt/weight_sync/utils.py +1 -1
  245. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  246. sglang/test/few_shot_gsm8k.py +1 -0
  247. sglang/test/runners.py +4 -0
  248. sglang/test/test_cutlass_moe.py +24 -6
  249. sglang/test/test_disaggregation_utils.py +66 -0
  250. sglang/test/test_fp4_moe.py +370 -1
  251. sglang/test/test_utils.py +28 -1
  252. sglang/utils.py +12 -0
  253. sglang/version.py +1 -1
  254. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  255. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
  256. sglang/srt/disaggregation/launch_lb.py +0 -118
  257. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  258. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  259. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  260. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  261. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  262. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  263. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  264. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  265. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import math
18
18
  import threading
19
19
  import time
20
20
  from queue import Empty, Full, PriorityQueue, Queue
21
- from typing import TYPE_CHECKING, List, Optional
21
+ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
22
22
 
23
23
  import torch
24
24
 
@@ -33,6 +33,7 @@ from sglang.srt.distributed import (
33
33
  get_tensor_model_parallel_world_size,
34
34
  )
35
35
  from sglang.srt.layers.dp_attention import (
36
+ get_attention_dp_rank,
36
37
  get_attention_tp_rank,
37
38
  get_attention_tp_size,
38
39
  is_dp_attention_enabled,
@@ -42,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
42
43
  logger = logging.getLogger(__name__)
43
44
 
44
45
 
46
+ class LayerLoadingEvent:
47
+ def __init__(self, num_layers: int):
48
+ self._num_layers = num_layers
49
+ self.load_events = [torch.cuda.Event() for _ in range(num_layers)]
50
+ self.start_event = torch.cuda.Event() # start event on controller stream
51
+
52
+ def complete(self, layer_index: int):
53
+ assert 0 <= layer_index < self._num_layers
54
+ self.load_events[layer_index].record()
55
+
56
+ def wait(self, layer_index: int):
57
+ torch.cuda.current_stream().wait_event(self.load_events[layer_index])
58
+
59
+ @property
60
+ def finish_event(self):
61
+ return self.load_events[-1]
62
+
63
+
45
64
  class LayerDoneCounter:
46
- def __init__(self, num_layers):
65
+ def __init__(self, num_layers: int):
47
66
  self.num_layers = num_layers
48
67
  # extra producer and consumer counters for overlap mode
49
68
  self.num_counters = 3
50
- self.counters = [num_layers] * self.num_counters
51
- self.conditions = [threading.Condition() for _ in range(self.num_counters)]
52
- self.producer_index = 0
53
- self.consumer_index = 0
54
-
55
- def next_producer(self):
56
- return (self.producer_index + 1) % self.num_counters
69
+ self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)]
70
+ self.producer_index = -1
71
+ self.consumer_index = -1
57
72
 
58
73
  def update_producer(self):
59
- self.producer_index = self.next_producer()
74
+ self.producer_index = (self.producer_index + 1) % self.num_counters
75
+ assert self.events[
76
+ self.producer_index
77
+ ].finish_event.query(), (
78
+ "Producer finish event should be ready before being reused."
79
+ )
60
80
  return self.producer_index
61
81
 
62
- def set_consumer(self, index):
82
+ def set_consumer(self, index: int):
63
83
  self.consumer_index = index
64
84
 
65
- def increment(self):
66
- with self.conditions[self.producer_index]:
67
- self.counters[self.producer_index] += 1
68
- self.conditions[self.producer_index].notify_all()
69
-
70
- def wait_until(self, threshold):
71
- with self.conditions[self.consumer_index]:
72
- while self.counters[self.consumer_index] <= threshold:
73
- self.conditions[self.consumer_index].wait()
85
+ def wait_until(self, threshold: int):
86
+ if self.consumer_index < 0:
87
+ return
88
+ self.events[self.consumer_index].wait(threshold)
74
89
 
75
90
  def reset(self):
76
- with self.conditions[self.producer_index]:
77
- self.counters[self.producer_index] = 0
91
+ self.producer_index = -1
92
+ self.consumer_index = -1
78
93
 
79
94
 
80
95
  class CacheOperation:
@@ -98,36 +113,30 @@ class CacheOperation:
98
113
  # default priority is the order of creation
99
114
  self.priority = priority if priority is not None else self.id
100
115
 
101
- def merge(self, other: "CacheOperation") -> None:
102
- # multiple operations can be merged into a single operation for batch processing
103
- self.host_indices = torch.cat([self.host_indices, other.host_indices])
104
- self.device_indices = torch.cat([self.device_indices, other.device_indices])
105
- self.priority = min(self.priority, other.priority)
106
- self.node_ids.extend(other.node_ids)
107
-
108
- def split(self, factor) -> List["CacheOperation"]:
109
- # split an operation into smaller operations to reduce the size of intermediate buffers
110
- if factor <= 1:
111
- return [self]
112
-
113
- chunk_size = math.ceil(len(self.host_indices) / factor)
114
- split_ops = []
115
- for i in range(0, len(self.host_indices), chunk_size):
116
- split_ops.append(
117
- CacheOperation(
118
- host_indices=self.host_indices[i : i + chunk_size],
119
- device_indices=self.device_indices[i : i + chunk_size],
120
- node_id=0,
121
- )
122
- )
123
- # Inherit the node_ids on the final chunk
124
- if split_ops:
125
- split_ops[-1].node_ids = self.node_ids
116
+ @staticmethod
117
+ def merge_ops(ops: List[CacheOperation]) -> CacheOperation:
118
+ assert len(ops) > 0
119
+ if len(ops) == 1:
120
+ return ops[0]
121
+
122
+ host_indices = torch.cat([op.host_indices for op in ops])
123
+ device_indices = torch.cat([op.device_indices for op in ops])
124
+ node_ids = []
125
+ priority = min(op.priority for op in ops)
126
+ for op in ops:
127
+ node_ids.extend(op.node_ids)
128
+ merged_op = CacheOperation(host_indices, device_indices, -1, priority)
129
+ merged_op.node_ids = node_ids
130
+ return merged_op
131
+
132
+ def __lt__(self, other: CacheOperation):
133
+ return self.priority < other.priority
126
134
 
127
- return split_ops
128
135
 
129
- def __lt__(self, other: "CacheOperation"):
130
- return self.priority < other.priority
136
+ class HiCacheAck(NamedTuple):
137
+ start_event: torch.cuda.Event
138
+ finish_event: torch.cuda.Event
139
+ node_ids: List[int]
131
140
 
132
141
 
133
142
  class TransferBuffer:
@@ -206,26 +215,25 @@ class PrefetchOperation(StorageOperation):
206
215
  ):
207
216
  self.request_id = request_id
208
217
 
209
- self._done_flag = False
210
218
  self._lock = threading.Lock()
211
-
219
+ self._terminated_flag = False
212
220
  self.start_time = time.monotonic()
213
221
 
214
222
  super().__init__(host_indices, token_ids, last_hash)
215
223
 
216
224
  def increment(self, num_tokens: int):
217
225
  with self._lock:
218
- if self._done_flag:
226
+ if self._terminated_flag:
219
227
  return False
220
228
  self.completed_tokens += num_tokens
221
229
  return True
222
230
 
223
- def mark_done(self):
231
+ def mark_terminate(self):
224
232
  with self._lock:
225
- self._done_flag = True
233
+ self._terminated_flag = True
226
234
 
227
- def is_done(self) -> bool:
228
- return self._done_flag
235
+ def is_terminated(self) -> bool:
236
+ return self._terminated_flag
229
237
 
230
238
 
231
239
  class HiCacheController:
@@ -236,7 +244,7 @@ class HiCacheController:
236
244
  mem_pool_host: HostKVCache,
237
245
  page_size: int,
238
246
  tp_group: torch.distributed.ProcessGroup,
239
- load_cache_event: threading.Event = None,
247
+ load_cache_event: threading.Event,
240
248
  write_policy: str = "write_through_selective",
241
249
  io_backend: str = "",
242
250
  storage_backend: Optional[str] = None,
@@ -324,8 +332,25 @@ class HiCacheController:
324
332
  group_ranks, backend="gloo"
325
333
  )
326
334
 
327
- self.load_cache_event = load_cache_event
328
- self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
335
+ # Select the get and set functions
336
+ self.page_get_func = self._generic_page_get
337
+ self.page_set_func = self._generic_page_set
338
+ self.batch_exists_func = self.storage_backend.batch_exists
339
+ self.is_3fs_zerocopy = (
340
+ self.storage_backend_type == "hf3fs"
341
+ and self.mem_pool_host.layout == "page_first"
342
+ )
343
+ if self.storage_backend_type == "mooncake":
344
+ self.page_get_func = self._mooncake_page_get
345
+ self.page_set_func = self._mooncake_page_set
346
+ elif self.is_3fs_zerocopy:
347
+ self.page_get_func = self._3fs_zero_copy_page_get
348
+ self.page_set_func = self._3fs_zero_copy_page_set
349
+ self.batch_exists_func = self._3fs_zero_copy_batch_exists
350
+
351
+ self.device = self.mem_pool_device.device
352
+ self.layer_num = self.mem_pool_device.layer_num
353
+ self.layer_done_counter = LayerDoneCounter(self.layer_num)
329
354
  self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
330
355
 
331
356
  if write_policy not in [
@@ -335,11 +360,11 @@ class HiCacheController:
335
360
  ]:
336
361
  raise ValueError(f"Invalid write policy: {write_policy}")
337
362
 
338
- self.write_queue = PriorityQueue()
339
- self.load_queue = PriorityQueue()
340
-
341
- self.ack_write_queue = Queue()
342
- self.ack_load_queue = Queue()
363
+ # self.write_queue = PriorityQueue[CacheOperation]()
364
+ self.load_queue: List[CacheOperation] = []
365
+ self.write_queue: List[CacheOperation] = []
366
+ self.ack_load_queue: List[HiCacheAck] = []
367
+ self.ack_write_queue: List[HiCacheAck] = []
343
368
 
344
369
  self.stop_event = threading.Event()
345
370
  self.write_buffer = TransferBuffer(self.stop_event)
@@ -350,16 +375,6 @@ class HiCacheController:
350
375
  self.write_stream = torch.cuda.Stream()
351
376
  self.load_stream = torch.cuda.Stream()
352
377
 
353
- self.write_thread = threading.Thread(
354
- target=self.write_thread_func_direct, daemon=True
355
- )
356
- self.load_thread = threading.Thread(
357
- target=self.load_thread_func_layer_by_layer, daemon=True
358
- )
359
-
360
- self.write_thread.start()
361
- self.load_thread.start()
362
-
363
378
  if self.enable_storage:
364
379
  self.prefetch_thread = threading.Thread(
365
380
  target=self.prefetch_thread_func, daemon=True
@@ -386,9 +401,11 @@ class HiCacheController:
386
401
  if is_dp_attention_enabled():
387
402
  self.tp_rank = get_attention_tp_rank()
388
403
  self.tp_size = get_attention_tp_size()
404
+ self.dp_rank = get_attention_dp_rank()
389
405
  else:
390
406
  self.tp_rank = get_tensor_model_parallel_rank()
391
407
  self.tp_size = get_tensor_model_parallel_world_size()
408
+ self.dp_rank = 0
392
409
 
393
410
  # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
394
411
  is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
@@ -407,21 +424,20 @@ class HiCacheController:
407
424
  tp_rank=self.tp_rank,
408
425
  tp_size=self.tp_size,
409
426
  is_mla_model=is_mla_backend,
427
+ is_page_first_layout=self.mem_pool_host.layout == "page_first",
410
428
  model_name=model_name,
411
429
  extra_config=extra_config,
412
430
  )
413
431
 
414
432
  def reset(self):
415
433
  self.stop_event.set()
416
- self.write_thread.join()
417
- self.load_thread.join()
418
434
 
419
- self.write_queue.queue.clear()
420
- self.load_queue.queue.clear()
435
+ self.write_queue.clear()
436
+ self.load_queue.clear()
421
437
  self.write_buffer.clear()
422
438
  self.load_buffer.clear()
423
- self.ack_write_queue.queue.clear()
424
- self.ack_load_queue.queue.clear()
439
+ self.ack_write_queue.clear()
440
+ self.ack_load_queue.clear()
425
441
  if self.enable_storage:
426
442
  self.prefetch_thread.join()
427
443
  self.backup_thread.join()
@@ -430,15 +446,7 @@ class HiCacheController:
430
446
  self.prefetch_revoke_queue.queue.clear()
431
447
  self.ack_backup_queue.queue.clear()
432
448
 
433
- self.write_thread = threading.Thread(
434
- target=self.write_thread_func_direct, daemon=True
435
- )
436
- self.load_thread = threading.Thread(
437
- target=self.load_thread_func_layer_by_layer, daemon=True
438
- )
439
449
  self.stop_event.clear()
440
- self.write_thread.start()
441
- self.load_thread.start()
442
450
 
443
451
  if self.enable_storage:
444
452
  self.prefetch_thread = threading.Thread(
@@ -454,7 +462,7 @@ class HiCacheController:
454
462
  self,
455
463
  device_indices: torch.Tensor,
456
464
  priority: Optional[int] = None,
457
- node_id: int = 0,
465
+ node_id: int = -1,
458
466
  ) -> Optional[torch.Tensor]:
459
467
  """
460
468
  Back up KV caches from device memory to host memory.
@@ -463,17 +471,46 @@ class HiCacheController:
463
471
  if host_indices is None:
464
472
  return None
465
473
  self.mem_pool_host.protect_write(host_indices)
466
- torch.cuda.current_stream().synchronize()
467
- self.write_queue.put(
474
+ self.write_queue.append(
468
475
  CacheOperation(host_indices, device_indices, node_id, priority)
469
476
  )
477
+ self.start_writing()
470
478
  return host_indices
471
479
 
480
+ def start_writing(self) -> None:
481
+ if len(self.write_queue) == 0:
482
+ return
483
+
484
+ op = CacheOperation.merge_ops(self.write_queue)
485
+ host_indices, device_indices = self.move_indices(op)
486
+ self.write_queue.clear()
487
+
488
+ start_event = torch.cuda.Event()
489
+ finish_event = torch.cuda.Event()
490
+
491
+ start_event.record()
492
+ with torch.cuda.stream(self.write_stream):
493
+ start_event.wait(self.write_stream)
494
+ self.mem_pool_host.backup_from_device_all_layer(
495
+ self.mem_pool_device, host_indices, device_indices, self.io_backend
496
+ )
497
+ self.mem_pool_host.complete_io(op.host_indices)
498
+ finish_event.record()
499
+ # NOTE: We must save the host indices and device indices here,
500
+ # this is because we need to guarantee that these tensors are
501
+ # still alive when the write stream is executing.
502
+ if host_indices.is_cuda:
503
+ host_indices.record_stream(self.write_stream)
504
+ if device_indices.is_cuda:
505
+ device_indices.record_stream(self.write_stream)
506
+
507
+ self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids))
508
+
472
509
  def load(
473
510
  self,
474
511
  host_indices: torch.Tensor,
475
512
  priority: Optional[int] = None,
476
- node_id: int = 0,
513
+ node_id: int = -1,
477
514
  ) -> Optional[torch.Tensor]:
478
515
  """
479
516
  Load KV caches from host memory to device memory.
@@ -482,76 +519,42 @@ class HiCacheController:
482
519
  if device_indices is None:
483
520
  return None
484
521
  self.mem_pool_host.protect_load(host_indices)
485
- # to ensure the device indices are ready before accessed by another CUDA stream
486
- torch.cuda.current_stream().synchronize()
487
- self.load_queue.put(
522
+ self.load_queue.append(
488
523
  CacheOperation(host_indices, device_indices, node_id, priority)
489
524
  )
490
525
  return device_indices
491
526
 
492
- def move_indices(self, host_indices, device_indices):
527
+ def move_indices(self, op: CacheOperation):
528
+ host_indices, device_indices = op.host_indices, op.device_indices
493
529
  # move indices to GPU if using kernels, to host if using direct indexing
494
530
  if self.io_backend == "kernel":
495
- return host_indices.to(self.mem_pool_device.device), device_indices
531
+ if not host_indices.is_cuda:
532
+ host_indices = host_indices.to(self.device, non_blocking=True)
533
+ return host_indices, device_indices
496
534
  elif self.io_backend == "direct":
497
- device_indices = device_indices.cpu()
498
- host_indices, idx = host_indices.sort()
499
- return host_indices, device_indices.index_select(0, idx)
535
+ if self.mem_pool_host.layout == "layer_first":
536
+ device_indices = device_indices.cpu()
537
+ host_indices, idx = host_indices.sort()
538
+ return host_indices, device_indices.index_select(0, idx)
539
+ elif self.mem_pool_host.layout == "page_first_direct":
540
+ return host_indices, device_indices.cpu()
500
541
  else:
501
542
  raise ValueError(f"Unsupported io backend")
502
543
 
503
- def write_thread_func_direct(self):
504
- """
505
- Directly write through KV caches to host memory without buffering.
506
- """
507
- torch.cuda.set_stream(self.write_stream)
508
- while not self.stop_event.is_set():
509
- try:
510
- operation = self.write_queue.get(block=True, timeout=1)
511
- host_indices, device_indices = self.move_indices(
512
- operation.host_indices, operation.device_indices
513
- )
514
- self.mem_pool_host.backup_from_device_all_layer(
515
- self.mem_pool_device, host_indices, device_indices, self.io_backend
516
- )
517
- self.write_stream.synchronize()
518
- self.mem_pool_host.complete_io(operation.host_indices)
519
- for node_id in operation.node_ids:
520
- if node_id != 0:
521
- self.ack_write_queue.put(node_id)
522
- except Empty:
523
- continue
524
- except Exception as e:
525
- logger.error(e)
544
+ def start_loading(self) -> int:
545
+ if len(self.load_queue) == 0:
546
+ return -1
526
547
 
527
- def load_thread_func_layer_by_layer(self):
528
- """
529
- Load KV caches from host memory to device memory layer by layer.
530
- """
531
- torch.cuda.set_stream(self.load_stream)
532
- while not self.stop_event.is_set():
533
- self.load_cache_event.wait(timeout=1)
534
- if not self.load_cache_event.is_set():
535
- continue
536
- self.load_cache_event.clear()
537
- self.layer_done_counter.update_producer()
538
-
539
- batch_operation = None
540
- while self.load_queue.qsize() > 0:
541
- op = self.load_queue.get(block=True)
542
- if batch_operation is None:
543
- batch_operation = op
544
- else:
545
- batch_operation.merge(op)
546
- if batch_operation is None:
547
- continue
548
+ producer_id = self.layer_done_counter.update_producer()
549
+ op = CacheOperation.merge_ops(self.load_queue)
550
+ host_indices, device_indices = self.move_indices(op)
551
+ self.load_queue.clear()
552
+ producer_event = self.layer_done_counter.events[producer_id]
553
+ producer_event.start_event.record()
548
554
 
549
- # start layer-wise KV cache transfer from CPU to GPU
550
- self.layer_done_counter.reset()
551
- host_indices, device_indices = self.move_indices(
552
- batch_operation.host_indices, batch_operation.device_indices
553
- )
554
- for i in range(self.mem_pool_host.layer_num):
555
+ with torch.cuda.stream(self.load_stream):
556
+ producer_event.start_event.wait(self.load_stream)
557
+ for i in range(self.layer_num):
555
558
  self.mem_pool_host.load_to_device_per_layer(
556
559
  self.mem_pool_device,
557
560
  host_indices,
@@ -559,13 +562,24 @@ class HiCacheController:
559
562
  i,
560
563
  self.io_backend,
561
564
  )
562
- self.load_stream.synchronize()
563
- self.layer_done_counter.increment()
564
-
565
- self.mem_pool_host.complete_io(batch_operation.host_indices)
566
- for node_id in batch_operation.node_ids:
567
- if node_id != 0:
568
- self.ack_load_queue.put(node_id)
565
+ producer_event.complete(i)
566
+ self.mem_pool_host.complete_io(op.host_indices)
567
+ # NOTE: We must save the host indices and device indices here,
568
+ # this is because we need to guarantee that these tensors are
569
+ # still alive when the load stream is executing.
570
+ if host_indices.is_cuda:
571
+ host_indices.record_stream(self.load_stream)
572
+ if device_indices.is_cuda:
573
+ device_indices.record_stream(self.load_stream)
574
+
575
+ self.ack_load_queue.append(
576
+ HiCacheAck(
577
+ start_event=producer_event.start_event,
578
+ finish_event=producer_event.finish_event,
579
+ node_ids=op.node_ids,
580
+ )
581
+ )
582
+ return producer_id
569
583
 
570
584
  def evict_device(
571
585
  self, device_indices: torch.Tensor, host_indices: torch.Tensor
@@ -608,7 +622,7 @@ class HiCacheController:
608
622
  return operation
609
623
 
610
624
  def terminate_prefetch(self, operation):
611
- operation.mark_done()
625
+ operation.mark_terminate()
612
626
  return operation.completed_tokens, operation.hash_value
613
627
 
614
628
  def append_host_mem_release(self, host_indices: torch.Tensor):
@@ -616,13 +630,19 @@ class HiCacheController:
616
630
  for chunk in chunks:
617
631
  self.host_mem_release_queue.put(chunk)
618
632
 
633
+ def _3fs_zero_copy_batch_exists(self, batch_hashes):
634
+ _batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
635
+ hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
636
+ return hit_page_num
637
+
619
638
  def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
620
- hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
639
+ hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
621
640
  hash_values, host_indices
622
641
  )
623
642
  page_data = self.storage_backend.batch_get(hashes, dsts)
624
643
  if page_data:
625
- operation.increment(self.page_size * len(hashes))
644
+ inc = self.page_size * len(hashes) // factor
645
+ operation.increment(inc)
626
646
  else:
627
647
  logger.warning(
628
648
  f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
@@ -636,7 +656,7 @@ class HiCacheController:
636
656
  )
637
657
  get_result = self.storage_backend.batch_get(
638
658
  key_strs,
639
- target_location=buffer_ptrs,
659
+ target_locations=buffer_ptrs,
640
660
  target_sizes=buffer_sizes,
641
661
  )
642
662
  if get_result != len(hash_values):
@@ -647,9 +667,9 @@ class HiCacheController:
647
667
  operation.increment(get_result * self.page_size)
648
668
 
649
669
  def _generic_page_get(self, operation, hash_values, host_indices):
650
- dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
651
- hash_values
652
- )
670
+ dummy_page_dst = [
671
+ self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
672
+ ]
653
673
  page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
654
674
  if page_data is None:
655
675
  return
@@ -659,26 +679,16 @@ class HiCacheController:
659
679
  f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
660
680
  )
661
681
  break
662
- if operation.increment(self.page_size):
663
- self.mem_pool_host.set_from_flat_data_page(
664
- host_indices[i * self.page_size],
665
- page_data[i],
666
- )
667
- else:
668
- break
682
+ # Must set the data before increasing the completed tokens.
683
+ # Otherwise this page may be read before being set.
684
+ self.mem_pool_host.set_from_flat_data_page(
685
+ host_indices[i * self.page_size],
686
+ page_data[i],
687
+ )
688
+ if not operation.increment(self.page_size):
689
+ break # Operation terminated by controller
669
690
 
670
691
  def _page_transfer(self, operation):
671
- # Select the get function and batch size
672
- if self.storage_backend_type == "mooncake":
673
- get_func = self._mooncake_page_get
674
- elif (
675
- self.storage_backend_type == "hf3fs"
676
- and self.mem_pool_host.layout == "page_first"
677
- ):
678
- get_func = self._3fs_zero_copy_page_get
679
- else:
680
- get_func = self._generic_page_get
681
-
682
692
  # Transfer batch by batch
683
693
  for i in range(0, len(operation.hash_value), self.storage_batch_size):
684
694
  batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
@@ -687,12 +697,13 @@ class HiCacheController:
687
697
  ]
688
698
  prev_completed_tokens = operation.completed_tokens
689
699
  # Get one batch token, and update the completed_tokens if succeed
690
- get_func(operation, batch_hashes, batch_host_indices)
700
+ self.page_get_func(operation, batch_hashes, batch_host_indices)
691
701
  # Check termination
692
702
  if (
693
703
  operation.completed_tokens
694
704
  != prev_completed_tokens + len(batch_hashes) * self.page_size
695
705
  ):
706
+ operation.mark_terminate()
696
707
  break # Some operations fail or operation terminated by controller
697
708
  # release pre-allocated memory
698
709
  self.append_host_mem_release(
@@ -744,7 +755,7 @@ class HiCacheController:
744
755
  batch_tokens[i : i + self.page_size], last_hash
745
756
  )
746
757
  batch_hashes.append(last_hash)
747
- hit_page_num = self.storage_backend.batch_exists(batch_hashes)
758
+ hit_page_num = self.batch_exists_func(batch_hashes)
748
759
  hash_value.extend(batch_hashes[:hit_page_num])
749
760
  storage_query_count += hit_page_num * self.page_size
750
761
  if hit_page_num < len(batch_hashes):
@@ -830,30 +841,20 @@ class HiCacheController:
830
841
  )
831
842
  success = self.storage_backend.batch_set(
832
843
  key_strs,
833
- target_location=buffer_ptrs,
844
+ target_locations=buffer_ptrs,
834
845
  target_sizes=buffer_sizes,
835
846
  )
836
847
  return success
837
848
 
838
849
  # zero copy
839
850
  def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
840
- hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
851
+ hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
841
852
  hash_values, host_indices
842
853
  )
843
854
  return self.storage_backend.batch_set(hashes, dsts)
844
855
 
845
856
  # Backup batch by batch
846
857
  def _page_backup(self, operation):
847
- # Select the set function and batch size
848
- if self.storage_backend_type == "mooncake":
849
- backup_set_func = self._mooncake_page_set
850
- elif (
851
- self.storage_backend_type == "hf3fs"
852
- and self.mem_pool_host.layout == "page_first"
853
- ):
854
- backup_set_func = self._3fs_zero_copy_page_set
855
- else:
856
- backup_set_func = self._generic_page_set
857
858
  # Backup batch by batch
858
859
  for i in range(0, len(operation.hash_value), self.storage_batch_size):
859
860
  batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
@@ -862,7 +863,7 @@ class HiCacheController:
862
863
  ]
863
864
  # Set one batch token, and record if success.
864
865
  # todo: allow partial success
865
- success = backup_set_func(batch_hashes, batch_host_indices)
866
+ success = self.page_set_func(batch_hashes, batch_host_indices)
866
867
  if not success:
867
868
  logger.warning(
868
869
  f"Write page to storage: {len(batch_hashes)} pages failed."
@@ -882,7 +883,7 @@ class HiCacheController:
882
883
 
883
884
  if not self.backup_skip:
884
885
  self._page_backup(operation)
885
- self.ack_backup_queue.put(operation.id)
886
+ self.ack_backup_queue.put(operation)
886
887
 
887
888
  except Empty:
888
889
  continue