sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import math
18
18
  import threading
19
19
  import time
20
20
  from queue import Empty, Full, PriorityQueue, Queue
21
- from typing import TYPE_CHECKING, List, Optional
21
+ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
22
22
 
23
23
  import torch
24
24
 
@@ -33,6 +33,7 @@ from sglang.srt.distributed import (
33
33
  get_tensor_model_parallel_world_size,
34
34
  )
35
35
  from sglang.srt.layers.dp_attention import (
36
+ get_attention_dp_rank,
36
37
  get_attention_tp_rank,
37
38
  get_attention_tp_size,
38
39
  is_dp_attention_enabled,
@@ -42,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
42
43
  logger = logging.getLogger(__name__)
43
44
 
44
45
 
46
+ class LayerLoadingEvent:
47
+ def __init__(self, num_layers: int):
48
+ self._num_layers = num_layers
49
+ self.load_events = [torch.cuda.Event() for _ in range(num_layers)]
50
+ self.start_event = torch.cuda.Event() # start event on controller stream
51
+
52
+ def complete(self, layer_index: int):
53
+ assert 0 <= layer_index < self._num_layers
54
+ self.load_events[layer_index].record()
55
+
56
+ def wait(self, layer_index: int):
57
+ torch.cuda.current_stream().wait_event(self.load_events[layer_index])
58
+
59
+ @property
60
+ def finish_event(self):
61
+ return self.load_events[-1]
62
+
63
+
45
64
  class LayerDoneCounter:
46
- def __init__(self, num_layers):
65
+ def __init__(self, num_layers: int):
47
66
  self.num_layers = num_layers
48
67
  # extra producer and consumer counters for overlap mode
49
68
  self.num_counters = 3
50
- self.counters = [num_layers] * self.num_counters
51
- self.conditions = [threading.Condition() for _ in range(self.num_counters)]
52
- self.producer_index = 0
53
- self.consumer_index = 0
54
-
55
- def next_producer(self):
56
- return (self.producer_index + 1) % self.num_counters
69
+ self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)]
70
+ self.producer_index = -1
71
+ self.consumer_index = -1
57
72
 
58
73
  def update_producer(self):
59
- self.producer_index = self.next_producer()
74
+ self.producer_index = (self.producer_index + 1) % self.num_counters
75
+ assert self.events[
76
+ self.producer_index
77
+ ].finish_event.query(), (
78
+ "Producer finish event should be ready before being reused."
79
+ )
60
80
  return self.producer_index
61
81
 
62
- def set_consumer(self, index):
82
+ def set_consumer(self, index: int):
63
83
  self.consumer_index = index
64
84
 
65
- def increment(self):
66
- with self.conditions[self.producer_index]:
67
- self.counters[self.producer_index] += 1
68
- self.conditions[self.producer_index].notify_all()
69
-
70
- def wait_until(self, threshold):
71
- with self.conditions[self.consumer_index]:
72
- while self.counters[self.consumer_index] <= threshold:
73
- self.conditions[self.consumer_index].wait()
85
+ def wait_until(self, threshold: int):
86
+ if self.consumer_index < 0:
87
+ return
88
+ self.events[self.consumer_index].wait(threshold)
74
89
 
75
90
  def reset(self):
76
- with self.conditions[self.producer_index]:
77
- self.counters[self.producer_index] = 0
91
+ self.producer_index = -1
92
+ self.consumer_index = -1
78
93
 
79
94
 
80
95
  class CacheOperation:
@@ -98,36 +113,30 @@ class CacheOperation:
98
113
  # default priority is the order of creation
99
114
  self.priority = priority if priority is not None else self.id
100
115
 
