sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -31,27 +31,13 @@ if not (_is_npu or _is_xpu):
31
31
  logger = logging.getLogger(__name__)
32
32
 
33
33
 
34
- class MemoryStateInt(IntEnum):
35
- IDLE = 0
36
- RESERVED = 1
37
- PROTECTED = 2
38
- SYNCED = 3
39
- BACKUP = 4
40
-
41
-
42
- def synchronized(debug_only=False):
43
- def _decorator(func):
44
- @wraps(func)
45
- def wrapper(self, *args, **kwargs):
46
- if (not debug_only) or self.debug:
47
- with self.lock:
48
- return func(self, *args, **kwargs)
49
- else:
50
- return True
51
-
52
- return wrapper
34
+ def synchronized(func):
35
+ @wraps(func)
36
+ def wrapper(self, *args, **kwargs):
37
+ with self.lock:
38
+ return func(self, *args, **kwargs)
53
39
 
54
- return _decorator
40
+ return wrapper
55
41
 
56
42
 
57
43
  class HostKVCache(abc.ABC):
@@ -110,7 +96,6 @@ class HostKVCache(abc.ABC):
110
96
 
111
97
  # A lock for synchronized operations on memory allocation and state transitions.
112
98
  self.lock = threading.RLock()
113
- self.debug = logger.isEnabledFor(logging.DEBUG)
114
99
  self.clear()
115
100
 
116
101
  @abc.abstractmethod
@@ -140,7 +125,7 @@ class HostKVCache(abc.ABC):
140
125
  raise NotImplementedError()
141
126
 
142
127
  @abc.abstractmethod
143
- def get_flat_data_page(self, index) -> torch.Tensor:
128
+ def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
144
129
  """
145
130
  Get a flat data page from the host memory pool.
146
131
  """
