sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,13 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ from __future__ import annotations
17
+
18
+ from dataclasses import dataclass
19
+
20
+ from sglang.srt.configs.mamba_utils import Mamba2CacheParams
21
+ from sglang.srt.layers.attention.nsa import index_buf_accessor
22
+ from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
16
23
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
17
24
 
18
25
  """
@@ -27,7 +34,7 @@ KVCache actually holds the physical kv cache.
27
34
  import abc
28
35
  import logging
29
36
  from contextlib import nullcontext
30
- from typing import Dict, List, Optional, Tuple, Union
37
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
31
38
 
32
39
  import numpy as np
33
40
  import torch
@@ -38,6 +45,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
38
45
  from sglang.srt.layers.radix_attention import RadixAttention
39
46
  from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
40
47
 
48
+ if TYPE_CHECKING:
49
+ from sglang.srt.managers.cache_controller import LayerDoneCounter
50
+
41
51
  logger = logging.getLogger(__name__)
42
52
 
43
53
  GB = 1024 * 1024 * 1024
@@ -47,6 +57,10 @@ if _is_npu:
47
57
  import torch_npu
48
58
 
49
59
 
60
+ def get_tensor_size_bytes(t: torch.Tensor):
61
+ return np.prod(t.shape) * t.dtype.itemsize
62
+
63
+
50
64
  class ReqToTokenPool:
51
65
  """A memory pool that maps a request to its token locations."""
52
66
 
@@ -97,6 +111,225 @@ class ReqToTokenPool:
97
111
  self.free_slots = list(range(self.size))
98
112
 
99
113
 
114
+ class MambaPool:
115
+ @dataclass(frozen=True, kw_only=True)
116
+ class State:
117
+ conv: torch.Tensor
118
+ temporal: torch.Tensor
119
+
120
+ def at_layer_idx(self, layer: int):
121
+ return type(self)(**{k: v[layer] for k, v in vars(self).items()})
122
+
123
+ def mem_usage_bytes(self):
124
+ return sum(get_tensor_size_bytes(t) for t in vars(self).values())
125
+
126
+ @dataclass(frozen=True, kw_only=True)
127
+ class SpeculativeState(State):
128
+ intermediate_ssm: torch.Tensor
129
+ intermediate_conv_window: torch.Tensor
130
+
131
+ def __init__(
132
+ self,
133
+ *,
134
+ size: int,
135
+ cache_params: "Mamba2CacheParams",
136
+ device: str,
137
+ speculative_num_draft_tokens: Optional[int] = None,
138
+ ):
139
+ conv_state_shape = cache_params.shape.conv
140
+ temporal_state_shape = cache_params.shape.temporal
141
+ conv_dtype = cache_params.dtype.conv
142
+ ssm_dtype = cache_params.dtype.temporal
143
+ num_mamba_layers = len(cache_params.layers)
144
+
145
+ # assume conv_state = (dim, state_len)
146
+ assert conv_state_shape[0] > conv_state_shape[1]
147
+ conv_state = torch.zeros(
148
+ size=(num_mamba_layers, size + 1) + conv_state_shape,
149
+ dtype=conv_dtype,
150
+ device=device,
151
+ )
152
+ temporal_state = torch.zeros(
153
+ size=(num_mamba_layers, size + 1) + temporal_state_shape,
154
+ dtype=ssm_dtype,
155
+ device=device,
156
+ )
157
+ if speculative_num_draft_tokens is not None:
158
+ # Cache intermediate SSM states per draft token during target verify
159
+ # Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
160
+ intermediate_ssm_state_cache = torch.zeros(
161
+ size=(
162
+ num_mamba_layers,
163
+ size + 1,
164
+ speculative_num_draft_tokens,
165
+ temporal_state_shape[0],
166
+ temporal_state_shape[1],
167
+ temporal_state_shape[2],
168
+ ),
169
+ dtype=ssm_dtype,
170
+ device="cuda",
171
+ )
172
+ # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
173
+ # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
174
+ intermediate_conv_window_cache = torch.zeros(
175
+ size=(
176
+ num_mamba_layers,
177
+ size + 1,
178
+ speculative_num_draft_tokens,
179
+ conv_state_shape[0],
180
+ conv_state_shape[1],
181
+ ),
182
+ dtype=conv_dtype,
183
+ device="cuda",
184
+ )
185
+ self.mamba_cache = self.SpeculativeState(
186
+ conv=conv_state,
187
+ temporal=temporal_state,
188
+ intermediate_ssm=intermediate_ssm_state_cache,
189
+ intermediate_conv_window=intermediate_conv_window_cache,
190
+ )
191
+ logger.info(
192
+ f"Mamba Cache is allocated. "
193
+ f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
194
+ f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
195
+ f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
196
+ f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
197
+ )
198
+ else:
199
+ self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
200
+ logger.info(
201
+ f"Mamba Cache is allocated. "
202
+ f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
203
+ f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
204
+ )
205
+ self.size = size
206
+ self.free_slots = list(range(size))
207
+ self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
208
+
209
+ def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
210
+ assert isinstance(self.mamba_cache, self.SpeculativeState)
211
+ return self.mamba_cache
212
+
213
+ def mamba2_layer_cache(self, layer_id: int):
214
+ return self.mamba_cache.at_layer_idx(layer_id)
215
+
216
+ def available_size(self):
217
+ return len(self.free_slots)
218
+
219
+ def alloc(self, need_size: int) -> Optional[List[int]]:
220
+ if need_size > len(self.free_slots):
221
+ return None
222
+
223
+ select_index = self.free_slots[:need_size]
224
+ self.free_slots = self.free_slots[need_size:]
225
+
226
+ return select_index
227
+
228
+ def free(self, free_index: Union[int, List[int]]):
229
+ if isinstance(free_index, (int,)):
230
+ self.free_slots.append(free_index)
231
+ else:
232
+ self.free_slots.extend(free_index)
233
+ self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
234
+ :, free_index
235
+ ] = 0
236
+
237
+ def clear(self):
238
+ self.free_slots = list(range(self.size))
239
+
240
+
241
+ class HybridReqToTokenPool(ReqToTokenPool):
242
+ """A memory pool that maps a request to its token locations."""
243
+
244
+ def __init__(
245
+ self,
246
+ *,
247
+ size: int,
248
+ max_context_len: int,
249
+ device: str,
250
+ enable_memory_saver: bool,
251
+ cache_params: "Mamba2CacheParams",
252
+ speculative_num_draft_tokens: int = None,
253
+ ):
254
+ super().__init__(
255
+ size=size,
256
+ max_context_len=max_context_len,
257
+ device=device,
258
+ enable_memory_saver=enable_memory_saver,
259
+ )
260
+
261
+ self.mamba_pool = MambaPool(
262
+ size=size,
263
+ cache_params=cache_params,
264
+ device=device,
265
+ speculative_num_draft_tokens=speculative_num_draft_tokens,
266
+ )
267
+ self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
268
+
269
+ self.device = device
270
+ self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
271
+ size, dtype=torch.int32, device=self.device
272
+ )
273
+
274
+ self.rid_to_mamba_index_mapping: Dict[str, int] = {}
275
+ self.mamba_index_to_rid_mapping: Dict[int, str] = {}
276
+
277
+ # For chunk prefill req, we do not need to allocate mamba cache,
278
+ # We could use allocated mamba cache instead.
279
+ def alloc(
280
+ self, need_size: int, reqs: Optional[List["Req"]] = None
281
+ ) -> Optional[List[int]]:
282
+ select_index = super().alloc(need_size)
283
+ if select_index == None:
284
+ return None
285
+
286
+ mamba_index = []
287
+ for req in reqs:
288
+ rid = req.rid
289
+ if rid in self.rid_to_mamba_index_mapping:
290
+ mid = self.rid_to_mamba_index_mapping[rid]
291
+ elif (mid := self.mamba_pool.alloc(1)) is not None:
292
+ mid = mid[0]
293
+ self.rid_to_mamba_index_mapping[rid] = mid
294
+ self.mamba_index_to_rid_mapping[mid] = rid
295
+ mamba_index.append(mid)
296
+ assert len(select_index) == len(
297
+ mamba_index
298
+ ), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
299
+ self.req_index_to_mamba_index_mapping[select_index] = torch.tensor(
300
+ mamba_index, dtype=torch.int32, device=self.device
301
+ )
302
+ return select_index
303
+
304
+ def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
305
+ return self.req_index_to_mamba_index_mapping[req_indices]
306
+
307
+ def mamba2_layer_cache(self, layer_id: int):
308
+ assert layer_id in self.mamba_map
309
+ return self.mamba_pool.mamba2_layer_cache(self.mamba_map[layer_id])
310
+
311
+ def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState:
312
+ return self.mamba_pool.get_speculative_mamba2_params_all_layers()
313
+
314
+ # For chunk prefill, we can not free mamba cache, we need use it in the future
315
+ def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
316
+ super().free(free_index)
317
+ if free_mamba_cache:
318
+ mamba_index = self.req_index_to_mamba_index_mapping[free_index]
319
+ mamba_index_list = mamba_index.tolist()
320
+ if isinstance(mamba_index_list, int):
321
+ mamba_index_list = [mamba_index_list]
322
+ self.mamba_pool.free(mamba_index_list)
323
+ for mid in mamba_index_list:
324
+ rid = self.mamba_index_to_rid_mapping[mid]
325
+ self.mamba_index_to_rid_mapping.pop(mid)
326
+ self.rid_to_mamba_index_mapping.pop(rid)
327
+
328
+ def clear(self):
329
+ super().clear()
330
+ self.mamba_pool.clear()
331
+
332
+
100
333
  class KVCache(abc.ABC):
101
334
  @abc.abstractmethod
102
335
  def __init__(
@@ -130,6 +363,29 @@ class KVCache(abc.ABC):
130
363
  # used for chunked cpu-offloading
131
364
  self.cpu_offloading_chunk_size = 8192
132
365
 
366
+ # default state for optional layer-wise transfer control
367
+ self.layer_transfer_counter = None
368
+
369
+ def _finalize_allocation_log(self, num_tokens: int):
370
+ """Common logging and mem_usage computation for KV cache allocation.
371
+ Supports both tuple (K, V) size returns and single KV size returns.
372
+ """
373
+ kv_size_bytes = self.get_kv_size_bytes()
374
+ if isinstance(kv_size_bytes, tuple):
375
+ k_size, v_size = kv_size_bytes
376
+ k_size_GB = k_size / GB
377
+ v_size_GB = v_size / GB
378
+ logger.info(
379
+ f"KV Cache is allocated. #tokens: {num_tokens}, K size: {k_size_GB:.2f} GB, V size: {v_size_GB:.2f} GB"
380
+ )
381
+ self.mem_usage = k_size_GB + v_size_GB
382
+ else:
383
+ kv_size_GB = kv_size_bytes / GB
384
+ logger.info(
385
+ f"KV Cache is allocated. #tokens: {num_tokens}, KV size: {kv_size_GB:.2f} GB"
386
+ )
387
+ self.mem_usage = kv_size_GB
388
+
133
389
  @abc.abstractmethod
134
390
  def get_key_buffer(self, layer_id: int) -> torch.Tensor:
135
391
  raise NotImplementedError()
@@ -152,7 +408,7 @@ class KVCache(abc.ABC):
152
408
  ) -> None:
153
409
  raise NotImplementedError()
154
410
 
155
- def register_layer_transfer_counter(self, layer_transfer_counter):
411
+ def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter):
156
412
  self.layer_transfer_counter = layer_transfer_counter
157
413
 
158
414
  def get_cpu_copy(self, indices):
@@ -176,6 +432,7 @@ class MHATokenToKVPool(KVCache):
176
432
  enable_memory_saver: bool,
177
433
  start_layer: Optional[int] = None,
178
434
  end_layer: Optional[int] = None,
435
+ enable_kv_cache_copy: bool = False,
179
436
  ):
180
437
  super().__init__(
181
438
  size,
@@ -205,15 +462,58 @@ class MHATokenToKVPool(KVCache):
205
462
 
206
463
  self._create_buffers()
207
464
 
208
- self.layer_transfer_counter = None
209
465
  self.device_module = torch.get_device_module(self.device)
210
466
  self.alt_stream = self.device_module.Stream() if _is_cuda else None
211
467
 
212
- k_size, v_size = self.get_kv_size_bytes()
213
- logger.info(
214
- f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
468
+ if enable_kv_cache_copy:
469
+ self._init_kv_copy_and_warmup()
470
+ else:
471
+ self._kv_copy_config = None
472
+
473
+ self._finalize_allocation_log(size)
474
+
475
+ def _init_kv_copy_and_warmup(self):
476
+ # Heuristics for KV copy tiling
477
+ _KV_COPY_STRIDE_THRESHOLD_LARGE = 8192
478
+ _KV_COPY_STRIDE_THRESHOLD_MEDIUM = 4096
479
+ _KV_COPY_TILE_SIZE_LARGE = 512
480
+ _KV_COPY_TILE_SIZE_MEDIUM = 256
481
+ _KV_COPY_TILE_SIZE_SMALL = 128
482
+ _KV_COPY_NUM_WARPS_LARGE_TILE = 8
483
+ _KV_COPY_NUM_WARPS_SMALL_TILE = 4
484
+
485
+ stride_bytes = int(self.data_strides[0].item())
486
+ if stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_LARGE:
487
+ bytes_per_tile = _KV_COPY_TILE_SIZE_LARGE
488
+ elif stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_MEDIUM:
489
+ bytes_per_tile = _KV_COPY_TILE_SIZE_MEDIUM
490
+ else:
491
+ bytes_per_tile = _KV_COPY_TILE_SIZE_SMALL
492
+
493
+ self._kv_copy_config = {
494
+ "bytes_per_tile": bytes_per_tile,
495
+ "byte_tiles": (stride_bytes + bytes_per_tile - 1) // bytes_per_tile,
496
+ "num_warps": (
497
+ _KV_COPY_NUM_WARPS_SMALL_TILE
498
+ if bytes_per_tile <= _KV_COPY_TILE_SIZE_MEDIUM
499
+ else _KV_COPY_NUM_WARPS_LARGE_TILE
500
+ ),
501
+ }
502
+
503
+ dummy_loc = torch.zeros(1, dtype=torch.int32, device=self.device)
504
+ grid = (self.data_ptrs.numel(), self._kv_copy_config["byte_tiles"])
505
+
506
+ copy_all_layer_kv_cache_tiled[grid](
507
+ self.data_ptrs,
508
+ self.data_strides,
509
+ dummy_loc,
510
+ dummy_loc,
511
+ 1,
512
+ 1,
513
+ BYTES_PER_TILE=self._kv_copy_config["bytes_per_tile"],
514
+ num_warps=self._kv_copy_config["num_warps"],
515
+ num_stages=2,
215
516
  )
216
- self.mem_usage = (k_size + v_size) / GB
217
517
 
218
518
  def _create_buffers(self):
219
519
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
@@ -269,10 +569,10 @@ class MHATokenToKVPool(KVCache):
269
569
  assert hasattr(self, "v_buffer")
270
570
  k_size_bytes = 0
271
571
  for k_cache in self.k_buffer:
272
- k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
572
+ k_size_bytes += get_tensor_size_bytes(k_cache)
273
573
  v_size_bytes = 0
274
574
  for v_cache in self.v_buffer:
275
- v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
575
+ v_size_bytes += get_tensor_size_bytes(v_cache)
276
576
  return k_size_bytes, v_size_bytes
277
577
 
278
578
  # for disagg
@@ -352,7 +652,6 @@ class MHATokenToKVPool(KVCache):
352
652
  # same applies to get_value_buffer and get_kv_buffer
353
653
  if self.layer_transfer_counter is not None:
354
654
  self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
355
-
356
655
  return self._get_key_buffer(layer_id)
357
656
 
358
657
  def _get_value_buffer(self, layer_id: int):
@@ -410,60 +709,156 @@ class MHATokenToKVPool(KVCache):
410
709
  self.v_buffer[layer_id - self.start_layer][loc] = cache_v
411
710
 
412
711
  def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
413
- copy_all_layer_kv_cache[(len(self.data_ptrs),)](
712
+ N = tgt_loc.numel()
713
+ if N == 0:
714
+ return
715
+
716
+ assert (
717
+ self._kv_copy_config is not None
718
+ ), "KV copy not initialized. Set enable_kv_cache_copy=True in __init__"
719
+
720
+ cfg = self._kv_copy_config
721
+ N_upper = next_power_of_2(N)
722
+ grid = (self.data_ptrs.numel(), cfg["byte_tiles"])
723
+
724
+ copy_all_layer_kv_cache_tiled[grid](
414
725
  self.data_ptrs,
415
726
  self.data_strides,
416
727
  tgt_loc,
417
728
  src_loc,
418
- len(tgt_loc),
419
- next_power_of_2(len(tgt_loc)),
729
+ N,
730
+ N_upper,
731
+ BYTES_PER_TILE=cfg["bytes_per_tile"],
732
+ num_warps=cfg["num_warps"],
733
+ num_stages=2,
420
734
  )
421
735
 
422
736
 
423
- class SWAKVPool(KVCache):
424
- """KV cache with separate pools for full and SWA attention layers."""
737
+ class HybridLinearKVPool(KVCache):
738
+ """KV cache with separate pools for full and linear attention layers."""
425
739
 
426
740
  def __init__(
427
741
  self,
428
742
  size: int,
429
- size_swa: int,
430
743
  dtype: torch.dtype,
744
+ page_size: int,
431
745
  head_num: int,
432
746
  head_dim: int,
433
- swa_attention_layer_ids: List[int],
434
747
  full_attention_layer_ids: List[int],
435
748
  enable_kvcache_transpose: bool,
436
749
  device: str,
437
750
  ):
438
751
  self.size = size
439
- self.size_swa = size_swa
440
752
  self.dtype = dtype
441
753
  self.device = device
442
- self.swa_layer_nums = len(swa_attention_layer_ids)
443
754
  self.full_layer_nums = len(full_attention_layer_ids)
444
- self.page_size = 1
755
+ self.page_size = page_size
445
756
  # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
446
757
  assert not enable_kvcache_transpose
447
- TokenToKVPoolClass = MHATokenToKVPool
448
- self.swa_kv_pool = TokenToKVPoolClass(
449
- size=size_swa,
758
+ if _is_npu:
759
+ TokenToKVPoolClass = AscendTokenToKVPool
760
+ else:
761
+ TokenToKVPoolClass = MHATokenToKVPool
762
+ self.full_kv_pool = TokenToKVPoolClass(
763
+ size=size,
450
764
  page_size=self.page_size,
451
765
  dtype=dtype,
452
766
  head_num=head_num,
453
767
  head_dim=head_dim,
454
- layer_num=self.swa_layer_nums,
768
+ layer_num=self.full_layer_nums,
455
769
  device=device,
456
770
  enable_memory_saver=False,
457
771
  )
458
- self.full_kv_pool = TokenToKVPoolClass(
772
+ self.full_attention_layer_id_mapping = {
773
+ id: i for i, id in enumerate(full_attention_layer_ids)
774
+ }
775
+ k_size, v_size = self.get_kv_size_bytes()
776
+ self.mem_usage = (k_size + v_size) / GB
777
+
778
+ def get_kv_size_bytes(self):
779
+ return self.full_kv_pool.get_kv_size_bytes()
780
+
781
+ def get_contiguous_buf_infos(self):
782
+ return self.full_kv_pool.get_contiguous_buf_infos()
783
+
784
+ def _transfer_full_attention_id(self, layer_id: int):
785
+ if layer_id not in self.full_attention_layer_id_mapping:
786
+ raise ValueError(
787
+ f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}"
788
+ )
789
+ return self.full_attention_layer_id_mapping[layer_id]
790
+
791
+ def get_key_buffer(self, layer_id: int):
792
+ layer_id = self._transfer_full_attention_id(layer_id)
793
+ return self.full_kv_pool.get_key_buffer(layer_id)
794
+
795
+ def get_value_buffer(self, layer_id: int):
796
+ layer_id = self._transfer_full_attention_id(layer_id)
797
+ return self.full_kv_pool.get_value_buffer(layer_id)
798
+
799
+ def get_kv_buffer(self, layer_id: int):
800
+ layer_id = self._transfer_full_attention_id(layer_id)
801
+ return self.full_kv_pool.get_kv_buffer(layer_id)
802
+
803
+ def set_kv_buffer(
804
+ self,
805
+ layer: RadixAttention,
806
+ loc: torch.Tensor,
807
+ cache_k: torch.Tensor,
808
+ cache_v: torch.Tensor,
809
+ k_scale: float = 1.0,
810
+ v_scale: float = 1.0,
811
+ ):
812
+ layer_id = self._transfer_full_attention_id(layer.layer_id)
813
+ self.full_kv_pool.set_kv_buffer(
814
+ None,
815
+ loc,
816
+ cache_k,
817
+ cache_v,
818
+ k_scale,
819
+ v_scale,
820
+ layer_id_override=layer_id,
821
+ )
822
+
823
+ def get_v_head_dim(self):
824
+ return self.full_kv_pool.get_value_buffer(0).shape[-1]
825
+
826
+
827
+ class SWAKVPool(KVCache):
828
+ """KV cache with separate pools for full and SWA attention layers."""
829
+
830
+ def __init__(
831
+ self,
832
+ size: int,
833
+ size_swa: int,
834
+ dtype: torch.dtype,
835
+ swa_attention_layer_ids: List[int],
836
+ full_attention_layer_ids: List[int],
837
+ enable_kvcache_transpose: bool,
838
+ token_to_kv_pool_class: KVCache = MHATokenToKVPool,
839
+ **kwargs,
840
+ ):
841
+ self.size = size
842
+ self.size_swa = size_swa
843
+ self.dtype = dtype
844
+ self.swa_layer_nums = len(swa_attention_layer_ids)
845
+ self.full_layer_nums = len(full_attention_layer_ids)
846
+ kwargs["page_size"] = 1
847
+ kwargs["enable_memory_saver"] = False
848
+ # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
849
+ assert not enable_kvcache_transpose
850
+
851
+ self.swa_kv_pool = token_to_kv_pool_class(
852
+ size=size_swa,
853
+ dtype=dtype,
854
+ layer_num=self.swa_layer_nums,
855
+ **kwargs,
856
+ )
857
+ self.full_kv_pool = token_to_kv_pool_class(
459
858
  size=size,
460
- page_size=self.page_size,
461
859
  dtype=dtype,
462
- head_num=head_num,
463
- head_dim=head_dim,
464
860
  layer_num=self.full_layer_nums,
465
- device=device,
466
- enable_memory_saver=False,
861
+ **kwargs,
467
862
  )
468
863
  self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
469
864
  for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
@@ -613,8 +1008,12 @@ class AscendTokenToKVPool(MHATokenToKVPool):
613
1008
  cache_v: torch.Tensor,
614
1009
  k_scale: Optional[float] = None,
615
1010
  v_scale: Optional[float] = None,
1011
+ layer_id_override: Optional[int] = None,
616
1012
  ):
617
- layer_id = layer.layer_id
1013
+ if layer_id_override is not None:
1014
+ layer_id = layer_id_override
1015
+ else:
1016
+ layer_id = layer.layer_id
618
1017
  if cache_k.dtype != self.dtype:
619
1018
  if k_scale is not None:
620
1019
  cache_k.div_(k_scale)
@@ -719,6 +1118,8 @@ class MLATokenToKVPool(KVCache):
719
1118
  enable_memory_saver: bool,
720
1119
  start_layer: Optional[int] = None,
721
1120
  end_layer: Optional[int] = None,
1121
+ use_nsa: bool = False,
1122
+ override_kv_cache_dim: Optional[int] = None,
722
1123
  ):
723
1124
  super().__init__(
724
1125
  size,
@@ -733,6 +1134,14 @@ class MLATokenToKVPool(KVCache):
733
1134
 
734
1135
  self.kv_lora_rank = kv_lora_rank
735
1136
  self.qk_rope_head_dim = qk_rope_head_dim
1137
+ self.use_nsa = use_nsa
1138
+ self.nsa_kv_cache_store_fp8 = use_nsa and dtype == torch.float8_e4m3fn
1139
+ # TODO do not hardcode
1140
+ self.kv_cache_dim = (
1141
+ 656
1142
+ if self.use_nsa and self.nsa_kv_cache_store_fp8
1143
+ else (kv_lora_rank + qk_rope_head_dim)
1144
+ )
736
1145
 
737
1146
  # for disagg with nvlink
738
1147
  self.enable_custom_mem_pool = get_bool_env_var(
@@ -756,7 +1165,7 @@ class MLATokenToKVPool(KVCache):
756
1165
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
757
1166
  self.kv_buffer = [
758
1167
  torch.zeros(
759
- (size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
1168
+ (size + page_size, 1, self.kv_cache_dim),
760
1169
  dtype=self.store_dtype,
761
1170
  device=device,
762
1171
  )
@@ -768,19 +1177,15 @@ class MLATokenToKVPool(KVCache):
768
1177
  dtype=torch.uint64,
769
1178
  device=self.device,
770
1179
  )
771
- self.layer_transfer_counter = None
772
-
773
- kv_size = self.get_kv_size_bytes()
774
- logger.info(
775
- f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
776
- )
777
- self.mem_usage = kv_size / GB
1180
+ if not use_nsa:
1181
+ # NSA will allocate indexer KV cache later and then log the total size
1182
+ self._finalize_allocation_log(size)
778
1183
 
779
1184
  def get_kv_size_bytes(self):
780
1185
  assert hasattr(self, "kv_buffer")
781
1186
  kv_size_bytes = 0
782
1187
  for kv_cache in self.kv_buffer:
783
- kv_size_bytes += np.prod(kv_cache.shape) * kv_cache.dtype.itemsize
1188
+ kv_size_bytes += get_tensor_size_bytes(kv_cache)
784
1189
  return kv_size_bytes
785
1190
 
786
1191
  # for disagg
@@ -825,6 +1230,7 @@ class MLATokenToKVPool(KVCache):
825
1230
  cache_v: torch.Tensor,
826
1231
  ):
827
1232
  layer_id = layer.layer_id
1233
+ assert not (self.use_nsa and self.nsa_kv_cache_store_fp8)
828
1234
  if cache_k.dtype != self.dtype:
829
1235
  cache_k = cache_k.to(self.dtype)
830
1236
  if self.store_dtype != self.dtype:
@@ -842,16 +1248,28 @@ class MLATokenToKVPool(KVCache):
842
1248
  cache_k_rope: torch.Tensor,
843
1249
  ):
844
1250
  layer_id = layer.layer_id
845
- if cache_k_nope.dtype != self.dtype:
846
- cache_k_nope = cache_k_nope.to(self.dtype)
847
- cache_k_rope = cache_k_rope.to(self.dtype)
848
- if self.store_dtype != self.dtype:
849
- cache_k_nope = cache_k_nope.view(self.store_dtype)
850
- cache_k_rope = cache_k_rope.view(self.store_dtype)
851
1251
 
852
- set_mla_kv_buffer_triton(
853
- self.kv_buffer[layer_id - self.start_layer], loc, cache_k_nope, cache_k_rope
854
- )
1252
+ if self.use_nsa and self.nsa_kv_cache_store_fp8:
1253
+ # original cache_k: (num_tokens, num_heads 1, hidden 576); we unsqueeze the page_size=1 dim here
1254
+ # TODO no need to cat
1255
+ cache_k = torch.cat([cache_k_nope, cache_k_rope], dim=-1)
1256
+ cache_k = quantize_k_cache(cache_k.unsqueeze(1)).squeeze(1)
1257
+ cache_k = cache_k.view(self.store_dtype)
1258
+ self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
1259
+ else:
1260
+ if cache_k_nope.dtype != self.dtype:
1261
+ cache_k_nope = cache_k_nope.to(self.dtype)
1262
+ cache_k_rope = cache_k_rope.to(self.dtype)
1263
+ if self.store_dtype != self.dtype:
1264
+ cache_k_nope = cache_k_nope.view(self.store_dtype)
1265
+ cache_k_rope = cache_k_rope.view(self.store_dtype)
1266
+
1267
+ set_mla_kv_buffer_triton(
1268
+ self.kv_buffer[layer_id - self.start_layer],
1269
+ loc,
1270
+ cache_k_nope,
1271
+ cache_k_rope,
1272
+ )
855
1273
 
856
1274
  def get_cpu_copy(self, indices):
857
1275
  torch.cuda.synchronize()
@@ -881,6 +1299,111 @@ class MLATokenToKVPool(KVCache):
881
1299
  torch.cuda.synchronize()
882
1300
 
883
1301
 
1302
+ class NSATokenToKVPool(MLATokenToKVPool):
1303
+ quant_block_size = 128
1304
+ index_k_with_scale_buffer_dtype = torch.uint8
1305
+
1306
+ def __init__(
1307
+ self,
1308
+ size: int,
1309
+ page_size: int,
1310
+ kv_lora_rank: int,
1311
+ dtype: torch.dtype,
1312
+ qk_rope_head_dim: int,
1313
+ layer_num: int,
1314
+ device: str,
1315
+ index_head_dim: int,
1316
+ enable_memory_saver: bool,
1317
+ start_layer: Optional[int] = None,
1318
+ end_layer: Optional[int] = None,
1319
+ ):
1320
+ super().__init__(
1321
+ size,
1322
+ page_size,
1323
+ dtype,
1324
+ kv_lora_rank,
1325
+ qk_rope_head_dim,
1326
+ layer_num,
1327
+ device,
1328
+ enable_memory_saver,
1329
+ start_layer,
1330
+ end_layer,
1331
+ use_nsa=True,
1332
+ )
1333
+ # self.index_k_dtype = torch.float8_e4m3fn
1334
+ # self.index_k_scale_dtype = torch.float32
1335
+ self.index_head_dim = index_head_dim
1336
+ # num head == 1 and head dim == 128 for index_k in NSA
1337
+ assert index_head_dim == 128
1338
+
1339
+ assert self.page_size == 64
1340
+ self.index_k_with_scale_buffer = [
1341
+ torch.zeros(
1342
+ # Layout:
1343
+ # ref: test_attention.py :: kv_cache_cast_to_fp8
1344
+ # shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
1345
+ # data: for page i,
1346
+ # * buf[i, :page_size * head_dim] for fp8 data
1347
+ # * buf[i, page_size * head_dim:].view(float32) for scale
1348
+ (
1349
+ (size + page_size + 1) // self.page_size,
1350
+ self.page_size
1351
+ * (index_head_dim + index_head_dim // self.quant_block_size * 4),
1352
+ ),
1353
+ dtype=self.index_k_with_scale_buffer_dtype,
1354
+ device=device,
1355
+ )
1356
+ for _ in range(layer_num)
1357
+ ]
1358
+ self._finalize_allocation_log(size)
1359
+
1360
+ def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
1361
+ if self.layer_transfer_counter is not None:
1362
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
1363
+ return self.index_k_with_scale_buffer[layer_id - self.start_layer]
1364
+
1365
+ def get_index_k_continuous(
1366
+ self,
1367
+ layer_id: int,
1368
+ seq_len: int,
1369
+ page_indices: torch.Tensor,
1370
+ ):
1371
+ buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
1372
+ return index_buf_accessor.GetK.execute(
1373
+ self, buf, seq_len=seq_len, page_indices=page_indices
1374
+ )
1375
+
1376
+ def get_index_k_scale_continuous(
1377
+ self,
1378
+ layer_id: int,
1379
+ seq_len: int,
1380
+ page_indices: torch.Tensor,
1381
+ ):
1382
+ buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
1383
+ return index_buf_accessor.GetS.execute(
1384
+ self, buf, seq_len=seq_len, page_indices=page_indices
1385
+ )
1386
+
1387
+ # TODO rename later (currently use diff name to avoid confusion)
1388
+ def set_index_k_and_scale_buffer(
1389
+ self,
1390
+ layer_id: int,
1391
+ loc: torch.Tensor,
1392
+ index_k: torch.Tensor,
1393
+ index_k_scale: torch.Tensor,
1394
+ ) -> None:
1395
+ buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
1396
+ index_buf_accessor.SetKAndS.execute(
1397
+ pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
1398
+ )
1399
+
1400
+ def get_kv_size_bytes(self):
1401
+ kv_size_bytes = super().get_kv_size_bytes()
1402
+ for index_k_cache in self.index_k_with_scale_buffer:
1403
+ kv_size_bytes += get_tensor_size_bytes(index_k_cache)
1404
+ return kv_size_bytes
1405
+
1406
+
884
1407
  class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
885
1408
  def __init__(
886
1409
  self,
@@ -889,6 +1412,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
889
1412
  dtype: torch.dtype,
890
1413
  kv_lora_rank: int,
891
1414
  qk_rope_head_dim: int,
1415
+ index_head_dim: Optional[int],
892
1416
  layer_num: int,
893
1417
  device: str,
894
1418
  enable_memory_saver: bool,
@@ -908,6 +1432,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
908
1432
 
909
1433
  self.kv_lora_rank = kv_lora_rank
910
1434
  self.qk_rope_head_dim = qk_rope_head_dim
1435
+ self.index_head_dim = index_head_dim
911
1436
 
912
1437
  self.custom_mem_pool = None
913
1438
 
@@ -935,23 +1460,33 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
935
1460
  dtype=self.store_dtype,
936
1461
  device=self.device,
937
1462
  )
1463
+ if self.index_head_dim is not None:
1464
+ self.index_k_buffer = torch.zeros(
1465
+ (
1466
+ layer_num,
1467
+ self.size // self.page_size + 1,
1468
+ self.page_size,
1469
+ 1,
1470
+ self.index_head_dim,
1471
+ ),
1472
+ dtype=self.store_dtype,
1473
+ device=self.device,
1474
+ )
938
1475
 
939
- self.layer_transfer_counter = None
940
-
941
- kv_size = self.get_kv_size_bytes()
942
- logger.info(
943
- f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
944
- )
945
- self.mem_usage = kv_size / GB
1476
+ self._finalize_allocation_log(size)
946
1477
 
947
1478
  def get_kv_size_bytes(self):
948
1479
  assert hasattr(self, "k_buffer")
949
1480
  assert hasattr(self, "v_buffer")
950
1481
  kv_size_bytes = 0
951
1482
  for k_cache in self.k_buffer:
952
- kv_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
1483
+ kv_size_bytes += get_tensor_size_bytes(k_cache)
953
1484
  for v_cache in self.v_buffer:
954
- kv_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
1485
+ kv_size_bytes += get_tensor_size_bytes(v_cache)
1486
+ if self.index_head_dim is not None:
1487
+ assert hasattr(self, "index_k_buffer")
1488
+ for index_k_cache in self.index_k_buffer:
1489
+ kv_size_bytes += get_tensor_size_bytes(index_k_cache)
955
1490
  return kv_size_bytes
956
1491
 
957
1492
  def get_kv_buffer(self, layer_id: int):
@@ -978,6 +1513,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
978
1513
  return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
979
1514
  return self.v_buffer[layer_id - self.start_layer]
980
1515
 
1516
+ def get_index_k_buffer(self, layer_id: int):
1517
+ if self.layer_transfer_counter is not None:
1518
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
1519
+
1520
+ if self.store_dtype != self.dtype:
1521
+ return self.index_k_buffer[layer_id - self.start_layer].view(self.dtype)
1522
+ return self.index_k_buffer[layer_id - self.start_layer]
1523
+
981
1524
  # for disagg
982
1525
  def get_contiguous_buf_infos(self):
983
1526
  # MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
@@ -990,6 +1533,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
990
1533
  kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
991
1534
  self.v_buffer[i][0].nbytes for i in range(self.layer_num)
992
1535
  ]
1536
+ if self.index_head_dim is not None:
1537
+ kv_data_ptrs += [
1538
+ self.index_k_buffer[i].data_ptr() for i in range(self.layer_num)
1539
+ ]
1540
+ kv_data_lens += [
1541
+ self.index_k_buffer[i].nbytes for i in range(self.layer_num)
1542
+ ]
1543
+ kv_item_lens += [
1544
+ self.index_k_buffer[i][0].nbytes for i in range(self.layer_num)
1545
+ ]
993
1546
  return kv_data_ptrs, kv_data_lens, kv_item_lens
994
1547
 
995
1548
  def set_kv_buffer(
@@ -1026,6 +1579,26 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1026
1579
  cache_v.view(-1, 1, self.qk_rope_head_dim),
1027
1580
  )
1028
1581
 
1582
+ def set_index_k_buffer(
1583
+ self,
1584
+ layer_id: int,
1585
+ loc: torch.Tensor,
1586
+ index_k: torch.Tensor,
1587
+ ):
1588
+ if index_k.dtype != self.dtype:
1589
+ index_k = index_k.to(self.dtype)
1590
+
1591
+ if self.store_dtype != self.dtype:
1592
+ index_k = index_k.view(self.store_dtype)
1593
+
1594
+ torch_npu.npu_scatter_nd_update_(
1595
+ self.index_k_buffer[layer_id - self.start_layer].view(
1596
+ -1, 1, self.index_head_dim
1597
+ ),
1598
+ loc.view(-1, 1),
1599
+ index_k.view(-1, 1, self.index_head_dim),
1600
+ )
1601
+
1029
1602
 
1030
1603
  class DoubleSparseTokenToKVPool(KVCache):
1031
1604
  def __init__(
@@ -1107,38 +1680,36 @@ class DoubleSparseTokenToKVPool(KVCache):
1107
1680
 
1108
1681
 
1109
1682
  @triton.jit
1110
- def copy_all_layer_kv_cache(
1683
+ def copy_all_layer_kv_cache_tiled(
1111
1684
  data_ptrs,
1112
1685
  strides,
1113
1686
  tgt_loc_ptr,
1114
1687
  src_loc_ptr,
1115
1688
  num_locs,
1116
1689
  num_locs_upper: tl.constexpr,
1690
+ BYTES_PER_TILE: tl.constexpr,
1117
1691
  ):
1118
- BLOCK_SIZE: tl.constexpr = 128
1119
-
1692
+ """2D tiled kernel. Safe for in-place copy."""
1120
1693
  bid = tl.program_id(0)
1694
+ tid = tl.program_id(1)
1695
+
1121
1696
  stride = tl.load(strides + bid)
1697
+ base_ptr = tl.load(data_ptrs + bid)
1698
+ base_ptr = tl.cast(base_ptr, tl.pointer_type(tl.uint8))
1122
1699
 
1123
- data_ptr = tl.load(data_ptrs + bid)
1124
- data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
1700
+ byte_off = tid * BYTES_PER_TILE + tl.arange(0, BYTES_PER_TILE)
1701
+ mask_byte = byte_off < stride
1702
+ tl.multiple_of(byte_off, 16)
1125
1703
 
1126
- num_locs_offset = tl.arange(0, num_locs_upper)
1127
- tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
1128
- src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
1704
+ loc_idx = tl.arange(0, num_locs_upper)
1705
+ mask_loc = loc_idx < num_locs
1129
1706
 
1130
- # NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
1131
- # because this copy is an inplace operation.
1707
+ src = tl.load(src_loc_ptr + loc_idx, mask=mask_loc, other=0)
1708
+ tgt = tl.load(tgt_loc_ptr + loc_idx, mask=mask_loc, other=0)
1132
1709
 
1133
- num_loop = tl.cdiv(stride, BLOCK_SIZE)
1134
- for i in range(num_loop):
1135
- copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
1136
- mask = (num_locs_offset < num_locs)[:, None] & (copy_offset < stride)[None, :]
1137
- value = tl.load(
1138
- data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
1139
- )
1140
- tl.store(
1141
- data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
1142
- value,
1143
- mask=mask,
1144
- )
1710
+ src_ptr = base_ptr + src[:, None] * stride + byte_off[None, :]
1711
+ tgt_ptr = base_ptr + tgt[:, None] * stride + byte_off[None, :]
1712
+
1713
+ mask = mask_loc[:, None] & mask_byte[None, :]
1714
+ vals = tl.load(src_ptr, mask=mask)
1715
+ tl.store(tgt_ptr, vals, mask=mask)