sglang 0.4.10.post1__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/conversation.py +0 -112
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
- sglang/srt/disaggregation/prefill.py +1 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +11 -0
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +35 -15
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/hf_transformers_utils.py +25 -10
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
- sglang/srt/layers/attention/vision.py +27 -10
- sglang/srt/layers/communicator.py +14 -4
- sglang/srt/layers/linear.py +7 -1
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/ep_moe/layer.py +11 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +26 -23
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/utils.py +43 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp8.py +5 -1
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +7 -1
- sglang/srt/lora/lora_registry.py +7 -0
- sglang/srt/managers/cache_controller.py +8 -4
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/io_struct.py +6 -1
- sglang/srt/managers/schedule_batch.py +3 -2
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +144 -6
- sglang/srt/managers/template_manager.py +25 -22
- sglang/srt/managers/tokenizer_manager.py +114 -62
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -21
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/model_executor/cuda_graph_runner.py +17 -3
- sglang/srt/model_executor/forward_batch_info.py +13 -3
- sglang/srt/model_executor/model_runner.py +5 -0
- sglang/srt/models/deepseek_v2.py +23 -17
- sglang/srt/models/glm4_moe.py +82 -19
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2_moe.py +1 -4
- sglang/srt/models/qwen3_moe.py +7 -8
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/server_args.py +80 -20
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +6 -4
- sglang/srt/utils.py +3 -24
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +80 -74
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -64,6 +64,7 @@ from sglang.srt.hf_transformers_utils import (
|
|
64
64
|
)
|
65
65
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
66
66
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
67
|
+
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
67
68
|
from sglang.srt.managers.io_struct import (
|
68
69
|
AbortReq,
|
69
70
|
CloseSessionReqInput,
|
@@ -125,7 +126,7 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
|
|
125
126
|
from sglang.srt.managers.session_controller import Session
|
126
127
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
127
128
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
128
|
-
from sglang.srt.managers.utils import validate_input_length
|
129
|
+
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
|
129
130
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
130
131
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
131
132
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
@@ -137,7 +138,6 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|
137
138
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
138
139
|
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
139
140
|
from sglang.srt.utils import (
|
140
|
-
DeepEPMode,
|
141
141
|
DynamicGradMode,
|
142
142
|
broadcast_pyobj,
|
143
143
|
configure_gc_logger,
|
@@ -203,6 +203,7 @@ class Scheduler(
|
|
203
203
|
moe_ep_rank: int,
|
204
204
|
pp_rank: int,
|
205
205
|
dp_rank: Optional[int],
|
206
|
+
dp_balance_meta: Optional[DPBalanceMeta] = None,
|
206
207
|
):
|
207
208
|
# Parse args
|
208
209
|
self.server_args = server_args
|
@@ -522,6 +523,15 @@ class Scheduler(
|
|
522
523
|
]
|
523
524
|
)
|
524
525
|
|
526
|
+
self.balance_meta = dp_balance_meta
|
527
|
+
if (
|
528
|
+
server_args.enable_dp_attention
|
529
|
+
and server_args.load_balance_method == "minimum_tokens"
|
530
|
+
):
|
531
|
+
assert dp_balance_meta is not None
|
532
|
+
|
533
|
+
self.recv_dp_balance_id_this_term = []
|
534
|
+
|
525
535
|
def init_tokenizer(self):
|
526
536
|
server_args = self.server_args
|
527
537
|
|
@@ -569,7 +579,23 @@ class Scheduler(
|
|
569
579
|
page_size=self.page_size,
|
570
580
|
)
|
571
581
|
else:
|
572
|
-
if
|
582
|
+
if os.environ.get("SGLANG_EXPERIMENTAL_CPP_RADIX_TREE") == "1":
|
583
|
+
# lazy import to avoid JIT overhead
|
584
|
+
from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp
|
585
|
+
|
586
|
+
self.tree_cache = RadixCacheCpp(
|
587
|
+
disable=False,
|
588
|
+
use_hicache=self.enable_hierarchical_cache,
|
589
|
+
req_to_token_pool=self.req_to_token_pool,
|
590
|
+
token_to_kv_pool=self.token_to_kv_pool_allocator,
|
591
|
+
tp_cache_group=self.tp_cpu_group,
|
592
|
+
page_size=self.page_size,
|
593
|
+
hicache_ratio=server_args.hicache_ratio,
|
594
|
+
hicache_size=server_args.hicache_size,
|
595
|
+
hicache_write_policy=server_args.hicache_write_policy,
|
596
|
+
enable_kv_cache_events=self.enable_kv_cache_events,
|
597
|
+
)
|
598
|
+
elif self.enable_hierarchical_cache:
|
573
599
|
self.tree_cache = HiRadixCache(
|
574
600
|
req_to_token_pool=self.req_to_token_pool,
|
575
601
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
@@ -1033,6 +1059,12 @@ class Scheduler(
|
|
1033
1059
|
self,
|
1034
1060
|
recv_req: TokenizedGenerateReqInput,
|
1035
1061
|
):
|
1062
|
+
if (
|
1063
|
+
self.server_args.enable_dp_attention
|
1064
|
+
and self.server_args.load_balance_method == "minimum_tokens"
|
1065
|
+
):
|
1066
|
+
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
|
1067
|
+
|
1036
1068
|
# Create a new request
|
1037
1069
|
if (
|
1038
1070
|
recv_req.session_params is None
|
@@ -1443,6 +1475,11 @@ class Scheduler(
|
|
1443
1475
|
|
1444
1476
|
# Handle DP attention
|
1445
1477
|
if need_dp_attn_preparation:
|
1478
|
+
if (
|
1479
|
+
self.server_args.load_balance_method == "minimum_tokens"
|
1480
|
+
and self.forward_ct % 40 == 0
|
1481
|
+
):
|
1482
|
+
self.handle_dp_balance_data(ret)
|
1446
1483
|
ret = self.prepare_mlp_sync_batch(ret)
|
1447
1484
|
|
1448
1485
|
return ret
|
@@ -1744,6 +1781,9 @@ class Scheduler(
|
|
1744
1781
|
elif batch.forward_mode.is_dummy_first():
|
1745
1782
|
self.set_next_batch_sampling_info_done(batch)
|
1746
1783
|
|
1784
|
+
self.maybe_send_health_check_signal()
|
1785
|
+
|
1786
|
+
def maybe_send_health_check_signal(self):
|
1747
1787
|
if self.return_health_check_ct:
|
1748
1788
|
# Return some signal for the health check.
|
1749
1789
|
# This is used to prevent the health check signal being blocked by long context prefill.
|
@@ -1762,12 +1802,94 @@ class Scheduler(
|
|
1762
1802
|
spec_algorithm=self.spec_algorithm,
|
1763
1803
|
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
1764
1804
|
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
|
1765
|
-
enable_deepep_moe=
|
1766
|
-
|
1805
|
+
enable_deepep_moe=MoeA2ABackend(
|
1806
|
+
self.server_args.moe_a2a_backend
|
1807
|
+
).is_deepep(),
|
1808
|
+
deepep_mode=DeepEPMode(self.server_args.deepep_mode),
|
1767
1809
|
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
1768
1810
|
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
1769
1811
|
)
|
1770
1812
|
|
1813
|
+
def handle_dp_balance_data(self, local_batch: ScheduleBatch):
|
1814
|
+
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
|
1815
|
+
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
|
1816
|
+
recv_list = self.recv_dp_balance_id_this_term
|
1817
|
+
assert len(recv_list) <= 511, (
|
1818
|
+
"The number of requests received this round is too large. "
|
1819
|
+
"Please increase gather_tensor_size and onfly_info_size."
|
1820
|
+
)
|
1821
|
+
# The maximum size of the tensor used for gathering data from all workers.
|
1822
|
+
gather_tensor_size = 512
|
1823
|
+
|
1824
|
+
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
|
1825
|
+
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
|
1826
|
+
recv_tensor[0] = holding_tokens_list
|
1827
|
+
recv_tensor[1] = len(
|
1828
|
+
recv_list
|
1829
|
+
) # The first element is the length of the list.
|
1830
|
+
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
|
1831
|
+
recv_list, dtype=torch.int32
|
1832
|
+
)
|
1833
|
+
|
1834
|
+
if self.tp_rank == 0:
|
1835
|
+
gathered_list = [
|
1836
|
+
torch.zeros(gather_tensor_size, dtype=torch.int32)
|
1837
|
+
for _ in range(self.balance_meta.num_workers)
|
1838
|
+
]
|
1839
|
+
else:
|
1840
|
+
gathered_list = None
|
1841
|
+
|
1842
|
+
torch.distributed.gather(
|
1843
|
+
recv_tensor, gathered_list, group=self.tp_cpu_group
|
1844
|
+
)
|
1845
|
+
|
1846
|
+
gathered_id_list_per_worker = None
|
1847
|
+
if self.tp_rank == 0:
|
1848
|
+
gathered_id_list_per_worker = []
|
1849
|
+
holding_tokens_list = []
|
1850
|
+
for tensor in gathered_list:
|
1851
|
+
holding_tokens_list.append(tensor[0].item())
|
1852
|
+
list_length = tensor[1].item()
|
1853
|
+
gathered_id_list_per_worker.append(
|
1854
|
+
tensor[2 : list_length + 2].tolist()
|
1855
|
+
)
|
1856
|
+
|
1857
|
+
return gathered_id_list_per_worker, holding_tokens_list
|
1858
|
+
|
1859
|
+
def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
|
1860
|
+
meta = self.balance_meta
|
1861
|
+
|
1862
|
+
with meta.mutex:
|
1863
|
+
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
|
1864
|
+
assert len(new_recv_rid_lists) == len(
|
1865
|
+
onfly_list
|
1866
|
+
), "num_worker not equal"
|
1867
|
+
# 1.Check if the rid received by each worker this round is present in onfly.
|
1868
|
+
# If it is, remove the corresponding onfly item.
|
1869
|
+
worker_id = 0
|
1870
|
+
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
|
1871
|
+
for new_recv_rid in new_recv_rids:
|
1872
|
+
assert (
|
1873
|
+
new_recv_rid in on_fly_reqs
|
1874
|
+
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
|
1875
|
+
del on_fly_reqs[new_recv_rid]
|
1876
|
+
worker_id += 1
|
1877
|
+
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
1878
|
+
meta.set_shared_onfly_info(onfly_list)
|
1879
|
+
meta.set_shared_local_tokens(local_tokens)
|
1880
|
+
|
1881
|
+
holding_tokens = self.get_load()
|
1882
|
+
|
1883
|
+
new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
|
1884
|
+
holding_tokens
|
1885
|
+
)
|
1886
|
+
|
1887
|
+
self.recv_dp_balance_id_this_term.clear()
|
1888
|
+
if self.tp_rank == 0: # only first worker write info
|
1889
|
+
write_shared_dp_balance_info(
|
1890
|
+
new_recv_dp_balance_id_list, holding_token_list
|
1891
|
+
)
|
1892
|
+
|
1771
1893
|
@staticmethod
|
1772
1894
|
def prepare_mlp_sync_batch_raw(
|
1773
1895
|
local_batch: ScheduleBatch,
|
@@ -2344,11 +2466,19 @@ class IdleSleeper:
|
|
2344
2466
|
|
2345
2467
|
def __init__(self, sockets):
|
2346
2468
|
self.poller = zmq.Poller()
|
2469
|
+
self.last_empty_time = time.time()
|
2347
2470
|
for s in sockets:
|
2348
2471
|
self.poller.register(s, zmq.POLLIN)
|
2349
2472
|
|
2350
2473
|
def maybe_sleep(self):
|
2351
2474
|
self.poller.poll(1000)
|
2475
|
+
if (
|
2476
|
+
global_config.torch_empty_cache_interval > 0
|
2477
|
+
and time.time() - self.last_empty_time
|
2478
|
+
> global_config.torch_empty_cache_interval
|
2479
|
+
):
|
2480
|
+
self.last_empty_time = time.time()
|
2481
|
+
torch.cuda.empty_cache()
|
2352
2482
|
|
2353
2483
|
|
2354
2484
|
def is_health_check_generate_req(recv_req):
|
@@ -2368,6 +2498,7 @@ def run_scheduler_process(
|
|
2368
2498
|
pp_rank: int,
|
2369
2499
|
dp_rank: Optional[int],
|
2370
2500
|
pipe_writer,
|
2501
|
+
balance_meta: Optional[DPBalanceMeta] = None,
|
2371
2502
|
):
|
2372
2503
|
# Generate the prefix
|
2373
2504
|
prefix = ""
|
@@ -2401,7 +2532,14 @@ def run_scheduler_process(
|
|
2401
2532
|
# Create a scheduler and run the event loop
|
2402
2533
|
try:
|
2403
2534
|
scheduler = Scheduler(
|
2404
|
-
server_args,
|
2535
|
+
server_args,
|
2536
|
+
port_args,
|
2537
|
+
gpu_id,
|
2538
|
+
tp_rank,
|
2539
|
+
moe_ep_rank,
|
2540
|
+
pp_rank,
|
2541
|
+
dp_rank,
|
2542
|
+
dp_balance_meta=balance_meta,
|
2405
2543
|
)
|
2406
2544
|
pipe_writer.send(
|
2407
2545
|
{
|
@@ -84,26 +84,27 @@ class TemplateManager:
|
|
84
84
|
if chat_template_arg:
|
85
85
|
self._load_explicit_chat_template(tokenizer_manager, chat_template_arg)
|
86
86
|
else:
|
87
|
-
#
|
88
|
-
hf_template = self._resolve_hf_chat_template(tokenizer_manager)
|
89
|
-
if hf_template:
|
90
|
-
self._jinja_template_content_format = (
|
91
|
-
detect_jinja_template_content_format(hf_template)
|
92
|
-
)
|
93
|
-
logger.info(
|
94
|
-
f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
|
95
|
-
)
|
96
|
-
return
|
97
|
-
|
98
|
-
# Fallback to SGLang template guessing
|
87
|
+
# Guess chat template from model path
|
99
88
|
self.guess_chat_template_from_model_path(model_path)
|
100
89
|
|
101
|
-
#
|
90
|
+
# If no pre-defined template was found, fallback to HuggingFace template
|
102
91
|
if self._chat_template_name is None:
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
92
|
+
# Try HuggingFace template first
|
93
|
+
hf_template = self._resolve_hf_chat_template(tokenizer_manager)
|
94
|
+
if hf_template:
|
95
|
+
# override the chat template
|
96
|
+
tokenizer_manager.tokenizer.chat_template = hf_template
|
97
|
+
self._jinja_template_content_format = (
|
98
|
+
detect_jinja_template_content_format(hf_template)
|
99
|
+
)
|
100
|
+
logger.info(
|
101
|
+
f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
|
102
|
+
)
|
103
|
+
return
|
104
|
+
|
105
|
+
# Default to string content format if no template was found
|
106
|
+
self._jinja_template_content_format = "string"
|
107
|
+
logger.info("No chat template found, defaulting to 'string' content format")
|
107
108
|
|
108
109
|
def _load_explicit_chat_template(
|
109
110
|
self, tokenizer_manager, chat_template_arg: str
|
@@ -257,13 +258,15 @@ class TemplateManager:
|
|
257
258
|
|
258
259
|
Returns the chat template string if found, None otherwise.
|
259
260
|
"""
|
260
|
-
tokenizer = tokenizer_manager.tokenizer
|
261
|
-
|
262
|
-
# Try to get AutoTokenizer chat template
|
263
261
|
try:
|
264
|
-
|
262
|
+
if processor := tokenizer_manager.processor:
|
263
|
+
if hasattr(processor, "chat_template") and processor.chat_template:
|
264
|
+
return processor.chat_template
|
265
|
+
if tokenizer := tokenizer_manager.tokenizer:
|
266
|
+
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
|
267
|
+
return tokenizer.chat_template
|
265
268
|
except Exception as e:
|
266
|
-
logger.debug(f"Error getting chat template
|
269
|
+
logger.debug(f"Error getting chat template: {e}")
|
267
270
|
|
268
271
|
logger.debug("No HuggingFace chat template found")
|
269
272
|
return None
|
@@ -29,6 +29,7 @@ import uuid
|
|
29
29
|
from collections import deque
|
30
30
|
from contextlib import nullcontext
|
31
31
|
from datetime import datetime
|
32
|
+
from enum import Enum
|
32
33
|
from http import HTTPStatus
|
33
34
|
from typing import (
|
34
35
|
Any,
|
@@ -70,7 +71,6 @@ from sglang.srt.managers.io_struct import (
|
|
70
71
|
BatchMultimodalOut,
|
71
72
|
BatchStrOut,
|
72
73
|
BatchTokenIDOut,
|
73
|
-
BlockReqType,
|
74
74
|
CloseSessionReqInput,
|
75
75
|
ConfigureLoggingReq,
|
76
76
|
EmbeddingReqInput,
|
@@ -116,6 +116,7 @@ from sglang.srt.managers.io_struct import (
|
|
116
116
|
)
|
117
117
|
from sglang.srt.managers.mm_utils import TensorTransportMode
|
118
118
|
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
119
|
+
from sglang.srt.managers.scheduler import is_health_check_generate_req
|
119
120
|
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
|
120
121
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
121
122
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -202,13 +203,29 @@ class TokenizerManager:
|
|
202
203
|
|
203
204
|
if self.model_config.is_multimodal:
|
204
205
|
import_processors()
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
206
|
+
try:
|
207
|
+
_processor = get_processor(
|
208
|
+
server_args.tokenizer_path,
|
209
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
210
|
+
trust_remote_code=server_args.trust_remote_code,
|
211
|
+
revision=server_args.revision,
|
212
|
+
use_fast=not server_args.disable_fast_image_processor,
|
213
|
+
)
|
214
|
+
except ValueError as e:
|
215
|
+
error_message = str(e)
|
216
|
+
if "does not have a slow version" in error_message:
|
217
|
+
logger.info(
|
218
|
+
f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version"
|
219
|
+
)
|
220
|
+
_processor = get_processor(
|
221
|
+
server_args.tokenizer_path,
|
222
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
223
|
+
trust_remote_code=server_args.trust_remote_code,
|
224
|
+
revision=server_args.revision,
|
225
|
+
use_fast=True,
|
226
|
+
)
|
227
|
+
else:
|
228
|
+
raise e
|
212
229
|
transport_mode = _determine_tensor_transport_mode(self.server_args)
|
213
230
|
|
214
231
|
# We want to parallelize the image pre-processing so we create an executor for it
|
@@ -225,10 +242,10 @@ class TokenizerManager:
|
|
225
242
|
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
226
243
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
227
244
|
else:
|
228
|
-
self.mm_processor = None
|
245
|
+
self.mm_processor = self.processor = None
|
229
246
|
|
230
247
|
if server_args.skip_tokenizer_init:
|
231
|
-
self.tokenizer =
|
248
|
+
self.tokenizer = None
|
232
249
|
else:
|
233
250
|
self.tokenizer = get_tokenizer(
|
234
251
|
server_args.tokenizer_path,
|
@@ -255,6 +272,7 @@ class TokenizerManager:
|
|
255
272
|
self.health_check_failed = False
|
256
273
|
self.gracefully_exit = False
|
257
274
|
self.last_receive_tstamp = 0
|
275
|
+
self.server_status = ServerStatus.Starting
|
258
276
|
|
259
277
|
# Dumping
|
260
278
|
self.dump_requests_folder = "" # By default do not dump
|
@@ -1069,38 +1087,56 @@ class TokenizerManager:
|
|
1069
1087
|
_: Optional[fastapi.Request] = None,
|
1070
1088
|
) -> LoadLoRAAdapterReqOutput:
|
1071
1089
|
self.auto_create_handle_loop()
|
1072
|
-
if not self.server_args.enable_lora:
|
1073
|
-
raise ValueError(
|
1074
|
-
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1075
|
-
)
|
1076
1090
|
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
logger.info(
|
1083
|
-
"Start load Lora adapter. Lora name=%s, path=%s",
|
1084
|
-
obj.lora_name,
|
1085
|
-
obj.lora_path,
|
1086
|
-
)
|
1091
|
+
try:
|
1092
|
+
if not self.server_args.enable_lora:
|
1093
|
+
raise ValueError(
|
1094
|
+
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1095
|
+
)
|
1087
1096
|
|
1088
|
-
|
1089
|
-
#
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1097
|
+
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1098
|
+
# with dp_size > 1.
|
1099
|
+
assert (
|
1100
|
+
self.server_args.dp_size == 1
|
1101
|
+
), "dp_size must be 1 for dynamic lora loading"
|
1102
|
+
logger.info(
|
1103
|
+
"Start load Lora adapter. Lora name=%s, path=%s",
|
1104
|
+
obj.lora_name,
|
1105
|
+
obj.lora_path,
|
1093
1106
|
)
|
1094
1107
|
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1108
|
+
async with self.lora_update_lock:
|
1109
|
+
if (
|
1110
|
+
self.server_args.max_loaded_loras is not None
|
1111
|
+
and self.lora_registry.num_registered_loras
|
1112
|
+
>= self.server_args.max_loaded_loras
|
1113
|
+
):
|
1114
|
+
raise ValueError(
|
1115
|
+
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
|
1116
|
+
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
|
1117
|
+
"Please unload some LoRA adapters before loading new ones."
|
1118
|
+
)
|
1098
1119
|
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1120
|
+
# Generate new uniquely identifiable LoRARef object.
|
1121
|
+
new_adapter = LoRARef(
|
1122
|
+
lora_name=obj.lora_name,
|
1123
|
+
lora_path=obj.lora_path,
|
1124
|
+
)
|
1102
1125
|
|
1103
|
-
|
1126
|
+
# Trigger the actual loading operation at the backend processes.
|
1127
|
+
obj.lora_id = new_adapter.lora_id
|
1128
|
+
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1129
|
+
|
1130
|
+
# Register the LoRA adapter only after loading is successful.
|
1131
|
+
if result.success:
|
1132
|
+
await self.lora_registry.register(new_adapter)
|
1133
|
+
|
1134
|
+
return result
|
1135
|
+
except ValueError as e:
|
1136
|
+
return LoadLoRAAdapterReqOutput(
|
1137
|
+
success=False,
|
1138
|
+
error_message=str(e),
|
1139
|
+
)
|
1104
1140
|
|
1105
1141
|
async def unload_lora_adapter(
|
1106
1142
|
self,
|
@@ -1108,37 +1144,41 @@ class TokenizerManager:
|
|
1108
1144
|
_: Optional[fastapi.Request] = None,
|
1109
1145
|
) -> UnloadLoRAAdapterReqOutput:
|
1110
1146
|
self.auto_create_handle_loop()
|
1111
|
-
if not self.server_args.enable_lora:
|
1112
|
-
raise ValueError(
|
1113
|
-
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1114
|
-
)
|
1115
1147
|
|
1116
|
-
|
1117
|
-
|
1118
|
-
|
1148
|
+
try:
|
1149
|
+
if not self.server_args.enable_lora:
|
1150
|
+
raise ValueError(
|
1151
|
+
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1152
|
+
)
|
1119
1153
|
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1154
|
+
assert (
|
1155
|
+
obj.lora_name is not None
|
1156
|
+
), "lora_name must be provided to unload LoRA adapter"
|
1157
|
+
|
1158
|
+
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1159
|
+
# with dp_size > 1.
|
1160
|
+
assert (
|
1161
|
+
self.server_args.dp_size == 1
|
1162
|
+
), "dp_size must be 1 for dynamic lora loading"
|
1163
|
+
logger.info(
|
1164
|
+
"Start unload Lora adapter. Lora name=%s",
|
1165
|
+
obj.lora_name,
|
1166
|
+
)
|
1129
1167
|
|
1130
|
-
|
1131
|
-
|
1132
|
-
|
1133
|
-
|
1134
|
-
|
1168
|
+
async with self.lora_update_lock:
|
1169
|
+
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
1170
|
+
# from being started.
|
1171
|
+
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
1172
|
+
obj.lora_id = lora_id
|
1135
1173
|
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1174
|
+
# Initiate the actual unloading operation at the backend processes only after all
|
1175
|
+
# ongoing requests using this LoRA adapter are finished.
|
1176
|
+
await self.lora_registry.wait_for_unload(lora_id)
|
1177
|
+
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1140
1178
|
|
1141
|
-
|
1179
|
+
return result
|
1180
|
+
except ValueError as e:
|
1181
|
+
return UnloadLoRAAdapterReqOutput(success=False, rror_message=str(e))
|
1142
1182
|
|
1143
1183
|
async def get_weights_by_name(
|
1144
1184
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
@@ -1767,6 +1807,8 @@ class TokenizerManager:
|
|
1767
1807
|
asyncio.create_task(asyncio.to_thread(background_task))
|
1768
1808
|
|
1769
1809
|
def _handle_abort_req(self, recv_obj):
|
1810
|
+
if is_health_check_generate_req(recv_obj):
|
1811
|
+
return
|
1770
1812
|
state = self.rid_to_state[recv_obj.rid]
|
1771
1813
|
state.finished = True
|
1772
1814
|
if recv_obj.finished_reason:
|
@@ -1901,6 +1943,16 @@ class TokenizerManager:
|
|
1901
1943
|
return scores
|
1902
1944
|
|
1903
1945
|
|
1946
|
+
class ServerStatus(Enum):
|
1947
|
+
Up = "Up"
|
1948
|
+
Starting = "Starting"
|
1949
|
+
UnHealthy = "UnHealthy"
|
1950
|
+
Crashed = "Crashed"
|
1951
|
+
|
1952
|
+
def is_healthy(self) -> bool:
|
1953
|
+
return self == ServerStatus.Up
|
1954
|
+
|
1955
|
+
|
1904
1956
|
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
1905
1957
|
is_cross_node = server_args.dist_init_addr
|
1906
1958
|
|
sglang/srt/managers/utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import logging
|
2
|
+
import multiprocessing as mp
|
2
3
|
from http import HTTPStatus
|
3
|
-
from typing import Optional
|
4
|
+
from typing import Dict, List, Optional
|
4
5
|
|
5
6
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
6
7
|
|
@@ -38,3 +39,46 @@ def validate_input_length(
|
|
38
39
|
return error_msg
|
39
40
|
|
40
41
|
return None
|
42
|
+
|
43
|
+
|
44
|
+
class DPBalanceMeta:
|
45
|
+
"""
|
46
|
+
This class will be use in scheduler and dp controller
|
47
|
+
"""
|
48
|
+
|
49
|
+
def __init__(self, num_workers: int):
|
50
|
+
self.num_workers = num_workers
|
51
|
+
self._manager = mp.Manager()
|
52
|
+
self.mutex = self._manager.Lock()
|
53
|
+
|
54
|
+
init_local_tokens = [0] * self.num_workers
|
55
|
+
init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
|
56
|
+
|
57
|
+
self.shared_state = self._manager.Namespace()
|
58
|
+
self.shared_state.local_tokens = self._manager.list(init_local_tokens)
|
59
|
+
self.shared_state.onfly_info = self._manager.list(init_onfly_info)
|
60
|
+
|
61
|
+
def destructor(self):
|
62
|
+
# we must destructor this class manually
|
63
|
+
self._manager.shutdown()
|
64
|
+
|
65
|
+
def get_shared_onfly(self) -> List[Dict[int, int]]:
|
66
|
+
return [dict(d) for d in self.shared_state.onfly_info]
|
67
|
+
|
68
|
+
def set_shared_onfly_info(self, data: List[Dict[int, int]]):
|
69
|
+
self.shared_state.onfly_info = data
|
70
|
+
|
71
|
+
def get_shared_local_tokens(self) -> List[int]:
|
72
|
+
return list(self.shared_state.local_tokens)
|
73
|
+
|
74
|
+
def set_shared_local_tokens(self, data: List[int]):
|
75
|
+
self.shared_state.local_tokens = data
|
76
|
+
|
77
|
+
def __getstate__(self):
|
78
|
+
state = self.__dict__.copy()
|
79
|
+
del state["_manager"]
|
80
|
+
return state
|
81
|
+
|
82
|
+
def __setstate__(self, state):
|
83
|
+
self.__dict__.update(state)
|
84
|
+
self._manager = None
|