sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +13 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +12 -16
- sglang/srt/disaggregation/prefill.py +17 -13
- sglang/srt/disaggregation/utils.py +46 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +22 -28
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +21 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +19 -9
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +207 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +91 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/io_struct.py +9 -12
- sglang/srt/managers/schedule_batch.py +40 -31
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +147 -62
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +76 -45
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +22 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +108 -26
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +36 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/utils.py +177 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -23,7 +23,6 @@ import time
|
|
23
23
|
from collections import defaultdict, deque
|
24
24
|
from concurrent import futures
|
25
25
|
from dataclasses import dataclass
|
26
|
-
from http import HTTPStatus
|
27
26
|
from pathlib import Path
|
28
27
|
from types import SimpleNamespace
|
29
28
|
from typing import Dict, List, Optional, Tuple, Union
|
@@ -36,6 +35,7 @@ from torch.distributed import barrier
|
|
36
35
|
|
37
36
|
from sglang.global_config import global_config
|
38
37
|
from sglang.srt.configs.model_config import ModelConfig
|
38
|
+
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
39
39
|
from sglang.srt.constrained.base_grammar_backend import (
|
40
40
|
INVALID_GRAMMAR_OBJ,
|
41
41
|
create_grammar_backend,
|
@@ -140,6 +140,7 @@ from sglang.srt.utils import (
|
|
140
140
|
DeepEPMode,
|
141
141
|
DynamicGradMode,
|
142
142
|
broadcast_pyobj,
|
143
|
+
configure_gc_logger,
|
143
144
|
configure_logger,
|
144
145
|
disable_request_logging,
|
145
146
|
get_available_gpu_memory,
|
@@ -148,6 +149,8 @@ from sglang.srt.utils import (
|
|
148
149
|
kill_itself_when_parent_died,
|
149
150
|
point_to_point_pyobj,
|
150
151
|
pyspy_dump_schedulers,
|
152
|
+
require_mlp_sync,
|
153
|
+
require_mlp_tp_gather,
|
151
154
|
set_gpu_proc_affinity,
|
152
155
|
set_random_seed,
|
153
156
|
suppress_other_loggers,
|
@@ -450,8 +453,6 @@ class Scheduler(
|
|
450
453
|
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
451
454
|
t.start()
|
452
455
|
self.parent_process = psutil.Process().parent()
|
453
|
-
|
454
|
-
# Init memory saver
|
455
456
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
456
457
|
enable=server_args.enable_memory_saver
|
457
458
|
)
|
@@ -508,6 +509,9 @@ class Scheduler(
|
|
508
509
|
)
|
509
510
|
self.init_disaggregation()
|
510
511
|
|
512
|
+
if get_bool_env_var("SGLANG_GC_LOG"):
|
513
|
+
configure_gc_logger()
|
514
|
+
|
511
515
|
def maybe_sleep_on_idle(self):
|
512
516
|
if self.idle_sleeper is not None:
|
513
517
|
self.idle_sleeper.maybe_sleep()
|
@@ -559,12 +563,20 @@ class Scheduler(
|
|
559
563
|
self.tree_cache = HiRadixCache(
|
560
564
|
req_to_token_pool=self.req_to_token_pool,
|
561
565
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
562
|
-
tp_cache_group=
|
566
|
+
tp_cache_group=(
|
567
|
+
self.attn_tp_cpu_group
|
568
|
+
if self.server_args.enable_dp_attention
|
569
|
+
else self.tp_cpu_group
|
570
|
+
),
|
563
571
|
page_size=self.page_size,
|
564
572
|
hicache_ratio=server_args.hicache_ratio,
|
565
573
|
hicache_size=server_args.hicache_size,
|
566
574
|
hicache_write_policy=server_args.hicache_write_policy,
|
567
575
|
)
|
576
|
+
self.tp_worker.register_hicache_layer_transfer_counter(
|
577
|
+
self.tree_cache.cache_controller.layer_done_counter
|
578
|
+
)
|
579
|
+
|
568
580
|
else:
|
569
581
|
self.tree_cache = RadixCache(
|
570
582
|
req_to_token_pool=self.req_to_token_pool,
|
@@ -622,7 +634,12 @@ class Scheduler(
|
|
622
634
|
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
623
635
|
buffer_size
|
624
636
|
)
|
625
|
-
self.disagg_metadata_buffers = MetadataBuffers(
|
637
|
+
self.disagg_metadata_buffers = MetadataBuffers(
|
638
|
+
buffer_size,
|
639
|
+
hidden_size=self.model_config.hf_text_config.hidden_size,
|
640
|
+
dtype=self.model_config.dtype,
|
641
|
+
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
642
|
+
)
|
626
643
|
|
627
644
|
# The decode requests polling kv cache
|
628
645
|
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
@@ -669,7 +686,12 @@ class Scheduler(
|
|
669
686
|
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
670
687
|
buffer_size
|
671
688
|
)
|
672
|
-
self.disagg_metadata_buffers = MetadataBuffers(
|
689
|
+
self.disagg_metadata_buffers = MetadataBuffers(
|
690
|
+
buffer_size,
|
691
|
+
hidden_size=self.model_config.hf_text_config.hidden_size,
|
692
|
+
dtype=self.model_config.dtype,
|
693
|
+
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
694
|
+
)
|
673
695
|
|
674
696
|
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
|
675
697
|
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
@@ -795,11 +817,28 @@ class Scheduler(
|
|
795
817
|
result.next_token_ids,
|
796
818
|
result.bid,
|
797
819
|
)
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
820
|
+
if self.cur_batch.return_logprob:
|
821
|
+
pp_outputs = PPProxyTensors(
|
822
|
+
{
|
823
|
+
"next_token_ids": next_token_ids,
|
824
|
+
"extend_input_len_per_req": result.extend_input_len_per_req,
|
825
|
+
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
|
826
|
+
}
|
827
|
+
| (
|
828
|
+
{
|
829
|
+
f"logits_output.{k}": v
|
830
|
+
for k, v in result.logits_output.__dict__.items()
|
831
|
+
}
|
832
|
+
if result.logits_output is not None
|
833
|
+
else {}
|
834
|
+
)
|
835
|
+
)
|
836
|
+
else:
|
837
|
+
pp_outputs = PPProxyTensors(
|
838
|
+
{
|
839
|
+
"next_token_ids": next_token_ids,
|
840
|
+
}
|
841
|
+
)
|
803
842
|
# send the output from the last round to let the next stage worker run post processing
|
804
843
|
self.pp_group.send_tensor_dict(
|
805
844
|
pp_outputs.tensors,
|
@@ -816,12 +855,25 @@ class Scheduler(
|
|
816
855
|
)
|
817
856
|
)
|
818
857
|
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
858
|
+
logits_output_args = {
|
859
|
+
k[len("logits_output.") :]: v
|
860
|
+
for k, v in next_pp_outputs.tensors.items()
|
861
|
+
if k.startswith("logits_output.")
|
862
|
+
}
|
863
|
+
if len(logits_output_args) > 0:
|
864
|
+
logits_output = LogitsProcessorOutput(**logits_output_args)
|
865
|
+
else:
|
866
|
+
logits_output = None
|
819
867
|
output_result = GenerationBatchResult(
|
820
|
-
logits_output=
|
868
|
+
logits_output=logits_output,
|
821
869
|
pp_hidden_states_proxy_tensors=None,
|
822
870
|
next_token_ids=next_pp_outputs["next_token_ids"],
|
823
|
-
extend_input_len_per_req=
|
824
|
-
|
871
|
+
extend_input_len_per_req=next_pp_outputs.tensors.get(
|
872
|
+
"extend_input_len_per_req", None
|
873
|
+
),
|
874
|
+
extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
|
875
|
+
"extend_logprob_start_len_per_req", None
|
876
|
+
),
|
825
877
|
bid=bids[next_mb_id],
|
826
878
|
can_run_cuda_graph=result.can_run_cuda_graph,
|
827
879
|
)
|
@@ -1322,7 +1374,14 @@ class Scheduler(
|
|
1322
1374
|
)
|
1323
1375
|
raise ValueError(msg)
|
1324
1376
|
|
1325
|
-
if
|
1377
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
1378
|
+
req_total_size = (
|
1379
|
+
self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
|
1380
|
+
)
|
1381
|
+
else:
|
1382
|
+
req_total_size = self.req_to_token_pool.size
|
1383
|
+
|
1384
|
+
if len(self.req_to_token_pool.free_slots) != req_total_size:
|
1326
1385
|
msg = (
|
1327
1386
|
"req_to_token_pool memory leak detected!"
|
1328
1387
|
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
@@ -1383,6 +1442,15 @@ class Scheduler(
|
|
1383
1442
|
self.running_batch.merge_batch(self.last_batch)
|
1384
1443
|
|
1385
1444
|
new_batch = self.get_new_batch_prefill()
|
1445
|
+
|
1446
|
+
need_dp_attn_preparation = require_mlp_sync(self.server_args)
|
1447
|
+
|
1448
|
+
if need_dp_attn_preparation and not self.spec_algorithm.is_none():
|
1449
|
+
# In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
|
1450
|
+
# We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
|
1451
|
+
new_batch, _ = self.prepare_mlp_sync_batch(new_batch)
|
1452
|
+
need_dp_attn_preparation = new_batch is None
|
1453
|
+
|
1386
1454
|
if new_batch is not None:
|
1387
1455
|
# Run prefill first if possible
|
1388
1456
|
ret = new_batch
|
@@ -1395,8 +1463,8 @@ class Scheduler(
|
|
1395
1463
|
ret = None
|
1396
1464
|
|
1397
1465
|
# Handle DP attention
|
1398
|
-
if
|
1399
|
-
ret, _ = self.
|
1466
|
+
if need_dp_attn_preparation:
|
1467
|
+
ret, _ = self.prepare_mlp_sync_batch(ret)
|
1400
1468
|
|
1401
1469
|
return ret
|
1402
1470
|
|
@@ -1428,15 +1496,14 @@ class Scheduler(
|
|
1428
1496
|
return None
|
1429
1497
|
|
1430
1498
|
if self.enable_hierarchical_cache:
|
1431
|
-
|
1432
|
-
self.tree_cache.writing_check()
|
1433
|
-
self.tree_cache.loading_check()
|
1499
|
+
self.tree_cache.check_hicache_events()
|
1434
1500
|
|
1435
1501
|
# Get priority queue
|
1436
|
-
|
1502
|
+
self.policy.calc_priority(self.waiting_queue)
|
1437
1503
|
|
1438
1504
|
# Prefill policy
|
1439
1505
|
adder = PrefillAdder(
|
1506
|
+
self.page_size,
|
1440
1507
|
self.tree_cache,
|
1441
1508
|
self.token_to_kv_pool_allocator,
|
1442
1509
|
self.running_batch,
|
@@ -1478,14 +1545,8 @@ class Scheduler(
|
|
1478
1545
|
self.running_batch.batch_is_full = True
|
1479
1546
|
break
|
1480
1547
|
|
1481
|
-
req.init_next_round_input(
|
1482
|
-
|
1483
|
-
self.enable_hierarchical_cache,
|
1484
|
-
)
|
1485
|
-
|
1486
|
-
res = adder.add_one_req(
|
1487
|
-
req, self.chunked_req, self.enable_hierarchical_cache
|
1488
|
-
)
|
1548
|
+
req.init_next_round_input(self.tree_cache)
|
1549
|
+
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
|
1489
1550
|
|
1490
1551
|
if res != AddReqResult.CONTINUE:
|
1491
1552
|
if res == AddReqResult.NO_TOKEN:
|
@@ -1512,9 +1573,6 @@ class Scheduler(
|
|
1512
1573
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
1513
1574
|
]
|
1514
1575
|
|
1515
|
-
if self.enable_hierarchical_cache:
|
1516
|
-
self.tree_cache.ready_to_load_cache()
|
1517
|
-
|
1518
1576
|
if adder.new_chunked_req is not None:
|
1519
1577
|
assert self.chunked_req is None
|
1520
1578
|
self.chunked_req = adder.new_chunked_req
|
@@ -1538,6 +1596,12 @@ class Scheduler(
|
|
1538
1596
|
self.server_args.enable_custom_logit_processor,
|
1539
1597
|
chunked_req=self.chunked_req,
|
1540
1598
|
)
|
1599
|
+
if self.enable_hierarchical_cache:
|
1600
|
+
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered
|
1601
|
+
new_batch.hicache_consumer_index = (
|
1602
|
+
self.tree_cache.ready_to_load_host_cache()
|
1603
|
+
)
|
1604
|
+
|
1541
1605
|
new_batch.prepare_for_extend()
|
1542
1606
|
|
1543
1607
|
# Mixed-style chunked prefill
|
@@ -1613,6 +1677,11 @@ class Scheduler(
|
|
1613
1677
|
if self.is_generation:
|
1614
1678
|
if self.spec_algorithm.is_none():
|
1615
1679
|
model_worker_batch = batch.get_model_worker_batch()
|
1680
|
+
|
1681
|
+
# update the consumer index of hicache to the running batch
|
1682
|
+
self.tp_worker.set_hicache_consumer(
|
1683
|
+
model_worker_batch.hicache_consumer_index
|
1684
|
+
)
|
1616
1685
|
if self.pp_group.is_last_rank:
|
1617
1686
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
1618
1687
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
@@ -1641,13 +1710,15 @@ class Scheduler(
|
|
1641
1710
|
# These 2 values are needed for processing the output, but the values can be
|
1642
1711
|
# modified by overlap schedule. So we have to copy them here so that
|
1643
1712
|
# we can use the correct values in output processing.
|
1644
|
-
if batch.return_logprob:
|
1713
|
+
if batch.return_logprob or self.spec_algorithm.is_eagle():
|
1645
1714
|
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
|
1715
|
+
else:
|
1716
|
+
extend_input_len_per_req = None
|
1717
|
+
if batch.return_logprob:
|
1646
1718
|
extend_logprob_start_len_per_req = [
|
1647
1719
|
req.extend_logprob_start_len for req in batch.reqs
|
1648
1720
|
]
|
1649
1721
|
else:
|
1650
|
-
extend_input_len_per_req = None
|
1651
1722
|
extend_logprob_start_len_per_req = None
|
1652
1723
|
|
1653
1724
|
ret = GenerationBatchResult(
|
@@ -1695,12 +1766,11 @@ class Scheduler(
|
|
1695
1766
|
self.return_health_check_ct -= 1
|
1696
1767
|
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
|
1697
1768
|
|
1698
|
-
def
|
1699
|
-
return self.
|
1769
|
+
def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
|
1770
|
+
return self.prepare_mlp_sync_batch_raw(
|
1700
1771
|
local_batch,
|
1701
1772
|
dp_size=self.server_args.dp_size,
|
1702
1773
|
attn_tp_size=self.attn_tp_size,
|
1703
|
-
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
|
1704
1774
|
tp_cpu_group=self.tp_cpu_group,
|
1705
1775
|
get_idle_batch=self.get_idle_batch,
|
1706
1776
|
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
@@ -1709,14 +1779,14 @@ class Scheduler(
|
|
1709
1779
|
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
|
1710
1780
|
enable_deepep_moe=self.server_args.enable_deepep_moe,
|
1711
1781
|
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
|
1782
|
+
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
1712
1783
|
)
|
1713
1784
|
|
1714
1785
|
@staticmethod
|
1715
|
-
def
|
1786
|
+
def prepare_mlp_sync_batch_raw(
|
1716
1787
|
local_batch: ScheduleBatch,
|
1717
1788
|
dp_size,
|
1718
1789
|
attn_tp_size: int,
|
1719
|
-
moe_dense_tp_size: Optional[int],
|
1720
1790
|
tp_cpu_group,
|
1721
1791
|
get_idle_batch,
|
1722
1792
|
disable_cuda_graph: bool,
|
@@ -1725,6 +1795,7 @@ class Scheduler(
|
|
1725
1795
|
enable_two_batch_overlap: bool,
|
1726
1796
|
enable_deepep_moe: bool,
|
1727
1797
|
deepep_mode: DeepEPMode,
|
1798
|
+
require_mlp_tp_gather: bool,
|
1728
1799
|
):
|
1729
1800
|
# Check if other DP workers have running batches
|
1730
1801
|
if local_batch is None:
|
@@ -1732,8 +1803,6 @@ class Scheduler(
|
|
1732
1803
|
num_tokens_for_logprob = 0
|
1733
1804
|
elif local_batch.forward_mode.is_decode():
|
1734
1805
|
num_tokens = local_batch.batch_size()
|
1735
|
-
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
|
1736
|
-
num_tokens = num_tokens * speculative_num_draft_tokens
|
1737
1806
|
num_tokens_for_logprob = num_tokens
|
1738
1807
|
else:
|
1739
1808
|
num_tokens = local_batch.extend_num_tokens
|
@@ -1752,11 +1821,6 @@ class Scheduler(
|
|
1752
1821
|
else:
|
1753
1822
|
can_cuda_graph = 0
|
1754
1823
|
|
1755
|
-
if not spec_algorithm.is_none():
|
1756
|
-
# TODO(sang): Support cuda graph when idle batch is there.
|
1757
|
-
if local_batch is None or local_batch.forward_mode.is_idle():
|
1758
|
-
can_cuda_graph = 0
|
1759
|
-
|
1760
1824
|
is_extend_in_batch = (
|
1761
1825
|
local_batch.forward_mode.is_extend() if local_batch else False
|
1762
1826
|
)
|
@@ -1801,7 +1865,7 @@ class Scheduler(
|
|
1801
1865
|
|
1802
1866
|
if local_batch is not None:
|
1803
1867
|
# TODO: handle the case when moe_dense_tp_size != 1
|
1804
|
-
if
|
1868
|
+
if not require_mlp_tp_gather:
|
1805
1869
|
local_batch.global_num_tokens = [num_tokens]
|
1806
1870
|
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
|
1807
1871
|
else:
|
@@ -1809,6 +1873,7 @@ class Scheduler(
|
|
1809
1873
|
local_batch.global_num_tokens_for_logprob = (
|
1810
1874
|
global_num_tokens_for_logprob
|
1811
1875
|
)
|
1876
|
+
local_batch.is_extend_in_batch = any(is_extend_in_batch)
|
1812
1877
|
local_batch.tbo_split_seq_index = tbo_split_seq_index
|
1813
1878
|
local_batch.global_forward_mode = global_forward_mode
|
1814
1879
|
|
@@ -1816,6 +1881,7 @@ class Scheduler(
|
|
1816
1881
|
if not disable_cuda_graph:
|
1817
1882
|
local_batch.can_run_dp_cuda_graph = can_cuda_graph
|
1818
1883
|
|
1884
|
+
# TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
|
1819
1885
|
return local_batch, any(is_extend_in_batch)
|
1820
1886
|
|
1821
1887
|
def get_idle_batch(self):
|
@@ -2176,23 +2242,40 @@ class Scheduler(
|
|
2176
2242
|
return GetWeightsByNameReqOutput(parameter)
|
2177
2243
|
|
2178
2244
|
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
2179
|
-
|
2180
|
-
|
2181
|
-
|
2182
|
-
|
2183
|
-
|
2184
|
-
|
2185
|
-
|
2186
|
-
|
2245
|
+
tags = recv_req.tags
|
2246
|
+
import subprocess
|
2247
|
+
|
2248
|
+
if tags is None:
|
2249
|
+
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
2250
|
+
|
2251
|
+
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
2252
|
+
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
|
2253
|
+
self.flush_cache()
|
2254
|
+
|
2255
|
+
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
2256
|
+
self.stashed_model_static_state = _export_static_state(
|
2257
|
+
self.tp_worker.worker.model_runner.model
|
2258
|
+
)
|
2259
|
+
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
|
2260
|
+
|
2187
2261
|
return ReleaseMemoryOccupationReqOutput()
|
2188
2262
|
|
2189
2263
|
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
2190
|
-
|
2191
|
-
|
2192
|
-
|
2193
|
-
|
2194
|
-
|
2195
|
-
|
2264
|
+
tags = recv_req.tags
|
2265
|
+
if tags is None or len(tags) == 0:
|
2266
|
+
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
2267
|
+
|
2268
|
+
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
2269
|
+
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
2270
|
+
_import_static_state(
|
2271
|
+
self.tp_worker.worker.model_runner.model,
|
2272
|
+
self.stashed_model_static_state,
|
2273
|
+
)
|
2274
|
+
del self.stashed_model_static_state
|
2275
|
+
|
2276
|
+
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
2277
|
+
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
|
2278
|
+
|
2196
2279
|
return ResumeMemoryOccupationReqOutput()
|
2197
2280
|
|
2198
2281
|
def slow_down(self, recv_req: SlowDownReqInput):
|
@@ -2421,8 +2504,10 @@ class Scheduler(
|
|
2421
2504
|
if self.profiler_decode_ct > self.profiler_target_decode_ct:
|
2422
2505
|
if self.profile_in_progress:
|
2423
2506
|
self.stop_profile(stage=ForwardMode.DECODE)
|
2507
|
+
elif batch.forward_mode.is_idle():
|
2508
|
+
pass
|
2424
2509
|
else:
|
2425
|
-
raise RuntimeError("unsupported profile stage")
|
2510
|
+
raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
|
2426
2511
|
else:
|
2427
2512
|
# Check profiler
|
2428
2513
|
if (
|
@@ -0,0 +1,226 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""
|
15
|
+
Centralized template management for chat templates and completion templates.
|
16
|
+
|
17
|
+
This module provides a unified interface for managing both chat conversation templates
|
18
|
+
and code completion templates, eliminating global state and improving modularity.
|
19
|
+
"""
|
20
|
+
|
21
|
+
import json
|
22
|
+
import logging
|
23
|
+
import os
|
24
|
+
from typing import Optional
|
25
|
+
|
26
|
+
from sglang.srt.code_completion_parser import (
|
27
|
+
CompletionTemplate,
|
28
|
+
FimPosition,
|
29
|
+
completion_template_exists,
|
30
|
+
register_completion_template,
|
31
|
+
)
|
32
|
+
from sglang.srt.conversation import (
|
33
|
+
Conversation,
|
34
|
+
SeparatorStyle,
|
35
|
+
chat_template_exists,
|
36
|
+
get_conv_template_by_model_path,
|
37
|
+
register_conv_template,
|
38
|
+
)
|
39
|
+
from sglang.srt.jinja_template_utils import detect_jinja_template_content_format
|
40
|
+
|
41
|
+
logger = logging.getLogger(__name__)
|
42
|
+
|
43
|
+
|
44
|
+
class TemplateManager:
|
45
|
+
"""
|
46
|
+
Centralized manager for chat and completion templates.
|
47
|
+
|
48
|
+
This class encapsulates all template-related state and operations,
|
49
|
+
eliminating the need for global variables and providing a clean
|
50
|
+
interface for template management.
|
51
|
+
"""
|
52
|
+
|
53
|
+
def __init__(self):
|
54
|
+
self._chat_template_name: Optional[str] = None
|
55
|
+
self._completion_template_name: Optional[str] = None
|
56
|
+
self._jinja_template_content_format: Optional[str] = None
|
57
|
+
|
58
|
+
@property
|
59
|
+
def chat_template_name(self) -> Optional[str]:
|
60
|
+
"""Get the current chat template name."""
|
61
|
+
return self._chat_template_name
|
62
|
+
|
63
|
+
@property
|
64
|
+
def completion_template_name(self) -> Optional[str]:
|
65
|
+
"""Get the current completion template name."""
|
66
|
+
return self._completion_template_name
|
67
|
+
|
68
|
+
@property
|
69
|
+
def jinja_template_content_format(self) -> Optional[str]:
|
70
|
+
"""Get the detected template content format ('string' or 'openai' or None)."""
|
71
|
+
return self._jinja_template_content_format
|
72
|
+
|
73
|
+
def load_chat_template(
|
74
|
+
self, tokenizer_manager, chat_template_arg: str, model_path: str
|
75
|
+
) -> None:
|
76
|
+
"""
|
77
|
+
Load a chat template from various sources.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
tokenizer_manager: The tokenizer manager instance
|
81
|
+
chat_template_arg: Template name or file path
|
82
|
+
model_path: Path to the model
|
83
|
+
"""
|
84
|
+
logger.info(f"Loading chat template: {chat_template_arg}")
|
85
|
+
|
86
|
+
if not chat_template_exists(chat_template_arg):
|
87
|
+
if not os.path.exists(chat_template_arg):
|
88
|
+
raise RuntimeError(
|
89
|
+
f"Chat template {chat_template_arg} is not a built-in template name "
|
90
|
+
"or a valid chat template file path."
|
91
|
+
)
|
92
|
+
|
93
|
+
if chat_template_arg.endswith(".jinja"):
|
94
|
+
self._load_jinja_template(tokenizer_manager, chat_template_arg)
|
95
|
+
else:
|
96
|
+
self._load_json_chat_template(chat_template_arg)
|
97
|
+
else:
|
98
|
+
self._chat_template_name = chat_template_arg
|
99
|
+
|
100
|
+
def guess_chat_template_from_model_path(self, model_path: str) -> None:
|
101
|
+
"""
|
102
|
+
Infer chat template name from model path.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
model_path: Path to the model
|
106
|
+
"""
|
107
|
+
template_name = get_conv_template_by_model_path(model_path)
|
108
|
+
if template_name is not None:
|
109
|
+
logger.info(f"Inferred chat template from model path: {template_name}")
|
110
|
+
self._chat_template_name = template_name
|
111
|
+
|
112
|
+
def load_completion_template(self, completion_template_arg: str) -> None:
|
113
|
+
"""
|
114
|
+
Load completion template for code completion.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
completion_template_arg: Template name or file path
|
118
|
+
"""
|
119
|
+
logger.info(f"Loading completion template: {completion_template_arg}")
|
120
|
+
|
121
|
+
if not completion_template_exists(completion_template_arg):
|
122
|
+
if not os.path.exists(completion_template_arg):
|
123
|
+
raise RuntimeError(
|
124
|
+
f"Completion template {completion_template_arg} is not a built-in template name "
|
125
|
+
"or a valid completion template file path."
|
126
|
+
)
|
127
|
+
|
128
|
+
self._load_json_completion_template(completion_template_arg)
|
129
|
+
else:
|
130
|
+
self._completion_template_name = completion_template_arg
|
131
|
+
|
132
|
+
def initialize_templates(
|
133
|
+
self,
|
134
|
+
tokenizer_manager,
|
135
|
+
model_path: str,
|
136
|
+
chat_template: Optional[str] = None,
|
137
|
+
completion_template: Optional[str] = None,
|
138
|
+
) -> None:
|
139
|
+
"""
|
140
|
+
Initialize all templates based on provided configuration.
|
141
|
+
|
142
|
+
Args:
|
143
|
+
tokenizer_manager: The tokenizer manager instance
|
144
|
+
model_path: Path to the model
|
145
|
+
chat_template: Optional chat template name/path
|
146
|
+
completion_template: Optional completion template name/path
|
147
|
+
"""
|
148
|
+
# Load chat template
|
149
|
+
if chat_template:
|
150
|
+
self.load_chat_template(tokenizer_manager, chat_template, model_path)
|
151
|
+
else:
|
152
|
+
self.guess_chat_template_from_model_path(model_path)
|
153
|
+
|
154
|
+
# Load completion template
|
155
|
+
if completion_template:
|
156
|
+
self.load_completion_template(completion_template)
|
157
|
+
|
158
|
+
def _load_jinja_template(self, tokenizer_manager, template_path: str) -> None:
|
159
|
+
"""Load a Jinja template file."""
|
160
|
+
with open(template_path, "r") as f:
|
161
|
+
chat_template = "".join(f.readlines()).strip("\n")
|
162
|
+
tokenizer_manager.tokenizer.chat_template = chat_template.replace("\\n", "\n")
|
163
|
+
self._chat_template_name = None
|
164
|
+
# Detect content format from the loaded template
|
165
|
+
self._jinja_template_content_format = detect_jinja_template_content_format(
|
166
|
+
chat_template
|
167
|
+
)
|
168
|
+
logger.info(
|
169
|
+
f"Detected chat template content format: {self._jinja_template_content_format}"
|
170
|
+
)
|
171
|
+
|
172
|
+
def _load_json_chat_template(self, template_path: str) -> None:
|
173
|
+
"""Load a JSON chat template file."""
|
174
|
+
assert template_path.endswith(
|
175
|
+
".json"
|
176
|
+
), "unrecognized format of chat template file"
|
177
|
+
|
178
|
+
with open(template_path, "r") as filep:
|
179
|
+
template = json.load(filep)
|
180
|
+
try:
|
181
|
+
sep_style = SeparatorStyle[template["sep_style"]]
|
182
|
+
except KeyError:
|
183
|
+
raise ValueError(
|
184
|
+
f"Unknown separator style: {template['sep_style']}"
|
185
|
+
) from None
|
186
|
+
|
187
|
+
register_conv_template(
|
188
|
+
Conversation(
|
189
|
+
name=template["name"],
|
190
|
+
system_template=template["system"] + "\n{system_message}",
|
191
|
+
system_message=template.get("system_message", ""),
|
192
|
+
roles=(template["user"], template["assistant"]),
|
193
|
+
sep_style=sep_style,
|
194
|
+
sep=template.get("sep", "\n"),
|
195
|
+
stop_str=template["stop_str"],
|
196
|
+
),
|
197
|
+
override=True,
|
198
|
+
)
|
199
|
+
self._chat_template_name = template["name"]
|
200
|
+
|
201
|
+
def _load_json_completion_template(self, template_path: str) -> None:
|
202
|
+
"""Load a JSON completion template file."""
|
203
|
+
assert template_path.endswith(
|
204
|
+
".json"
|
205
|
+
), "unrecognized format of completion template file"
|
206
|
+
|
207
|
+
with open(template_path, "r") as filep:
|
208
|
+
template = json.load(filep)
|
209
|
+
try:
|
210
|
+
fim_position = FimPosition[template["fim_position"]]
|
211
|
+
except KeyError:
|
212
|
+
raise ValueError(
|
213
|
+
f"Unknown fim position: {template['fim_position']}"
|
214
|
+
) from None
|
215
|
+
|
216
|
+
register_completion_template(
|
217
|
+
CompletionTemplate(
|
218
|
+
name=template["name"],
|
219
|
+
fim_begin_token=template["fim_begin_token"],
|
220
|
+
fim_middle_token=template["fim_middle_token"],
|
221
|
+
fim_end_token=template["fim_end_token"],
|
222
|
+
fim_position=fim_position,
|
223
|
+
),
|
224
|
+
override=True,
|
225
|
+
)
|
226
|
+
self._completion_template_name = template["name"]
|