101
- def merge(self, other: "CacheOperation") -> None:
102
- # multiple operations can be merged into a single operation for batch processing
103
- self.host_indices = torch.cat([self.host_indices, other.host_indices])
104
- self.device_indices = torch.cat([self.device_indices, other.device_indices])
105
- self.priority = min(self.priority, other.priority)
106
- self.node_ids.extend(other.node_ids)
107
-
108
- def split(self, factor) -> List["CacheOperation"]:
109
- # split an operation into smaller operations to reduce the size of intermediate buffers
110
- if factor <= 1:
111
- return [self]
112
-
113
- chunk_size = math.ceil(len(self.host_indices) / factor)
114
- split_ops = []
115
- for i in range(0, len(self.host_indices), chunk_size):
116
- split_ops.append(
117
- CacheOperation(
118
- host_indices=self.host_indices[i : i + chunk_size],
119
- device_indices=self.device_indices[i : i + chunk_size],
120
- node_id=0,
121
- )
122
- )
123
- # Inherit the node_ids on the final chunk
124
- if split_ops:
125
- split_ops[-1].node_ids = self.node_ids
116
+ @staticmethod
117
+ def merge_ops(ops: List[CacheOperation]) -> CacheOperation:
118
+ assert len(ops) > 0
119
+ if len(ops) == 1:
120
+ return ops[0]
121
+
122
+ host_indices = torch.cat([op.host_indices for op in ops])
123
+ device_indices = torch.cat([op.device_indices for op in ops])
124
+ node_ids = []
125
+ priority = min(op.priority for op in ops)
126
+ for op in ops:
127
+ node_ids.extend(op.node_ids)
128
+ merged_op = CacheOperation(host_indices, device_indices, -1, priority)
129
+ merged_op.node_ids = node_ids
130
+ return merged_op
131
+
132
+ def __lt__(self, other: CacheOperation):
133
+ return self.priority < other.priority
126
134
 
127
- return split_ops
128
135
 
129
- def __lt__(self, other: "CacheOperation"):
130
- return self.priority < other.priority
136
+ class HiCacheAck(NamedTuple):
137
+ start_event: torch.cuda.Event
138
+ finish_event: torch.cuda.Event
139
+ node_ids: List[int]
131
140
 
132
141
 
133
142
  class TransferBuffer:
@@ -206,26 +215,25 @@ class PrefetchOperation(StorageOperation):
206
215
  ):
207
216
  self.request_id = request_id
208
217
 
209
- self._done_flag = False
210
218
  self._lock = threading.Lock()
211
-
219
+ self._terminated_flag = False
212
220
  self.start_time = time.monotonic()
213
221
 
214
222
  super().__init__(host_indices, token_ids, last_hash)
215
223
 
216
224
  def increment(self, num_tokens: int):
217
225
  with self._lock:
218
- if self._done_flag:
226
+ if self._terminated_flag:
219
227
  return False
220
228
  self.completed_tokens += num_tokens
221
229
  return True
222
230
 
223
- def mark_done(self):
231
+ def mark_terminate(self):
224
232
  with self._lock:
225
- self._done_flag = True
233
+ self._terminated_flag = True
226
234
 
227
- def is_done(self) -> bool:
228
- return self._done_flag
235
+ def is_terminated(self) -> bool:
236
+ return self._terminated_flag
229
237
 
230
238
 
231
239
  class HiCacheController:
@@ -236,7 +244,7 @@ class HiCacheController:
236
244
  mem_pool_host: HostKVCache,
237
245
  page_size: int,
238
246
  tp_group: torch.distributed.ProcessGroup,
239
- load_cache_event: threading.Event = None,
247
+ load_cache_event: threading.Event,
240
248
  write_policy: str = "write_through_selective",
241
249
  io_backend: str = "",
242
250
  storage_backend: Optional[str] = None,
@@ -340,8 +348,9 @@ class HiCacheController:
340
348
  self.page_set_func = self._3fs_zero_copy_page_set
341
349
  self.batch_exists_func = self._3fs_zero_copy_batch_exists
342
350
 
343
- self.load_cache_event = load_cache_event
344
- self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
351
+ self.device = self.mem_pool_device.device
352
+ self.layer_num = self.mem_pool_device.layer_num
353
+ self.layer_done_counter = LayerDoneCounter(self.layer_num)
345
354
  self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
346
355
 
347
356
  if write_policy not in [
@@ -351,11 +360,11 @@ class HiCacheController:
351
360
  ]:
352
361
  raise ValueError(f"Invalid write policy: {write_policy}")
353
362
 