@@ -161,7 +146,7 @@ class HostKVCache(abc.ABC):
161
146
  """
162
147
  raise NotImplementedError()
163
148
 
164
- @synchronized()
149
+ @synchronized
165
150
  def clear(self):
166
151
  # Initialize memory states and tracking structures.
167
152
  self.mem_state = torch.zeros(
@@ -172,7 +157,7 @@ class HostKVCache(abc.ABC):
172
157
  def available_size(self):
173
158
  return len(self.free_slots)
174
159
 
175
- @synchronized()
160
+ @synchronized
176
161
  def alloc(self, need_size: int) -> Optional[torch.Tensor]:
177
162
  assert (
178
163
  need_size % self.page_size == 0
@@ -183,92 +168,13 @@ class HostKVCache(abc.ABC):
183
168
  select_index = self.free_slots[:need_size]
184
169
  self.free_slots = self.free_slots[need_size:]
185
170
 
186
- if self.debug:
187
- self.mem_state[select_index] = MemoryStateInt.RESERVED
188
-
189
171
  return select_index
190
172
 
191
- @synchronized()
173
+ @synchronized
192
174
  def free(self, indices: torch.Tensor) -> int:
193
175
  self.free_slots = torch.cat([self.free_slots, indices])
194
- if self.debug:
195
- self.mem_state[indices] = MemoryStateInt.IDLE
196
176
  return len(indices)
197
177
 
198
- @synchronized(debug_only=True)
199
- def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
200
- assert len(indices) > 0, "The indices should not be empty"
201
- states = self.mem_state[indices]
202
- assert (
203
- states == states[0]
204
- ).all(), "The memory slots should have the same state {}".format(states)
205
- return MemoryStateInt(states[0].item())
206
-
207
- @synchronized(debug_only=True)
208
- def is_reserved(self, indices: torch.Tensor) -> bool:
209
- return self.get_state(indices) == MemoryStateInt.RESERVED
210
-
211
- @synchronized(debug_only=True)
212
- def is_protected(self, indices: torch.Tensor) -> bool:
213
- return self.get_state(indices) == MemoryStateInt.PROTECTED
214
-
215
- @synchronized(debug_only=True)
216
- def is_synced(self, indices: torch.Tensor) -> bool:
217
- return self.get_state(indices) == MemoryStateInt.SYNCED
218
-
219
- @synchronized(debug_only=True)
220
- def is_backup(self, indices: torch.Tensor) -> bool:
221
- return self.get_state(indices) == MemoryStateInt.BACKUP
222
-
223
- @synchronized(debug_only=True)
224
- def update_backup(self, indices: torch.Tensor):
225
- if not self.is_synced(indices):
226
- raise ValueError(
227
- f"The host memory slots should be in SYNCED state before turning into BACKUP. "
228
- f"Current state: {self.get_state(indices)}"
229
- )
230
- self.mem_state[indices] = MemoryStateInt.BACKUP
231
-
232
- @synchronized(debug_only=True)
233
- def update_prefetch(self, indices: torch.Tensor):
234
- if not self.is_reserved(indices):
235
- raise ValueError(
236
- f"The host memory slots should be in RESERVED state before turning into BACKUP. "
237
- f"Current state: {self.get_state(indices)}"
238
- )
239
- self.mem_state[indices] = MemoryStateInt.BACKUP
240
-
241
- @synchronized(debug_only=True)
242
- def update_synced(self, indices: torch.Tensor):
243
- self.mem_state[indices] = MemoryStateInt.SYNCED
244
-
245
- @synchronized(debug_only=True)
246
- def protect_write(self, indices: torch.Tensor):
247
- if not self.is_reserved(indices):
248
- raise ValueError(
249
- f"The host memory slots should be RESERVED before write operations. "
250
- f"Current state: {self.get_state(indices)}"
251
- )
252
- self.mem_state[indices] = MemoryStateInt.PROTECTED
253
-
254
- @synchronized(debug_only=True)
255
- def protect_load(self, indices: torch.Tensor):
256
- if not self.is_backup(indices):
257
- raise ValueError(
258
- f"The host memory slots should be in BACKUP state before load operations. "
259
- f"Current state: {self.get_state(indices)}"
260
- )
261
- self.mem_state[indices] = MemoryStateInt.PROTECTED
262
-
263
- @synchronized(debug_only=True)
264
- def complete_io(self, indices: torch.Tensor):
265
- if not self.is_protected(indices):
266
- raise ValueError(
267
- f"The host memory slots should be PROTECTED during I/O operations. "
268
- f"Current state: {self.get_state(indices)}"
269
- )
270
- self.mem_state[indices] = MemoryStateInt.SYNCED
271
-
272
178
 
273
179
  class MHATokenToKVPoolHost(HostKVCache):
274
180
  device_pool: MHATokenToKVPool
@@ -461,13 +367,19 @@ class MHATokenToKVPoolHost(HostKVCache):
461
367
  else:
462
368
  raise ValueError(f"Unsupported IO backend: {io_backend}")
463
369
 
464
- def get_flat_data_page(self, index) -> torch.Tensor:
370
+ def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
465
371
  if self.layout == "layer_first":
466
- return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
372
+ data_page = self.kv_buffer[:, :, index : index + self.page_size, :, :]
467
373
  elif self.layout == "page_first":
468
- return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten()
374
+ data_page = self.kv_buffer[:, index : index + self.page_size, :, :, :]
375
+ elif self.layout == "page_first_direct":
376
+ real_index = index // self.page_size
377
+ data_page = self.kv_buffer[:, real_index : real_index + 1, :, :, :, :]
469
378
  else:
470
379
  raise ValueError(f"Unsupported layout: {self.layout}")
380
+ if flat:
381
+ data_page = data_page.flatten()
382
+ return data_page
471
383
 
472
384
  def get_dummy_flat_data_page(self) -> torch.Tensor:
473
385
  return torch.zeros(
@@ -494,12 +406,22 @@ class MHATokenToKVPoolHost(HostKVCache):
494
406
  2, self.page_size, self.layer_num, self.head_num, self.head_dim
495
407
  )
496
408
  )
409
+ elif self.layout == "page_first_direct":
410
+ real_index = index // self.page_size
411
+ self.kv_buffer[:, real_index : real_index + 1, :, :, :, :] = (
412
+ data_page.reshape(
413
+ 2, 1, self.layer_num, self.page_size, self.head_num, self.head_dim
414
+ )
415
+ )
497
416
  else:
498
417
  raise ValueError(f"Unsupported layout: {self.layout}")
499
418
 
500
- def get_buffer_meta(self, keys, indices, local_rank):
419
+ def get_page_buffer_meta(self, indices):
420
+ """ "
421
+ meta data for zero copy
422
+ """
423
+ assert len(indices) % self.page_size == 0
501
424
  ptr_list = []
502
- key_list = []
503
425
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
504
426
  indices = indices.tolist()
505
427
  v_offset = (
@@ -509,48 +431,52 @@ class MHATokenToKVPoolHost(HostKVCache):
509
431
  * self.head_dim
510
432
  * self.dtype.itemsize
511
433
  )
512
- for index in range(0, len(indices), self.page_size):
513
- k_ptr = (
514
- kv_buffer_data_ptr
515
- + indices[index]
516
- * self.layer_num
434
+ if self.layout == "layer_first":
435
+ for index in range(0, len(indices), self.page_size):
436
+ for layer_id in range(self.layer_num):
437
+ k_ptr = (
438
+ kv_buffer_data_ptr
439
+ + indices[index]
440
+ * self.head_num
441
+ * self.head_dim
442
+ * self.dtype.itemsize
443
+ + layer_id
444
+ * self.size
445
+ * self.head_num
446
+ * self.head_dim
447
+ * self.dtype.itemsize
448
+ )
449
+ v_ptr = k_ptr + v_offset
450
+ ptr_list.append(k_ptr)
451
+ ptr_list.append(v_ptr)
452
+ element_size = (
453
+ self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
454
+ )
455
+ element_size_list = [element_size] * len(ptr_list)
456
+ elif self.layout in ["page_first", "page_first_direct"]:
457
+ for index in range(0, len(indices), self.page_size):
458
+ k_ptr = (
459
+ kv_buffer_data_ptr
460
+ + indices[index]
461
+ * self.layer_num
462
+ * self.head_num
463
+ * self.head_dim
464
+ * self.dtype.itemsize
465
+ )
466
+ v_ptr = k_ptr + v_offset
467
+ ptr_list.append(k_ptr)
468
+ ptr_list.append(v_ptr)
469
+ element_size = (
470
+ self.layer_num
471
+ * self.dtype.itemsize
472
+ * self.page_size
517
473
  * self.head_num
518
474
  * self.head_dim
519
- * self.dtype.itemsize
520
475
  )
521
- v_ptr = k_ptr + v_offset
522
- ptr_list.append(k_ptr)
523
- ptr_list.append(v_ptr)
524
- key_ = keys[index // self.page_size]
525
- key_list.append(f"{key_}_{local_rank}_k")
526
- key_list.append(f"{key_}_{local_rank}_v")
527
- element_size = (
528
- self.layer_num
529
- * self.dtype.itemsize
530
- * self.page_size
531
- * self.head_num
532
- * self.head_dim
533
- )
534
- element_size_list = [element_size] * len(key_list)
535
- return key_list, ptr_list, element_size_list
536
-
537
- def get_buffer_with_hash(self, keys, indices=None):
538
- assert self.layout == "page_first"
539
- assert indices is None or (len(keys) == (len(indices) // self.page_size))
540
-
541
- key_list = []
542
- buf_list = []
543
-
544
- for i in range(len(keys)):
545
- key = keys[i]
546
- key_list.append(f"{key}-k")
547
- key_list.append(f"{key}-v")
548
- if indices is not None:
549
- index = indices[i * self.page_size]
550
- buf_list.append(self.k_buffer[index : index + self.page_size])
551
- buf_list.append(self.v_buffer[index : index + self.page_size])
552
-
553
- return key_list, buf_list, 2
476
+ element_size_list = [element_size] * len(ptr_list)
477
+ else:
478
+ raise ValueError(f"Unsupported layout: {self.layout}")
479
+ return ptr_list, element_size_list
554
480
 
555
481
 
556
482
  class MLATokenToKVPoolHost(HostKVCache):
@@ -726,13 +652,19 @@ class MLATokenToKVPoolHost(HostKVCache):
726
652
  else:
727
653
  raise ValueError(f"Unsupported IO backend: {io_backend}")
728
654
 
729
- def get_flat_data_page(self, index) -> torch.Tensor:
655
+ def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
730
656
  if self.layout == "layer_first":
731
- return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
657
+ data_page = self.kv_buffer[:, index : index + self.page_size, :, :]
732
658
  elif self.layout == "page_first":
733
- return self.kv_buffer[index : index + self.page_size, :, :, :].flatten()
659
+ data_page = self.kv_buffer[index : index + self.page_size, :, :, :]
660
+ elif self.layout == "page_first_direct":
661
+ real_index = index // self.page_size
662
+ data_page = self.kv_buffer[real_index : real_index + 1, :, :, :, :]
734
663
  else:
735
664
  raise ValueError(f"Unsupported layout: {self.layout}")
665
+ if flat:
666
+ data_page = data_page.flatten()
667
+ return data_page
736
668
 
737
669
  def get_dummy_flat_data_page(self) -> torch.Tensor:
738
670
  return torch.zeros(
@@ -762,43 +694,63 @@ class MLATokenToKVPoolHost(HostKVCache):
762
694
  1,
763
695
  self.kv_lora_rank + self.qk_rope_head_dim,
764
696
  )
697
+ elif self.layout == "page_first_direct":
698
+ real_index = index // self.page_size
699
+ self.kv_buffer[real_index : real_index + 1, :, :, :, :] = data_page.reshape(
700
+ 1,
701
+ self.layer_num,
702
+ self.page_size,
703
+ 1,
704
+ self.kv_lora_rank + self.qk_rope_head_dim,
705
+ )
765
706
  else:
766
707
  raise ValueError(f"Unsupported layout: {self.layout}")
767
708
 
768
- def get_buffer_meta(self, keys, indices, local_rank):
709
+ def get_page_buffer_meta(self, indices):
710
+ """ "
711
+ meta data for zero copy
712
+ """
713
+ assert len(indices) % self.page_size == 0
769
714
  ptr_list = []
770
- key_list = []
771
715
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
772
716
  indices = indices.tolist()
773
- for index in range(0, len(indices), self.page_size):
774
- k_ptr = (
775
- kv_buffer_data_ptr
776
- + indices[index]
777
- * self.layer_num
717
+ if self.layout == "layer_first":
718
+ for index in range(0, len(indices), self.page_size):
719
+ for layer_id in range(self.layer_num):
720
+ k_ptr = (
721
+ kv_buffer_data_ptr
722
+ + indices[index]
723
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
724
+ * self.dtype.itemsize
725
+ + layer_id
726
+ * self.size
727
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
728
+ * self.dtype.itemsize
729
+ )
730
+ ptr_list.append(k_ptr)
731
+ element_size = (
732
+ self.dtype.itemsize
733
+ * self.page_size
778
734
  * (self.kv_lora_rank + self.qk_rope_head_dim)
735
+ )
736
+ element_size_list = [element_size] * len(ptr_list)
737
+ elif self.layout in ["page_first", "page_first_direct"]:
738
+ for index in range(0, len(indices), self.page_size):
739
+ k_ptr = (
740
+ kv_buffer_data_ptr
741
+ + indices[index]
742
+ * self.layer_num
743
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
744
+ * self.dtype.itemsize
745
+ )
746
+ ptr_list.append(k_ptr)
747
+ element_size = (
748
+ self.layer_num
779
749
  * self.dtype.itemsize
750
+ * self.page_size
751
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
780
752
  )
781
- ptr_list.append(k_ptr)
782
- key_ = keys[index // self.page_size]
783
- key_list.append(f"{key_}_k")
784
- element_size = (
785
- self.layer_num
786
- * self.dtype.itemsize
787
- * self.page_size
788
- * (self.kv_lora_rank + self.qk_rope_head_dim)
789
- )
790
- element_size_list = [element_size] * len(key_list)
791
- return key_list, ptr_list, element_size_list
792
-
793
- def get_buffer_with_hash(self, keys, indices=None):
794
- assert self.layout == "page_first"
795
- assert indices is None or (len(keys) == (len(indices) // self.page_size))
796
-
797
- buf_list = []
798
-
799
- if indices is not None:
800
- for i in range(len(keys)):
801
- index = indices[i * self.page_size]
802
- buf_list.append(self.kv_buffer[index : index + self.page_size])
803
-
804
- return keys, buf_list, 1
753
+ element_size_list = [element_size] * len(ptr_list)
754
+ else:
755
+ raise ValueError(f"Unsupported layout: {self.layout}")
756
+ return ptr_list, element_size_list