sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -45,17 +45,20 @@ import triton
|
|
45
45
|
import triton.language as tl
|
46
46
|
|
47
47
|
from sglang.global_config import global_config
|
48
|
-
from sglang.srt.configs.model_config import ModelConfig
|
49
48
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
50
49
|
from sglang.srt.disaggregation.base import BaseKVSender
|
51
50
|
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
52
51
|
ScheduleBatchDisaggregationDecodeMixin,
|
53
52
|
)
|
54
53
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
55
|
-
from sglang.srt.mem_cache.allocator import
|
54
|
+
from sglang.srt.mem_cache.allocator import (
|
55
|
+
BaseTokenToKVPoolAllocator,
|
56
|
+
SWATokenToKVPoolAllocator,
|
57
|
+
)
|
56
58
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
57
59
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
58
60
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
61
|
+
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
59
62
|
from sglang.srt.metrics.collector import TimeStats
|
60
63
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
61
64
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
@@ -64,6 +67,7 @@ from sglang.srt.server_args import ServerArgs
|
|
64
67
|
from sglang.srt.utils import flatten_nested_list, support_triton
|
65
68
|
|
66
69
|
if TYPE_CHECKING:
|
70
|
+
from sglang.srt.configs.model_config import ModelConfig
|
67
71
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
68
72
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
69
73
|
|
@@ -102,6 +106,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
102
106
|
"num_reserved_decode_tokens",
|
103
107
|
"weight_loader_disable_mmap",
|
104
108
|
"enable_triton_kernel_moe",
|
109
|
+
"enable_multimodal",
|
105
110
|
]
|
106
111
|
|
107
112
|
# Put some global args for easy access
|
@@ -197,45 +202,41 @@ class MultimodalDataItem:
|
|
197
202
|
For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
|
198
203
|
One for images and one for audio.
|
199
204
|
|
200
|
-
We put the common fields first and the model-specific fields
|
205
|
+
We put the common fields first and the model-specific fields in model_specific_data.
|
201
206
|
"""
|
202
207
|
|
203
208
|
modality: Modality
|
204
209
|
hash: int = None
|
205
210
|
pad_value: int = None
|
206
|
-
image_sizes: Tuple[int, int] = None
|
207
211
|
offsets: Optional[list] = None
|
212
|
+
# the raw features returned by processor, e.g. pixel_values or audio_features
|
213
|
+
feature: Union[torch.Tensor, np.ndarray] = None
|
208
214
|
|
209
|
-
# the
|
210
|
-
|
211
|
-
pixel_values: Union[torch.Tensor, np.ndarray, "PIL.Image"] = None
|
212
|
-
audio_features: Union[torch.Tensor, np.ndarray] = None
|
213
|
-
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
214
|
-
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
215
|
-
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
216
|
-
|
217
|
-
# For qwen-vl
|
218
|
-
image_grid_thw: Union[torch.Tensor, np.ndarray] = None
|
219
|
-
second_per_grid_ts: Optional[List[torch.Tensor]] = None
|
215
|
+
# the precomputed embeddings for the modality, e.g. image_emb for image, audio_emb for audio
|
216
|
+
precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
|
220
217
|
|
221
|
-
#
|
222
|
-
|
223
|
-
image_spatial_crop: Optional[torch.Tensor] = None
|
218
|
+
# Model-specific data stored in a dictionary
|
219
|
+
model_specific_data: dict[str, Any] = dataclasses.field(default_factory=dict)
|
224
220
|
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
221
|
+
def __getattr__(self, name: str):
|
222
|
+
if (
|
223
|
+
"model_specific_data" in self.__dict__
|
224
|
+
and name in self.__dict__["model_specific_data"]
|
225
|
+
):
|
226
|
+
return self.__dict__["model_specific_data"][name]
|
227
|
+
else:
|
228
|
+
raise AttributeError(
|
229
|
+
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
230
|
+
)
|
232
231
|
|
233
|
-
|
234
|
-
|
232
|
+
def __setitem__(self, key: str, value: Any):
|
233
|
+
if key in self.__dict__:
|
234
|
+
self.__dict__[key] = value
|
235
|
+
else:
|
236
|
+
self.model_specific_data[key] = value
|
235
237
|
|
236
|
-
|
237
|
-
|
238
|
-
input_features_mask: Optional[torch.Tensor] = None
|
238
|
+
def set(self, key: str, value: Any):
|
239
|
+
self.__setitem__(key, value)
|
239
240
|
|
240
241
|
@staticmethod
|
241
242
|
def is_empty_list(l):
|
@@ -250,18 +251,11 @@ class MultimodalDataItem:
|
|
250
251
|
from sglang.srt.managers.mm_utils import hash_feature
|
251
252
|
|
252
253
|
if self.hash is None:
|
253
|
-
if self.
|
254
|
-
|
255
|
-
elif self.is_audio():
|
256
|
-
if self.audio_features is not None:
|
257
|
-
self.hash = hash_feature(self.audio_features)
|
258
|
-
elif self.input_features is not None:
|
259
|
-
self.hash = hash_feature(self.input_features)
|
260
|
-
elif self.is_video():
|
261
|
-
self.hash = hash_feature(self.pixel_values_videos)
|
254
|
+
if self.feature is not None:
|
255
|
+
hashed_feature = self.feature
|
262
256
|
else:
|
263
|
-
|
264
|
-
|
257
|
+
hashed_feature = self.precomputed_embeddings
|
258
|
+
self.hash = hash_feature(hashed_feature)
|
265
259
|
assert self.hash is not None
|
266
260
|
self.pad_value = self.hash % (1 << 30)
|
267
261
|
|
@@ -269,25 +263,13 @@ class MultimodalDataItem:
|
|
269
263
|
return self.modality == modality
|
270
264
|
|
271
265
|
def is_audio(self):
|
272
|
-
return
|
273
|
-
self.precomputed_features is not None
|
274
|
-
or not MultimodalDataItem.is_empty_list(self.audio_features)
|
275
|
-
or not MultimodalDataItem.is_empty_list(self.input_features)
|
276
|
-
)
|
266
|
+
return self.modality == Modality.AUDIO
|
277
267
|
|
278
268
|
def is_image(self):
|
279
|
-
return
|
280
|
-
self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
|
281
|
-
) and (
|
282
|
-
self.precomputed_features is not None
|
283
|
-
or not MultimodalDataItem.is_empty_list(self.pixel_values)
|
284
|
-
)
|
269
|
+
return self.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]
|
285
270
|
|
286
271
|
def is_video(self):
|
287
|
-
return
|
288
|
-
self.precomputed_features is not None
|
289
|
-
or not MultimodalDataItem.is_empty_list(self.pixel_values_videos)
|
290
|
-
)
|
272
|
+
return self.modality == Modality.VIDEO
|
291
273
|
|
292
274
|
def is_valid(self) -> bool:
|
293
275
|
return self.is_image() or self.is_video() or self.is_audio()
|
@@ -307,9 +289,8 @@ class MultimodalDataItem:
|
|
307
289
|
return ret
|
308
290
|
|
309
291
|
def merge(self, other):
|
310
|
-
self.
|
311
|
-
self.
|
312
|
-
self.image_offsets += other.image_offsets
|
292
|
+
self.feature += other.feature
|
293
|
+
self.offsets += other.offsets
|
313
294
|
self.hash = hash((self.hash, other.hash))
|
314
295
|
self.set_pad_value()
|
315
296
|
|
@@ -350,7 +331,6 @@ class MultimodalInputs:
|
|
350
331
|
|
351
332
|
assert isinstance(ret.mm_items, list)
|
352
333
|
ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
|
353
|
-
|
354
334
|
for item in ret.mm_items:
|
355
335
|
item.set_pad_value()
|
356
336
|
|
@@ -451,6 +431,7 @@ class Req:
|
|
451
431
|
bootstrap_port: Optional[int] = None,
|
452
432
|
bootstrap_room: Optional[int] = None,
|
453
433
|
data_parallel_rank: Optional[int] = None,
|
434
|
+
vocab_size: Optional[int] = None,
|
454
435
|
):
|
455
436
|
# Input and output info
|
456
437
|
self.rid = rid
|
@@ -500,6 +481,7 @@ class Req:
|
|
500
481
|
self.to_abort_message: str = None
|
501
482
|
self.stream = stream
|
502
483
|
self.eos_token_ids = eos_token_ids
|
484
|
+
self.vocab_size = vocab_size
|
503
485
|
|
504
486
|
# For incremental decoding
|
505
487
|
# ----- | --------- read_ids -------|
|
@@ -527,6 +509,8 @@ class Req:
|
|
527
509
|
self.last_node: Any = None
|
528
510
|
self.last_host_node: Any = None
|
529
511
|
self.host_hit_length = 0
|
512
|
+
# The node to lock until for swa radix tree lock ref
|
513
|
+
self.swa_uuid_for_lock: Optional[int] = None
|
530
514
|
|
531
515
|
# Whether or not if it is chunked. It increments whenever
|
532
516
|
# it is chunked, and decrement whenever chunked request is
|
@@ -731,6 +715,14 @@ class Req:
|
|
731
715
|
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
732
716
|
return
|
733
717
|
|
718
|
+
if last_token_id > self.vocab_size or last_token_id < 0:
|
719
|
+
if self.sampling_params.stop_token_ids:
|
720
|
+
self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
|
721
|
+
if self.eos_token_ids:
|
722
|
+
self.output_ids[-1] = next(iter(self.eos_token_ids))
|
723
|
+
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
|
724
|
+
return
|
725
|
+
|
734
726
|
# Check stop strings
|
735
727
|
if len(self.sampling_params.stop_strs) > 0:
|
736
728
|
tail_str = self.tokenizer.decode(
|
@@ -745,6 +737,7 @@ class Req:
|
|
745
737
|
def reset_for_retract(self):
|
746
738
|
self.prefix_indices = []
|
747
739
|
self.last_node = None
|
740
|
+
self.swa_uuid_for_lock = None
|
748
741
|
self.extend_input_len = 0
|
749
742
|
self.is_retracted = True
|
750
743
|
self.input_token_logprobs = None
|
@@ -813,6 +806,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
813
806
|
req_to_token_pool: ReqToTokenPool = None
|
814
807
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
|
815
808
|
tree_cache: BasePrefixCache = None
|
809
|
+
is_hybrid: bool = False
|
816
810
|
|
817
811
|
# Batch configs
|
818
812
|
model_config: ModelConfig = None
|
@@ -918,11 +912,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
918
912
|
):
|
919
913
|
return_logprob = any(req.return_logprob for req in reqs)
|
920
914
|
|
915
|
+
is_hybrid = False
|
916
|
+
if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
|
917
|
+
assert isinstance(tree_cache, SWARadixCache) or isinstance(
|
918
|
+
tree_cache, SWAChunkCache
|
919
|
+
), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
|
920
|
+
is_hybrid = True
|
921
|
+
|
921
922
|
return cls(
|
922
923
|
reqs=reqs,
|
923
924
|
req_to_token_pool=req_to_token_pool,
|
924
925
|
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
|
925
926
|
tree_cache=tree_cache,
|
927
|
+
is_hybrid=is_hybrid,
|
926
928
|
model_config=model_config,
|
927
929
|
enable_overlap=enable_overlap,
|
928
930
|
return_logprob=return_logprob,
|
@@ -953,9 +955,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
953
955
|
return req_pool_indices
|
954
956
|
|
955
957
|
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
|
956
|
-
|
957
|
-
if self.tree_cache is not None:
|
958
|
-
self.tree_cache.evict(num_tokens)
|
958
|
+
self._evict_tree_cache_if_needed(num_tokens)
|
959
959
|
|
960
960
|
if backup_state:
|
961
961
|
state = self.token_to_kv_pool_allocator.backup_state()
|
@@ -966,7 +966,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
966
966
|
error_msg = (
|
967
967
|
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
968
968
|
f"Try to allocate {num_tokens} tokens.\n"
|
969
|
-
f"
|
969
|
+
f"{self._available_and_evictable_str()}"
|
970
970
|
)
|
971
971
|
logger.error(error_msg)
|
972
972
|
if self.tree_cache is not None:
|
@@ -986,16 +986,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
986
986
|
extend_num_tokens: int,
|
987
987
|
backup_state: bool = False,
|
988
988
|
):
|
989
|
-
|
990
|
-
|
991
|
-
< extend_num_tokens
|
989
|
+
num_tokens = (
|
990
|
+
extend_num_tokens
|
992
991
|
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
993
|
-
)
|
994
|
-
|
995
|
-
self.tree_cache.evict(
|
996
|
-
extend_num_tokens
|
997
|
-
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
998
|
-
)
|
992
|
+
)
|
993
|
+
self._evict_tree_cache_if_needed(num_tokens)
|
999
994
|
|
1000
995
|
if backup_state:
|
1001
996
|
state = self.token_to_kv_pool_allocator.backup_state()
|
@@ -1007,9 +1002,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1007
1002
|
error_msg = (
|
1008
1003
|
f"Prefill out of memory. Try to lower your batch size.\n"
|
1009
1004
|
f"Try to allocate {extend_num_tokens} tokens.\n"
|
1010
|
-
f"
|
1011
|
-
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
1012
|
-
f"{self.tree_cache.evictable_size()=}\n"
|
1005
|
+
f"{self._available_and_evictable_str()}"
|
1013
1006
|
)
|
1014
1007
|
logger.error(error_msg)
|
1015
1008
|
raise RuntimeError(error_msg)
|
@@ -1025,14 +1018,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1025
1018
|
last_loc: torch.Tensor,
|
1026
1019
|
backup_state: bool = False,
|
1027
1020
|
):
|
1028
|
-
|
1029
|
-
|
1030
|
-
|
1031
|
-
< len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
1032
|
-
):
|
1033
|
-
self.tree_cache.evict(
|
1034
|
-
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
1035
|
-
)
|
1021
|
+
num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
1022
|
+
|
1023
|
+
self._evict_tree_cache_if_needed(num_tokens)
|
1036
1024
|
|
1037
1025
|
if backup_state:
|
1038
1026
|
state = self.token_to_kv_pool_allocator.backup_state()
|
@@ -1042,9 +1030,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1042
1030
|
error_msg = (
|
1043
1031
|
f"Decode out of memory. Try to lower your batch size.\n"
|
1044
1032
|
f"Try to allocate {len(seq_lens)} tokens.\n"
|
1045
|
-
f"
|
1046
|
-
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
1047
|
-
f"{self.tree_cache.evictable_size()=}\n"
|
1033
|
+
f"{self._available_and_evictable_str()}"
|
1048
1034
|
)
|
1049
1035
|
logger.error(error_msg)
|
1050
1036
|
raise RuntimeError(error_msg)
|
@@ -1181,7 +1167,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1181
1167
|
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
1182
1168
|
)
|
1183
1169
|
if isinstance(self.tree_cache, SWAChunkCache):
|
1184
|
-
self.tree_cache.
|
1170
|
+
self.tree_cache.evict_swa(
|
1185
1171
|
req, pre_len, self.model_config.attention_chunk_size
|
1186
1172
|
)
|
1187
1173
|
|
@@ -1278,11 +1264,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1278
1264
|
if mm_input is None:
|
1279
1265
|
continue
|
1280
1266
|
for mm_item in mm_input.mm_items:
|
1281
|
-
pixel_values = getattr(mm_item, "
|
1267
|
+
pixel_values = getattr(mm_item, "feature", None)
|
1282
1268
|
if isinstance(pixel_values, torch.Tensor):
|
1283
|
-
mm_item.
|
1284
|
-
self.device, non_blocking=True
|
1285
|
-
)
|
1269
|
+
mm_item.feature = pixel_values.to(self.device, non_blocking=True)
|
1286
1270
|
self.multimodal_inputs = multimodal_inputs
|
1287
1271
|
self.token_type_ids = token_type_ids_tensor
|
1288
1272
|
self.seq_lens_sum = sum(seq_lens)
|
@@ -1328,6 +1312,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1328
1312
|
self.model_config.vocab_size,
|
1329
1313
|
)
|
1330
1314
|
|
1315
|
+
def prepare_for_split_prefill(self):
|
1316
|
+
self.prepare_for_extend()
|
1317
|
+
# For split prefill, we need to set the forward mode to SPLIT_PREFILL
|
1318
|
+
self.forward_mode = ForwardMode.SPLIT_PREFILL
|
1319
|
+
|
1331
1320
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
1332
1321
|
self.forward_mode = ForwardMode.MIXED
|
1333
1322
|
running_bs = running_batch.batch_size()
|
@@ -1371,17 +1360,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1371
1360
|
)
|
1372
1361
|
|
1373
1362
|
def check_decode_mem(self, buf_multiplier=1):
|
1374
|
-
|
1363
|
+
num_tokens = (
|
1375
1364
|
self.new_page_count_next_decode()
|
1376
1365
|
* buf_multiplier
|
1377
1366
|
* self.token_to_kv_pool_allocator.page_size
|
1378
1367
|
)
|
1379
|
-
if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
|
1380
|
-
return True
|
1381
1368
|
|
1382
|
-
self.
|
1383
|
-
|
1384
|
-
return self.token_to_kv_pool_allocator.available_size() >= tokens_required
|
1369
|
+
self._evict_tree_cache_if_needed(num_tokens)
|
1370
|
+
return self._is_available_size_sufficient(num_tokens)
|
1385
1371
|
|
1386
1372
|
def retract_decode(self, server_args: ServerArgs):
|
1387
1373
|
"""Retract the decoding requests when there is not enough memory."""
|
@@ -1414,19 +1400,38 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1414
1400
|
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
1415
1401
|
)
|
1416
1402
|
|
1403
|
+
def _get_available_size():
|
1404
|
+
if self.is_hybrid:
|
1405
|
+
return min(
|
1406
|
+
self.token_to_kv_pool_allocator.full_available_size(),
|
1407
|
+
self.token_to_kv_pool_allocator.swa_available_size(),
|
1408
|
+
)
|
1409
|
+
else:
|
1410
|
+
return self.token_to_kv_pool_allocator.available_size()
|
1411
|
+
|
1417
1412
|
retracted_reqs = []
|
1418
1413
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
1419
1414
|
first_iter = True
|
1420
1415
|
while (
|
1421
|
-
|
1422
|
-
< get_required_tokens(len(sorted_indices))
|
1416
|
+
_get_available_size() < get_required_tokens(len(sorted_indices))
|
1423
1417
|
or first_iter
|
1424
1418
|
):
|
1425
1419
|
if len(sorted_indices) == 1:
|
1426
1420
|
# Corner case: only one request left
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1421
|
+
if self.is_hybrid:
|
1422
|
+
full_available_size = (
|
1423
|
+
self.token_to_kv_pool_allocator.full_available_size()
|
1424
|
+
)
|
1425
|
+
swa_available_size = (
|
1426
|
+
self.token_to_kv_pool_allocator.swa_available_size()
|
1427
|
+
)
|
1428
|
+
assert (
|
1429
|
+
full_available_size > 0 and swa_available_size > 0
|
1430
|
+
), f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}"
|
1431
|
+
else:
|
1432
|
+
assert (
|
1433
|
+
self.token_to_kv_pool_allocator.available_size() > 0
|
1434
|
+
), f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}"
|
1430
1435
|
break
|
1431
1436
|
|
1432
1437
|
first_iter = False
|
@@ -1458,15 +1463,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1458
1463
|
self.req_to_token_pool.free(req.req_pool_idx)
|
1459
1464
|
|
1460
1465
|
# release the last node
|
1461
|
-
self.
|
1466
|
+
if self.is_hybrid:
|
1467
|
+
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
1468
|
+
else:
|
1469
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
1462
1470
|
|
1463
1471
|
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
1464
|
-
|
1465
|
-
|
1466
|
-
- self.token_to_kv_pool_allocator.available_size()
|
1467
|
-
)
|
1468
|
-
residual_size = max(0, residual_size)
|
1469
|
-
self.tree_cache.evict(residual_size)
|
1472
|
+
num_tokens = len(sorted_indices) * global_config.retract_decode_steps
|
1473
|
+
self._evict_tree_cache_if_needed(num_tokens)
|
1470
1474
|
|
1471
1475
|
req.reset_for_retract()
|
1472
1476
|
|
@@ -1559,7 +1563,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1559
1563
|
# free memory
|
1560
1564
|
if isinstance(self.tree_cache, SWAChunkCache):
|
1561
1565
|
for req in self.reqs:
|
1562
|
-
self.tree_cache.
|
1566
|
+
self.tree_cache.evict_swa(
|
1563
1567
|
req, req.seqlen - 1, self.model_config.attention_chunk_size
|
1564
1568
|
)
|
1565
1569
|
|
@@ -1778,6 +1782,53 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1778
1782
|
is_extend_in_batch=self.is_extend_in_batch,
|
1779
1783
|
)
|
1780
1784
|
|
1785
|
+
def _evict_tree_cache_if_needed(
|
1786
|
+
self,
|
1787
|
+
num_tokens: int,
|
1788
|
+
) -> None:
|
1789
|
+
if isinstance(self.tree_cache, SWAChunkCache):
|
1790
|
+
return
|
1791
|
+
|
1792
|
+
if self.is_hybrid:
|
1793
|
+
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
|
1794
|
+
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
|
1795
|
+
|
1796
|
+
if full_available_size < num_tokens or swa_available_size < num_tokens:
|
1797
|
+
if self.tree_cache is not None:
|
1798
|
+
full_num_tokens = max(0, num_tokens - full_available_size)
|
1799
|
+
swa_num_tokens = max(0, num_tokens - swa_available_size)
|
1800
|
+
self.tree_cache.evict(full_num_tokens, swa_num_tokens)
|
1801
|
+
else:
|
1802
|
+
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
|
1803
|
+
if self.tree_cache is not None:
|
1804
|
+
self.tree_cache.evict(num_tokens)
|
1805
|
+
|
1806
|
+
def _is_available_size_sufficient(self, num_tokens: int) -> bool:
|
1807
|
+
if self.is_hybrid:
|
1808
|
+
return (
|
1809
|
+
self.token_to_kv_pool_allocator.full_available_size() >= num_tokens
|
1810
|
+
and self.token_to_kv_pool_allocator.swa_available_size() >= num_tokens
|
1811
|
+
)
|
1812
|
+
else:
|
1813
|
+
return self.token_to_kv_pool_allocator.available_size() >= num_tokens
|
1814
|
+
|
1815
|
+
def _available_and_evictable_str(self) -> str:
|
1816
|
+
if self.is_hybrid:
|
1817
|
+
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
|
1818
|
+
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
|
1819
|
+
full_evictable_size = self.tree_cache.full_evictable_size()
|
1820
|
+
swa_evictable_size = self.tree_cache.swa_evictable_size()
|
1821
|
+
return (
|
1822
|
+
f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
|
1823
|
+
f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
|
1824
|
+
f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n"
|
1825
|
+
f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n"
|
1826
|
+
)
|
1827
|
+
else:
|
1828
|
+
available_size = self.token_to_kv_pool_allocator.available_size()
|
1829
|
+
evictable_size = self.tree_cache.evictable_size()
|
1830
|
+
return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"
|
1831
|
+
|
1781
1832
|
def __str__(self):
|
1782
1833
|
return (
|
1783
1834
|
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
|
@@ -1839,7 +1890,7 @@ class ModelWorkerBatch:
|
|
1839
1890
|
sampling_info: SamplingBatchInfo
|
1840
1891
|
|
1841
1892
|
# The input Embeds
|
1842
|
-
input_embeds: Optional[torch.
|
1893
|
+
input_embeds: Optional[torch.Tensor] = None
|
1843
1894
|
|
1844
1895
|
# For corss-encoder model
|
1845
1896
|
token_type_ids: Optional[torch.Tensor] = None
|
@@ -1849,7 +1900,6 @@ class ModelWorkerBatch:
|
|
1849
1900
|
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
1850
1901
|
# If set, the output of the batch contains the hidden states of the run.
|
1851
1902
|
capture_hidden_mode: CaptureHiddenMode = None
|
1852
|
-
spec_num_draft_tokens: Optional[int] = None
|
1853
1903
|
hicache_consumer_index: int = 0
|
1854
1904
|
|
1855
1905
|
# Overlap event
|
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
|
|
25
25
|
import torch
|
26
26
|
|
27
27
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
28
|
+
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
28
29
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
29
30
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
30
31
|
|
@@ -311,21 +312,43 @@ class PrefillAdder:
|
|
311
312
|
]
|
312
313
|
)
|
313
314
|
|
315
|
+
self.is_hybrid = isinstance(
|
316
|
+
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
|
317
|
+
)
|
318
|
+
|
314
319
|
@property
|
315
320
|
def rem_total_tokens(self):
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
+
if self.is_hybrid:
|
322
|
+
available_and_evictable = min(
|
323
|
+
self.token_to_kv_pool_allocator.full_available_size()
|
324
|
+
+ self.tree_cache.full_evictable_size(),
|
325
|
+
self.token_to_kv_pool_allocator.swa_available_size()
|
326
|
+
+ self.tree_cache.swa_evictable_size(),
|
327
|
+
)
|
328
|
+
else:
|
329
|
+
available_and_evictable = (
|
330
|
+
self.token_to_kv_pool_allocator.available_size()
|
331
|
+
+ self.tree_cache.evictable_size()
|
332
|
+
)
|
333
|
+
|
334
|
+
return available_and_evictable - self.rem_total_token_offset
|
321
335
|
|
322
336
|
@property
|
323
337
|
def cur_rem_tokens(self):
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
338
|
+
if self.is_hybrid:
|
339
|
+
available_and_evictable = min(
|
340
|
+
self.token_to_kv_pool_allocator.full_available_size()
|
341
|
+
+ self.tree_cache.full_evictable_size(),
|
342
|
+
self.token_to_kv_pool_allocator.swa_available_size()
|
343
|
+
+ self.tree_cache.swa_evictable_size(),
|
344
|
+
)
|
345
|
+
else:
|
346
|
+
available_and_evictable = (
|
347
|
+
self.token_to_kv_pool_allocator.available_size()
|
348
|
+
+ self.tree_cache.evictable_size()
|
349
|
+
)
|
350
|
+
|
351
|
+
return available_and_evictable - self.cur_rem_token_offset
|
329
352
|
|
330
353
|
def ceil_paged_tokens(self, tokens: int) -> int:
|
331
354
|
return -(-tokens // self.page_size) * self.page_size
|
@@ -376,11 +399,18 @@ class PrefillAdder:
|
|
376
399
|
|
377
400
|
@contextmanager
|
378
401
|
def _lock_node(self, last_node: TreeNode):
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
402
|
+
if self.is_hybrid:
|
403
|
+
try:
|
404
|
+
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(last_node)
|
405
|
+
yield None
|
406
|
+
finally:
|
407
|
+
self.tree_cache.dec_lock_ref(last_node, swa_uuid_for_lock)
|
408
|
+
else:
|
409
|
+
try:
|
410
|
+
self.tree_cache.inc_lock_ref(last_node)
|
411
|
+
yield None
|
412
|
+
finally:
|
413
|
+
self.tree_cache.dec_lock_ref(last_node)
|
384
414
|
|
385
415
|
def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
|
386
416
|
# Early exit if no enough tokens for the input tokens
|
@@ -422,16 +452,19 @@ class PrefillAdder:
|
|
422
452
|
else:
|
423
453
|
add_req_state(req, insert_sort=True)
|
424
454
|
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
455
|
+
if not self.is_hybrid:
|
456
|
+
# Skip this logic for swa. The SWA has different memory management, and
|
457
|
+
# this mechanism is underestimating the memory usage.
|
458
|
+
cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
|
459
|
+
tokens_freed = 0
|
460
|
+
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
461
|
+
# tokens_left gives a reservative calculation as the last token is not stored
|
462
|
+
bs = len(self.req_states) - i
|
463
|
+
min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
|
464
|
+
# reserve tokens for corner cases
|
465
|
+
if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
|
466
|
+
return AddReqResult.NO_TOKEN
|
467
|
+
tokens_freed += tokens_occupied
|
435
468
|
|
436
469
|
if (
|
437
470
|
self.rem_chunk_tokens is None # chunked prefill is disabled
|
@@ -499,7 +532,11 @@ class PrefillAdder:
|
|
499
532
|
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
|
500
533
|
# Non-chunked prefill
|
501
534
|
self.can_run_list.append(req)
|
502
|
-
self.
|
535
|
+
if self.is_hybrid:
|
536
|
+
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node)
|
537
|
+
req.swa_uuid_for_lock = swa_uuid_for_lock
|
538
|
+
else:
|
539
|
+
self.tree_cache.inc_lock_ref(req.last_node)
|
503
540
|
self._update_prefill_budget(
|
504
541
|
prefix_len,
|
505
542
|
input_tokens,
|
@@ -520,7 +557,11 @@ class PrefillAdder:
|
|
520
557
|
|
521
558
|
self.can_run_list.append(req)
|
522
559
|
self.new_chunked_req = req
|
523
|
-
self.
|
560
|
+
if self.is_hybrid:
|
561
|
+
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node)
|
562
|
+
req.swa_uuid_for_lock = swa_uuid_for_lock
|
563
|
+
else:
|
564
|
+
self.tree_cache.inc_lock_ref(req.last_node)
|
524
565
|
self._update_prefill_budget(prefix_len, trunc_len, 0)
|
525
566
|
|
526
567
|
return self.budget_state()
|