sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
sglang/srt/managers/mm_utils.py
CHANGED
@@ -76,7 +76,6 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|
76
76
|
This function will replace the data-tokens in between with pad_values accordingly
|
77
77
|
"""
|
78
78
|
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
79
|
-
print(f"{mm_inputs.mm_items=}")
|
80
79
|
data_token_pairs = self.data_token_id_pairs
|
81
80
|
mm_inputs.data_offsets = []
|
82
81
|
if data_token_pairs is None:
|
@@ -222,17 +221,17 @@ def _get_precomputed_embedding(
|
|
222
221
|
items: List[MultimodalDataItem],
|
223
222
|
) -> Optional[torch.Tensor]:
|
224
223
|
"""
|
225
|
-
If all items have
|
226
|
-
If some but not all have
|
227
|
-
If none have
|
224
|
+
If all items have precomputed_embeddings, return their concatenation.
|
225
|
+
If some but not all have precomputed_embeddings, raise NotImplementedError.
|
226
|
+
If none have precomputed_embeddings, return None.
|
228
227
|
"""
|
229
|
-
|
230
|
-
if any(feature is not None for feature in
|
231
|
-
if not all(feature is not None for feature in
|
228
|
+
precomputed_embeddings = [item.precomputed_embeddings for item in items]
|
229
|
+
if any(feature is not None for feature in precomputed_embeddings):
|
230
|
+
if not all(feature is not None for feature in precomputed_embeddings):
|
232
231
|
raise NotImplementedError(
|
233
232
|
"MM inputs where only some items are precomputed."
|
234
233
|
)
|
235
|
-
result = torch.concat(
|
234
|
+
result = torch.concat(precomputed_embeddings)
|
236
235
|
# some models embedding is 3-dim, reshape it to 2-dim (similar to get_embedding_chunk)
|
237
236
|
result = result.reshape(-1, result.shape[-1])
|
238
237
|
return result
|
@@ -52,10 +52,14 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|
52
52
|
ScheduleBatchDisaggregationDecodeMixin,
|
53
53
|
)
|
54
54
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
55
|
-
from sglang.srt.mem_cache.allocator import
|
55
|
+
from sglang.srt.mem_cache.allocator import (
|
56
|
+
BaseTokenToKVPoolAllocator,
|
57
|
+
SWATokenToKVPoolAllocator,
|
58
|
+
)
|
56
59
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
57
60
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
58
61
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
62
|
+
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
59
63
|
from sglang.srt.metrics.collector import TimeStats
|
60
64
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
61
65
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
@@ -197,45 +201,41 @@ class MultimodalDataItem:
|
|
197
201
|
For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
|
198
202
|
One for images and one for audio.
|
199
203
|
|
200
|
-
We put the common fields first and the model-specific fields
|
204
|
+
We put the common fields first and the model-specific fields in model_specific_data.
|
201
205
|
"""
|
202
206
|
|
203
207
|
modality: Modality
|
204
208
|
hash: int = None
|
205
209
|
pad_value: int = None
|
206
|
-
image_sizes: Tuple[int, int] = None
|
207
210
|
offsets: Optional[list] = None
|
211
|
+
# the raw features returned by processor, e.g. pixel_values or audio_features
|
212
|
+
feature: Union[torch.Tensor, np.ndarray] = None
|
208
213
|
|
209
|
-
# the
|
210
|
-
|
211
|
-
pixel_values: Union[torch.Tensor, np.ndarray, "PIL.Image"] = None
|
212
|
-
audio_features: Union[torch.Tensor, np.ndarray] = None
|
213
|
-
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
214
|
-
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
215
|
-
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
216
|
-
|
217
|
-
# For qwen-vl
|
218
|
-
image_grid_thw: Union[torch.Tensor, np.ndarray] = None
|
219
|
-
second_per_grid_ts: Optional[List[torch.Tensor]] = None
|
220
|
-
|
221
|
-
# For deepseek-vl
|
222
|
-
image_emb_mask: Optional[torch.Tensor] = None
|
223
|
-
image_spatial_crop: Optional[torch.Tensor] = None
|
214
|
+
# the precomputed embeddings for the modality, e.g. image_emb for image, audio_emb for audio
|
215
|
+
precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
|
224
216
|
|
225
|
-
#
|
226
|
-
|
227
|
-
tgt_size: Tuple[int, int] = None
|
217
|
+
# Model-specific data stored in a dictionary
|
218
|
+
model_specific_data: dict[str, Any] = dataclasses.field(default_factory=dict)
|
228
219
|
|
229
|
-
|
230
|
-
|
231
|
-
|
220
|
+
def __getattr__(self, name: str):
|
221
|
+
if (
|
222
|
+
"model_specific_data" in self.__dict__
|
223
|
+
and name in self.__dict__["model_specific_data"]
|
224
|
+
):
|
225
|
+
return self.__dict__["model_specific_data"][name]
|
226
|
+
else:
|
227
|
+
raise AttributeError(
|
228
|
+
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
229
|
+
)
|
232
230
|
|
233
|
-
|
234
|
-
|
231
|
+
def __setitem__(self, key: str, value: Any):
|
232
|
+
if key in self.__dict__:
|
233
|
+
self.__dict__[key] = value
|
234
|
+
else:
|
235
|
+
self.model_specific_data[key] = value
|
235
236
|
|
236
|
-
|
237
|
-
|
238
|
-
input_features_mask: Optional[torch.Tensor] = None
|
237
|
+
def set(self, key: str, value: Any):
|
238
|
+
self.__setitem__(key, value)
|
239
239
|
|
240
240
|
@staticmethod
|
241
241
|
def is_empty_list(l):
|
@@ -250,18 +250,11 @@ class MultimodalDataItem:
|
|
250
250
|
from sglang.srt.managers.mm_utils import hash_feature
|
251
251
|
|
252
252
|
if self.hash is None:
|
253
|
-
if self.
|
254
|
-
|
255
|
-
elif self.is_audio():
|
256
|
-
if self.audio_features is not None:
|
257
|
-
self.hash = hash_feature(self.audio_features)
|
258
|
-
elif self.input_features is not None:
|
259
|
-
self.hash = hash_feature(self.input_features)
|
260
|
-
elif self.is_video():
|
261
|
-
self.hash = hash_feature(self.pixel_values_videos)
|
253
|
+
if self.feature is not None:
|
254
|
+
hashed_feature = self.feature
|
262
255
|
else:
|
263
|
-
|
264
|
-
|
256
|
+
hashed_feature = self.precomputed_embeddings
|
257
|
+
self.hash = hash_feature(hashed_feature)
|
265
258
|
assert self.hash is not None
|
266
259
|
self.pad_value = self.hash % (1 << 30)
|
267
260
|
|
@@ -269,25 +262,13 @@ class MultimodalDataItem:
|
|
269
262
|
return self.modality == modality
|
270
263
|
|
271
264
|
def is_audio(self):
|
272
|
-
return
|
273
|
-
self.precomputed_features is not None
|
274
|
-
or not MultimodalDataItem.is_empty_list(self.audio_features)
|
275
|
-
or not MultimodalDataItem.is_empty_list(self.input_features)
|
276
|
-
)
|
265
|
+
return self.modality == Modality.AUDIO
|
277
266
|
|
278
267
|
def is_image(self):
|
279
|
-
return
|
280
|
-
self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
|
281
|
-
) and (
|
282
|
-
self.precomputed_features is not None
|
283
|
-
or not MultimodalDataItem.is_empty_list(self.pixel_values)
|
284
|
-
)
|
268
|
+
return self.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]
|
285
269
|
|
286
270
|
def is_video(self):
|
287
|
-
return
|
288
|
-
self.precomputed_features is not None
|
289
|
-
or not MultimodalDataItem.is_empty_list(self.pixel_values_videos)
|
290
|
-
)
|
271
|
+
return self.modality == Modality.VIDEO
|
291
272
|
|
292
273
|
def is_valid(self) -> bool:
|
293
274
|
return self.is_image() or self.is_video() or self.is_audio()
|
@@ -307,9 +288,8 @@ class MultimodalDataItem:
|
|
307
288
|
return ret
|
308
289
|
|
309
290
|
def merge(self, other):
|
310
|
-
self.
|
311
|
-
self.
|
312
|
-
self.image_offsets += other.image_offsets
|
291
|
+
self.feature += other.feature
|
292
|
+
self.offsets += other.offsets
|
313
293
|
self.hash = hash((self.hash, other.hash))
|
314
294
|
self.set_pad_value()
|
315
295
|
|
@@ -350,7 +330,6 @@ class MultimodalInputs:
|
|
350
330
|
|
351
331
|
assert isinstance(ret.mm_items, list)
|
352
332
|
ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
|
353
|
-
|
354
333
|
for item in ret.mm_items:
|
355
334
|
item.set_pad_value()
|
356
335
|
|
@@ -527,6 +506,8 @@ class Req:
|
|
527
506
|
self.last_node: Any = None
|
528
507
|
self.last_host_node: Any = None
|
529
508
|
self.host_hit_length = 0
|
509
|
+
# The node to lock until for swa radix tree lock ref
|
510
|
+
self.swa_uuid_for_lock: Optional[int] = None
|
530
511
|
|
531
512
|
# Whether or not if it is chunked. It increments whenever
|
532
513
|
# it is chunked, and decrement whenever chunked request is
|
@@ -745,6 +726,7 @@ class Req:
|
|
745
726
|
def reset_for_retract(self):
|
746
727
|
self.prefix_indices = []
|
747
728
|
self.last_node = None
|
729
|
+
self.swa_uuid_for_lock = None
|
748
730
|
self.extend_input_len = 0
|
749
731
|
self.is_retracted = True
|
750
732
|
self.input_token_logprobs = None
|
@@ -813,6 +795,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
813
795
|
req_to_token_pool: ReqToTokenPool = None
|
814
796
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
|
815
797
|
tree_cache: BasePrefixCache = None
|
798
|
+
is_hybrid: bool = False
|
816
799
|
|
817
800
|
# Batch configs
|
818
801
|
model_config: ModelConfig = None
|
@@ -918,11 +901,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
918
901
|
):
|
919
902
|
return_logprob = any(req.return_logprob for req in reqs)
|
920
903
|
|
904
|
+
is_hybrid = False
|
905
|
+
if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
|
906
|
+
assert isinstance(tree_cache, SWARadixCache) or isinstance(
|
907
|
+
tree_cache, SWAChunkCache
|
908
|
+
), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
|
909
|
+
is_hybrid = True
|
910
|
+
|
921
911
|
return cls(
|
922
912
|
reqs=reqs,
|
923
913
|
req_to_token_pool=req_to_token_pool,
|
924
914
|
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
|
925
915
|
tree_cache=tree_cache,
|
916
|
+
is_hybrid=is_hybrid,
|
926
917
|
model_config=model_config,
|
927
918
|
enable_overlap=enable_overlap,
|
928
919
|
return_logprob=return_logprob,
|
@@ -953,9 +944,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
953
944
|
return req_pool_indices
|
954
945
|
|
955
946
|
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
|
956
|
-
|
957
|
-
if self.tree_cache is not None:
|
958
|
-
self.tree_cache.evict(num_tokens)
|
947
|
+
self._evict_tree_cache_if_needed(num_tokens)
|
959
948
|
|
960
949
|
if backup_state:
|
961
950
|
state = self.token_to_kv_pool_allocator.backup_state()
|
@@ -966,7 +955,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
966
955
|
error_msg = (
|
967
956
|
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
968
957
|
f"Try to allocate {num_tokens} tokens.\n"
|
969
|
-
f"
|
958
|
+
f"{self._available_and_evictable_str()}"
|
970
959
|
)
|
971
960
|
logger.error(error_msg)
|
972
961
|
if self.tree_cache is not None:
|
@@ -986,16 +975,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
986
975
|
extend_num_tokens: int,
|
987
976
|
backup_state: bool = False,
|
988
977
|
):
|
989
|
-
|
990
|
-
|
991
|
-
< extend_num_tokens
|
978
|
+
num_tokens = (
|
979
|
+
extend_num_tokens
|
992
980
|
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
993
|
-
)
|
994
|
-
|
995
|
-
self.tree_cache.evict(
|
996
|
-
extend_num_tokens
|
997
|
-
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
998
|
-
)
|
981
|
+
)
|
982
|
+
self._evict_tree_cache_if_needed(num_tokens)
|
999
983
|
|
1000
984
|
if backup_state:
|
1001
985
|
state = self.token_to_kv_pool_allocator.backup_state()
|
@@ -1007,9 +991,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1007
991
|
error_msg = (
|
1008
992
|
f"Prefill out of memory. Try to lower your batch size.\n"
|
1009
993
|
f"Try to allocate {extend_num_tokens} tokens.\n"
|
1010
|
-
f"
|
1011
|
-
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
1012
|
-
f"{self.tree_cache.evictable_size()=}\n"
|
994
|
+
f"{self._available_and_evictable_str()}"
|
1013
995
|
)
|
1014
996
|
logger.error(error_msg)
|
1015
997
|
raise RuntimeError(error_msg)
|
@@ -1025,14 +1007,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1025
1007
|
last_loc: torch.Tensor,
|
1026
1008
|
backup_state: bool = False,
|
1027
1009
|
):
|
1028
|
-
|
1029
|
-
|
1030
|
-
|
1031
|
-
< len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
1032
|
-
):
|
1033
|
-
self.tree_cache.evict(
|
1034
|
-
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
1035
|
-
)
|
1010
|
+
num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
1011
|
+
|
1012
|
+
self._evict_tree_cache_if_needed(num_tokens)
|
1036
1013
|
|
1037
1014
|
if backup_state:
|
1038
1015
|
state = self.token_to_kv_pool_allocator.backup_state()
|
@@ -1042,9 +1019,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1042
1019
|
error_msg = (
|
1043
1020
|
f"Decode out of memory. Try to lower your batch size.\n"
|
1044
1021
|
f"Try to allocate {len(seq_lens)} tokens.\n"
|
1045
|
-
f"
|
1046
|
-
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
1047
|
-
f"{self.tree_cache.evictable_size()=}\n"
|
1022
|
+
f"{self._available_and_evictable_str()}"
|
1048
1023
|
)
|
1049
1024
|
logger.error(error_msg)
|
1050
1025
|
raise RuntimeError(error_msg)
|
@@ -1181,7 +1156,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1181
1156
|
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
1182
1157
|
)
|
1183
1158
|
if isinstance(self.tree_cache, SWAChunkCache):
|
1184
|
-
self.tree_cache.
|
1159
|
+
self.tree_cache.evict_swa(
|
1185
1160
|
req, pre_len, self.model_config.attention_chunk_size
|
1186
1161
|
)
|
1187
1162
|
|
@@ -1278,11 +1253,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1278
1253
|
if mm_input is None:
|
1279
1254
|
continue
|
1280
1255
|
for mm_item in mm_input.mm_items:
|
1281
|
-
pixel_values = getattr(mm_item, "
|
1256
|
+
pixel_values = getattr(mm_item, "feature", None)
|
1282
1257
|
if isinstance(pixel_values, torch.Tensor):
|
1283
|
-
mm_item.
|
1284
|
-
self.device, non_blocking=True
|
1285
|
-
)
|
1258
|
+
mm_item.feature = pixel_values.to(self.device, non_blocking=True)
|
1286
1259
|
self.multimodal_inputs = multimodal_inputs
|
1287
1260
|
self.token_type_ids = token_type_ids_tensor
|
1288
1261
|
self.seq_lens_sum = sum(seq_lens)
|
@@ -1328,6 +1301,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1328
1301
|
self.model_config.vocab_size,
|
1329
1302
|
)
|
1330
1303
|
|
1304
|
+
def prepare_for_split_prefill(self):
|
1305
|
+
self.prepare_for_extend()
|
1306
|
+
# For split prefill, we need to set the forward mode to SPLIT_PREFILL
|
1307
|
+
self.forward_mode = ForwardMode.SPLIT_PREFILL
|
1308
|
+
|
1331
1309
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
1332
1310
|
self.forward_mode = ForwardMode.MIXED
|
1333
1311
|
running_bs = running_batch.batch_size()
|
@@ -1371,17 +1349,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1371
1349
|
)
|
1372
1350
|
|
1373
1351
|
def check_decode_mem(self, buf_multiplier=1):
|
1374
|
-
|
1352
|
+
num_tokens = (
|
1375
1353
|
self.new_page_count_next_decode()
|
1376
1354
|
* buf_multiplier
|
1377
1355
|
* self.token_to_kv_pool_allocator.page_size
|
1378
1356
|
)
|
1379
|
-
if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
|
1380
|
-
return True
|
1381
|
-
|
1382
|
-
self.tree_cache.evict(tokens_required)
|
1383
1357
|
|
1384
|
-
|
1358
|
+
self._evict_tree_cache_if_needed(num_tokens)
|
1359
|
+
return self._is_available_size_sufficient(num_tokens)
|
1385
1360
|
|
1386
1361
|
def retract_decode(self, server_args: ServerArgs):
|
1387
1362
|
"""Retract the decoding requests when there is not enough memory."""
|
@@ -1414,19 +1389,38 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1414
1389
|
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
1415
1390
|
)
|
1416
1391
|
|
1392
|
+
def _get_available_size():
|
1393
|
+
if self.is_hybrid:
|
1394
|
+
return min(
|
1395
|
+
self.token_to_kv_pool_allocator.full_available_size(),
|
1396
|
+
self.token_to_kv_pool_allocator.swa_available_size(),
|
1397
|
+
)
|
1398
|
+
else:
|
1399
|
+
return self.token_to_kv_pool_allocator.available_size()
|
1400
|
+
|
1417
1401
|
retracted_reqs = []
|
1418
1402
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
1419
1403
|
first_iter = True
|
1420
1404
|
while (
|
1421
|
-
|
1422
|
-
< get_required_tokens(len(sorted_indices))
|
1405
|
+
_get_available_size() < get_required_tokens(len(sorted_indices))
|
1423
1406
|
or first_iter
|
1424
1407
|
):
|
1425
1408
|
if len(sorted_indices) == 1:
|
1426
1409
|
# Corner case: only one request left
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1410
|
+
if self.is_hybrid:
|
1411
|
+
full_available_size = (
|
1412
|
+
self.token_to_kv_pool_allocator.full_available_size()
|
1413
|
+
)
|
1414
|
+
swa_available_size = (
|
1415
|
+
self.token_to_kv_pool_allocator.swa_available_size()
|
1416
|
+
)
|
1417
|
+
assert (
|
1418
|
+
full_available_size > 0 and swa_available_size > 0
|
1419
|
+
), f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}"
|
1420
|
+
else:
|
1421
|
+
assert (
|
1422
|
+
self.token_to_kv_pool_allocator.available_size() > 0
|
1423
|
+
), f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}"
|
1430
1424
|
break
|
1431
1425
|
|
1432
1426
|
first_iter = False
|
@@ -1458,15 +1452,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1458
1452
|
self.req_to_token_pool.free(req.req_pool_idx)
|
1459
1453
|
|
1460
1454
|
# release the last node
|
1461
|
-
self.
|
1455
|
+
if self.is_hybrid:
|
1456
|
+
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
1457
|
+
else:
|
1458
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
1462
1459
|
|
1463
1460
|
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
1464
|
-
|
1465
|
-
|
1466
|
-
- self.token_to_kv_pool_allocator.available_size()
|
1467
|
-
)
|
1468
|
-
residual_size = max(0, residual_size)
|
1469
|
-
self.tree_cache.evict(residual_size)
|
1461
|
+
num_tokens = len(sorted_indices) * global_config.retract_decode_steps
|
1462
|
+
self._evict_tree_cache_if_needed(num_tokens)
|
1470
1463
|
|
1471
1464
|
req.reset_for_retract()
|
1472
1465
|
|
@@ -1559,7 +1552,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1559
1552
|
# free memory
|
1560
1553
|
if isinstance(self.tree_cache, SWAChunkCache):
|
1561
1554
|
for req in self.reqs:
|
1562
|
-
self.tree_cache.
|
1555
|
+
self.tree_cache.evict_swa(
|
1563
1556
|
req, req.seqlen - 1, self.model_config.attention_chunk_size
|
1564
1557
|
)
|
1565
1558
|
|
@@ -1778,6 +1771,53 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1778
1771
|
is_extend_in_batch=self.is_extend_in_batch,
|
1779
1772
|
)
|
1780
1773
|
|
1774
|
+
def _evict_tree_cache_if_needed(
|
1775
|
+
self,
|
1776
|
+
num_tokens: int,
|
1777
|
+
) -> None:
|
1778
|
+
if isinstance(self.tree_cache, SWAChunkCache):
|
1779
|
+
return
|
1780
|
+
|
1781
|
+
if self.is_hybrid:
|
1782
|
+
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
|
1783
|
+
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
|
1784
|
+
|
1785
|
+
if full_available_size < num_tokens or swa_available_size < num_tokens:
|
1786
|
+
if self.tree_cache is not None:
|
1787
|
+
full_num_tokens = max(0, num_tokens - full_available_size)
|
1788
|
+
swa_num_tokens = max(0, num_tokens - swa_available_size)
|
1789
|
+
self.tree_cache.evict(full_num_tokens, swa_num_tokens)
|
1790
|
+
else:
|
1791
|
+
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
|
1792
|
+
if self.tree_cache is not None:
|
1793
|
+
self.tree_cache.evict(num_tokens)
|
1794
|
+
|
1795
|
+
def _is_available_size_sufficient(self, num_tokens: int) -> bool:
|
1796
|
+
if self.is_hybrid:
|
1797
|
+
return (
|
1798
|
+
self.token_to_kv_pool_allocator.full_available_size() >= num_tokens
|
1799
|
+
and self.token_to_kv_pool_allocator.swa_available_size() >= num_tokens
|
1800
|
+
)
|
1801
|
+
else:
|
1802
|
+
return self.token_to_kv_pool_allocator.available_size() >= num_tokens
|
1803
|
+
|
1804
|
+
def _available_and_evictable_str(self) -> str:
|
1805
|
+
if self.is_hybrid:
|
1806
|
+
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
|
1807
|
+
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
|
1808
|
+
full_evictable_size = self.tree_cache.full_evictable_size()
|
1809
|
+
swa_evictable_size = self.tree_cache.swa_evictable_size()
|
1810
|
+
return (
|
1811
|
+
f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
|
1812
|
+
f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
|
1813
|
+
f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n"
|
1814
|
+
f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n"
|
1815
|
+
)
|
1816
|
+
else:
|
1817
|
+
available_size = self.token_to_kv_pool_allocator.available_size()
|
1818
|
+
evictable_size = self.tree_cache.evictable_size()
|
1819
|
+
return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"
|
1820
|
+
|
1781
1821
|
def __str__(self):
|
1782
1822
|
return (
|
1783
1823
|
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
|
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
|
|
25
25
|
import torch
|
26
26
|
|
27
27
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
28
|
+
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
28
29
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
29
30
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
30
31
|
|
@@ -311,21 +312,43 @@ class PrefillAdder:
|
|
311
312
|
]
|
312
313
|
)
|
313
314
|
|
315
|
+
self.is_hybrid = isinstance(
|
316
|
+
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
|
317
|
+
)
|
318
|
+
|
314
319
|
@property
|
315
320
|
def rem_total_tokens(self):
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
+
if self.is_hybrid:
|
322
|
+
available_and_evictable = min(
|
323
|
+
self.token_to_kv_pool_allocator.full_available_size()
|
324
|
+
+ self.tree_cache.full_evictable_size(),
|
325
|
+
self.token_to_kv_pool_allocator.swa_available_size()
|
326
|
+
+ self.tree_cache.swa_evictable_size(),
|
327
|
+
)
|
328
|
+
else:
|
329
|
+
available_and_evictable = (
|
330
|
+
self.token_to_kv_pool_allocator.available_size()
|
331
|
+
+ self.tree_cache.evictable_size()
|
332
|
+
)
|
333
|
+
|
334
|
+
return available_and_evictable - self.rem_total_token_offset
|
321
335
|
|
322
336
|
@property
|
323
337
|
def cur_rem_tokens(self):
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
338
|
+
if self.is_hybrid:
|
339
|
+
available_and_evictable = min(
|
340
|
+
self.token_to_kv_pool_allocator.full_available_size()
|
341
|
+
+ self.tree_cache.full_evictable_size(),
|
342
|
+
self.token_to_kv_pool_allocator.swa_available_size()
|
343
|
+
+ self.tree_cache.swa_evictable_size(),
|
344
|
+
)
|
345
|
+
else:
|
346
|
+
available_and_evictable = (
|
347
|
+
self.token_to_kv_pool_allocator.available_size()
|
348
|
+
+ self.tree_cache.evictable_size()
|
349
|
+
)
|
350
|
+
|
351
|
+
return available_and_evictable - self.cur_rem_token_offset
|
329
352
|
|
330
353
|
def ceil_paged_tokens(self, tokens: int) -> int:
|
331
354
|
return -(-tokens // self.page_size) * self.page_size
|
@@ -376,11 +399,18 @@ class PrefillAdder:
|
|
376
399
|
|
377
400
|
@contextmanager
|
378
401
|
def _lock_node(self, last_node: TreeNode):
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
402
|
+
if self.is_hybrid:
|
403
|
+
try:
|
404
|
+
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(last_node)
|
405
|
+
yield None
|
406
|
+
finally:
|
407
|
+
self.tree_cache.dec_lock_ref(last_node, swa_uuid_for_lock)
|
408
|
+
else:
|
409
|
+
try:
|
410
|
+
self.tree_cache.inc_lock_ref(last_node)
|
411
|
+
yield None
|
412
|
+
finally:
|
413
|
+
self.tree_cache.dec_lock_ref(last_node)
|
384
414
|
|
385
415
|
def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
|
386
416
|
# Early exit if no enough tokens for the input tokens
|
@@ -422,16 +452,19 @@ class PrefillAdder:
|
|
422
452
|
else:
|
423
453
|
add_req_state(req, insert_sort=True)
|
424
454
|
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
455
|
+
if not self.is_hybrid:
|
456
|
+
# Skip this logic for swa. The SWA has different memory management, and
|
457
|
+
# this mechanism is underestimating the memory usage.
|
458
|
+
cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
|
459
|
+
tokens_freed = 0
|
460
|
+
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
461
|
+
# tokens_left gives a reservative calculation as the last token is not stored
|
462
|
+
bs = len(self.req_states) - i
|
463
|
+
min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
|
464
|
+
# reserve tokens for corner cases
|
465
|
+
if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
|
466
|
+
return AddReqResult.NO_TOKEN
|
467
|
+
tokens_freed += tokens_occupied
|
435
468
|
|
436
469
|
if (
|
437
470
|
self.rem_chunk_tokens is None # chunked prefill is disabled
|
@@ -499,7 +532,11 @@ class PrefillAdder:
|
|
499
532
|
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
|
500
533
|
# Non-chunked prefill
|
501
534
|
self.can_run_list.append(req)
|
502
|
-
self.
|
535
|
+
if self.is_hybrid:
|
536
|
+
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node)
|
537
|
+
req.swa_uuid_for_lock = swa_uuid_for_lock
|
538
|
+
else:
|
539
|
+
self.tree_cache.inc_lock_ref(req.last_node)
|
503
540
|
self._update_prefill_budget(
|
504
541
|
prefix_len,
|
505
542
|
input_tokens,
|
@@ -520,7 +557,11 @@ class PrefillAdder:
|
|
520
557
|
|
521
558
|
self.can_run_list.append(req)
|
522
559
|
self.new_chunked_req = req
|
523
|
-
self.
|
560
|
+
if self.is_hybrid:
|
561
|
+
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node)
|
562
|
+
req.swa_uuid_for_lock = swa_uuid_for_lock
|
563
|
+
else:
|
564
|
+
self.tree_cache.inc_lock_ref(req.last_node)
|
524
565
|
self._update_prefill_budget(prefix_len, trunc_len, 0)
|
525
566
|
|
526
567
|
return self.budget_state()
|