354
- self.write_queue = PriorityQueue()
355
- self.load_queue = PriorityQueue()
356
-
357
- self.ack_write_queue = Queue()
358
- self.ack_load_queue = Queue()
363
+ # self.write_queue = PriorityQueue[CacheOperation]()
364
+ self.load_queue: List[CacheOperation] = []
365
+ self.write_queue: List[CacheOperation] = []
366
+ self.ack_load_queue: List[HiCacheAck] = []
367
+ self.ack_write_queue: List[HiCacheAck] = []
359
368
 
360
369
  self.stop_event = threading.Event()
361
370
  self.write_buffer = TransferBuffer(self.stop_event)
@@ -366,16 +375,6 @@ class HiCacheController:
366
375
  self.write_stream = torch.cuda.Stream()
367
376
  self.load_stream = torch.cuda.Stream()
368
377
 
369
- self.write_thread = threading.Thread(
370
- target=self.write_thread_func_direct, daemon=True
371
- )
372
- self.load_thread = threading.Thread(
373
- target=self.load_thread_func_layer_by_layer, daemon=True
374
- )
375
-
376
- self.write_thread.start()
377
- self.load_thread.start()
378
-
379
378
  if self.enable_storage:
380
379
  self.prefetch_thread = threading.Thread(
381
380
  target=self.prefetch_thread_func, daemon=True
@@ -402,9 +401,11 @@ class HiCacheController:
402
401
  if is_dp_attention_enabled():
403
402
  self.tp_rank = get_attention_tp_rank()
404
403
  self.tp_size = get_attention_tp_size()
404
+ self.dp_rank = get_attention_dp_rank()
405
405
  else:
406
406
  self.tp_rank = get_tensor_model_parallel_rank()
407
407
  self.tp_size = get_tensor_model_parallel_world_size()
408
+ self.dp_rank = 0
408
409
 
409
410
  # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
410
411
  is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
@@ -430,15 +431,13 @@ class HiCacheController:
430
431
 
431
432
  def reset(self):
432
433
  self.stop_event.set()
433
- self.write_thread.join()
434
- self.load_thread.join()
435
434
 
436
- self.write_queue.queue.clear()
437
- self.load_queue.queue.clear()
435
+ self.write_queue.clear()
436
+ self.load_queue.clear()
438
437
  self.write_buffer.clear()
439
438
  self.load_buffer.clear()
440
- self.ack_write_queue.queue.clear()
441
- self.ack_load_queue.queue.clear()
439
+ self.ack_write_queue.clear()
440
+ self.ack_load_queue.clear()
442
441
  if self.enable_storage:
443
442
  self.prefetch_thread.join()
444
443
  self.backup_thread.join()
@@ -447,15 +446,7 @@ class HiCacheController:
447
446
  self.prefetch_revoke_queue.queue.clear()
448
447
  self.ack_backup_queue.queue.clear()
449
448
 
450
- self.write_thread = threading.Thread(
451
- target=self.write_thread_func_direct, daemon=True
452
- )
453
- self.load_thread = threading.Thread(
454
- target=self.load_thread_func_layer_by_layer, daemon=True
455
- )
456
449
  self.stop_event.clear()
457
- self.write_thread.start()
458
- self.load_thread.start()
459
450
 
460
451
  if self.enable_storage:
461
452
  self.prefetch_thread = threading.Thread(
@@ -471,7 +462,7 @@ class HiCacheController:
471
462
  self,
472
463
  device_indices: torch.Tensor,
473
464
  priority: Optional[int] = None,
474
- node_id: int = 0,
465
+ node_id: int = -1,
475
466
  ) -> Optional[torch.Tensor]:
476
467
  """
477
468
  Back up KV caches from device memory to host memory.
@@ -480,17 +471,46 @@ class HiCacheController:
480
471
  if host_indices is None:
481
472
  return None
482
473
  self.mem_pool_host.protect_write(host_indices)
483
- torch.cuda.current_stream().synchronize()
484
- self.write_queue.put(
474
+ self.write_queue.append(
485
475
  CacheOperation(host_indices, device_indices, node_id, priority)
486
476
  )
477
+ self.start_writing()
487
478
  return host_indices
488
479
 
480
+ def start_writing(self) -> None:
481
+ if len(self.write_queue) == 0:
482
+ return
483
+
484
+ op = CacheOperation.merge_ops(self.write_queue)
485
+ host_indices, device_indices = self.move_indices(op)
486
+ self.write_queue.clear()
487
+
488
+ start_event = torch.cuda.Event()
489
+ finish_event = torch.cuda.Event()
490
+
491
+ start_event.record()
492
+ with torch.cuda.stream(self.write_stream):
493
+ start_event.wait(self.write_stream)
494
+ self.mem_pool_host.backup_from_device_all_layer(
495
+ self.mem_pool_device, host_indices, device_indices, self.io_backend
496
+ )
497
+ self.mem_pool_host.complete_io(op.host_indices)
498
+ finish_event.record()
499
+ # NOTE: We must save the host indices and device indices here,
500
+ # this is because we need to guarantee that these tensors are
501
+ # still alive when the write stream is executing.
502
+ if host_indices.is_cuda:
503
+ host_indices.record_stream(self.write_stream)
504
+ if device_indices.is_cuda:
505
+ device_indices.record_stream(self.write_stream)
506
+
507
+ self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids))
508
+
489
509
  def load(
490
510
  self,
491
511
  host_indices: torch.Tensor,
492
512
  priority: Optional[int] = None,
493
- node_id: int = 0,
513
+ node_id: int = -1,
494
514
  ) -> Optional[torch.Tensor]:
495
515
  """
496
516
  Load KV caches from host memory to device memory.
@@ -499,76 +519,42 @@ class HiCacheController:
499
519
  if device_indices is None:
500
520
  return None
501
521
  self.mem_pool_host.protect_load(host_indices)
502
- # to ensure the device indices are ready before accessed by another CUDA stream
503
- torch.cuda.current_stream().synchronize()
504
- self.load_queue.put(
522
+ self.load_queue.append(
505
523
  CacheOperation(host_indices, device_indices, node_id, priority)
506
524
  )
507
525
  return device_indices
508
526
 
509
- def move_indices(self, host_indices, device_indices):
527
+ def move_indices(self, op: CacheOperation):
528
+ host_indices, device_indices = op.host_indices, op.device_indices
510
529
  # move indices to GPU if using kernels, to host if using direct indexing
511
530
  if self.io_backend == "kernel":
512
- return host_indices.to(self.mem_pool_device.device), device_indices
531
+ if not host_indices.is_cuda:
532
+ host_indices = host_indices.to(self.device, non_blocking=True)
533
+ return host_indices, device_indices
513
534
  elif self.io_backend == "direct":
514
- device_indices = device_indices.cpu()
515
- host_indices, idx = host_indices.sort()
516
- return host_indices, device_indices.index_select(0, idx)
535
+ if self.mem_pool_host.layout == "layer_first":
536
+ device_indices = device_indices.cpu()
537
+ host_indices, idx = host_indices.sort()
538
+ return host_indices, device_indices.index_select(0, idx)
539
+ elif self.mem_pool_host.layout == "page_first_direct":
540
+ return host_indices, device_indices.cpu()
517
541
  else:
518
542
  raise ValueError(f"Unsupported io backend")
519
543
 
520
- def write_thread_func_direct(self):
521
- """
522
- Directly write through KV caches to host memory without buffering.
523
- """
524
- torch.cuda.set_stream(self.write_stream)
525
- while not self.stop_event.is_set():
526
- try:
527
- operation = self.write_queue.get(block=True, timeout=1)
528
- host_indices, device_indices = self.move_indices(
529
- operation.host_indices, operation.device_indices
530
- )
531
- self.mem_pool_host.backup_from_device_all_layer(
532
- self.mem_pool_device, host_indices, device_indices, self.io_backend
533
- )
534
- self.write_stream.synchronize()
535
- self.mem_pool_host.complete_io(operation.host_indices)
536
- for node_id in operation.node_ids:
537
- if node_id != 0:
538
- self.ack_write_queue.put(node_id)
539
- except Empty:
540
- continue
541
- except Exception as e:
542
- logger.error(e)
544
+ def start_loading(self) -> int:
545
+ if len(self.load_queue) == 0:
546
+ return -1
543
547
 
544
- def load_thread_func_layer_by_layer(self):
545
- """
546
- Load KV caches from host memory to device memory layer by layer.
547
- """
548
- torch.cuda.set_stream(self.load_stream)
549
- while not self.stop_event.is_set():
550
- self.load_cache_event.wait(timeout=1)
551
- if not self.load_cache_event.is_set():
552
- continue
553
- self.load_cache_event.clear()
554
- self.layer_done_counter.update_producer()
555
-
556
- batch_operation = None
557
- while self.load_queue.qsize() > 0:
558
- op = self.load_queue.get(block=True)
559
- if batch_operation is None:
560
- batch_operation = op
561
- else:
562
- batch_operation.merge(op)
563
- if batch_operation is None:
564
- continue
548
+ producer_id = self.layer_done_counter.update_producer()
549
+ op = CacheOperation.merge_ops(self.load_queue)
550
+ host_indices, device_indices = self.move_indices(op)
551
+ self.load_queue.clear()
552
+ producer_event = self.layer_done_counter.events[producer_id]
553
+ producer_event.start_event.record()
565
554
 
566
- # start layer-wise KV cache transfer from CPU to GPU
567
- self.layer_done_counter.reset()
568
- host_indices, device_indices = self.move_indices(
569
- batch_operation.host_indices, batch_operation.device_indices
570
- )
571
- for i in range(self.mem_pool_host.layer_num):
555
+ with torch.cuda.stream(self.load_stream):
556
+ producer_event.start_event.wait(self.load_stream)
557
+ for i in range(self.layer_num):
572
558
  self.mem_pool_host.load_to_device_per_layer(
573
559
  self.mem_pool_device,
574
560
  host_indices,
@@ -576,13 +562,24 @@ class HiCacheController:
576
562
  i,
577
563
  self.io_backend,
578
564
  )
579
- self.load_stream.synchronize()
580
- self.layer_done_counter.increment()
581
-
582
- self.mem_pool_host.complete_io(batch_operation.host_indices)
583
- for node_id in batch_operation.node_ids:
584
- if node_id != 0:
585
- self.ack_load_queue.put(node_id)
565
+ producer_event.complete(i)
566
+ self.mem_pool_host.complete_io(op.host_indices)
567
+ # NOTE: We must save the host indices and device indices here,
568
+ # this is because we need to guarantee that these tensors are
569
+ # still alive when the load stream is executing.
570
+ if host_indices.is_cuda:
571
+ host_indices.record_stream(self.load_stream)
572
+ if device_indices.is_cuda:
573
+ device_indices.record_stream(self.load_stream)
574
+
575
+ self.ack_load_queue.append(
576
+ HiCacheAck(
577
+ start_event=producer_event.start_event,
578
+ finish_event=producer_event.finish_event,
579
+ node_ids=op.node_ids,
580
+ )
581
+ )
582
+ return producer_id
586
583
 
587
584
  def evict_device(
588
585
  self, device_indices: torch.Tensor, host_indices: torch.Tensor
@@ -625,7 +622,7 @@ class HiCacheController:
625
622
  return operation
626
623
 
627
624
  def terminate_prefetch(self, operation):
628
- operation.mark_done()
625
+ operation.mark_terminate()
629
626
  return operation.completed_tokens, operation.hash_value
630
627
 
631
628
  def append_host_mem_release(self, host_indices: torch.Tensor):
@@ -706,6 +703,7 @@ class HiCacheController:
706
703
  operation.completed_tokens
707
704
  != prev_completed_tokens + len(batch_hashes) * self.page_size
708
705
  ):
706
+ operation.mark_terminate()
709
707
  break # Some operations fail or operation terminated by controller
710
708
  # release pre-allocated memory
711
709
  self.append_host_mem_release(
@@ -885,7 +883,7 @@ class HiCacheController:
885
883
 
886
884
  if not self.backup_skip:
887
885
  self._page_backup(operation)
888
- self.ack_backup_queue.put(operation.id)
886
+ self.ack_backup_queue.put(operation)
889
887
 
890
888
  except Empty:
891
889
  continue