sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -117
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +3 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +22 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +8 -5
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +106 -15
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +55 -13
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
- sglang/srt/layers/attention/vision.py +40 -15
- sglang/srt/layers/communicator.py +35 -8
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +9 -8
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +87 -107
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +59 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +8 -7
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +15 -4
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +10 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +61 -32
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +21 -4
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +30 -8
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +170 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +59 -22
- sglang/srt/managers/tokenizer_manager.py +137 -67
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -21
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +48 -17
- sglang/srt/model_executor/model_runner.py +24 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +95 -50
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +102 -27
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/qwen3_moe.py +39 -14
- sglang/srt/models/step3_vl.py +10 -1
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +218 -23
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +163 -9
- sglang/srt/utils.py +41 -26
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/runners.py +4 -4
- sglang/test/test_utils.py +4 -4
- sglang/version.py +1 -1
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -64,6 +64,7 @@ from sglang.srt.hf_transformers_utils import (
|
|
64
64
|
)
|
65
65
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
66
66
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
67
|
+
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
67
68
|
from sglang.srt.managers.io_struct import (
|
68
69
|
AbortReq,
|
69
70
|
CloseSessionReqInput,
|
@@ -119,13 +120,14 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
|
|
119
120
|
SchedulerOutputProcessorMixin,
|
120
121
|
)
|
121
122
|
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
|
123
|
+
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
|
122
124
|
from sglang.srt.managers.scheduler_update_weights_mixin import (
|
123
125
|
SchedulerUpdateWeightsMixin,
|
124
126
|
)
|
125
127
|
from sglang.srt.managers.session_controller import Session
|
126
128
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
127
129
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
128
|
-
from sglang.srt.managers.utils import validate_input_length
|
130
|
+
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
|
129
131
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
130
132
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
131
133
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
@@ -137,7 +139,6 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|
137
139
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
138
140
|
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
139
141
|
from sglang.srt.utils import (
|
140
|
-
DeepEPMode,
|
141
142
|
DynamicGradMode,
|
142
143
|
broadcast_pyobj,
|
143
144
|
configure_gc_logger,
|
@@ -203,6 +204,7 @@ class Scheduler(
|
|
203
204
|
moe_ep_rank: int,
|
204
205
|
pp_rank: int,
|
205
206
|
dp_rank: Optional[int],
|
207
|
+
dp_balance_meta: Optional[DPBalanceMeta] = None,
|
206
208
|
):
|
207
209
|
# Parse args
|
208
210
|
self.server_args = server_args
|
@@ -471,8 +473,10 @@ class Scheduler(
|
|
471
473
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
472
474
|
enable=server_args.enable_memory_saver
|
473
475
|
)
|
476
|
+
self.offload_tags = set()
|
474
477
|
self.init_profier()
|
475
478
|
|
479
|
+
self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
|
476
480
|
self.input_blocker = (
|
477
481
|
SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
|
478
482
|
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
|
@@ -522,6 +526,15 @@ class Scheduler(
|
|
522
526
|
]
|
523
527
|
)
|
524
528
|
|
529
|
+
self.balance_meta = dp_balance_meta
|
530
|
+
if (
|
531
|
+
server_args.enable_dp_attention
|
532
|
+
and server_args.load_balance_method == "minimum_tokens"
|
533
|
+
):
|
534
|
+
assert dp_balance_meta is not None
|
535
|
+
|
536
|
+
self.recv_dp_balance_id_this_term = []
|
537
|
+
|
525
538
|
def init_tokenizer(self):
|
526
539
|
server_args = self.server_args
|
527
540
|
|
@@ -569,7 +582,23 @@ class Scheduler(
|
|
569
582
|
page_size=self.page_size,
|
570
583
|
)
|
571
584
|
else:
|
572
|
-
if
|
585
|
+
if os.environ.get("SGLANG_EXPERIMENTAL_CPP_RADIX_TREE") == "1":
|
586
|
+
# lazy import to avoid JIT overhead
|
587
|
+
from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp
|
588
|
+
|
589
|
+
self.tree_cache = RadixCacheCpp(
|
590
|
+
disable=False,
|
591
|
+
use_hicache=self.enable_hierarchical_cache,
|
592
|
+
req_to_token_pool=self.req_to_token_pool,
|
593
|
+
token_to_kv_pool=self.token_to_kv_pool_allocator,
|
594
|
+
tp_cache_group=self.tp_cpu_group,
|
595
|
+
page_size=self.page_size,
|
596
|
+
hicache_ratio=server_args.hicache_ratio,
|
597
|
+
hicache_size=server_args.hicache_size,
|
598
|
+
hicache_write_policy=server_args.hicache_write_policy,
|
599
|
+
enable_kv_cache_events=self.enable_kv_cache_events,
|
600
|
+
)
|
601
|
+
elif self.enable_hierarchical_cache:
|
573
602
|
self.tree_cache = HiRadixCache(
|
574
603
|
req_to_token_pool=self.req_to_token_pool,
|
575
604
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
@@ -590,6 +619,7 @@ class Scheduler(
|
|
590
619
|
),
|
591
620
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
592
621
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
622
|
+
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
593
623
|
)
|
594
624
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
595
625
|
self.tree_cache.cache_controller.layer_done_counter
|
@@ -920,6 +950,14 @@ class Scheduler(
|
|
920
950
|
|
921
951
|
def recv_requests(self) -> List[Req]:
|
922
952
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
953
|
+
|
954
|
+
if self.recv_skipper is not None:
|
955
|
+
last_forward_mode = (
|
956
|
+
self.last_batch.forward_mode if self.last_batch is not None else None
|
957
|
+
)
|
958
|
+
if not self.recv_skipper.handle(last_forward_mode):
|
959
|
+
return []
|
960
|
+
|
923
961
|
if self.pp_rank == 0:
|
924
962
|
if self.attn_tp_rank == 0:
|
925
963
|
recv_reqs = []
|
@@ -1003,7 +1041,9 @@ class Scheduler(
|
|
1003
1041
|
for recv_req in recv_reqs:
|
1004
1042
|
# If it is a health check generation request and there are running requests, ignore it.
|
1005
1043
|
if is_health_check_generate_req(recv_req) and (
|
1006
|
-
self.chunked_req is not None
|
1044
|
+
self.chunked_req is not None
|
1045
|
+
or not self.running_batch.is_empty()
|
1046
|
+
or len(self.offload_tags) > 0
|
1007
1047
|
):
|
1008
1048
|
self.return_health_check_ct += 1
|
1009
1049
|
continue
|
@@ -1033,6 +1073,12 @@ class Scheduler(
|
|
1033
1073
|
self,
|
1034
1074
|
recv_req: TokenizedGenerateReqInput,
|
1035
1075
|
):
|
1076
|
+
if (
|
1077
|
+
self.server_args.enable_dp_attention
|
1078
|
+
and self.server_args.load_balance_method == "minimum_tokens"
|
1079
|
+
):
|
1080
|
+
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
|
1081
|
+
|
1036
1082
|
# Create a new request
|
1037
1083
|
if (
|
1038
1084
|
recv_req.session_params is None
|
@@ -1058,7 +1104,7 @@ class Scheduler(
|
|
1058
1104
|
top_logprobs_num=recv_req.top_logprobs_num,
|
1059
1105
|
token_ids_logprob=recv_req.token_ids_logprob,
|
1060
1106
|
stream=recv_req.stream,
|
1061
|
-
|
1107
|
+
lora_id=recv_req.lora_id,
|
1062
1108
|
input_embeds=recv_req.input_embeds,
|
1063
1109
|
custom_logit_processor=recv_req.custom_logit_processor,
|
1064
1110
|
return_hidden_states=recv_req.return_hidden_states,
|
@@ -1443,6 +1489,11 @@ class Scheduler(
|
|
1443
1489
|
|
1444
1490
|
# Handle DP attention
|
1445
1491
|
if need_dp_attn_preparation:
|
1492
|
+
if (
|
1493
|
+
self.server_args.load_balance_method == "minimum_tokens"
|
1494
|
+
and self.forward_ct % 40 == 0
|
1495
|
+
):
|
1496
|
+
self.handle_dp_balance_data(ret)
|
1446
1497
|
ret = self.prepare_mlp_sync_batch(ret)
|
1447
1498
|
|
1448
1499
|
return ret
|
@@ -1497,18 +1548,15 @@ class Scheduler(
|
|
1497
1548
|
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
1498
1549
|
|
1499
1550
|
if self.enable_lora:
|
1500
|
-
lora_set = set([req.
|
1551
|
+
lora_set = set([req.lora_id for req in self.running_batch.reqs])
|
1501
1552
|
|
1502
1553
|
# Get requests from the waiting queue to a new prefill batch
|
1503
1554
|
for req in self.waiting_queue:
|
1504
|
-
|
1505
|
-
|
1506
|
-
|
1507
|
-
|
1508
|
-
|
1509
|
-
| set([req.lora_path])
|
1510
|
-
)
|
1511
|
-
> self.max_loras_per_batch
|
1555
|
+
|
1556
|
+
if self.enable_lora and not self.tp_worker.can_run_lora_batch(
|
1557
|
+
lora_set
|
1558
|
+
| set([req.lora_id for req in adder.can_run_list])
|
1559
|
+
| set([req.lora_id])
|
1512
1560
|
):
|
1513
1561
|
self.running_batch.batch_is_full = True
|
1514
1562
|
break
|
@@ -1525,7 +1573,10 @@ class Scheduler(
|
|
1525
1573
|
break
|
1526
1574
|
|
1527
1575
|
if self.enable_hicache_storage:
|
1528
|
-
self.tree_cache.check_prefetch_progress(req.rid)
|
1576
|
+
prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
|
1577
|
+
if not prefetch_done:
|
1578
|
+
# skip staging requests that are ongoing prefetch
|
1579
|
+
continue
|
1529
1580
|
|
1530
1581
|
req.init_next_round_input(self.tree_cache)
|
1531
1582
|
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
|
@@ -1744,6 +1795,9 @@ class Scheduler(
|
|
1744
1795
|
elif batch.forward_mode.is_dummy_first():
|
1745
1796
|
self.set_next_batch_sampling_info_done(batch)
|
1746
1797
|
|
1798
|
+
self.maybe_send_health_check_signal()
|
1799
|
+
|
1800
|
+
def maybe_send_health_check_signal(self):
|
1747
1801
|
if self.return_health_check_ct:
|
1748
1802
|
# Return some signal for the health check.
|
1749
1803
|
# This is used to prevent the health check signal being blocked by long context prefill.
|
@@ -1762,12 +1816,94 @@ class Scheduler(
|
|
1762
1816
|
spec_algorithm=self.spec_algorithm,
|
1763
1817
|
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
1764
1818
|
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
|
1765
|
-
enable_deepep_moe=
|
1766
|
-
|
1819
|
+
enable_deepep_moe=MoeA2ABackend(
|
1820
|
+
self.server_args.moe_a2a_backend
|
1821
|
+
).is_deepep(),
|
1822
|
+
deepep_mode=DeepEPMode(self.server_args.deepep_mode),
|
1767
1823
|
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
1768
1824
|
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
1769
1825
|
)
|
1770
1826
|
|
1827
|
+
def handle_dp_balance_data(self, local_batch: ScheduleBatch):
|
1828
|
+
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
|
1829
|
+
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
|
1830
|
+
recv_list = self.recv_dp_balance_id_this_term
|
1831
|
+
assert len(recv_list) <= 511, (
|
1832
|
+
"The number of requests received this round is too large. "
|
1833
|
+
"Please increase gather_tensor_size and onfly_info_size."
|
1834
|
+
)
|
1835
|
+
# The maximum size of the tensor used for gathering data from all workers.
|
1836
|
+
gather_tensor_size = 512
|
1837
|
+
|
1838
|
+
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
|
1839
|
+
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
|
1840
|
+
recv_tensor[0] = holding_tokens_list
|
1841
|
+
recv_tensor[1] = len(
|
1842
|
+
recv_list
|
1843
|
+
) # The first element is the length of the list.
|
1844
|
+
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
|
1845
|
+
recv_list, dtype=torch.int32
|
1846
|
+
)
|
1847
|
+
|
1848
|
+
if self.tp_rank == 0:
|
1849
|
+
gathered_list = [
|
1850
|
+
torch.zeros(gather_tensor_size, dtype=torch.int32)
|
1851
|
+
for _ in range(self.balance_meta.num_workers)
|
1852
|
+
]
|
1853
|
+
else:
|
1854
|
+
gathered_list = None
|
1855
|
+
|
1856
|
+
torch.distributed.gather(
|
1857
|
+
recv_tensor, gathered_list, group=self.tp_cpu_group
|
1858
|
+
)
|
1859
|
+
|
1860
|
+
gathered_id_list_per_worker = None
|
1861
|
+
if self.tp_rank == 0:
|
1862
|
+
gathered_id_list_per_worker = []
|
1863
|
+
holding_tokens_list = []
|
1864
|
+
for tensor in gathered_list:
|
1865
|
+
holding_tokens_list.append(tensor[0].item())
|
1866
|
+
list_length = tensor[1].item()
|
1867
|
+
gathered_id_list_per_worker.append(
|
1868
|
+
tensor[2 : list_length + 2].tolist()
|
1869
|
+
)
|
1870
|
+
|
1871
|
+
return gathered_id_list_per_worker, holding_tokens_list
|
1872
|
+
|
1873
|
+
def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
|
1874
|
+
meta = self.balance_meta
|
1875
|
+
|
1876
|
+
with meta.mutex:
|
1877
|
+
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
|
1878
|
+
assert len(new_recv_rid_lists) == len(
|
1879
|
+
onfly_list
|
1880
|
+
), "num_worker not equal"
|
1881
|
+
# 1.Check if the rid received by each worker this round is present in onfly.
|
1882
|
+
# If it is, remove the corresponding onfly item.
|
1883
|
+
worker_id = 0
|
1884
|
+
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
|
1885
|
+
for new_recv_rid in new_recv_rids:
|
1886
|
+
assert (
|
1887
|
+
new_recv_rid in on_fly_reqs
|
1888
|
+
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
|
1889
|
+
del on_fly_reqs[new_recv_rid]
|
1890
|
+
worker_id += 1
|
1891
|
+
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
1892
|
+
meta.set_shared_onfly_info(onfly_list)
|
1893
|
+
meta.set_shared_local_tokens(local_tokens)
|
1894
|
+
|
1895
|
+
holding_tokens = self.get_load()
|
1896
|
+
|
1897
|
+
new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
|
1898
|
+
holding_tokens
|
1899
|
+
)
|
1900
|
+
|
1901
|
+
self.recv_dp_balance_id_this_term.clear()
|
1902
|
+
if self.tp_rank == 0: # only first worker write info
|
1903
|
+
write_shared_dp_balance_info(
|
1904
|
+
new_recv_dp_balance_id_list, holding_token_list
|
1905
|
+
)
|
1906
|
+
|
1771
1907
|
@staticmethod
|
1772
1908
|
def prepare_mlp_sync_batch_raw(
|
1773
1909
|
local_batch: ScheduleBatch,
|
@@ -2344,11 +2480,19 @@ class IdleSleeper:
|
|
2344
2480
|
|
2345
2481
|
def __init__(self, sockets):
|
2346
2482
|
self.poller = zmq.Poller()
|
2483
|
+
self.last_empty_time = time.time()
|
2347
2484
|
for s in sockets:
|
2348
2485
|
self.poller.register(s, zmq.POLLIN)
|
2349
2486
|
|
2350
2487
|
def maybe_sleep(self):
|
2351
2488
|
self.poller.poll(1000)
|
2489
|
+
if (
|
2490
|
+
global_config.torch_empty_cache_interval > 0
|
2491
|
+
and time.time() - self.last_empty_time
|
2492
|
+
> global_config.torch_empty_cache_interval
|
2493
|
+
):
|
2494
|
+
self.last_empty_time = time.time()
|
2495
|
+
torch.cuda.empty_cache()
|
2352
2496
|
|
2353
2497
|
|
2354
2498
|
def is_health_check_generate_req(recv_req):
|
@@ -2368,6 +2512,7 @@ def run_scheduler_process(
|
|
2368
2512
|
pp_rank: int,
|
2369
2513
|
dp_rank: Optional[int],
|
2370
2514
|
pipe_writer,
|
2515
|
+
balance_meta: Optional[DPBalanceMeta] = None,
|
2371
2516
|
):
|
2372
2517
|
# Generate the prefix
|
2373
2518
|
prefix = ""
|
@@ -2401,7 +2546,14 @@ def run_scheduler_process(
|
|
2401
2546
|
# Create a scheduler and run the event loop
|
2402
2547
|
try:
|
2403
2548
|
scheduler = Scheduler(
|
2404
|
-
server_args,
|
2549
|
+
server_args,
|
2550
|
+
port_args,
|
2551
|
+
gpu_id,
|
2552
|
+
tp_rank,
|
2553
|
+
moe_ep_rank,
|
2554
|
+
pp_rank,
|
2555
|
+
dp_rank,
|
2556
|
+
dp_balance_meta=balance_meta,
|
2405
2557
|
)
|
2406
2558
|
pipe_writer.send(
|
2407
2559
|
{
|
@@ -571,8 +571,7 @@ class SchedulerOutputProcessorMixin:
|
|
571
571
|
|
572
572
|
req.send_decode_id_offset = len(decode_ids)
|
573
573
|
read_offsets.append(read_offset)
|
574
|
-
|
575
|
-
output_ids.append(req.output_ids[send_token_offset:])
|
574
|
+
output_ids.append(req.output_ids[send_token_offset:])
|
576
575
|
req.send_token_offset = len(req.output_ids)
|
577
576
|
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
578
577
|
spaces_between_special_tokens.append(
|
@@ -0,0 +1,37 @@
|
|
1
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
2
|
+
from sglang.srt.server_args import ServerArgs
|
3
|
+
|
4
|
+
|
5
|
+
class SchedulerRecvSkipper:
|
6
|
+
@staticmethod
|
7
|
+
def maybe_create(server_args: ServerArgs):
|
8
|
+
if server_args.scheduler_recv_interval <= 1:
|
9
|
+
return None
|
10
|
+
return SchedulerRecvSkipper(server_args)
|
11
|
+
|
12
|
+
def __init__(self, server_args: ServerArgs):
|
13
|
+
# Can be supported if needed, but may need e.g. `global_forward_mode`
|
14
|
+
assert not server_args.enable_dp_attention
|
15
|
+
self._counter = 0
|
16
|
+
self._threshold = server_args.scheduler_recv_interval
|
17
|
+
|
18
|
+
def handle(self, last_forward_mode: ForwardMode):
|
19
|
+
should_recv = False
|
20
|
+
|
21
|
+
last_weight = _WEIGHT_OF_FORWARD_MODE.get(last_forward_mode, _DEFAULT_WEIGHT)
|
22
|
+
self._counter += last_weight
|
23
|
+
|
24
|
+
if self._counter >= self._threshold:
|
25
|
+
self._counter = 0
|
26
|
+
should_recv = True
|
27
|
+
|
28
|
+
return should_recv
|
29
|
+
|
30
|
+
|
31
|
+
# All can be tuned if needed
|
32
|
+
_DEFAULT_WEIGHT = 1000
|
33
|
+
_WEIGHT_OF_FORWARD_MODE = {
|
34
|
+
ForwardMode.DECODE: 1,
|
35
|
+
ForwardMode.TARGET_VERIFY: 1,
|
36
|
+
None: 1,
|
37
|
+
}
|
@@ -78,6 +78,9 @@ class SchedulerUpdateWeightsMixin:
|
|
78
78
|
if tags is None or len(tags) == 0:
|
79
79
|
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
80
80
|
|
81
|
+
for tag in tags:
|
82
|
+
self.offload_tags.add(tag)
|
83
|
+
|
81
84
|
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
82
85
|
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
|
83
86
|
self.flush_cache()
|
@@ -97,6 +100,9 @@ class SchedulerUpdateWeightsMixin:
|
|
97
100
|
if tags is None or len(tags) == 0:
|
98
101
|
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
99
102
|
|
103
|
+
for tag in tags:
|
104
|
+
self.offload_tags.remove(tag)
|
105
|
+
|
100
106
|
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
101
107
|
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
102
108
|
torch.distributed.barrier(self.tp_cpu_group)
|
@@ -21,6 +21,7 @@ and code completion templates, eliminating global state and improving modularity
|
|
21
21
|
import json
|
22
22
|
import logging
|
23
23
|
import os
|
24
|
+
import re
|
24
25
|
from typing import Optional
|
25
26
|
|
26
27
|
from sglang.srt.code_completion_parser import (
|
@@ -54,6 +55,7 @@ class TemplateManager:
|
|
54
55
|
self._chat_template_name: Optional[str] = None
|
55
56
|
self._completion_template_name: Optional[str] = None
|
56
57
|
self._jinja_template_content_format: Optional[str] = "openai"
|
58
|
+
self._force_reasoning: bool = False
|
57
59
|
|
58
60
|
@property
|
59
61
|
def chat_template_name(self) -> Optional[str]:
|
@@ -70,6 +72,31 @@ class TemplateManager:
|
|
70
72
|
"""Get the detected template content format ('string' or 'openai' or None)."""
|
71
73
|
return self._jinja_template_content_format
|
72
74
|
|
75
|
+
@property
|
76
|
+
def force_reasoning(self) -> bool:
|
77
|
+
"""
|
78
|
+
Check if the current chat template enforces reasoning/thinking.
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
True if the template contains reasoning patterns like <think> tags
|
82
|
+
"""
|
83
|
+
return self._force_reasoning
|
84
|
+
|
85
|
+
def _detect_reasoning_pattern(self, template: str) -> bool:
|
86
|
+
"""
|
87
|
+
Detect if the chat template contains reasoning/thinking patterns.
|
88
|
+
"""
|
89
|
+
if template is None:
|
90
|
+
return False
|
91
|
+
|
92
|
+
force_reasoning_pattern = r"<\|im_start\|>assistant\\n<think>\\n"
|
93
|
+
has_reasoning = re.search(force_reasoning_pattern, template) is not None
|
94
|
+
|
95
|
+
if has_reasoning:
|
96
|
+
logger.info("Detected the force reasoning pattern in chat template.")
|
97
|
+
|
98
|
+
return has_reasoning
|
99
|
+
|
73
100
|
def load_chat_template(
|
74
101
|
self, tokenizer_manager, chat_template_arg: Optional[str], model_path: str
|
75
102
|
) -> None:
|
@@ -84,26 +111,34 @@ class TemplateManager:
|
|
84
111
|
if chat_template_arg:
|
85
112
|
self._load_explicit_chat_template(tokenizer_manager, chat_template_arg)
|
86
113
|
else:
|
87
|
-
#
|
88
|
-
hf_template = self._resolve_hf_chat_template(tokenizer_manager)
|
89
|
-
if hf_template:
|
90
|
-
self._jinja_template_content_format = (
|
91
|
-
detect_jinja_template_content_format(hf_template)
|
92
|
-
)
|
93
|
-
logger.info(
|
94
|
-
f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
|
95
|
-
)
|
96
|
-
return
|
97
|
-
|
98
|
-
# Fallback to SGLang template guessing
|
114
|
+
# Guess chat template from model path
|
99
115
|
self.guess_chat_template_from_model_path(model_path)
|
100
116
|
|
101
|
-
#
|
117
|
+
# If no pre-defined template was found, fallback to HuggingFace template
|
102
118
|
if self._chat_template_name is None:
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
119
|
+
# Try HuggingFace template first
|
120
|
+
hf_template = self._resolve_hf_chat_template(tokenizer_manager)
|
121
|
+
if hf_template:
|
122
|
+
# override the chat template
|
123
|
+
if tokenizer_manager.tokenizer:
|
124
|
+
tokenizer_manager.tokenizer.chat_template = hf_template
|
125
|
+
self._jinja_template_content_format = (
|
126
|
+
detect_jinja_template_content_format(hf_template)
|
127
|
+
)
|
128
|
+
logger.info(
|
129
|
+
f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
|
130
|
+
)
|
131
|
+
return
|
132
|
+
|
133
|
+
# Default to string content format if no template was found
|
134
|
+
self._jinja_template_content_format = "string"
|
135
|
+
logger.info("No chat template found, defaulting to 'string' content format")
|
136
|
+
|
137
|
+
# Detect reasoning pattern from chat template
|
138
|
+
if tokenizer_manager.tokenizer:
|
139
|
+
self._force_reasoning = self._detect_reasoning_pattern(
|
140
|
+
tokenizer_manager.tokenizer.chat_template
|
141
|
+
)
|
107
142
|
|
108
143
|
def _load_explicit_chat_template(
|
109
144
|
self, tokenizer_manager, chat_template_arg: str
|
@@ -257,13 +292,15 @@ class TemplateManager:
|
|
257
292
|
|
258
293
|
Returns the chat template string if found, None otherwise.
|
259
294
|
"""
|
260
|
-
tokenizer = tokenizer_manager.tokenizer
|
261
|
-
|
262
|
-
# Try to get AutoTokenizer chat template
|
263
295
|
try:
|
264
|
-
|
296
|
+
if processor := tokenizer_manager.processor:
|
297
|
+
if hasattr(processor, "chat_template") and processor.chat_template:
|
298
|
+
return processor.chat_template
|
299
|
+
if tokenizer := tokenizer_manager.tokenizer:
|
300
|
+
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
|
301
|
+
return tokenizer.chat_template
|
265
302
|
except Exception as e:
|
266
|
-
logger.debug(f"Error getting chat template
|
303
|
+
logger.debug(f"Error getting chat template: {e}")
|
267
304
|
|
268
305
|
logger.debug("No HuggingFace chat template found")
|
269
306
|
return None
|