sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
sglang/srt/lora/mem_pool.py
CHANGED
@@ -13,9 +13,9 @@ from sglang.srt.lora.utils import (
|
|
13
13
|
ROW_PARALLELISM_LINEAR_LORA_NAMES,
|
14
14
|
LoRAType,
|
15
15
|
get_hidden_dim,
|
16
|
-
|
16
|
+
get_normalized_target_modules,
|
17
17
|
get_stacked_multiply,
|
18
|
-
|
18
|
+
get_target_module_name,
|
19
19
|
)
|
20
20
|
|
21
21
|
logger = logging.getLogger(__name__)
|
@@ -52,7 +52,7 @@ class LoRAMemoryPool:
|
|
52
52
|
tp_size: int,
|
53
53
|
tp_rank: int,
|
54
54
|
max_lora_rank: int,
|
55
|
-
|
55
|
+
target_modules: Set[str],
|
56
56
|
base_model: torch.nn.Module,
|
57
57
|
):
|
58
58
|
self.base_hf_config: AutoConfig = base_hf_config
|
@@ -62,7 +62,7 @@ class LoRAMemoryPool:
|
|
62
62
|
self.tp_size: int = tp_size
|
63
63
|
self.tp_rank: int = tp_rank
|
64
64
|
self.max_lora_rank: int = max_lora_rank
|
65
|
-
self.
|
65
|
+
self.target_modules: Set[str] = target_modules
|
66
66
|
|
67
67
|
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
68
68
|
# A_buffer contains num_layer number of row-major tensors with shape
|
@@ -95,8 +95,8 @@ class LoRAMemoryPool:
|
|
95
95
|
"""
|
96
96
|
if config.r > self.max_lora_rank:
|
97
97
|
return False
|
98
|
-
|
99
|
-
return
|
98
|
+
target_module_names = get_normalized_target_modules(config.target_modules)
|
99
|
+
return target_module_names.issubset(self.target_modules)
|
100
100
|
|
101
101
|
if isinstance(config, LoRAConfig):
|
102
102
|
return _can_support(config)
|
@@ -139,10 +139,10 @@ class LoRAMemoryPool:
|
|
139
139
|
|
140
140
|
def init_buffer(
|
141
141
|
buffer: Dict[str, List[torch.Tensor]],
|
142
|
-
|
142
|
+
target_modules: Set[str],
|
143
143
|
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
144
144
|
):
|
145
|
-
for module_name in
|
145
|
+
for module_name in target_modules:
|
146
146
|
lora_shape = get_lora_shape_fn(
|
147
147
|
module_name, base_model, self.max_lora_rank
|
148
148
|
)
|
@@ -157,13 +157,13 @@ class LoRAMemoryPool:
|
|
157
157
|
|
158
158
|
init_buffer(
|
159
159
|
self.A_buffer,
|
160
|
-
self.
|
160
|
+
self.target_modules,
|
161
161
|
self.get_lora_A_shape,
|
162
162
|
)
|
163
163
|
|
164
164
|
init_buffer(
|
165
165
|
self.B_buffer,
|
166
|
-
self.
|
166
|
+
self.target_modules,
|
167
167
|
self.get_lora_B_shape,
|
168
168
|
)
|
169
169
|
|
@@ -242,32 +242,34 @@ class LoRAMemoryPool:
|
|
242
242
|
for layer_id in range(self.num_layer):
|
243
243
|
layer_weights = lora_adapter.layers[layer_id].weights
|
244
244
|
temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
|
245
|
-
|
245
|
+
target_module: None for target_module in self.A_buffer
|
246
246
|
}
|
247
247
|
temp_B_buffer: Dict[str, Optional[torch.Tensor]] = {
|
248
|
-
|
248
|
+
target_module: None for target_module in self.B_buffer
|
249
249
|
}
|
250
250
|
for name, weights in layer_weights.items():
|
251
|
-
|
251
|
+
target_module = get_target_module_name(name, self.target_modules)
|
252
252
|
if "lora_A" in name:
|
253
|
-
temp_A_buffer[
|
253
|
+
temp_A_buffer[target_module] = weights
|
254
254
|
else:
|
255
|
-
temp_B_buffer[
|
255
|
+
temp_B_buffer[target_module] = weights
|
256
256
|
|
257
257
|
if self.tp_size > 1:
|
258
258
|
cur_layer_modules = lora_modules[layer_id]
|
259
259
|
for module_name, module in cur_layer_modules.items():
|
260
|
-
|
260
|
+
target_module = get_target_module_name(
|
261
|
+
module_name, self.target_modules
|
262
|
+
)
|
261
263
|
|
262
|
-
if temp_A_buffer[
|
264
|
+
if temp_A_buffer[target_module] is None:
|
263
265
|
# Skip weight slicing if the weight is not present in the adapter
|
264
266
|
continue
|
265
267
|
|
266
|
-
temp_A_buffer[
|
267
|
-
temp_A_buffer[
|
268
|
+
temp_A_buffer[target_module] = module.slice_lora_a_weights(
|
269
|
+
temp_A_buffer[target_module], self.tp_rank
|
268
270
|
)
|
269
|
-
temp_B_buffer[
|
270
|
-
temp_B_buffer[
|
271
|
+
temp_B_buffer[target_module] = module.slice_lora_b_weights(
|
272
|
+
temp_B_buffer[target_module], self.tp_rank
|
271
273
|
)
|
272
274
|
|
273
275
|
for name, weights in temp_A_buffer.items():
|
@@ -282,12 +284,12 @@ class LoRAMemoryPool:
|
|
282
284
|
load_lora_weight_tensor(buffer_view, weights)
|
283
285
|
|
284
286
|
def get_tensor(
|
285
|
-
self,
|
287
|
+
self, target_module: str, layer_id: int, lora_type: LoRAType
|
286
288
|
) -> torch.Tensor:
|
287
289
|
if lora_type == LoRAType.LORA_A:
|
288
|
-
return self.A_buffer[
|
290
|
+
return self.A_buffer[target_module][layer_id]
|
289
291
|
|
290
|
-
return self.B_buffer[
|
292
|
+
return self.B_buffer[target_module][layer_id]
|
291
293
|
|
292
294
|
def get_buffer_id(self, lora_uid: str):
|
293
295
|
return self.uid_to_buffer_id[lora_uid]
|
sglang/srt/lora/utils.py
CHANGED
@@ -84,7 +84,7 @@ def get_hidden_dim(
|
|
84
84
|
raise NotImplementedError()
|
85
85
|
|
86
86
|
|
87
|
-
def
|
87
|
+
def get_normalized_target_modules(
|
88
88
|
target_modules: Iterable[str],
|
89
89
|
) -> set[str]:
|
90
90
|
"""
|
@@ -100,8 +100,8 @@ def get_normalized_lora_weight_names(
|
|
100
100
|
|
101
101
|
result = set()
|
102
102
|
for name in target_modules:
|
103
|
-
|
104
|
-
result.add(
|
103
|
+
normalized_name = params_mapping.get(name, name)
|
104
|
+
result.add(normalized_name)
|
105
105
|
return result
|
106
106
|
|
107
107
|
|
@@ -116,20 +116,18 @@ def get_stacked_multiply(module_name: str) -> int:
|
|
116
116
|
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
117
117
|
|
118
118
|
|
119
|
-
def
|
120
|
-
target_name: str, lora_weight_names: Tuple[Set[str]]
|
121
|
-
) -> Optional[str]:
|
119
|
+
def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> str:
|
122
120
|
"""
|
123
|
-
Get the
|
121
|
+
Get the target module name in target_modules that can match full_module_name.
|
124
122
|
|
125
|
-
If there is a
|
123
|
+
If there is a target module name in target_modules that can match full_module_name, return this name
|
126
124
|
Else raise ValueError.
|
127
125
|
"""
|
128
|
-
for
|
129
|
-
if
|
130
|
-
return
|
126
|
+
for target_module in target_modules:
|
127
|
+
if target_module in full_module_name:
|
128
|
+
return target_module
|
131
129
|
raise ValueError(
|
132
|
-
f"Cannot find
|
130
|
+
f"Cannot find target module name for {full_module_name} in {target_modules}"
|
133
131
|
)
|
134
132
|
|
135
133
|
|
@@ -26,6 +26,8 @@ if TYPE_CHECKING:
|
|
26
26
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
27
27
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
28
28
|
|
29
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
30
|
+
from sglang.srt.mem_cache.memory_pool_host import MLATokenToKVPoolHost
|
29
31
|
|
30
32
|
logger = logging.getLogger(__name__)
|
31
33
|
|
@@ -238,13 +240,14 @@ class HiCacheController:
|
|
238
240
|
self.io_backend = io_backend
|
239
241
|
|
240
242
|
self.enable_storage = False
|
243
|
+
self.is_mla = isinstance(self.mem_pool_host, MLATokenToKVPoolHost)
|
241
244
|
# todo: move backend initialization to storage backend module
|
242
245
|
if storage_backend is not None:
|
243
246
|
self.storage_backend_type = storage_backend
|
244
247
|
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
245
248
|
|
246
249
|
if storage_backend == "file":
|
247
|
-
self.storage_backend = HiCacheFile()
|
250
|
+
self.storage_backend = HiCacheFile(is_mla=self.is_mla)
|
248
251
|
self.get_hash_str = get_hash_str
|
249
252
|
elif storage_backend == "nixl":
|
250
253
|
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
@@ -257,23 +260,26 @@ class HiCacheController:
|
|
257
260
|
get_hash_str_mooncake,
|
258
261
|
)
|
259
262
|
|
260
|
-
self.storage_backend = MooncakeStore()
|
263
|
+
self.storage_backend = MooncakeStore(is_mla=self.is_mla)
|
261
264
|
self.get_hash_str = get_hash_str_mooncake
|
262
265
|
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
263
266
|
assert self.mem_pool_host.layout == "page_first"
|
264
267
|
elif storage_backend == "hf3fs":
|
265
|
-
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
266
268
|
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
267
269
|
HiCacheHF3FS,
|
268
270
|
)
|
269
271
|
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
272
|
+
if self.mem_pool_host.layout == "page_first":
|
273
|
+
bytes_per_page = (
|
274
|
+
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
|
275
|
+
)
|
276
|
+
elif self.mem_pool_host.layout == "layer_first":
|
277
|
+
bytes_per_page = (
|
278
|
+
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
|
279
|
+
)
|
274
280
|
dtype = mem_pool_host.dtype
|
275
281
|
self.storage_backend = HiCacheHF3FS.from_env_config(
|
276
|
-
|
282
|
+
bytes_per_page, dtype
|
277
283
|
)
|
278
284
|
self.get_hash_str = get_hash_str
|
279
285
|
else:
|
@@ -296,6 +302,9 @@ class HiCacheController:
|
|
296
302
|
self.prefetch_tp_group = torch.distributed.new_group(
|
297
303
|
group_ranks, backend="gloo"
|
298
304
|
)
|
305
|
+
self.prefetch_io_tp_group = torch.distributed.new_group(
|
306
|
+
group_ranks, backend="gloo"
|
307
|
+
)
|
299
308
|
self.backup_tp_group = torch.distributed.new_group(
|
300
309
|
group_ranks, backend="gloo"
|
301
310
|
)
|
@@ -391,6 +400,15 @@ class HiCacheController:
|
|
391
400
|
self.prefetch_thread.start()
|
392
401
|
self.backup_thread.start()
|
393
402
|
|
403
|
+
@property
|
404
|
+
def backup_skip(self):
|
405
|
+
return (
|
406
|
+
self.is_mla
|
407
|
+
and get_tensor_model_parallel_rank() != 0
|
408
|
+
# todo: only support file and mooncake
|
409
|
+
and self.storage_backend_type in ["file", "mooncake"]
|
410
|
+
)
|
411
|
+
|
394
412
|
def write(
|
395
413
|
self,
|
396
414
|
device_indices: torch.Tensor,
|
@@ -552,13 +570,34 @@ class HiCacheController:
|
|
552
570
|
operation.mark_done()
|
553
571
|
return operation.completed_tokens, operation.hash_value
|
554
572
|
|
573
|
+
def zerocopy_page_transfer(self, operation, batch_size=8):
|
574
|
+
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
575
|
+
operation.hash_value, operation.host_indices
|
576
|
+
)
|
577
|
+
for i in range(0, len(hashes), batch_size):
|
578
|
+
page_hashes = hashes[i : i + batch_size]
|
579
|
+
page_dsts = dsts[i : i + batch_size]
|
580
|
+
page_data = self.storage_backend.batch_get(page_hashes, page_dsts)
|
581
|
+
if page_data is None:
|
582
|
+
logger.warning(
|
583
|
+
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
|
584
|
+
)
|
585
|
+
break
|
586
|
+
completed_tokens = operation.completed_tokens
|
587
|
+
if operation.increment(self.page_size * len(page_hashes)):
|
588
|
+
for i in range(len(page_hashes)):
|
589
|
+
completed_tokens += self.page_size
|
590
|
+
else:
|
591
|
+
break
|
592
|
+
|
555
593
|
def generic_page_transfer(self, operation, batch_size=8):
|
556
594
|
for i in range(0, len(operation.hash_value), batch_size):
|
557
595
|
page_hashes = operation.hash_value[i : i + batch_size]
|
558
596
|
# todo: zero copy
|
559
|
-
dummy_page_dst = [
|
560
|
-
|
561
|
-
|
597
|
+
dummy_page_dst = [
|
598
|
+
self.mem_pool_host.get_dummy_flat_data_page()
|
599
|
+
for _ in range(len(page_hashes))
|
600
|
+
]
|
562
601
|
page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
|
563
602
|
if page_data is None:
|
564
603
|
logger.warning(
|
@@ -596,13 +635,16 @@ class HiCacheController:
|
|
596
635
|
if self.is_mooncake_backend():
|
597
636
|
self.mooncake_page_transfer(operation)
|
598
637
|
elif self.storage_backend_type == "hf3fs":
|
599
|
-
self.
|
638
|
+
if self.mem_pool_host.layout == "page_first":
|
639
|
+
self.zerocopy_page_transfer(operation, batch_size=128)
|
640
|
+
elif self.mem_pool_host.layout == "layer_first":
|
641
|
+
self.generic_page_transfer(operation, batch_size=128)
|
600
642
|
else:
|
601
643
|
self.generic_page_transfer(operation)
|
602
644
|
|
603
645
|
if self.tp_world_size > 1:
|
604
646
|
# to ensure all TP workers release the host memory at the same time
|
605
|
-
torch.distributed.barrier(group=self.
|
647
|
+
torch.distributed.barrier(group=self.prefetch_io_tp_group)
|
606
648
|
# operation terminated by controller, release pre-allocated memory
|
607
649
|
self.mem_pool_host.free(
|
608
650
|
operation.host_indices[operation.completed_tokens :]
|
@@ -713,6 +755,19 @@ class HiCacheController:
|
|
713
755
|
self.backup_queue.put(operation)
|
714
756
|
return operation.id
|
715
757
|
|
758
|
+
def zerocopy_page_backup(self, operation, batch_size=8):
|
759
|
+
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
760
|
+
operation.hash_value, operation.host_indices
|
761
|
+
)
|
762
|
+
for i in range(0, len(hashes), batch_size):
|
763
|
+
page_hashes = hashes[i : i + batch_size]
|
764
|
+
page_data = dsts[i : i + batch_size]
|
765
|
+
success = self.storage_backend.batch_set(page_hashes, page_data)
|
766
|
+
if not success:
|
767
|
+
logger.warning(f"Failed to write page {page_hashes} to storage.")
|
768
|
+
break
|
769
|
+
operation.completed_tokens += self.page_size * len(page_hashes)
|
770
|
+
|
716
771
|
def generic_page_backup(self, operation, batch_size=8):
|
717
772
|
for i in range(0, len(operation.hash_value), batch_size):
|
718
773
|
page_hashes = operation.hash_value[i : i + batch_size]
|
@@ -764,14 +819,20 @@ class HiCacheController:
|
|
764
819
|
if operation is None:
|
765
820
|
continue
|
766
821
|
|
767
|
-
if self.
|
768
|
-
self.
|
769
|
-
|
770
|
-
self.
|
822
|
+
if not self.backup_skip:
|
823
|
+
if self.is_mooncake_backend():
|
824
|
+
self.mooncake_page_backup(operation)
|
825
|
+
elif self.storage_backend_type == "hf3fs":
|
826
|
+
if self.mem_pool_host.layout == "page_first":
|
827
|
+
self.zerocopy_page_backup(operation, batch_size=128)
|
828
|
+
elif self.mem_pool_host.layout == "layer_first":
|
829
|
+
self.generic_page_backup(operation, batch_size=128)
|
830
|
+
else:
|
831
|
+
self.generic_page_backup(operation)
|
832
|
+
min_completed_tokens = operation.completed_tokens
|
771
833
|
else:
|
772
|
-
|
834
|
+
min_completed_tokens = len(operation.token_ids)
|
773
835
|
|
774
|
-
min_completed_tokens = operation.completed_tokens
|
775
836
|
if self.tp_world_size > 1:
|
776
837
|
completed_tokens_tensor = torch.tensor(
|
777
838
|
min_completed_tokens, dtype=torch.int
|
@@ -31,10 +31,12 @@ from sglang.srt.managers.io_struct import (
|
|
31
31
|
BatchMultimodalOut,
|
32
32
|
BatchStrOut,
|
33
33
|
BatchTokenIDOut,
|
34
|
+
FreezeGCReq,
|
34
35
|
)
|
35
36
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
36
37
|
from sglang.srt.utils import (
|
37
38
|
configure_logger,
|
39
|
+
freeze_gc,
|
38
40
|
get_zmq_socket,
|
39
41
|
kill_itself_when_parent_died,
|
40
42
|
)
|
@@ -100,6 +102,7 @@ class DetokenizerManager:
|
|
100
102
|
(BatchEmbeddingOut, self.handle_batch_embedding_out),
|
101
103
|
(BatchTokenIDOut, self.handle_batch_token_id_out),
|
102
104
|
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
105
|
+
(FreezeGCReq, self.handle_freeze_gc_req),
|
103
106
|
]
|
104
107
|
)
|
105
108
|
|
@@ -108,7 +111,8 @@ class DetokenizerManager:
|
|
108
111
|
while True:
|
109
112
|
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
110
113
|
output = self._request_dispatcher(recv_obj)
|
111
|
-
|
114
|
+
if output is not None:
|
115
|
+
self.send_to_tokenizer.send_pyobj(output)
|
112
116
|
|
113
117
|
def trim_matched_stop(
|
114
118
|
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
@@ -216,7 +220,7 @@ class DetokenizerManager:
|
|
216
220
|
rids=recv_obj.rids,
|
217
221
|
finished_reasons=recv_obj.finished_reasons,
|
218
222
|
output_strs=output_strs,
|
219
|
-
output_ids=recv_obj.
|
223
|
+
output_ids=recv_obj.decode_ids,
|
220
224
|
prompt_tokens=recv_obj.prompt_tokens,
|
221
225
|
completion_tokens=recv_obj.completion_tokens,
|
222
226
|
cached_tokens=recv_obj.cached_tokens,
|
@@ -247,6 +251,10 @@ class DetokenizerManager:
|
|
247
251
|
cached_tokens=recv_obj.cached_tokens,
|
248
252
|
)
|
249
253
|
|
254
|
+
def handle_freeze_gc_req(self, recv_req: FreezeGCReq):
|
255
|
+
freeze_gc("Detokenizer Manager")
|
256
|
+
return None
|
257
|
+
|
250
258
|
|
251
259
|
class LimitedCapacityDict(OrderedDict):
|
252
260
|
def __init__(self, capacity: int, *args, **kwargs):
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -612,6 +612,8 @@ class EmbeddingReqInput:
|
|
612
612
|
|
613
613
|
if self.sampling_params is None:
|
614
614
|
self.sampling_params = [{}] * self.batch_size
|
615
|
+
elif isinstance(self.sampling_params, dict):
|
616
|
+
self.sampling_params = [self.sampling_params] * self.batch_size
|
615
617
|
for i in range(self.batch_size):
|
616
618
|
self.sampling_params[i]["max_new_tokens"] = 0
|
617
619
|
|
@@ -660,6 +662,8 @@ class TokenizedEmbeddingReqInput:
|
|
660
662
|
token_type_ids: List[int]
|
661
663
|
# Dummy sampling params for compatibility
|
662
664
|
sampling_params: SamplingParams
|
665
|
+
# For data parallel rank routing
|
666
|
+
data_parallel_rank: Optional[int] = None
|
663
667
|
# For dp balance
|
664
668
|
dp_balance_id: int = -1
|
665
669
|
|
@@ -798,6 +802,8 @@ class UpdateWeightFromDiskReqInput:
|
|
798
802
|
load_format: Optional[str] = None
|
799
803
|
# Whether to abort all requests before updating weights
|
800
804
|
abort_all_requests: bool = False
|
805
|
+
# Optional: Update weight version along with weights
|
806
|
+
weight_version: Optional[str] = None
|
801
807
|
|
802
808
|
|
803
809
|
@dataclass
|
@@ -819,6 +825,8 @@ class UpdateWeightsFromDistributedReqInput:
|
|
819
825
|
flush_cache: bool = True
|
820
826
|
# Whether to abort all requests before updating weights
|
821
827
|
abort_all_requests: bool = False
|
828
|
+
# Optional: Update weight version along with weights
|
829
|
+
weight_version: Optional[str] = None
|
822
830
|
|
823
831
|
|
824
832
|
@dataclass
|
@@ -842,6 +850,8 @@ class UpdateWeightsFromTensorReqInput:
|
|
842
850
|
flush_cache: bool = True
|
843
851
|
# Whether to abort all requests before updating weights
|
844
852
|
abort_all_requests: bool = False
|
853
|
+
# Optional: Update weight version along with weights
|
854
|
+
weight_version: Optional[str] = None
|
845
855
|
|
846
856
|
|
847
857
|
@dataclass
|
@@ -872,6 +882,14 @@ class InitWeightsUpdateGroupReqOutput:
|
|
872
882
|
message: str
|
873
883
|
|
874
884
|
|
885
|
+
@dataclass
|
886
|
+
class UpdateWeightVersionReqInput:
|
887
|
+
# The new weight version
|
888
|
+
new_version: str
|
889
|
+
# Whether to abort all running requests before updating
|
890
|
+
abort_all_requests: bool = True
|
891
|
+
|
892
|
+
|
875
893
|
@dataclass
|
876
894
|
class GetWeightsByNameReqInput:
|
877
895
|
name: str
|
@@ -987,6 +1005,11 @@ class ProfileReqOutput:
|
|
987
1005
|
message: str
|
988
1006
|
|
989
1007
|
|
1008
|
+
@dataclass
|
1009
|
+
class FreezeGCReq:
|
1010
|
+
pass
|
1011
|
+
|
1012
|
+
|
990
1013
|
@dataclass
|
991
1014
|
class ConfigureLoggingReq:
|
992
1015
|
log_requests: Optional[bool] = None
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -560,7 +560,7 @@ def embed_mm_inputs(
|
|
560
560
|
]
|
561
561
|
items_size[i + 1] = len(mm_items)
|
562
562
|
items_offsets.append(
|
563
|
-
flatten_nested_list([item.offsets for item in
|
563
|
+
flatten_nested_list([item.offsets for item in mm_items])
|
564
564
|
)
|
565
565
|
items_size = torch.cumsum(items_size, dim=0).tolist()
|
566
566
|
|
@@ -52,6 +52,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|
52
52
|
ScheduleBatchDisaggregationDecodeMixin,
|
53
53
|
)
|
54
54
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
55
|
+
from sglang.srt.layers.moe import is_tbo_enabled
|
55
56
|
from sglang.srt.mem_cache.allocator import (
|
56
57
|
BaseTokenToKVPoolAllocator,
|
57
58
|
SWATokenToKVPoolAllocator,
|
@@ -83,19 +84,13 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
83
84
|
"chunked_prefill_size",
|
84
85
|
"device",
|
85
86
|
"disable_chunked_prefix_cache",
|
87
|
+
"disable_flashinfer_cutlass_moe_fp4_allgather",
|
86
88
|
"disable_radix_cache",
|
87
|
-
"enable_dp_attention",
|
88
|
-
"enable_two_batch_overlap",
|
89
|
-
"tbo_token_distribution_threshold",
|
90
89
|
"enable_dp_lm_head",
|
91
|
-
"
|
92
|
-
"deepep_mode",
|
93
|
-
"enable_flashinfer_cutlass_moe",
|
94
|
-
"enable_flashinfer_trtllm_moe",
|
90
|
+
"flashinfer_mxfp4_moe_precision",
|
95
91
|
"enable_flashinfer_allreduce_fusion",
|
96
92
|
"moe_dense_tp_size",
|
97
93
|
"ep_dispatch_algorithm",
|
98
|
-
"deepep_config",
|
99
94
|
"ep_num_redundant_experts",
|
100
95
|
"enable_nan_detection",
|
101
96
|
"flashinfer_mla_disable_ragged",
|
@@ -108,11 +103,11 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
108
103
|
"triton_attention_reduce_in_fp32",
|
109
104
|
"num_reserved_decode_tokens",
|
110
105
|
"weight_loader_disable_mmap",
|
111
|
-
"enable_triton_kernel_moe",
|
112
|
-
"enable_flashinfer_mxfp4_moe",
|
113
106
|
"enable_multimodal",
|
114
107
|
"enable_symm_mem",
|
115
108
|
"quantization",
|
109
|
+
"enable_custom_logit_processor",
|
110
|
+
"disaggregation_mode",
|
116
111
|
]
|
117
112
|
|
118
113
|
# Put some global args for easy access
|
@@ -909,12 +904,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
909
904
|
spec_algorithm: SpeculativeAlgorithm = None
|
910
905
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
|
911
906
|
|
912
|
-
# Enable custom logit processor
|
913
|
-
enable_custom_logit_processor: bool = False
|
914
|
-
|
915
907
|
# Whether to return hidden states
|
916
908
|
return_hidden_states: bool = False
|
917
909
|
|
910
|
+
# Whether this batch is prefill-only (no token generation needed)
|
911
|
+
is_prefill_only: bool = False
|
912
|
+
|
918
913
|
# hicache pointer for synchronizing data loading from CPU to GPU
|
919
914
|
hicache_consumer_index: int = 0
|
920
915
|
|
@@ -928,7 +923,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
928
923
|
model_config: ModelConfig,
|
929
924
|
enable_overlap: bool,
|
930
925
|
spec_algorithm: SpeculativeAlgorithm,
|
931
|
-
enable_custom_logit_processor: bool,
|
932
926
|
chunked_req: Optional[Req] = None,
|
933
927
|
):
|
934
928
|
return_logprob = any(req.return_logprob for req in reqs)
|
@@ -955,8 +949,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
955
949
|
has_grammar=any(req.grammar for req in reqs),
|
956
950
|
device=req_to_token_pool.device,
|
957
951
|
spec_algorithm=spec_algorithm,
|
958
|
-
enable_custom_logit_processor=enable_custom_logit_processor,
|
959
952
|
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
953
|
+
is_prefill_only=all(
|
954
|
+
req.sampling_params.max_new_tokens == 0 for req in reqs
|
955
|
+
),
|
960
956
|
chunked_req=chunked_req,
|
961
957
|
)
|
962
958
|
|
@@ -1009,6 +1005,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1009
1005
|
extend_num_tokens: int,
|
1010
1006
|
backup_state: bool = False,
|
1011
1007
|
):
|
1008
|
+
# Over estimate the number of tokens: assume each request needs a new page.
|
1012
1009
|
num_tokens = (
|
1013
1010
|
extend_num_tokens
|
1014
1011
|
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
@@ -1041,8 +1038,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1041
1038
|
last_loc: torch.Tensor,
|
1042
1039
|
backup_state: bool = False,
|
1043
1040
|
):
|
1041
|
+
# Over estimate the number of tokens: assume each request needs a new page.
|
1044
1042
|
num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
1045
|
-
|
1046
1043
|
self._evict_tree_cache_if_needed(num_tokens)
|
1047
1044
|
|
1048
1045
|
if backup_state:
|
@@ -1721,38 +1718,18 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1721
1718
|
extend_prefix_lens = self.prefix_lens
|
1722
1719
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
1723
1720
|
|
1724
|
-
if self.forward_mode.is_decode_or_idle():
|
1725
|
-
attention_backend_str = global_server_args_dict["decode_attention_backend"]
|
1726
|
-
else:
|
1727
|
-
attention_backend_str = global_server_args_dict["prefill_attention_backend"]
|
1728
|
-
# Create seq_lens_cpu when needed
|
1729
|
-
if (
|
1730
|
-
attention_backend_str
|
1731
|
-
in [
|
1732
|
-
"fa3",
|
1733
|
-
"flashinfer",
|
1734
|
-
"flashmla",
|
1735
|
-
"cutlass_mla",
|
1736
|
-
"ascend",
|
1737
|
-
"trtllm_mha",
|
1738
|
-
"aiter",
|
1739
|
-
]
|
1740
|
-
or global_server_args_dict["enable_two_batch_overlap"]
|
1741
|
-
):
|
1742
|
-
seq_lens_cpu = (
|
1743
|
-
seq_lens_cpu_cache
|
1744
|
-
if seq_lens_cpu_cache is not None
|
1745
|
-
else self.seq_lens.cpu()
|
1746
|
-
)
|
1747
|
-
else:
|
1748
|
-
seq_lens_cpu = None
|
1749
|
-
|
1750
1721
|
if self.sampling_info:
|
1751
1722
|
if self.has_grammar:
|
1752
1723
|
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
1753
1724
|
else:
|
1754
1725
|
self.sampling_info.grammars = None
|
1755
1726
|
|
1727
|
+
seq_lens_cpu = (
|
1728
|
+
seq_lens_cpu_cache
|
1729
|
+
if seq_lens_cpu_cache is not None
|
1730
|
+
else self.seq_lens.cpu()
|
1731
|
+
)
|
1732
|
+
|
1756
1733
|
global bid
|
1757
1734
|
bid += 1
|
1758
1735
|
return ModelWorkerBatch(
|
@@ -1815,18 +1792,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1815
1792
|
return_logprob=self.return_logprob,
|
1816
1793
|
decoding_reqs=self.decoding_reqs,
|
1817
1794
|
spec_algorithm=self.spec_algorithm,
|
1818
|
-
enable_custom_logit_processor=self.enable_custom_logit_processor,
|
1819
1795
|
global_num_tokens=self.global_num_tokens,
|
1820
1796
|
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
1821
1797
|
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
1822
1798
|
is_extend_in_batch=self.is_extend_in_batch,
|
1799
|
+
is_prefill_only=self.is_prefill_only,
|
1823
1800
|
)
|
1824
1801
|
|
1825
|
-
def _evict_tree_cache_if_needed(
|
1826
|
-
self,
|
1827
|
-
num_tokens: int,
|
1828
|
-
) -> None:
|
1829
|
-
if isinstance(self.tree_cache, SWAChunkCache):
|
1802
|
+
def _evict_tree_cache_if_needed(self, num_tokens: int):
|
1803
|
+
if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
|
1830
1804
|
return
|
1831
1805
|
|
1832
1806
|
if self.is_hybrid:
|