sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -64,7 +64,7 @@ from sglang.srt.hf_transformers_utils import (
|
|
64
64
|
)
|
65
65
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
66
66
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
67
|
-
from sglang.srt.layers.moe
|
67
|
+
from sglang.srt.layers.moe import initialize_moe_config
|
68
68
|
from sglang.srt.managers.io_struct import (
|
69
69
|
AbortReq,
|
70
70
|
CloseSessionReqInput,
|
@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
|
|
72
72
|
ExpertDistributionReqOutput,
|
73
73
|
FlushCacheReqInput,
|
74
74
|
FlushCacheReqOutput,
|
75
|
+
FreezeGCReq,
|
75
76
|
GetInternalStateReq,
|
76
77
|
GetInternalStateReqOutput,
|
77
78
|
GetWeightsByNameReqInput,
|
@@ -145,6 +146,7 @@ from sglang.srt.utils import (
|
|
145
146
|
configure_gc_logger,
|
146
147
|
configure_logger,
|
147
148
|
disable_request_logging,
|
149
|
+
freeze_gc,
|
148
150
|
get_available_gpu_memory,
|
149
151
|
get_bool_env_var,
|
150
152
|
get_zmq_socket,
|
@@ -245,6 +247,9 @@ class Scheduler(
|
|
245
247
|
)
|
246
248
|
)
|
247
249
|
|
250
|
+
# Init model config
|
251
|
+
self.model_config = ModelConfig.from_server_args(server_args)
|
252
|
+
|
248
253
|
# Init inter-process communication
|
249
254
|
context = zmq.Context(2)
|
250
255
|
self.idle_sleeper = None
|
@@ -292,6 +297,9 @@ class Scheduler(
|
|
292
297
|
# Init tokenizer
|
293
298
|
self.init_tokenizer()
|
294
299
|
|
300
|
+
# Init moe config
|
301
|
+
self.init_moe_config()
|
302
|
+
|
295
303
|
# Set reasoning_parser and think_end_id if --reasoning_parser is enabled
|
296
304
|
if self.server_args.reasoning_parser and self.tokenizer:
|
297
305
|
reasoning_parser = ReasoningParser(
|
@@ -518,6 +526,7 @@ class Scheduler(
|
|
518
526
|
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
519
527
|
(SlowDownReqInput, self.slow_down),
|
520
528
|
(ProfileReq, self.profile),
|
529
|
+
(FreezeGCReq, self.handle_freeze_gc),
|
521
530
|
(GetInternalStateReq, self.get_internal_state),
|
522
531
|
(SetInternalStateReq, self.set_internal_state),
|
523
532
|
(RpcReqInput, self.handle_rpc_request),
|
@@ -538,8 +547,6 @@ class Scheduler(
|
|
538
547
|
|
539
548
|
def init_tokenizer(self):
|
540
549
|
server_args = self.server_args
|
541
|
-
|
542
|
-
self.model_config = ModelConfig.from_server_args(server_args)
|
543
550
|
self.is_generation = self.model_config.is_generation
|
544
551
|
|
545
552
|
if server_args.skip_tokenizer_init:
|
@@ -761,6 +768,10 @@ class Scheduler(
|
|
761
768
|
# The prefill requests that are in the middle of kv sending
|
762
769
|
self.disagg_prefill_inflight_queue: List[Req] = []
|
763
770
|
|
771
|
+
def init_moe_config(self):
|
772
|
+
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
|
773
|
+
initialize_moe_config(self.server_args)
|
774
|
+
|
764
775
|
@DynamicGradMode()
|
765
776
|
def event_loop_normal(self):
|
766
777
|
"""A normal scheduler loop."""
|
@@ -1133,7 +1144,7 @@ class Scheduler(
|
|
1133
1144
|
f"boostrap room id. {req.rid=}"
|
1134
1145
|
)
|
1135
1146
|
logger.error(error_msg)
|
1136
|
-
prepare_abort(req, error_msg)
|
1147
|
+
prepare_abort(req, error_msg, status_code=HTTPStatus.BAD_REQUEST)
|
1137
1148
|
self.stream_output([req], req.return_logprob)
|
1138
1149
|
return
|
1139
1150
|
|
@@ -1466,8 +1477,9 @@ class Scheduler(
|
|
1466
1477
|
if self.last_batch.batch_size() < last_bs:
|
1467
1478
|
self.running_batch.batch_is_full = False
|
1468
1479
|
|
1469
|
-
# Merge the new batch into the running batch
|
1470
|
-
|
1480
|
+
# Merge the new batch into the running batch.
|
1481
|
+
# For prefill-only batch, we can avoid going through decoding step.
|
1482
|
+
if not self.last_batch.is_empty() and not self.last_batch.is_prefill_only:
|
1471
1483
|
if self.running_batch.is_empty():
|
1472
1484
|
self.running_batch = self.last_batch
|
1473
1485
|
else:
|
@@ -1634,7 +1646,6 @@ class Scheduler(
|
|
1634
1646
|
self.model_config,
|
1635
1647
|
self.enable_overlap,
|
1636
1648
|
self.spec_algorithm,
|
1637
|
-
self.server_args.enable_custom_logit_processor,
|
1638
1649
|
chunked_req=self.chunked_req,
|
1639
1650
|
)
|
1640
1651
|
if self.enable_hierarchical_cache:
|
@@ -1823,11 +1834,6 @@ class Scheduler(
|
|
1823
1834
|
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
1824
1835
|
spec_algorithm=self.spec_algorithm,
|
1825
1836
|
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
1826
|
-
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
|
1827
|
-
enable_deepep_moe=MoeA2ABackend(
|
1828
|
-
self.server_args.moe_a2a_backend
|
1829
|
-
).is_deepep(),
|
1830
|
-
deepep_mode=DeepEPMode(self.server_args.deepep_mode),
|
1831
1837
|
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
1832
1838
|
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
1833
1839
|
)
|
@@ -1922,9 +1928,6 @@ class Scheduler(
|
|
1922
1928
|
disable_cuda_graph: bool,
|
1923
1929
|
spec_algorithm,
|
1924
1930
|
speculative_num_draft_tokens,
|
1925
|
-
enable_two_batch_overlap: bool,
|
1926
|
-
enable_deepep_moe: bool,
|
1927
|
-
deepep_mode: DeepEPMode,
|
1928
1931
|
require_mlp_tp_gather: bool,
|
1929
1932
|
disable_overlap_schedule: bool,
|
1930
1933
|
):
|
@@ -1972,9 +1975,6 @@ class Scheduler(
|
|
1972
1975
|
is_extend_in_batch,
|
1973
1976
|
*tbo_preparer.prepare_all_gather(
|
1974
1977
|
local_batch,
|
1975
|
-
deepep_mode,
|
1976
|
-
enable_deepep_moe,
|
1977
|
-
enable_two_batch_overlap,
|
1978
1978
|
),
|
1979
1979
|
],
|
1980
1980
|
dtype=torch.int64,
|
@@ -2031,7 +2031,6 @@ class Scheduler(
|
|
2031
2031
|
self.model_config,
|
2032
2032
|
self.enable_overlap,
|
2033
2033
|
self.spec_algorithm,
|
2034
|
-
self.server_args.enable_custom_logit_processor,
|
2035
2034
|
)
|
2036
2035
|
idle_batch.prepare_for_idle()
|
2037
2036
|
return idle_batch
|
@@ -2473,6 +2472,12 @@ class Scheduler(
|
|
2473
2472
|
if self.idle_sleeper is not None:
|
2474
2473
|
self.idle_sleeper.maybe_sleep()
|
2475
2474
|
|
2475
|
+
def handle_freeze_gc(self, recv_req: FreezeGCReq):
|
2476
|
+
"""Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer."""
|
2477
|
+
freeze_gc("Scheduler")
|
2478
|
+
self.send_to_detokenizer.send_pyobj(recv_req)
|
2479
|
+
return None
|
2480
|
+
|
2476
2481
|
|
2477
2482
|
class IdleSleeper:
|
2478
2483
|
"""
|
@@ -2583,7 +2588,10 @@ def run_scheduler_process(
|
|
2583
2588
|
if scheduler.enable_overlap:
|
2584
2589
|
scheduler.event_loop_overlap_disagg_prefill()
|
2585
2590
|
else:
|
2586
|
-
|
2591
|
+
if server_args.pp_size > 1:
|
2592
|
+
scheduler.event_loop_pp_disagg_prefill()
|
2593
|
+
else:
|
2594
|
+
scheduler.event_loop_normal_disagg_prefill()
|
2587
2595
|
|
2588
2596
|
elif disaggregation_mode == DisaggregationMode.DECODE:
|
2589
2597
|
if scheduler.enable_overlap:
|
@@ -54,7 +54,7 @@ class SessionReqNode:
|
|
54
54
|
prefix += " -- " + self.childs[0].req.rid
|
55
55
|
ret = self.childs[0]._str_helper(prefix)
|
56
56
|
for child in self.childs[1:]:
|
57
|
-
prefix = " " * len(origin_prefix) + "
|
57
|
+
prefix = " " * len(origin_prefix) + " \\- " + child.req.rid
|
58
58
|
ret += child._str_helper(prefix)
|
59
59
|
return ret
|
60
60
|
|
@@ -89,6 +89,7 @@ class TemplateManager:
|
|
89
89
|
if template is None:
|
90
90
|
return False
|
91
91
|
|
92
|
+
# TODO: remove this hard code the reasoning pattern
|
92
93
|
force_reasoning_pattern = r"<\|im_start\|>assistant\\n<think>\\n"
|
93
94
|
has_reasoning = re.search(force_reasoning_pattern, template) is not None
|
94
95
|
|
@@ -128,11 +129,12 @@ class TemplateManager:
|
|
128
129
|
logger.info(
|
129
130
|
f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
|
130
131
|
)
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
132
|
+
else:
|
133
|
+
# Default to string content format if no template was found
|
134
|
+
self._jinja_template_content_format = "string"
|
135
|
+
logger.info(
|
136
|
+
"No chat template found, defaulting to 'string' content format"
|
137
|
+
)
|
136
138
|
|
137
139
|
# Detect reasoning pattern from chat template
|
138
140
|
if tokenizer_manager.tokenizer:
|
@@ -78,6 +78,7 @@ from sglang.srt.managers.io_struct import (
|
|
78
78
|
ExpertDistributionReqOutput,
|
79
79
|
FlushCacheReqInput,
|
80
80
|
FlushCacheReqOutput,
|
81
|
+
FreezeGCReq,
|
81
82
|
GenerateReqInput,
|
82
83
|
GetInternalStateReq,
|
83
84
|
GetInternalStateReqOutput,
|
@@ -122,7 +123,9 @@ from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
|
122
123
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
123
124
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
124
125
|
from sglang.srt.utils import (
|
126
|
+
configure_gc_warning,
|
125
127
|
dataclass_to_string_truncated,
|
128
|
+
freeze_gc,
|
126
129
|
get_bool_env_var,
|
127
130
|
get_zmq_socket,
|
128
131
|
kill_process_tree,
|
@@ -298,7 +301,7 @@ class TokenizerManager:
|
|
298
301
|
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
|
299
302
|
# serves as the source of truth for available adapters and maps user-friendly LoRA names
|
300
303
|
# to internally used unique LoRA IDs.
|
301
|
-
self.lora_registry = LoRARegistry(self.server_args.lora_paths
|
304
|
+
self.lora_registry = LoRARegistry(self.server_args.lora_paths)
|
302
305
|
# Lock to serialize LoRA update operations.
|
303
306
|
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
|
304
307
|
# LoRA updates and inference to overlap.
|
@@ -352,6 +355,10 @@ class TokenizerManager:
|
|
352
355
|
collect_tokens_histogram=self.server_args.collect_tokens_histogram,
|
353
356
|
)
|
354
357
|
|
358
|
+
# Configure GC warning
|
359
|
+
if self.server_args.gc_warning_threshold_secs > 0.0:
|
360
|
+
configure_gc_warning(self.server_args.gc_warning_threshold_secs)
|
361
|
+
|
355
362
|
# Communicators
|
356
363
|
self.init_weights_update_group_communicator = _Communicator(
|
357
364
|
self.send_to_scheduler, server_args.dp_size
|
@@ -446,6 +453,10 @@ class TokenizerManager:
|
|
446
453
|
ProfileReqOutput,
|
447
454
|
self.profile_communicator.handle_recv,
|
448
455
|
),
|
456
|
+
(
|
457
|
+
FreezeGCReq,
|
458
|
+
lambda x: None,
|
459
|
+
), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
|
449
460
|
(
|
450
461
|
GetInternalStateReqOutput,
|
451
462
|
self.get_internal_state_communicator.handle_recv,
|
@@ -565,14 +576,24 @@ class TokenizerManager:
|
|
565
576
|
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
|
566
577
|
) -> None:
|
567
578
|
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
|
579
|
+
# FIXME: unify the length validation logic with the one in the scheduler.
|
580
|
+
_max_req_len = self.context_len
|
568
581
|
|
569
582
|
input_token_num = len(input_ids) if input_ids is not None else 0
|
570
|
-
# Check if input alone exceeds context length
|
571
583
|
if input_token_num >= self.context_len:
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
584
|
+
if self.server_args.allow_auto_truncate:
|
585
|
+
logger.warning(
|
586
|
+
f"The input ({input_token_num} tokens) is longer than the "
|
587
|
+
f"model's context length ({self.context_len} tokens). "
|
588
|
+
"Truncating the input."
|
589
|
+
)
|
590
|
+
del input_ids[_max_req_len:]
|
591
|
+
input_token_num = len(input_ids)
|
592
|
+
else:
|
593
|
+
raise ValueError(
|
594
|
+
f"The input ({input_token_num} tokens) is longer than the "
|
595
|
+
f"model's context length ({self.context_len} tokens)."
|
596
|
+
)
|
576
597
|
|
577
598
|
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
578
599
|
raise ValueError(
|
@@ -584,17 +605,27 @@ class TokenizerManager:
|
|
584
605
|
max_new_tokens = obj.sampling_params.get("max_new_tokens")
|
585
606
|
if (
|
586
607
|
max_new_tokens is not None
|
587
|
-
and (max_new_tokens + input_token_num) >=
|
608
|
+
and (max_new_tokens + input_token_num) >= _max_req_len
|
588
609
|
):
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
610
|
+
if self.server_args.allow_auto_truncate:
|
611
|
+
logger.warning(
|
612
|
+
f"Requested token count ({input_token_num} input + {max_new_tokens} new) "
|
613
|
+
f"exceeds the model's context length ({self.context_len} tokens). "
|
614
|
+
"Truncating max_new_tokens."
|
615
|
+
)
|
616
|
+
obj.sampling_params["max_new_tokens"] = max(
|
617
|
+
0, _max_req_len - input_token_num
|
618
|
+
)
|
619
|
+
else:
|
620
|
+
total_tokens = max_new_tokens + input_token_num
|
621
|
+
error_msg = (
|
622
|
+
f"Requested token count exceeds the model's maximum context length "
|
623
|
+
f"of {self.context_len} tokens. You requested a total of {total_tokens} "
|
624
|
+
f"tokens: {input_token_num} tokens from the input messages and "
|
625
|
+
f"{max_new_tokens} tokens for the completion. Please reduce the number "
|
626
|
+
f"of tokens in the input messages or the completion to fit within the limit."
|
627
|
+
)
|
628
|
+
raise ValueError(error_msg)
|
598
629
|
|
599
630
|
if isinstance(obj, GenerateReqInput):
|
600
631
|
if (
|
@@ -699,7 +730,7 @@ class TokenizerManager:
|
|
699
730
|
# Process all requests
|
700
731
|
tokenized_objs = []
|
701
732
|
for i, req in enumerate(requests):
|
702
|
-
self.
|
733
|
+
self._validate_one_request(obj[i], input_ids_list[i])
|
703
734
|
tokenized_objs.append(
|
704
735
|
self._create_tokenized_object(
|
705
736
|
req, req.text, input_ids_list[i], None, None
|
@@ -782,15 +813,17 @@ class TokenizerManager:
|
|
782
813
|
):
|
783
814
|
raise ValueError(finish_reason["message"])
|
784
815
|
|
785
|
-
if (
|
786
|
-
|
787
|
-
|
788
|
-
|
816
|
+
if finish_reason.get("type") == "abort" and finish_reason.get(
|
817
|
+
"status_code"
|
818
|
+
) in (
|
819
|
+
HTTPStatus.SERVICE_UNAVAILABLE,
|
820
|
+
HTTPStatus.INTERNAL_SERVER_ERROR,
|
789
821
|
):
|
790
822
|
# This is an abort request initiated by scheduler.
|
791
823
|
# Delete the key to prevent resending abort request to the scheduler and
|
792
824
|
# to ensure aborted request state is cleaned up.
|
793
|
-
|
825
|
+
if state.obj.rid in self.rid_to_state:
|
826
|
+
del self.rid_to_state[state.obj.rid]
|
794
827
|
|
795
828
|
# Mark ongoing LoRA request as finished.
|
796
829
|
if self.server_args.enable_lora and state.obj.lora_path:
|
@@ -1337,6 +1370,12 @@ class TokenizerManager:
|
|
1337
1370
|
logging.info(f"Config logging: {obj=}")
|
1338
1371
|
self.log_request_metadata = self.get_log_request_metadata()
|
1339
1372
|
|
1373
|
+
async def freeze_gc(self):
|
1374
|
+
"""Send a freeze_gc message to the scheduler first, then freeze locally."""
|
1375
|
+
self.send_to_scheduler.send_pyobj(FreezeGCReq())
|
1376
|
+
freeze_gc("Tokenizer Manager")
|
1377
|
+
return None
|
1378
|
+
|
1340
1379
|
def create_abort_task(self, obj: GenerateReqInput):
|
1341
1380
|
# Abort the request if the client is disconnected.
|
1342
1381
|
async def abort_request():
|
@@ -1529,6 +1568,7 @@ class TokenizerManager:
|
|
1529
1568
|
"id": rid,
|
1530
1569
|
"finish_reason": recv_obj.finished_reasons[i],
|
1531
1570
|
"prompt_tokens": recv_obj.prompt_tokens[i],
|
1571
|
+
"weight_version": self.server_args.weight_version,
|
1532
1572
|
}
|
1533
1573
|
|
1534
1574
|
if getattr(state.obj, "return_logprob", False):
|
@@ -1892,6 +1932,13 @@ class TokenizerManager:
|
|
1892
1932
|
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
|
1893
1933
|
)
|
1894
1934
|
|
1935
|
+
batch_request = GenerateReqInput(
|
1936
|
+
token_ids_logprob=label_token_ids,
|
1937
|
+
return_logprob=True,
|
1938
|
+
stream=False,
|
1939
|
+
sampling_params={"max_new_tokens": 0},
|
1940
|
+
)
|
1941
|
+
|
1895
1942
|
# Handle string or tokenized query/items
|
1896
1943
|
if isinstance(query, str) and (
|
1897
1944
|
isinstance(items, str)
|
@@ -1903,13 +1950,9 @@ class TokenizerManager:
|
|
1903
1950
|
prompts = [f"{item}{query}" for item in items_list]
|
1904
1951
|
else:
|
1905
1952
|
prompts = [f"{query}{item}" for item in items_list]
|
1906
|
-
|
1907
|
-
|
1908
|
-
|
1909
|
-
token_ids_logprob=label_token_ids,
|
1910
|
-
stream=False,
|
1911
|
-
sampling_params={"max_new_tokens": 1},
|
1912
|
-
)
|
1953
|
+
|
1954
|
+
batch_request.text = prompts
|
1955
|
+
|
1913
1956
|
elif (
|
1914
1957
|
isinstance(query, list)
|
1915
1958
|
and isinstance(items, list)
|
@@ -1921,13 +1964,8 @@ class TokenizerManager:
|
|
1921
1964
|
input_ids_list = [item + query for item in items]
|
1922
1965
|
else:
|
1923
1966
|
input_ids_list = [query + item for item in items]
|
1924
|
-
|
1925
|
-
|
1926
|
-
return_logprob=True,
|
1927
|
-
token_ids_logprob=label_token_ids,
|
1928
|
-
stream=False,
|
1929
|
-
sampling_params={"max_new_tokens": 1},
|
1930
|
-
)
|
1967
|
+
|
1968
|
+
batch_request.input_ids = input_ids_list
|
1931
1969
|
else:
|
1932
1970
|
raise ValueError(
|
1933
1971
|
"Invalid combination of query/items types for score_request."
|
@@ -1939,9 +1977,20 @@ class TokenizerManager:
|
|
1939
1977
|
for result in results:
|
1940
1978
|
# Get logprobs for each token
|
1941
1979
|
logprobs = {}
|
1942
|
-
|
1943
|
-
|
1944
|
-
|
1980
|
+
|
1981
|
+
# For scoring requests, we read from output_token_ids_logprobs since we want
|
1982
|
+
# the logprobs for specific tokens mentioned in the label_token_ids at
|
1983
|
+
# the next position after the last token in the prompt
|
1984
|
+
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
1985
|
+
|
1986
|
+
# Throw an error here if output_logprobs is None
|
1987
|
+
if output_logprobs is None:
|
1988
|
+
raise RuntimeError(
|
1989
|
+
f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
|
1990
|
+
"This usually indicates a problem with the scoring request or the backend output."
|
1991
|
+
)
|
1992
|
+
|
1993
|
+
for logprob, token_id, _ in output_logprobs[0]:
|
1945
1994
|
if token_id in label_token_ids:
|
1946
1995
|
logprobs[token_id] = logprob
|
1947
1996
|
|
sglang/srt/managers/tp_worker.py
CHANGED
sglang/srt/managers/utils.py
CHANGED
@@ -1,9 +1,16 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import logging
|
2
4
|
import multiprocessing as mp
|
3
5
|
from http import HTTPStatus
|
4
|
-
from typing import Dict, List, Optional
|
6
|
+
from typing import TYPE_CHECKING, Dict, List, Optional
|
5
7
|
|
8
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
6
9
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
10
|
+
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from sglang.srt.managers.scheduler import GenerationBatchResult
|
7
14
|
|
8
15
|
logger = logging.getLogger(__name__)
|
9
16
|
|
@@ -41,6 +48,57 @@ def validate_input_length(
|
|
41
48
|
return None
|
42
49
|
|
43
50
|
|
51
|
+
def get_logprob_dict_from_result(result: GenerationBatchResult) -> dict:
|
52
|
+
|
53
|
+
logits_output = result.logits_output
|
54
|
+
assert logits_output is not None
|
55
|
+
|
56
|
+
return {
|
57
|
+
"extend_input_len_per_req": result.extend_input_len_per_req,
|
58
|
+
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
|
59
|
+
"next_token_logprobs": result.logits_output.next_token_logprobs,
|
60
|
+
"next_token_top_logprobs_val": result.logits_output.next_token_top_logprobs_val,
|
61
|
+
"next_token_top_logprobs_idx": result.logits_output.next_token_top_logprobs_idx,
|
62
|
+
"next_token_token_ids_logprobs_val": result.logits_output.next_token_token_ids_logprobs_val,
|
63
|
+
"next_token_token_ids_logprobs_idx": result.logits_output.next_token_token_ids_logprobs_idx,
|
64
|
+
"input_token_logprobs": result.logits_output.input_token_logprobs,
|
65
|
+
"input_top_logprobs_val": result.logits_output.input_top_logprobs_val,
|
66
|
+
"input_top_logprobs_idx": result.logits_output.input_top_logprobs_idx,
|
67
|
+
"input_token_ids_logprobs_val": result.logits_output.input_token_ids_logprobs_val,
|
68
|
+
"input_token_ids_logprobs_idx": result.logits_output.input_token_ids_logprobs_idx,
|
69
|
+
}
|
70
|
+
|
71
|
+
|
72
|
+
def get_logprob_from_pp_outputs(
|
73
|
+
next_pp_outputs: PPProxyTensors,
|
74
|
+
) -> tuple[LogitsProcessorOutput, list[int], list[int]]:
|
75
|
+
logits_output = LogitsProcessorOutput(
|
76
|
+
# Do not send logits and hidden states because they are large
|
77
|
+
next_token_logits=None,
|
78
|
+
hidden_states=None,
|
79
|
+
next_token_logprobs=next_pp_outputs["next_token_logprobs"],
|
80
|
+
next_token_top_logprobs_val=next_pp_outputs["next_token_top_logprobs_val"],
|
81
|
+
next_token_top_logprobs_idx=next_pp_outputs["next_token_top_logprobs_idx"],
|
82
|
+
next_token_token_ids_logprobs_val=next_pp_outputs[
|
83
|
+
"next_token_token_ids_logprobs_val"
|
84
|
+
],
|
85
|
+
next_token_token_ids_logprobs_idx=next_pp_outputs[
|
86
|
+
"next_token_token_ids_logprobs_idx"
|
87
|
+
],
|
88
|
+
input_token_logprobs=next_pp_outputs["input_token_logprobs"],
|
89
|
+
input_top_logprobs_val=next_pp_outputs["input_top_logprobs_val"],
|
90
|
+
input_top_logprobs_idx=next_pp_outputs["input_top_logprobs_idx"],
|
91
|
+
input_token_ids_logprobs_val=next_pp_outputs["input_token_ids_logprobs_val"],
|
92
|
+
input_token_ids_logprobs_idx=next_pp_outputs["input_token_ids_logprobs_idx"],
|
93
|
+
)
|
94
|
+
extend_input_len_per_req = next_pp_outputs["extend_input_len_per_req"]
|
95
|
+
extend_logprob_start_len_per_req = next_pp_outputs[
|
96
|
+
"extend_logprob_start_len_per_req"
|
97
|
+
]
|
98
|
+
|
99
|
+
return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req
|
100
|
+
|
101
|
+
|
44
102
|
class DPBalanceMeta:
|
45
103
|
"""
|
46
104
|
This class will be use in scheduler and dp controller
|