sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -776,14 +776,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
776
776
|
o = result
|
777
777
|
else:
|
778
778
|
if (
|
779
|
-
not
|
780
|
-
and forward_batch.attn_attend_prefix_cache is not None
|
779
|
+
forward_batch.attn_attend_prefix_cache is not None
|
781
780
|
and not forward_batch.forward_mode.is_target_verify()
|
782
781
|
and not forward_batch.forward_mode.is_draft_extend()
|
783
782
|
):
|
784
783
|
# Do multi-head attention with chunked prefix cache
|
785
|
-
|
786
784
|
if forward_batch.attn_attend_prefix_cache:
|
785
|
+
assert not global_server_args_dict["disable_chunked_prefix_cache"]
|
787
786
|
# MHA for chunked prefix kv cache when running model with MLA
|
788
787
|
assert forward_batch.prefix_chunk_idx is not None
|
789
788
|
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
@@ -792,7 +791,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
792
791
|
chunk_idx = forward_batch.prefix_chunk_idx
|
793
792
|
assert chunk_idx >= 0
|
794
793
|
|
795
|
-
|
794
|
+
assert forward_batch.mha_return_lse
|
795
|
+
output = flash_attn_varlen_func(
|
796
796
|
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
797
797
|
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
798
798
|
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
@@ -806,7 +806,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
806
806
|
)
|
807
807
|
else:
|
808
808
|
# MHA for extend part of sequence without attending prefix kv cache
|
809
|
-
output
|
809
|
+
output = flash_attn_varlen_func(
|
810
810
|
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
811
811
|
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
|
812
812
|
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
|
@@ -816,9 +816,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
816
816
|
max_seqlen_k=metadata.max_seq_len_q,
|
817
817
|
softmax_scale=layer.scaling,
|
818
818
|
causal=True,
|
819
|
-
return_softmax_lse=
|
819
|
+
return_softmax_lse=forward_batch.mha_return_lse,
|
820
820
|
)
|
821
|
-
|
821
|
+
if forward_batch.mha_return_lse:
|
822
|
+
output, lse, *rest = output
|
823
|
+
lse = torch.transpose(lse, 0, 1).contiguous()
|
824
|
+
return output, lse
|
825
|
+
return output
|
822
826
|
else:
|
823
827
|
# Do absorbed multi-latent attention
|
824
828
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
|
@@ -1163,6 +1167,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1163
1167
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
1164
1168
|
to avoid memory allocations.
|
1165
1169
|
"""
|
1170
|
+
max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size
|
1171
|
+
|
1166
1172
|
# This is being used by normal decode and draft decode when topk == 1
|
1167
1173
|
self.decode_cuda_graph_metadata = {
|
1168
1174
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
@@ -1174,13 +1180,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1174
1180
|
),
|
1175
1181
|
"page_table": torch.zeros(
|
1176
1182
|
max_bs,
|
1177
|
-
|
1178
|
-
dtype=torch.int32,
|
1179
|
-
device=self.device,
|
1180
|
-
),
|
1181
|
-
"page_table_draft_decode": torch.zeros(
|
1182
|
-
max_bs,
|
1183
|
-
(self.max_context_len + self.page_size - 1) // self.page_size,
|
1183
|
+
max_num_pages,
|
1184
1184
|
dtype=torch.int32,
|
1185
1185
|
device=self.device,
|
1186
1186
|
),
|
@@ -1188,7 +1188,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1188
1188
|
0, self.max_context_len, self.page_size, device=self.device
|
1189
1189
|
),
|
1190
1190
|
}
|
1191
|
-
|
1192
1191
|
# Only allocate local attention buffers if local attention is enabled
|
1193
1192
|
# This prevents OOM errors when local attention is not being used
|
1194
1193
|
if self.attention_chunk_size is not None:
|
@@ -1274,6 +1273,14 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1274
1273
|
self.speculative_num_draft_tokens is not None
|
1275
1274
|
and self.speculative_num_draft_tokens > 0
|
1276
1275
|
):
|
1276
|
+
# "page_table_draft_decode" will be set only when spec decoding enabled to save memory
|
1277
|
+
self.decode_cuda_graph_metadata["page_table_draft_decode"] = torch.zeros(
|
1278
|
+
max_bs,
|
1279
|
+
max_num_pages,
|
1280
|
+
dtype=torch.int32,
|
1281
|
+
device=self.device,
|
1282
|
+
)
|
1283
|
+
|
1277
1284
|
self.target_verify_metadata = {
|
1278
1285
|
"cache_seqlens": torch.zeros(
|
1279
1286
|
max_bs, dtype=torch.int32, device=self.device
|
@@ -1290,7 +1297,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1290
1297
|
),
|
1291
1298
|
"page_table": torch.zeros(
|
1292
1299
|
max_bs,
|
1293
|
-
|
1300
|
+
max_num_pages,
|
1294
1301
|
dtype=torch.int32,
|
1295
1302
|
device=self.device,
|
1296
1303
|
),
|
@@ -1313,7 +1320,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1313
1320
|
),
|
1314
1321
|
"page_table": torch.zeros(
|
1315
1322
|
max_bs,
|
1316
|
-
|
1323
|
+
max_num_pages,
|
1317
1324
|
dtype=torch.int32,
|
1318
1325
|
device=self.device,
|
1319
1326
|
),
|
@@ -122,6 +122,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
122
122
|
# Allocate buffers
|
123
123
|
global global_workspace_buffer
|
124
124
|
if global_workspace_buffer is None:
|
125
|
+
# different from flashinfer zero_init_global_workspace_buffer
|
125
126
|
global_workspace_buffer = torch.empty(
|
126
127
|
global_config.flashinfer_workspace_size,
|
127
128
|
dtype=torch.uint8,
|
@@ -870,6 +871,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
870
871
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
871
872
|
):
|
872
873
|
if use_ragged:
|
874
|
+
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
|
875
|
+
# and forward_batch.extend_seq_lens_cpu
|
873
876
|
paged_kernel_lens = prefix_lens
|
874
877
|
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
875
878
|
else:
|
@@ -1260,11 +1263,12 @@ def should_use_tensor_core(
|
|
1260
1263
|
# Calculate GQA group size
|
1261
1264
|
gqa_group_size = num_attention_heads // num_kv_heads
|
1262
1265
|
|
1263
|
-
#
|
1266
|
+
# For Flashinfer, a GQA group size of at least 4 is needed to efficiently
|
1267
|
+
# use Tensor Cores, as it fuses the head group with the token dimension in MMA.
|
1264
1268
|
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
|
1265
1269
|
return True
|
1266
1270
|
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
|
1267
|
-
return gqa_group_size
|
1271
|
+
return gqa_group_size >= 4
|
1268
1272
|
else:
|
1269
1273
|
return False
|
1270
1274
|
|
@@ -1369,7 +1373,14 @@ def fast_decode_plan(
|
|
1369
1373
|
|
1370
1374
|
if self.use_tensor_cores:
|
1371
1375
|
# ALSO convert last_page_len to CPU
|
1372
|
-
|
1376
|
+
if page_size == 1:
|
1377
|
+
# When page size is 1, last_page_len is always 1.
|
1378
|
+
# Directly construct the host tensor rather than executing a device-to-host copy.
|
1379
|
+
last_page_len_host = torch.ones(
|
1380
|
+
(batch_size,), dtype=torch.int32, device="cpu"
|
1381
|
+
)
|
1382
|
+
else:
|
1383
|
+
last_page_len_host = last_page_len.cpu()
|
1373
1384
|
|
1374
1385
|
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
|
1375
1386
|
|