sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -169,12 +169,13 @@ class StorageOperation:
|
|
169
169
|
host_indices: torch.Tensor,
|
170
170
|
token_ids: List[int],
|
171
171
|
last_hash: Optional[str] = None,
|
172
|
+
hash_value: Optional[List[str]] = None,
|
172
173
|
):
|
173
174
|
self.host_indices = host_indices
|
174
175
|
self.token_ids = token_ids
|
175
176
|
self.last_hash = last_hash
|
176
177
|
self.completed_tokens = 0
|
177
|
-
self.hash_value = []
|
178
|
+
self.hash_value = hash_value if hash_value is not None else []
|
178
179
|
|
179
180
|
self.id = StorageOperation.counter
|
180
181
|
StorageOperation.counter += 1
|
@@ -259,6 +260,7 @@ class HiCacheController:
|
|
259
260
|
self.storage_backend = MooncakeStore()
|
260
261
|
self.get_hash_str = get_hash_str_mooncake
|
261
262
|
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
263
|
+
assert self.mem_pool_host.layout == "page_first"
|
262
264
|
elif storage_backend == "hf3fs":
|
263
265
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
264
266
|
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
@@ -294,6 +296,9 @@ class HiCacheController:
|
|
294
296
|
self.prefetch_tp_group = torch.distributed.new_group(
|
295
297
|
group_ranks, backend="gloo"
|
296
298
|
)
|
299
|
+
self.prefetch_io_tp_group = torch.distributed.new_group(
|
300
|
+
group_ranks, backend="gloo"
|
301
|
+
)
|
297
302
|
self.backup_tp_group = torch.distributed.new_group(
|
298
303
|
group_ranks, backend="gloo"
|
299
304
|
)
|
@@ -433,7 +438,9 @@ class HiCacheController:
|
|
433
438
|
if self.io_backend == "kernel":
|
434
439
|
return host_indices.to(self.mem_pool_device.device), device_indices
|
435
440
|
elif self.io_backend == "direct":
|
436
|
-
|
441
|
+
device_indices = device_indices.cpu()
|
442
|
+
host_indices, idx = host_indices.sort()
|
443
|
+
return host_indices, device_indices.index_select(0, idx)
|
437
444
|
else:
|
438
445
|
raise ValueError(f"Unsupported io backend")
|
439
446
|
|
@@ -570,10 +577,6 @@ class HiCacheController:
|
|
570
577
|
)
|
571
578
|
completed_tokens += self.page_size
|
572
579
|
else:
|
573
|
-
# operation terminated by controller, release pre-allocated memory
|
574
|
-
self.mem_pool_host.free(
|
575
|
-
operation.host_indices[operation.completed_tokens :]
|
576
|
-
)
|
577
580
|
break
|
578
581
|
|
579
582
|
def mooncake_page_transfer(self, operation):
|
@@ -599,6 +602,14 @@ class HiCacheController:
|
|
599
602
|
self.generic_page_transfer(operation, batch_size=128)
|
600
603
|
else:
|
601
604
|
self.generic_page_transfer(operation)
|
605
|
+
|
606
|
+
if self.tp_world_size > 1:
|
607
|
+
# to ensure all TP workers release the host memory at the same time
|
608
|
+
torch.distributed.barrier(group=self.prefetch_io_tp_group)
|
609
|
+
# operation terminated by controller, release pre-allocated memory
|
610
|
+
self.mem_pool_host.free(
|
611
|
+
operation.host_indices[operation.completed_tokens :]
|
612
|
+
)
|
602
613
|
except Empty:
|
603
614
|
continue
|
604
615
|
|
@@ -626,7 +637,9 @@ class HiCacheController:
|
|
626
637
|
continue
|
627
638
|
|
628
639
|
storage_hit_count = 0
|
629
|
-
if
|
640
|
+
if (
|
641
|
+
operation.host_indices is not None
|
642
|
+
) and self.prefetch_rate_limit_check():
|
630
643
|
last_hash = operation.last_hash
|
631
644
|
tokens_to_fetch = operation.token_ids
|
632
645
|
|
@@ -670,7 +683,8 @@ class HiCacheController:
|
|
670
683
|
if storage_hit_count < self.prefetch_threshold:
|
671
684
|
# not to prefetch if not enough benefits
|
672
685
|
self.prefetch_revoke_queue.put(operation.request_id)
|
673
|
-
|
686
|
+
if operation.host_indices is not None:
|
687
|
+
self.mem_pool_host.free(operation.host_indices)
|
674
688
|
logger.debug(
|
675
689
|
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
|
676
690
|
)
|
@@ -693,12 +707,12 @@ class HiCacheController:
|
|
693
707
|
self,
|
694
708
|
host_indices: torch.Tensor,
|
695
709
|
token_ids: List[int],
|
696
|
-
|
710
|
+
hash_value: Optional[List[str]] = None,
|
697
711
|
) -> int:
|
698
712
|
"""
|
699
713
|
Write KV caches from host memory to storage backend.
|
700
714
|
"""
|
701
|
-
operation = StorageOperation(host_indices, token_ids,
|
715
|
+
operation = StorageOperation(host_indices, token_ids, hash_value=hash_value)
|
702
716
|
self.backup_queue.put(operation)
|
703
717
|
return operation.id
|
704
718
|
|
@@ -753,24 +767,6 @@ class HiCacheController:
|
|
753
767
|
if operation is None:
|
754
768
|
continue
|
755
769
|
|
756
|
-
last_hash = operation.last_hash
|
757
|
-
tokens_to_backup = operation.token_ids
|
758
|
-
|
759
|
-
backup_hit_count = 0
|
760
|
-
remaining_tokens = len(tokens_to_backup)
|
761
|
-
hash_value = []
|
762
|
-
while remaining_tokens >= self.page_size:
|
763
|
-
last_hash = self.get_hash_str(
|
764
|
-
tokens_to_backup[
|
765
|
-
backup_hit_count : backup_hit_count + self.page_size
|
766
|
-
],
|
767
|
-
last_hash,
|
768
|
-
)
|
769
|
-
backup_hit_count += self.page_size
|
770
|
-
hash_value.append(last_hash)
|
771
|
-
remaining_tokens -= self.page_size
|
772
|
-
operation.hash_value = hash_value
|
773
|
-
|
774
770
|
if self.is_mooncake_backend():
|
775
771
|
self.mooncake_page_backup(operation)
|
776
772
|
elif self.storage_backend_type == "hf3fs":
|
@@ -793,7 +789,6 @@ class HiCacheController:
|
|
793
789
|
self.ack_backup_queue.put(
|
794
790
|
(
|
795
791
|
operation.id,
|
796
|
-
operation.hash_value[: min_completed_tokens // self.page_size],
|
797
792
|
min_completed_tokens,
|
798
793
|
)
|
799
794
|
)
|
@@ -216,7 +216,7 @@ class DetokenizerManager:
|
|
216
216
|
rids=recv_obj.rids,
|
217
217
|
finished_reasons=recv_obj.finished_reasons,
|
218
218
|
output_strs=output_strs,
|
219
|
-
output_ids=recv_obj.
|
219
|
+
output_ids=recv_obj.output_ids,
|
220
220
|
prompt_tokens=recv_obj.prompt_tokens,
|
221
221
|
completion_tokens=recv_obj.completion_tokens,
|
222
222
|
cached_tokens=recv_obj.cached_tokens,
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -99,25 +99,24 @@ class GenerateReqInput:
|
|
99
99
|
stream: bool = False
|
100
100
|
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
101
101
|
log_metrics: bool = True
|
102
|
+
# Whether to return hidden states
|
103
|
+
return_hidden_states: Union[List[bool], bool] = False
|
102
104
|
|
103
105
|
# The modalities of the image data [image, multi-images, video]
|
104
106
|
modalities: Optional[List[str]] = None
|
107
|
+
# Session info for continual prompting
|
108
|
+
session_params: Optional[Union[List[Dict], Dict]] = None
|
109
|
+
|
105
110
|
# The path to the LoRA adaptors
|
106
111
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
107
112
|
# The uid of LoRA adaptors, should be initialized by tokenizer manager
|
108
113
|
lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
109
114
|
|
110
|
-
# Session info for continual prompting
|
111
|
-
session_params: Optional[Union[List[Dict], Dict]] = None
|
112
|
-
|
113
115
|
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
114
116
|
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
115
117
|
# Use the processor's `to_str()` method to generate the serialized string.
|
116
118
|
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
117
119
|
|
118
|
-
# Whether to return hidden states
|
119
|
-
return_hidden_states: Union[List[bool], bool] = False
|
120
|
-
|
121
120
|
# For disaggregated inference
|
122
121
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
123
122
|
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
@@ -456,6 +455,7 @@ class GenerateReqInput:
|
|
456
455
|
log_metrics=self.log_metrics,
|
457
456
|
modalities=self.modalities[i] if self.modalities else None,
|
458
457
|
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
458
|
+
lora_id=self.lora_id[i] if self.lora_id is not None else None,
|
459
459
|
custom_logit_processor=(
|
460
460
|
self.custom_logit_processor[i]
|
461
461
|
if self.custom_logit_processor is not None
|
@@ -798,6 +798,8 @@ class UpdateWeightFromDiskReqInput:
|
|
798
798
|
load_format: Optional[str] = None
|
799
799
|
# Whether to abort all requests before updating weights
|
800
800
|
abort_all_requests: bool = False
|
801
|
+
# Optional: Update weight version along with weights
|
802
|
+
weight_version: Optional[str] = None
|
801
803
|
|
802
804
|
|
803
805
|
@dataclass
|
@@ -819,6 +821,8 @@ class UpdateWeightsFromDistributedReqInput:
|
|
819
821
|
flush_cache: bool = True
|
820
822
|
# Whether to abort all requests before updating weights
|
821
823
|
abort_all_requests: bool = False
|
824
|
+
# Optional: Update weight version along with weights
|
825
|
+
weight_version: Optional[str] = None
|
822
826
|
|
823
827
|
|
824
828
|
@dataclass
|
@@ -842,6 +846,8 @@ class UpdateWeightsFromTensorReqInput:
|
|
842
846
|
flush_cache: bool = True
|
843
847
|
# Whether to abort all requests before updating weights
|
844
848
|
abort_all_requests: bool = False
|
849
|
+
# Optional: Update weight version along with weights
|
850
|
+
weight_version: Optional[str] = None
|
845
851
|
|
846
852
|
|
847
853
|
@dataclass
|
@@ -872,6 +878,14 @@ class InitWeightsUpdateGroupReqOutput:
|
|
872
878
|
message: str
|
873
879
|
|
874
880
|
|
881
|
+
@dataclass
|
882
|
+
class UpdateWeightVersionReqInput:
|
883
|
+
# The new weight version
|
884
|
+
new_version: str
|
885
|
+
# Whether to abort all running requests before updating
|
886
|
+
abort_all_requests: bool = True
|
887
|
+
|
888
|
+
|
875
889
|
@dataclass
|
876
890
|
class GetWeightsByNameReqInput:
|
877
891
|
name: str
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -614,8 +614,7 @@ def general_mm_embed_routine(
|
|
614
614
|
input_ids: Input token IDs tensor
|
615
615
|
forward_batch: Batch information for model forward pass
|
616
616
|
language_model: Base language model to use
|
617
|
-
|
618
|
-
audio_data_embedding_func: Function to embed audio data
|
617
|
+
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
|
619
618
|
placeholder_tokens: Token IDs for multimodal placeholders
|
620
619
|
**kwargs: Additional arguments passed to language model
|
621
620
|
|
@@ -20,7 +20,7 @@ def import_processors():
|
|
20
20
|
try:
|
21
21
|
module = importlib.import_module(name)
|
22
22
|
except Exception as e:
|
23
|
-
logger.warning(f"Ignore import error when loading {name}:
|
23
|
+
logger.warning(f"Ignore import error when loading {name}: {e}")
|
24
24
|
continue
|
25
25
|
all_members = inspect.getmembers(module, inspect.isclass)
|
26
26
|
classes = [
|
@@ -37,6 +37,7 @@ import logging
|
|
37
37
|
import threading
|
38
38
|
from enum import Enum, auto
|
39
39
|
from http import HTTPStatus
|
40
|
+
from itertools import chain
|
40
41
|
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
41
42
|
|
42
43
|
import numpy as np
|
@@ -57,6 +58,7 @@ from sglang.srt.mem_cache.allocator import (
|
|
57
58
|
)
|
58
59
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
59
60
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
61
|
+
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
|
60
62
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
61
63
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
62
64
|
from sglang.srt.metrics.collector import TimeStats
|
@@ -82,7 +84,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
82
84
|
"device",
|
83
85
|
"disable_chunked_prefix_cache",
|
84
86
|
"disable_radix_cache",
|
85
|
-
"enable_dp_attention",
|
86
87
|
"enable_two_batch_overlap",
|
87
88
|
"tbo_token_distribution_threshold",
|
88
89
|
"enable_dp_lm_head",
|
@@ -111,6 +112,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
111
112
|
"enable_multimodal",
|
112
113
|
"enable_symm_mem",
|
113
114
|
"quantization",
|
115
|
+
"enable_custom_logit_processor",
|
114
116
|
]
|
115
117
|
|
116
118
|
# Put some global args for easy access
|
@@ -638,14 +640,26 @@ class Req:
|
|
638
640
|
):
|
639
641
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
640
642
|
if tree_cache is not None:
|
641
|
-
(
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
643
|
+
if isinstance(tree_cache, LoRARadixCache):
|
644
|
+
(
|
645
|
+
self.prefix_indices,
|
646
|
+
self.last_node,
|
647
|
+
self.last_host_node,
|
648
|
+
self.host_hit_length,
|
649
|
+
) = tree_cache.match_prefix_with_lora_id(
|
650
|
+
key=LoRAKey(
|
651
|
+
lora_id=self.lora_id, token_ids=self.adjust_max_prefix_ids()
|
652
|
+
),
|
653
|
+
)
|
654
|
+
else:
|
655
|
+
(
|
656
|
+
self.prefix_indices,
|
657
|
+
self.last_node,
|
658
|
+
self.last_host_node,
|
659
|
+
self.host_hit_length,
|
660
|
+
) = tree_cache.match_prefix(
|
661
|
+
key=self.adjust_max_prefix_ids(),
|
662
|
+
)
|
649
663
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
650
664
|
|
651
665
|
def adjust_max_prefix_ids(self):
|
@@ -895,12 +909,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
895
909
|
spec_algorithm: SpeculativeAlgorithm = None
|
896
910
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
|
897
911
|
|
898
|
-
# Enable custom logit processor
|
899
|
-
enable_custom_logit_processor: bool = False
|
900
|
-
|
901
912
|
# Whether to return hidden states
|
902
913
|
return_hidden_states: bool = False
|
903
914
|
|
915
|
+
# Whether this batch is prefill-only (no token generation needed)
|
916
|
+
is_prefill_only: bool = False
|
917
|
+
|
904
918
|
# hicache pointer for synchronizing data loading from CPU to GPU
|
905
919
|
hicache_consumer_index: int = 0
|
906
920
|
|
@@ -914,7 +928,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
914
928
|
model_config: ModelConfig,
|
915
929
|
enable_overlap: bool,
|
916
930
|
spec_algorithm: SpeculativeAlgorithm,
|
917
|
-
enable_custom_logit_processor: bool,
|
918
931
|
chunked_req: Optional[Req] = None,
|
919
932
|
):
|
920
933
|
return_logprob = any(req.return_logprob for req in reqs)
|
@@ -941,8 +954,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
941
954
|
has_grammar=any(req.grammar for req in reqs),
|
942
955
|
device=req_to_token_pool.device,
|
943
956
|
spec_algorithm=spec_algorithm,
|
944
|
-
enable_custom_logit_processor=enable_custom_logit_processor,
|
945
957
|
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
958
|
+
is_prefill_only=all(
|
959
|
+
req.sampling_params.max_new_tokens == 0 for req in reqs
|
960
|
+
),
|
946
961
|
chunked_req=chunked_req,
|
947
962
|
)
|
948
963
|
|
@@ -995,6 +1010,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
995
1010
|
extend_num_tokens: int,
|
996
1011
|
backup_state: bool = False,
|
997
1012
|
):
|
1013
|
+
# Over estimate the number of tokens: assume each request needs a new page.
|
998
1014
|
num_tokens = (
|
999
1015
|
extend_num_tokens
|
1000
1016
|
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
@@ -1027,8 +1043,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1027
1043
|
last_loc: torch.Tensor,
|
1028
1044
|
backup_state: bool = False,
|
1029
1045
|
):
|
1046
|
+
# Over estimate the number of tokens: assume each request needs a new page.
|
1030
1047
|
num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
1031
|
-
|
1032
1048
|
self._evict_tree_cache_if_needed(num_tokens)
|
1033
1049
|
|
1034
1050
|
if backup_state:
|
@@ -1145,9 +1161,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1145
1161
|
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
1146
1162
|
self.device, non_blocking=True
|
1147
1163
|
)
|
1148
|
-
input_ids_tensor = torch.tensor(
|
1149
|
-
|
1150
|
-
)
|
1164
|
+
input_ids_tensor = torch.tensor(
|
1165
|
+
list(chain.from_iterable(input_ids)), dtype=torch.int64
|
1166
|
+
).to(self.device, non_blocking=True)
|
1151
1167
|
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
1152
1168
|
self.device, non_blocking=True
|
1153
1169
|
)
|
@@ -1707,37 +1723,18 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1707
1723
|
extend_prefix_lens = self.prefix_lens
|
1708
1724
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
1709
1725
|
|
1710
|
-
if self.forward_mode.is_decode_or_idle():
|
1711
|
-
attention_backend_str = global_server_args_dict["decode_attention_backend"]
|
1712
|
-
else:
|
1713
|
-
attention_backend_str = global_server_args_dict["prefill_attention_backend"]
|
1714
|
-
# Create seq_lens_cpu when needed
|
1715
|
-
if (
|
1716
|
-
attention_backend_str == "fa3"
|
1717
|
-
or (
|
1718
|
-
global_server_args_dict["use_mla_backend"]
|
1719
|
-
and attention_backend_str == "flashinfer"
|
1720
|
-
)
|
1721
|
-
or attention_backend_str == "flashmla"
|
1722
|
-
or attention_backend_str == "cutlass_mla"
|
1723
|
-
or attention_backend_str == "ascend"
|
1724
|
-
or attention_backend_str == "trtllm_mha"
|
1725
|
-
or global_server_args_dict["enable_two_batch_overlap"]
|
1726
|
-
):
|
1727
|
-
seq_lens_cpu = (
|
1728
|
-
seq_lens_cpu_cache
|
1729
|
-
if seq_lens_cpu_cache is not None
|
1730
|
-
else self.seq_lens.cpu()
|
1731
|
-
)
|
1732
|
-
else:
|
1733
|
-
seq_lens_cpu = None
|
1734
|
-
|
1735
1726
|
if self.sampling_info:
|
1736
1727
|
if self.has_grammar:
|
1737
1728
|
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
1738
1729
|
else:
|
1739
1730
|
self.sampling_info.grammars = None
|
1740
1731
|
|
1732
|
+
seq_lens_cpu = (
|
1733
|
+
seq_lens_cpu_cache
|
1734
|
+
if seq_lens_cpu_cache is not None
|
1735
|
+
else self.seq_lens.cpu()
|
1736
|
+
)
|
1737
|
+
|
1741
1738
|
global bid
|
1742
1739
|
bid += 1
|
1743
1740
|
return ModelWorkerBatch(
|
@@ -1800,18 +1797,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1800
1797
|
return_logprob=self.return_logprob,
|
1801
1798
|
decoding_reqs=self.decoding_reqs,
|
1802
1799
|
spec_algorithm=self.spec_algorithm,
|
1803
|
-
enable_custom_logit_processor=self.enable_custom_logit_processor,
|
1804
1800
|
global_num_tokens=self.global_num_tokens,
|
1805
1801
|
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
1806
1802
|
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
1807
1803
|
is_extend_in_batch=self.is_extend_in_batch,
|
1804
|
+
is_prefill_only=self.is_prefill_only,
|
1808
1805
|
)
|
1809
1806
|
|
1810
|
-
def _evict_tree_cache_if_needed(
|
1811
|
-
self,
|
1812
|
-
num_tokens: int,
|
1813
|
-
) -> None:
|
1814
|
-
if isinstance(self.tree_cache, SWAChunkCache):
|
1807
|
+
def _evict_tree_cache_if_needed(self, num_tokens: int):
|
1808
|
+
if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
|
1815
1809
|
return
|
1816
1810
|
|
1817
1811
|
if self.is_hybrid:
|
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
|
|
36
36
|
# This can prevent the server from being too conservative.
|
37
37
|
# Note that this only clips the estimation in the scheduler but does not change the stop
|
38
38
|
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
|
39
|
-
|
39
|
+
CLIP_MAX_NEW_TOKENS = int(
|
40
40
|
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
|
41
41
|
)
|
42
42
|
|
@@ -305,7 +305,7 @@ class PrefillAdder:
|
|
305
305
|
[
|
306
306
|
min(
|
307
307
|
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
308
|
-
|
308
|
+
CLIP_MAX_NEW_TOKENS,
|
309
309
|
)
|
310
310
|
* self.new_token_ratio
|
311
311
|
for r in running_batch.reqs
|
@@ -388,7 +388,7 @@ class PrefillAdder:
|
|
388
388
|
0,
|
389
389
|
req.extend_input_len,
|
390
390
|
(
|
391
|
-
min(req.sampling_params.max_new_tokens,
|
391
|
+
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
|
392
392
|
if not truncated
|
393
393
|
else 0
|
394
394
|
),
|
@@ -477,7 +477,7 @@ class PrefillAdder:
|
|
477
477
|
self._update_prefill_budget(
|
478
478
|
0,
|
479
479
|
req.extend_input_len,
|
480
|
-
min(req.sampling_params.max_new_tokens,
|
480
|
+
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
|
481
481
|
)
|
482
482
|
else:
|
483
483
|
if self.rem_chunk_tokens == 0:
|
@@ -499,7 +499,7 @@ class PrefillAdder:
|
|
499
499
|
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
500
500
|
|
501
501
|
total_tokens = req.extend_input_len + min(
|
502
|
-
req.sampling_params.max_new_tokens,
|
502
|
+
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
|
503
503
|
)
|
504
504
|
|
505
505
|
# adjusting the input_tokens based on host_hit_length and page_size
|
@@ -544,7 +544,7 @@ class PrefillAdder:
|
|
544
544
|
input_tokens,
|
545
545
|
min(
|
546
546
|
req.sampling_params.max_new_tokens,
|
547
|
-
|
547
|
+
CLIP_MAX_NEW_TOKENS,
|
548
548
|
),
|
549
549
|
)
|
550
550
|
else:
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -130,6 +130,7 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
|
130
130
|
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
|
131
131
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
132
132
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
133
|
+
from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
|
133
134
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
134
135
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
135
136
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
@@ -611,12 +612,7 @@ class Scheduler(
|
|
611
612
|
hicache_ratio=server_args.hicache_ratio,
|
612
613
|
hicache_size=server_args.hicache_size,
|
613
614
|
hicache_write_policy=server_args.hicache_write_policy,
|
614
|
-
hicache_io_backend=
|
615
|
-
"direct"
|
616
|
-
if server_args.attention_backend
|
617
|
-
== "fa3" # hot fix for incompatibility
|
618
|
-
else server_args.hicache_io_backend
|
619
|
-
),
|
615
|
+
hicache_io_backend=server_args.hicache_io_backend,
|
620
616
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
621
617
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
622
618
|
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
@@ -635,7 +631,19 @@ class Scheduler(
|
|
635
631
|
page_size=self.page_size,
|
636
632
|
disable=server_args.disable_radix_cache,
|
637
633
|
)
|
638
|
-
|
634
|
+
elif self.enable_lora:
|
635
|
+
assert (
|
636
|
+
not self.enable_hierarchical_cache
|
637
|
+
), "LoRA radix cache doesn't support hierarchical cache"
|
638
|
+
assert (
|
639
|
+
self.schedule_policy == "fcfs"
|
640
|
+
), "LoRA radix cache only supports FCFS policy"
|
641
|
+
self.tree_cache = LoRARadixCache(
|
642
|
+
req_to_token_pool=self.req_to_token_pool,
|
643
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
644
|
+
page_size=self.page_size,
|
645
|
+
disable=server_args.disable_radix_cache,
|
646
|
+
)
|
639
647
|
else:
|
640
648
|
self.tree_cache = RadixCache(
|
641
649
|
req_to_token_pool=self.req_to_token_pool,
|
@@ -1458,8 +1466,9 @@ class Scheduler(
|
|
1458
1466
|
if self.last_batch.batch_size() < last_bs:
|
1459
1467
|
self.running_batch.batch_is_full = False
|
1460
1468
|
|
1461
|
-
# Merge the new batch into the running batch
|
1462
|
-
|
1469
|
+
# Merge the new batch into the running batch.
|
1470
|
+
# For prefill-only batch, we can avoid going through decoding step.
|
1471
|
+
if not self.last_batch.is_empty() and not self.last_batch.is_prefill_only:
|
1463
1472
|
if self.running_batch.is_empty():
|
1464
1473
|
self.running_batch = self.last_batch
|
1465
1474
|
else:
|
@@ -1626,7 +1635,6 @@ class Scheduler(
|
|
1626
1635
|
self.model_config,
|
1627
1636
|
self.enable_overlap,
|
1628
1637
|
self.spec_algorithm,
|
1629
|
-
self.server_args.enable_custom_logit_processor,
|
1630
1638
|
chunked_req=self.chunked_req,
|
1631
1639
|
)
|
1632
1640
|
if self.enable_hierarchical_cache:
|
@@ -2023,7 +2031,6 @@ class Scheduler(
|
|
2023
2031
|
self.model_config,
|
2024
2032
|
self.enable_overlap,
|
2025
2033
|
self.spec_algorithm,
|
2026
|
-
self.server_args.enable_custom_logit_processor,
|
2027
2034
|
)
|
2028
2035
|
idle_batch.prepare_for_idle()
|
2029
2036
|
return idle_batch
|
@@ -8,6 +8,18 @@ import torch
|
|
8
8
|
|
9
9
|
from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
|
10
10
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
11
|
+
from sglang.srt.utils import is_npu
|
12
|
+
|
13
|
+
_is_npu = is_npu()
|
14
|
+
if _is_npu:
|
15
|
+
import torch_npu
|
16
|
+
|
17
|
+
patches = [
|
18
|
+
["profiler.profile", torch_npu.profiler.profile],
|
19
|
+
["profiler.ProfilerActivity.CUDA", torch_npu.profiler.ProfilerActivity.NPU],
|
20
|
+
["profiler.ProfilerActivity.CPU", torch_npu.profiler.ProfilerActivity.CPU],
|
21
|
+
]
|
22
|
+
torch_npu._apply_patches(patches)
|
11
23
|
|
12
24
|
logger = logging.getLogger(__name__)
|
13
25
|
|
@@ -136,6 +148,13 @@ class SchedulerProfilerMixin:
|
|
136
148
|
activities=torchprof_activities,
|
137
149
|
with_stack=with_stack if with_stack is not None else True,
|
138
150
|
record_shapes=record_shapes if record_shapes is not None else False,
|
151
|
+
on_trace_ready=(
|
152
|
+
None
|
153
|
+
if not _is_npu
|
154
|
+
else torch_npu.profiler.tensorboard_trace_handler(
|
155
|
+
self.torch_profiler_output_dir
|
156
|
+
)
|
157
|
+
),
|
139
158
|
)
|
140
159
|
self.torch_profiler.start()
|
141
160
|
self.profile_in_progress = True
|
@@ -166,15 +185,16 @@ class SchedulerProfilerMixin:
|
|
166
185
|
logger.info("Stop profiling" + stage_suffix + "...")
|
167
186
|
if self.torch_profiler is not None:
|
168
187
|
self.torch_profiler.stop()
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
188
|
+
if not _is_npu:
|
189
|
+
self.torch_profiler.export_chrome_trace(
|
190
|
+
os.path.join(
|
191
|
+
self.torch_profiler_output_dir,
|
192
|
+
self.profile_id
|
193
|
+
+ f"-TP-{self.tp_rank}"
|
194
|
+
+ stage_suffix
|
195
|
+
+ ".trace.json.gz",
|
196
|
+
)
|
176
197
|
)
|
177
|
-
)
|
178
198
|
torch.distributed.barrier(self.tp_cpu_group)
|
179
199
|
|
180
200
|
if self.rpd_profiler is not None:
|