sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -776,14 +776,13 @@ class FlashAttentionBackend(AttentionBackend):
776
776
  o = result
777
777
  else:
778
778
  if (
779
- not global_server_args_dict["disable_chunked_prefix_cache"]
780
- and forward_batch.attn_attend_prefix_cache is not None
779
+ forward_batch.attn_attend_prefix_cache is not None
781
780
  and not forward_batch.forward_mode.is_target_verify()
782
781
  and not forward_batch.forward_mode.is_draft_extend()
783
782
  ):
784
783
  # Do multi-head attention with chunked prefix cache
785
-
786
784
  if forward_batch.attn_attend_prefix_cache:
785
+ assert not global_server_args_dict["disable_chunked_prefix_cache"]
787
786
  # MHA for chunked prefix kv cache when running model with MLA
788
787
  assert forward_batch.prefix_chunk_idx is not None
789
788
  assert forward_batch.prefix_chunk_cu_seq_lens is not None
@@ -792,7 +791,8 @@ class FlashAttentionBackend(AttentionBackend):
792
791
  chunk_idx = forward_batch.prefix_chunk_idx
793
792
  assert chunk_idx >= 0
794
793
 
795
- output, lse, *rest = flash_attn_varlen_func(
794
+ assert forward_batch.mha_return_lse
795
+ output = flash_attn_varlen_func(
796
796
  q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
797
797
  k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
798
798
  v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
@@ -806,7 +806,7 @@ class FlashAttentionBackend(AttentionBackend):
806
806
  )
807
807
  else:
808
808
  # MHA for extend part of sequence without attending prefix kv cache
809
- output, lse, *rest = flash_attn_varlen_func(
809
+ output = flash_attn_varlen_func(
810
810
  q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
811
811
  k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
812
812
  v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
@@ -816,9 +816,13 @@ class FlashAttentionBackend(AttentionBackend):
816
816
  max_seqlen_k=metadata.max_seq_len_q,
817
817
  softmax_scale=layer.scaling,
818
818
  causal=True,
819
- return_softmax_lse=True,
819
+ return_softmax_lse=forward_batch.mha_return_lse,
820
820
  )
821
- return output, lse
821
+ if forward_batch.mha_return_lse:
822
+ output, lse, *rest = output
823
+ lse = torch.transpose(lse, 0, 1).contiguous()
824
+ return output, lse
825
+ return output
822
826
  else:
823
827
  # Do absorbed multi-latent attention
824
828
  kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
@@ -1163,6 +1167,8 @@ class FlashAttentionBackend(AttentionBackend):
1163
1167
  This creates fixed-size tensors that will be reused during CUDA graph replay
1164
1168
  to avoid memory allocations.
1165
1169
  """
1170
+ max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size
1171
+
1166
1172
  # This is being used by normal decode and draft decode when topk == 1
1167
1173
  self.decode_cuda_graph_metadata = {
1168
1174
  "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
@@ -1174,13 +1180,7 @@ class FlashAttentionBackend(AttentionBackend):
1174
1180
  ),
1175
1181
  "page_table": torch.zeros(
1176
1182
  max_bs,
1177
- (self.max_context_len + self.page_size - 1) // self.page_size,
1178
- dtype=torch.int32,
1179
- device=self.device,
1180
- ),
1181
- "page_table_draft_decode": torch.zeros(
1182
- max_bs,
1183
- (self.max_context_len + self.page_size - 1) // self.page_size,
1183
+ max_num_pages,
1184
1184
  dtype=torch.int32,
1185
1185
  device=self.device,
1186
1186
  ),
@@ -1188,7 +1188,6 @@ class FlashAttentionBackend(AttentionBackend):
1188
1188
  0, self.max_context_len, self.page_size, device=self.device
1189
1189
  ),
1190
1190
  }
1191
-
1192
1191
  # Only allocate local attention buffers if local attention is enabled
1193
1192
  # This prevents OOM errors when local attention is not being used
1194
1193
  if self.attention_chunk_size is not None:
@@ -1274,6 +1273,14 @@ class FlashAttentionBackend(AttentionBackend):
1274
1273
  self.speculative_num_draft_tokens is not None
1275
1274
  and self.speculative_num_draft_tokens > 0
1276
1275
  ):
1276
+ # "page_table_draft_decode" will be set only when spec decoding enabled to save memory
1277
+ self.decode_cuda_graph_metadata["page_table_draft_decode"] = torch.zeros(
1278
+ max_bs,
1279
+ max_num_pages,
1280
+ dtype=torch.int32,
1281
+ device=self.device,
1282
+ )
1283
+
1277
1284
  self.target_verify_metadata = {
1278
1285
  "cache_seqlens": torch.zeros(
1279
1286
  max_bs, dtype=torch.int32, device=self.device
@@ -1290,7 +1297,7 @@ class FlashAttentionBackend(AttentionBackend):
1290
1297
  ),
1291
1298
  "page_table": torch.zeros(
1292
1299
  max_bs,
1293
- (self.max_context_len + self.page_size - 1) // self.page_size,
1300
+ max_num_pages,
1294
1301
  dtype=torch.int32,
1295
1302
  device=self.device,
1296
1303
  ),
@@ -1313,7 +1320,7 @@ class FlashAttentionBackend(AttentionBackend):
1313
1320
  ),
1314
1321
  "page_table": torch.zeros(
1315
1322
  max_bs,
1316
- (self.max_context_len + self.page_size - 1) // self.page_size,
1323
+ max_num_pages,
1317
1324
  dtype=torch.int32,
1318
1325
  device=self.device,
1319
1326
  ),
@@ -122,6 +122,7 @@ class FlashInferAttnBackend(AttentionBackend):
122
122
  # Allocate buffers
123
123
  global global_workspace_buffer
124
124
  if global_workspace_buffer is None:
125
+ # different from flashinfer zero_init_global_workspace_buffer
125
126
  global_workspace_buffer = torch.empty(
126
127
  global_config.flashinfer_workspace_size,
127
128
  dtype=torch.uint8,
@@ -870,6 +871,8 @@ class FlashInferIndicesUpdaterPrefill:
870
871
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
871
872
  ):
872
873
  if use_ragged:
874
+ # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
875
+ # and forward_batch.extend_seq_lens_cpu
873
876
  paged_kernel_lens = prefix_lens
874
877
  paged_kernel_lens_sum = paged_kernel_lens.sum().item()
875
878
  else:
@@ -1260,11 +1263,12 @@ def should_use_tensor_core(
1260
1263
  # Calculate GQA group size
1261
1264
  gqa_group_size = num_attention_heads // num_kv_heads
1262
1265
 
1263
- # Determine based on dtype and GQA group size
1266
+ # For Flashinfer, a GQA group size of at least 4 is needed to efficiently
1267
+ # use Tensor Cores, as it fuses the head group with the token dimension in MMA.
1264
1268
  if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
1265
1269
  return True
1266
1270
  elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
1267
- return gqa_group_size > 4
1271
+ return gqa_group_size >= 4
1268
1272
  else:
1269
1273
  return False
1270
1274
 
@@ -1369,7 +1373,14 @@ def fast_decode_plan(
1369
1373
 
1370
1374
  if self.use_tensor_cores:
1371
1375
  # ALSO convert last_page_len to CPU
1372
- last_page_len_host = last_page_len.cpu()
1376
+ if page_size == 1:
1377
+ # When page size is 1, last_page_len is always 1.
1378
+ # Directly construct the host tensor rather than executing a device-to-host copy.
1379
+ last_page_len_host = torch.ones(
1380
+ (batch_size,), dtype=torch.int32, device="cpu"
1381
+ )
1382
+ else:
1383
+ last_page_len_host = last_page_len.cpu()
1373
1384
 
1374
1385
  kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
1375
1386