sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -11
- sglang/bench_serving.py +149 -1
- sglang/check_env.py +3 -3
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +32 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +151 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +58 -24
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +22 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +129 -94
- sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +6 -1
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +81 -35
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +44 -16
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +291 -72
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +60 -28
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +159 -90
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +2 -277
- sglang/srt/models/deepseek_v2.py +132 -37
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +93 -31
- sglang/srt/models/llama4.py +54 -7
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +4 -16
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +58 -62
- sglang/srt/openai_api/protocol.py +38 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +93 -24
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +123 -10
- sglang/test/runners.py +4 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +32 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,73 @@
|
|
1
|
+
import asyncio
|
2
|
+
import math
|
3
|
+
from typing import List, Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from PIL import Image
|
7
|
+
|
8
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
9
|
+
BaseMultimodalProcessor as SGLangBaseProcessor,
|
10
|
+
)
|
11
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
12
|
+
MultimodalSpecialTokens,
|
13
|
+
)
|
14
|
+
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
15
|
+
from sglang.srt.models.kimi_vl import KimiVLForConditionalGeneration
|
16
|
+
|
17
|
+
|
18
|
+
# Compatible with KimiVLForConditionalGeneration
|
19
|
+
class KimiVLImageProcessor(SGLangBaseProcessor):
|
20
|
+
models = [KimiVLForConditionalGeneration]
|
21
|
+
|
22
|
+
def __init__(self, hf_config, server_args, _processor):
|
23
|
+
super().__init__(hf_config, server_args, _processor)
|
24
|
+
self.IMAGE_TOKEN = "<|media_pad|>"
|
25
|
+
self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
|
26
|
+
|
27
|
+
self.im_start = "<|media_start|>"
|
28
|
+
self.im_start_id = _processor.tokenizer.convert_tokens_to_ids(self.im_start)
|
29
|
+
|
30
|
+
self.im_end = "<|media_end|>"
|
31
|
+
self.im_end_id = _processor.tokenizer.convert_tokens_to_ids(self.im_end)
|
32
|
+
|
33
|
+
self.im_content = "<|media_content|>"
|
34
|
+
self.im_content_id = _processor.tokenizer.convert_tokens_to_ids(self.im_content)
|
35
|
+
|
36
|
+
async def process_mm_data_async(
|
37
|
+
self,
|
38
|
+
image_data: List[Union[str, bytes]],
|
39
|
+
input_text,
|
40
|
+
request_obj,
|
41
|
+
max_req_input_len,
|
42
|
+
*args,
|
43
|
+
**kwargs,
|
44
|
+
):
|
45
|
+
if not image_data:
|
46
|
+
return None
|
47
|
+
if isinstance(image_data, str):
|
48
|
+
image_data = [image_data]
|
49
|
+
|
50
|
+
base_output = self.load_mm_data(
|
51
|
+
prompt=input_text,
|
52
|
+
image_data=image_data,
|
53
|
+
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
|
54
|
+
max_req_input_len=max_req_input_len,
|
55
|
+
)
|
56
|
+
ret = self.process_mm_data(
|
57
|
+
input_text=base_output.input_text,
|
58
|
+
images=base_output.images,
|
59
|
+
)
|
60
|
+
return {
|
61
|
+
"input_ids": ret["input_ids"].flatten().tolist(),
|
62
|
+
"mm_items": [
|
63
|
+
MultimodalDataItem(
|
64
|
+
pixel_values=ret["pixel_values"],
|
65
|
+
image_grid_thws=ret["image_grid_hws"],
|
66
|
+
modality=Modality.IMAGE,
|
67
|
+
)
|
68
|
+
],
|
69
|
+
"im_token_id": self.im_token_id,
|
70
|
+
"im_start_id": self.im_start_id,
|
71
|
+
"im_end_id": self.im_end_id,
|
72
|
+
"im_content_id": self.im_content_id,
|
73
|
+
}
|
@@ -66,23 +66,24 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
|
66
66
|
# Put some global args for easy access
|
67
67
|
global_server_args_dict = {
|
68
68
|
"attention_backend": ServerArgs.attention_backend,
|
69
|
-
"
|
70
|
-
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
71
|
-
"torchao_config": ServerArgs.torchao_config,
|
72
|
-
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
73
|
-
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
74
|
-
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
75
|
-
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
|
69
|
+
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
|
76
70
|
"deepep_mode": ServerArgs.deepep_mode,
|
77
71
|
"device": ServerArgs.device,
|
78
|
-
"
|
79
|
-
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
72
|
+
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
|
80
73
|
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
74
|
+
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
|
75
|
+
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
76
|
+
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
77
|
+
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
81
78
|
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
79
|
+
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
|
82
80
|
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
|
83
|
-
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
|
84
81
|
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
|
85
|
-
"
|
82
|
+
"sampling_backend": ServerArgs.sampling_backend,
|
83
|
+
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
84
|
+
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
85
|
+
"torchao_config": ServerArgs.torchao_config,
|
86
|
+
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
86
87
|
}
|
87
88
|
|
88
89
|
logger = logging.getLogger(__name__)
|
@@ -728,6 +729,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
728
729
|
# Events
|
729
730
|
launch_done: Optional[threading.Event] = None
|
730
731
|
|
732
|
+
# For chunked prefill in PP
|
733
|
+
chunked_req: Optional[Req] = None
|
734
|
+
|
731
735
|
# Sampling info
|
732
736
|
sampling_info: SamplingBatchInfo = None
|
733
737
|
next_batch_sampling_info: SamplingBatchInfo = None
|
@@ -741,6 +745,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
741
745
|
out_cache_loc: torch.Tensor = None # shape: [b], int64
|
742
746
|
output_ids: torch.Tensor = None # shape: [b], int64
|
743
747
|
|
748
|
+
# For multimodal inputs
|
749
|
+
multimodal_inputs: Optional[List] = None
|
750
|
+
|
744
751
|
# The sum of all sequence lengths
|
745
752
|
seq_lens_sum: int = None
|
746
753
|
|
@@ -761,7 +768,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
761
768
|
# For extend and mixed chunekd prefill
|
762
769
|
prefix_lens: List[int] = None
|
763
770
|
extend_lens: List[int] = None
|
764
|
-
extend_num_tokens: int = None
|
771
|
+
extend_num_tokens: Optional[int] = None
|
765
772
|
decoding_reqs: List[Req] = None
|
766
773
|
extend_logprob_start_lens: List[int] = None
|
767
774
|
# It comes empty list if logprob is not required.
|
@@ -803,6 +810,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
803
810
|
enable_overlap: bool,
|
804
811
|
spec_algorithm: SpeculativeAlgorithm,
|
805
812
|
enable_custom_logit_processor: bool,
|
813
|
+
chunked_req: Optional[Req] = None,
|
806
814
|
):
|
807
815
|
return_logprob = any(req.return_logprob for req in reqs)
|
808
816
|
|
@@ -820,6 +828,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
820
828
|
spec_algorithm=spec_algorithm,
|
821
829
|
enable_custom_logit_processor=enable_custom_logit_processor,
|
822
830
|
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
831
|
+
chunked_req=chunked_req,
|
823
832
|
)
|
824
833
|
|
825
834
|
def batch_size(self):
|
@@ -1044,6 +1053,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1044
1053
|
# Copy prefix and do some basic check
|
1045
1054
|
input_embeds = []
|
1046
1055
|
extend_input_logprob_token_ids = []
|
1056
|
+
multimodal_inputs = []
|
1047
1057
|
|
1048
1058
|
for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
|
1049
1059
|
req.req_pool_idx = req_pool_indices[i]
|
@@ -1059,6 +1069,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1059
1069
|
# If req.input_embeds is already a list, append its content directly
|
1060
1070
|
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
1061
1071
|
|
1072
|
+
multimodal_inputs.append(req.multimodal_inputs)
|
1073
|
+
|
1062
1074
|
req.cached_tokens += pre_len - req.already_computed
|
1063
1075
|
req.already_computed = seq_len
|
1064
1076
|
req.is_retracted = False
|
@@ -1141,6 +1153,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1141
1153
|
if input_embeds
|
1142
1154
|
else None
|
1143
1155
|
)
|
1156
|
+
for mm_input in multimodal_inputs:
|
1157
|
+
if mm_input is None:
|
1158
|
+
continue
|
1159
|
+
for mm_item in mm_input.mm_items:
|
1160
|
+
pixel_values = getattr(mm_item, "pixel_values", None)
|
1161
|
+
if isinstance(pixel_values, torch.Tensor):
|
1162
|
+
mm_item.pixel_values = pixel_values.to(
|
1163
|
+
self.device, non_blocking=True
|
1164
|
+
)
|
1165
|
+
self.multimodal_inputs = multimodal_inputs
|
1144
1166
|
self.seq_lens_sum = sum(seq_lens)
|
1145
1167
|
|
1146
1168
|
if self.return_logprob:
|
@@ -1236,7 +1258,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1236
1258
|
|
1237
1259
|
def retract_decode(self, server_args: ServerArgs):
|
1238
1260
|
"""Retract the decoding requests when there is not enough memory."""
|
1239
|
-
sorted_indices =
|
1261
|
+
sorted_indices = list(range(len(self.reqs)))
|
1240
1262
|
|
1241
1263
|
# TODO(lsyin): improve retraction policy for radix cache
|
1242
1264
|
# For spec decoding, filter_batch API can only filter
|
@@ -1413,15 +1435,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1413
1435
|
|
1414
1436
|
def filter_batch(
|
1415
1437
|
self,
|
1416
|
-
chunked_req_to_exclude: Optional[Req] = None,
|
1438
|
+
chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
|
1417
1439
|
keep_indices: Optional[List[int]] = None,
|
1418
1440
|
):
|
1419
1441
|
if keep_indices is None:
|
1442
|
+
if isinstance(chunked_req_to_exclude, Req):
|
1443
|
+
chunked_req_to_exclude = [chunked_req_to_exclude]
|
1444
|
+
elif chunked_req_to_exclude is None:
|
1445
|
+
chunked_req_to_exclude = []
|
1420
1446
|
keep_indices = [
|
1421
1447
|
i
|
1422
1448
|
for i in range(len(self.reqs))
|
1423
1449
|
if not self.reqs[i].finished()
|
1424
|
-
and self.reqs[i]
|
1450
|
+
and not self.reqs[i] in chunked_req_to_exclude
|
1425
1451
|
]
|
1426
1452
|
|
1427
1453
|
if keep_indices is None or len(keep_indices) == 0:
|
@@ -1442,6 +1468,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1442
1468
|
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
1443
1469
|
|
1444
1470
|
self.reqs = [self.reqs[i] for i in keep_indices]
|
1471
|
+
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
|
1445
1472
|
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
1446
1473
|
self.seq_lens = self.seq_lens[keep_indices_device]
|
1447
1474
|
self.out_cache_loc = None
|
@@ -1490,6 +1517,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1490
1517
|
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
1491
1518
|
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
|
1492
1519
|
self.reqs.extend(other.reqs)
|
1520
|
+
self.multimodal_inputs.extend(other.multimodal_inputs)
|
1493
1521
|
|
1494
1522
|
self.return_logprob |= other.return_logprob
|
1495
1523
|
self.has_stream |= other.has_stream
|
@@ -1548,7 +1576,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1548
1576
|
extend_seq_lens=extend_seq_lens,
|
1549
1577
|
extend_prefix_lens=extend_prefix_lens,
|
1550
1578
|
extend_logprob_start_lens=extend_logprob_start_lens,
|
1551
|
-
multimodal_inputs=
|
1579
|
+
multimodal_inputs=self.multimodal_inputs,
|
1552
1580
|
encoder_cached=self.encoder_cached,
|
1553
1581
|
encoder_lens=self.encoder_lens,
|
1554
1582
|
encoder_lens_cpu=self.encoder_lens_cpu,
|
@@ -455,7 +455,10 @@ class PrefillAdder:
|
|
455
455
|
total_tokens = req.extend_input_len + min(
|
456
456
|
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
|
457
457
|
)
|
458
|
-
input_tokens =
|
458
|
+
input_tokens = (
|
459
|
+
-(-req.extend_input_len // self.tree_cache.page_size)
|
460
|
+
* self.tree_cache.page_size
|
461
|
+
)
|
459
462
|
prefix_len = len(req.prefix_indices)
|
460
463
|
|
461
464
|
if total_tokens >= self.rem_total_tokens:
|
@@ -477,7 +480,10 @@ class PrefillAdder:
|
|
477
480
|
req.last_node_global, req.prefix_indices
|
478
481
|
)
|
479
482
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
480
|
-
input_tokens =
|
483
|
+
input_tokens = (
|
484
|
+
-(-req.extend_input_len // self.tree_cache.page_size)
|
485
|
+
* self.tree_cache.page_size
|
486
|
+
)
|
481
487
|
prefix_len = len(req.prefix_indices)
|
482
488
|
|
483
489
|
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
|
@@ -493,12 +499,12 @@ class PrefillAdder:
|
|
493
499
|
),
|
494
500
|
)
|
495
501
|
else:
|
496
|
-
|
502
|
+
# Make sure at least one page is available
|
503
|
+
trunc_len = self.rem_chunk_tokens - self.tree_cache.page_size + 1
|
504
|
+
if trunc_len <= 0:
|
497
505
|
return AddReqResult.OTHER
|
498
506
|
|
499
507
|
# Chunked prefill
|
500
|
-
trunc_len = self.rem_chunk_tokens
|
501
|
-
|
502
508
|
req.extend_input_len = trunc_len
|
503
509
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
504
510
|
|