sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +55 -2
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +4 -3
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/launch_server.py +3 -2
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +6 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/flashinfer_backend.py +13 -15
- sglang/srt/layers/attention/torch_native_backend.py +285 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +34 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -11
- sglang/srt/managers/detokenizer_manager.py +7 -4
- sglang/srt/managers/image_processor.py +1 -1
- sglang/srt/managers/io_struct.py +48 -12
- sglang/srt/managers/schedule_batch.py +42 -36
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +111 -46
- sglang/srt/managers/session_controller.py +0 -3
- sglang/srt/managers/tokenizer_manager.py +169 -100
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
- sglang/srt/model_executor/cuda_graph_runner.py +16 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +136 -150
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +2 -3
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +3 -11
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +14 -51
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +97 -27
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +10 -12
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +12 -5
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +391 -0
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -8
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -11
- sglang/srt/models/qwen2_vl.py +12 -9
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -12
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +10 -6
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/server.py +303 -204
- sglang/srt/server_args.py +65 -31
- sglang/srt/utils.py +253 -48
- sglang/test/test_utils.py +27 -7
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
- sglang-0.4.0.dist-info/RECORD +184 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
- sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
- sglang/srt/layers/fused_moe_grok/layer.py +0 -630
- sglang-0.3.6.post2.dist-info/RECORD +0 -164
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -45,20 +45,24 @@ from sglang.srt.managers.io_struct import (
|
|
45
45
|
EmbeddingReqInput,
|
46
46
|
FlushCacheReq,
|
47
47
|
GenerateReqInput,
|
48
|
-
|
49
|
-
|
48
|
+
GetWeightsByNameReqInput,
|
49
|
+
GetWeightsByNameReqOutput,
|
50
|
+
InitWeightsUpdateGroupReqInput,
|
51
|
+
InitWeightsUpdateGroupReqOutput,
|
50
52
|
OpenSessionReqInput,
|
51
53
|
OpenSessionReqOutput,
|
52
54
|
ProfileReq,
|
53
55
|
TokenizedEmbeddingReqInput,
|
54
56
|
TokenizedGenerateReqInput,
|
55
|
-
|
56
|
-
|
57
|
+
UpdateWeightFromDiskReqInput,
|
58
|
+
UpdateWeightFromDiskReqOutput,
|
59
|
+
UpdateWeightsFromDistributedReqInput,
|
60
|
+
UpdateWeightsFromDistributedReqOutput,
|
57
61
|
)
|
58
62
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
59
63
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
60
64
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
61
|
-
from sglang.srt.utils import get_zmq_socket,
|
65
|
+
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
62
66
|
|
63
67
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
64
68
|
|
@@ -105,9 +109,12 @@ class TokenizerManager:
|
|
105
109
|
self.model_config = ModelConfig(
|
106
110
|
server_args.model_path,
|
107
111
|
trust_remote_code=server_args.trust_remote_code,
|
112
|
+
revision=server_args.revision,
|
108
113
|
context_length=server_args.context_length,
|
109
114
|
model_override_args=server_args.json_model_override_args,
|
110
115
|
is_embedding=server_args.is_embedding,
|
116
|
+
dtype=server_args.dtype,
|
117
|
+
quantization=server_args.quantization,
|
111
118
|
)
|
112
119
|
|
113
120
|
self.is_generation = self.model_config.is_generation
|
@@ -218,7 +225,8 @@ class TokenizerManager:
|
|
218
225
|
input_ids = obj.input_ids
|
219
226
|
|
220
227
|
if self.is_generation:
|
221
|
-
|
228
|
+
# TODO: also support getting embeddings for multimodal models
|
229
|
+
image_inputs: Dict = await self.image_processor.process_images_async(
|
222
230
|
obj.image_data, input_text or input_ids, obj
|
223
231
|
)
|
224
232
|
if image_inputs and "input_ids" in image_inputs:
|
@@ -331,6 +339,12 @@ class TokenizerManager:
|
|
331
339
|
rids.append(tmp_obj.rid)
|
332
340
|
else:
|
333
341
|
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
342
|
+
if batch_size > 128:
|
343
|
+
logger.warning(
|
344
|
+
"Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
|
345
|
+
"The performance might be better if you just duplicate the requests n times or use "
|
346
|
+
"many threads to send them one by one with parallel sampling (n > 1)."
|
347
|
+
)
|
334
348
|
|
335
349
|
# Tokenize all requests
|
336
350
|
objs = [obj[i] for i in range(batch_size)]
|
@@ -406,27 +420,10 @@ class TokenizerManager:
|
|
406
420
|
req = ProfileReq.STOP_PROFILE
|
407
421
|
self.send_to_scheduler.send_pyobj(req)
|
408
422
|
|
409
|
-
async def
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
req = GetMemPoolSizeReq()
|
414
|
-
|
415
|
-
self.send_to_scheduler.send_pyobj(req)
|
416
|
-
self.mem_pool_size = asyncio.Future()
|
417
|
-
|
418
|
-
# FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
|
419
|
-
if self.server_args.dp_size == 1:
|
420
|
-
res = await self.mem_pool_size
|
421
|
-
return res.size
|
422
|
-
else: # self.server_args.dp_size > 1
|
423
|
-
self.mem_pool_size_tmp = []
|
424
|
-
res = await self.mem_pool_size
|
425
|
-
ret = [r.size for r in res]
|
426
|
-
return ret
|
427
|
-
|
428
|
-
async def update_weights(
|
429
|
-
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
423
|
+
async def update_weights_from_disk(
|
424
|
+
self,
|
425
|
+
obj: UpdateWeightFromDiskReqInput,
|
426
|
+
request: Optional[fastapi.Request] = None,
|
430
427
|
):
|
431
428
|
if self.to_create_loop:
|
432
429
|
self.create_handle_loop()
|
@@ -471,6 +468,63 @@ class TokenizerManager:
|
|
471
468
|
else:
|
472
469
|
return False, "Another update is in progress. Please try again later."
|
473
470
|
|
471
|
+
async def init_weights_update_group(
|
472
|
+
self,
|
473
|
+
obj: InitWeightsUpdateGroupReqInput,
|
474
|
+
request: Optional[fastapi.Request] = None,
|
475
|
+
) -> bool:
|
476
|
+
if self.to_create_loop:
|
477
|
+
self.create_handle_loop()
|
478
|
+
self.send_to_scheduler.send_pyobj(obj)
|
479
|
+
|
480
|
+
self.init_weights_update_group_result = asyncio.Future()
|
481
|
+
assert (
|
482
|
+
self.server_args.dp_size == 1
|
483
|
+
), "dp_size must be 1 for init parameter update group"
|
484
|
+
result = await self.init_weights_update_group_result
|
485
|
+
return result.success, result.message
|
486
|
+
|
487
|
+
async def update_weights_from_distributed(
|
488
|
+
self,
|
489
|
+
obj: UpdateWeightsFromDistributedReqInput,
|
490
|
+
request: Optional[fastapi.Request] = None,
|
491
|
+
):
|
492
|
+
if self.to_create_loop:
|
493
|
+
self.create_handle_loop()
|
494
|
+
|
495
|
+
if not self.model_update_lock.locked():
|
496
|
+
async with self.model_update_lock:
|
497
|
+
self.send_to_scheduler.send_pyobj(obj)
|
498
|
+
self.parameter_update_result = asyncio.Future()
|
499
|
+
assert (
|
500
|
+
self.server_args.dp_size == 1
|
501
|
+
), "dp_size must be for update weights from distributed"
|
502
|
+
result = await self.parameter_update_result
|
503
|
+
return result.success, result.message
|
504
|
+
else:
|
505
|
+
logger.error("Another parameter update is in progress in tokenizer manager")
|
506
|
+
return (
|
507
|
+
False,
|
508
|
+
"Another parameter update is in progress. Please try again later.",
|
509
|
+
)
|
510
|
+
|
511
|
+
async def get_weights_by_name(
|
512
|
+
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
513
|
+
):
|
514
|
+
if self.to_create_loop:
|
515
|
+
self.create_handle_loop()
|
516
|
+
|
517
|
+
self.send_to_scheduler.send_pyobj(obj)
|
518
|
+
self.get_weights_by_name_result = asyncio.Future()
|
519
|
+
if self.server_args.dp_size == 1:
|
520
|
+
result = await self.get_weights_by_name_result
|
521
|
+
return result.parameter
|
522
|
+
else:
|
523
|
+
self.get_weights_by_name_tmp = []
|
524
|
+
result = await self.get_weights_by_name_result
|
525
|
+
all_parameters = [r.parameter for r in result]
|
526
|
+
return all_parameters
|
527
|
+
|
474
528
|
async def open_session(
|
475
529
|
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
476
530
|
):
|
@@ -532,7 +586,7 @@ class TokenizerManager:
|
|
532
586
|
else:
|
533
587
|
break
|
534
588
|
|
535
|
-
|
589
|
+
kill_process_tree(os.getpid(), include_parent=True)
|
536
590
|
sys.exit(0)
|
537
591
|
|
538
592
|
async def handle_loop(self):
|
@@ -540,10 +594,77 @@ class TokenizerManager:
|
|
540
594
|
|
541
595
|
while True:
|
542
596
|
recv_obj: Union[
|
543
|
-
BatchStrOut,
|
597
|
+
BatchStrOut,
|
598
|
+
BatchEmbeddingOut,
|
599
|
+
BatchTokenIDOut,
|
600
|
+
UpdateWeightFromDiskReqOutput,
|
601
|
+
UpdateWeightsFromDistributedReqOutput,
|
602
|
+
GetWeightsByNameReqOutput,
|
603
|
+
InitWeightsUpdateGroupReqOutput,
|
544
604
|
] = await self.recv_from_detokenizer.recv_pyobj()
|
545
605
|
|
546
|
-
if isinstance(recv_obj,
|
606
|
+
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
|
607
|
+
for i, rid in enumerate(recv_obj.rids):
|
608
|
+
state = self.rid_to_state.get(rid, None)
|
609
|
+
if state is None:
|
610
|
+
continue
|
611
|
+
|
612
|
+
recv_obj.meta_info[i]["id"] = rid
|
613
|
+
if isinstance(recv_obj, BatchStrOut):
|
614
|
+
out_dict = {
|
615
|
+
"text": recv_obj.output_strs[i],
|
616
|
+
"meta_info": recv_obj.meta_info[i],
|
617
|
+
}
|
618
|
+
elif isinstance(recv_obj, BatchTokenIDOut):
|
619
|
+
out_dict = {
|
620
|
+
"token_ids": recv_obj.output_ids[i],
|
621
|
+
"meta_info": recv_obj.meta_info[i],
|
622
|
+
}
|
623
|
+
else:
|
624
|
+
assert isinstance(recv_obj, BatchEmbeddingOut)
|
625
|
+
out_dict = {
|
626
|
+
"embedding": recv_obj.embeddings[i],
|
627
|
+
"meta_info": recv_obj.meta_info[i],
|
628
|
+
}
|
629
|
+
state.out_list.append(out_dict)
|
630
|
+
state.finished = recv_obj.finished_reason[i] is not None
|
631
|
+
state.event.set()
|
632
|
+
|
633
|
+
if self.enable_metrics:
|
634
|
+
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
|
635
|
+
|
636
|
+
if state.first_token_time is None:
|
637
|
+
state.first_token_time = time.time()
|
638
|
+
self.metrics_collector.observe_time_to_first_token(
|
639
|
+
state.first_token_time - state.created_time
|
640
|
+
)
|
641
|
+
else:
|
642
|
+
if completion_tokens >= 2:
|
643
|
+
self.metrics_collector.observe_time_per_output_token(
|
644
|
+
(time.time() - state.first_token_time)
|
645
|
+
/ (completion_tokens - 1)
|
646
|
+
)
|
647
|
+
|
648
|
+
if state.finished:
|
649
|
+
self.metrics_collector.inc_prompt_tokens(
|
650
|
+
recv_obj.meta_info[i]["prompt_tokens"]
|
651
|
+
)
|
652
|
+
self.metrics_collector.inc_generation_tokens(
|
653
|
+
completion_tokens
|
654
|
+
)
|
655
|
+
self.metrics_collector.observe_e2e_request_latency(
|
656
|
+
time.time() - state.created_time
|
657
|
+
)
|
658
|
+
if completion_tokens >= 1:
|
659
|
+
self.metrics_collector.observe_time_per_output_token(
|
660
|
+
(time.time() - state.created_time)
|
661
|
+
/ completion_tokens
|
662
|
+
)
|
663
|
+
elif isinstance(recv_obj, OpenSessionReqOutput):
|
664
|
+
self.session_futures[recv_obj.session_id].set_result(
|
665
|
+
recv_obj.session_id
|
666
|
+
)
|
667
|
+
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
547
668
|
if self.server_args.dp_size == 1:
|
548
669
|
self.model_update_result.set_result(recv_obj)
|
549
670
|
else: # self.server_args.dp_size > 1
|
@@ -551,79 +672,27 @@ class TokenizerManager:
|
|
551
672
|
# set future if the all results are recevied
|
552
673
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
553
674
|
self.model_update_result.set_result(self.model_update_tmp)
|
554
|
-
|
555
|
-
|
675
|
+
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
|
676
|
+
assert (
|
677
|
+
self.server_args.dp_size == 1
|
678
|
+
), "dp_size must be 1 for init parameter update group"
|
679
|
+
self.init_weights_update_group_result.set_result(recv_obj)
|
680
|
+
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
|
681
|
+
assert (
|
682
|
+
self.server_args.dp_size == 1
|
683
|
+
), "dp_size must be 1 for update weights from distributed"
|
684
|
+
self.parameter_update_result.set_result(recv_obj)
|
685
|
+
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
556
686
|
if self.server_args.dp_size == 1:
|
557
|
-
self.
|
558
|
-
else: # self.sever_args.dp_size > 1
|
559
|
-
self.mem_pool_size_tmp.append(recv_obj)
|
560
|
-
# set future if the all results are received
|
561
|
-
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
|
562
|
-
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
|
563
|
-
continue
|
564
|
-
elif isinstance(recv_obj, OpenSessionReqOutput):
|
565
|
-
self.session_futures[recv_obj.session_id].set_result(
|
566
|
-
recv_obj.session_id
|
567
|
-
)
|
568
|
-
continue
|
569
|
-
|
570
|
-
assert isinstance(
|
571
|
-
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
572
|
-
), f"Unexpected obj received: {type(recv_obj)}"
|
573
|
-
|
574
|
-
for i, rid in enumerate(recv_obj.rids):
|
575
|
-
state = self.rid_to_state.get(rid, None)
|
576
|
-
if state is None:
|
577
|
-
continue
|
578
|
-
|
579
|
-
recv_obj.meta_info[i]["id"] = rid
|
580
|
-
if isinstance(recv_obj, BatchStrOut):
|
581
|
-
out_dict = {
|
582
|
-
"text": recv_obj.output_strs[i],
|
583
|
-
"meta_info": recv_obj.meta_info[i],
|
584
|
-
}
|
585
|
-
elif isinstance(recv_obj, BatchTokenIDOut):
|
586
|
-
out_dict = {
|
587
|
-
"token_ids": recv_obj.output_ids[i],
|
588
|
-
"meta_info": recv_obj.meta_info[i],
|
589
|
-
}
|
687
|
+
self.get_weights_by_name_result.set_result(recv_obj)
|
590
688
|
else:
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
}
|
596
|
-
state.out_list.append(out_dict)
|
597
|
-
state.finished = recv_obj.finished_reason[i] is not None
|
598
|
-
state.event.set()
|
599
|
-
|
600
|
-
if self.enable_metrics:
|
601
|
-
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
|
602
|
-
|
603
|
-
if state.first_token_time is None:
|
604
|
-
state.first_token_time = time.time()
|
605
|
-
self.metrics_collector.observe_time_to_first_token(
|
606
|
-
state.first_token_time - state.created_time
|
689
|
+
self.get_weights_by_name_tmp.append(recv_obj)
|
690
|
+
if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
|
691
|
+
self.get_weights_by_name_result.set_result(
|
692
|
+
self.get_weights_by_name_tmp
|
607
693
|
)
|
608
|
-
|
609
|
-
|
610
|
-
self.metrics_collector.observe_time_per_output_token(
|
611
|
-
(time.time() - state.first_token_time)
|
612
|
-
/ (completion_tokens - 1)
|
613
|
-
)
|
614
|
-
|
615
|
-
if state.finished:
|
616
|
-
self.metrics_collector.inc_prompt_tokens(
|
617
|
-
recv_obj.meta_info[i]["prompt_tokens"]
|
618
|
-
)
|
619
|
-
self.metrics_collector.inc_generation_tokens(completion_tokens)
|
620
|
-
self.metrics_collector.observe_e2e_request_latency(
|
621
|
-
time.time() - state.created_time
|
622
|
-
)
|
623
|
-
if completion_tokens >= 1:
|
624
|
-
self.metrics_collector.observe_time_per_output_token(
|
625
|
-
(time.time() - state.created_time) / completion_tokens
|
626
|
-
)
|
694
|
+
else:
|
695
|
+
raise ValueError(f"Invalid object: {recv_obj=}")
|
627
696
|
|
628
697
|
def convert_logprob_style(
|
629
698
|
self,
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -19,7 +19,12 @@ from typing import Optional
|
|
19
19
|
|
20
20
|
from sglang.srt.configs.model_config import ModelConfig
|
21
21
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
22
|
-
from sglang.srt.managers.io_struct import
|
22
|
+
from sglang.srt.managers.io_struct import (
|
23
|
+
GetWeightsByNameReqInput,
|
24
|
+
InitWeightsUpdateGroupReqInput,
|
25
|
+
UpdateWeightFromDiskReqInput,
|
26
|
+
UpdateWeightsFromDistributedReqInput,
|
27
|
+
)
|
23
28
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
24
29
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
25
30
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -47,9 +52,12 @@ class TpModelWorker:
|
|
47
52
|
self.model_config = ModelConfig(
|
48
53
|
server_args.model_path,
|
49
54
|
trust_remote_code=server_args.trust_remote_code,
|
55
|
+
revision=server_args.revision,
|
50
56
|
context_length=server_args.context_length,
|
51
57
|
model_override_args=server_args.json_model_override_args,
|
52
58
|
is_embedding=server_args.is_embedding,
|
59
|
+
dtype=server_args.dtype,
|
60
|
+
quantization=server_args.quantization,
|
53
61
|
)
|
54
62
|
self.model_runner = ModelRunner(
|
55
63
|
model_config=self.model_config,
|
@@ -155,8 +163,33 @@ class TpModelWorker:
|
|
155
163
|
embeddings = logits_output.embeddings
|
156
164
|
return embeddings
|
157
165
|
|
158
|
-
def
|
159
|
-
success, message = self.model_runner.
|
166
|
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
167
|
+
success, message = self.model_runner.update_weights_from_disk(
|
160
168
|
recv_req.model_path, recv_req.load_format
|
161
169
|
)
|
162
170
|
return success, message
|
171
|
+
|
172
|
+
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
173
|
+
success, message = self.model_runner.init_weights_update_group(
|
174
|
+
recv_req.master_address,
|
175
|
+
recv_req.master_port,
|
176
|
+
recv_req.rank_offset,
|
177
|
+
recv_req.world_size,
|
178
|
+
recv_req.group_name,
|
179
|
+
recv_req.backend,
|
180
|
+
)
|
181
|
+
return success, message
|
182
|
+
|
183
|
+
def update_weights_from_distributed(
|
184
|
+
self, recv_req: UpdateWeightsFromDistributedReqInput
|
185
|
+
):
|
186
|
+
success, message = self.model_runner.update_weights_from_distributed(
|
187
|
+
recv_req.name, recv_req.dtype, recv_req.shape
|
188
|
+
)
|
189
|
+
return success, message
|
190
|
+
|
191
|
+
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
192
|
+
parameter = self.model_runner.get_weights_by_name(
|
193
|
+
recv_req.name, recv_req.truncate_size
|
194
|
+
)
|
195
|
+
return parameter
|
@@ -15,16 +15,24 @@
|
|
15
15
|
|
16
16
|
import dataclasses
|
17
17
|
import logging
|
18
|
+
import signal
|
18
19
|
import threading
|
19
20
|
from queue import Queue
|
20
21
|
from typing import Optional
|
21
22
|
|
23
|
+
import psutil
|
22
24
|
import torch
|
23
25
|
|
24
|
-
from sglang.srt.managers.io_struct import
|
26
|
+
from sglang.srt.managers.io_struct import (
|
27
|
+
GetWeightsByNameReqInput,
|
28
|
+
InitWeightsUpdateGroupReqInput,
|
29
|
+
UpdateWeightFromDiskReqInput,
|
30
|
+
UpdateWeightsFromDistributedReqInput,
|
31
|
+
)
|
25
32
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
26
33
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
27
34
|
from sglang.srt.server_args import ServerArgs
|
35
|
+
from sglang.utils import get_exception_traceback
|
28
36
|
|
29
37
|
logger = logging.getLogger(__name__)
|
30
38
|
|
@@ -70,6 +78,7 @@ class TpModelWorkerClient:
|
|
70
78
|
target=self.forward_thread_func,
|
71
79
|
)
|
72
80
|
self.forward_thread.start()
|
81
|
+
self.parent_process = psutil.Process().parent()
|
73
82
|
|
74
83
|
def get_worker_info(self):
|
75
84
|
return self.worker.get_worker_info()
|
@@ -87,8 +96,13 @@ class TpModelWorkerClient:
|
|
87
96
|
)
|
88
97
|
|
89
98
|
def forward_thread_func(self):
|
90
|
-
|
91
|
-
self.
|
99
|
+
try:
|
100
|
+
with torch.cuda.stream(self.forward_stream):
|
101
|
+
self.forward_thread_func_()
|
102
|
+
except Exception:
|
103
|
+
traceback = get_exception_traceback()
|
104
|
+
logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
|
105
|
+
self.parent_process.send_signal(signal.SIGQUIT)
|
92
106
|
|
93
107
|
@torch.no_grad()
|
94
108
|
def forward_thread_func_(self):
|
@@ -195,10 +209,23 @@ class TpModelWorkerClient:
|
|
195
209
|
) % self.future_token_ids_limit
|
196
210
|
return None, future_next_token_ids
|
197
211
|
|
198
|
-
def
|
199
|
-
success, message = self.worker.
|
212
|
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
213
|
+
success, message = self.worker.update_weights_from_disk(recv_req)
|
200
214
|
return success, message
|
201
215
|
|
216
|
+
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
217
|
+
success, message = self.worker.init_weights_update_group(recv_req)
|
218
|
+
return success, message
|
219
|
+
|
220
|
+
def update_weights_from_distributed(
|
221
|
+
self, recv_req: UpdateWeightsFromDistributedReqInput
|
222
|
+
):
|
223
|
+
success, message = self.worker.update_weights_from_distributed(recv_req)
|
224
|
+
return success, message
|
225
|
+
|
226
|
+
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
227
|
+
return self.worker.get_weights_by_name(recv_req)
|
228
|
+
|
202
229
|
def __delete__(self):
|
203
230
|
self.input_queue.put((None, None))
|
204
231
|
self.copy_queue.put((None, None, None))
|
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
|
|
36
36
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
37
37
|
|
38
38
|
|
39
|
-
def _to_torch(model: torch.nn.Module, reverse: bool
|
39
|
+
def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
40
40
|
for sub in model._modules.values():
|
41
41
|
if isinstance(sub, CustomOp):
|
42
42
|
if reverse:
|
@@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
|
45
45
|
else:
|
46
46
|
# NOTE: Temporarily workaround MoE
|
47
47
|
if "FusedMoE" in sub.__class__.__name__:
|
48
|
-
|
48
|
+
if batch_size == 1:
|
49
|
+
# The performance of torch.compile on this layer is not always good when bs > 1,
|
50
|
+
# so we decide to skip it for now.
|
51
|
+
sub._forward_method = fused_moe_forward_native
|
49
52
|
else:
|
50
53
|
sub._forward_method = sub.forward_native
|
51
54
|
setattr(sub, "is_torch_compile", True)
|
52
55
|
if isinstance(sub, torch.nn.Module):
|
53
|
-
_to_torch(sub, reverse)
|
56
|
+
_to_torch(sub, reverse, batch_size)
|
54
57
|
|
55
58
|
|
56
59
|
@contextmanager
|
57
60
|
def patch_model(
|
58
|
-
model: torch.nn.Module,
|
61
|
+
model: torch.nn.Module,
|
62
|
+
enable_compile: bool,
|
63
|
+
batch_size: int,
|
64
|
+
tp_group: "GroupCoordinator",
|
59
65
|
):
|
60
66
|
"""Patch the model to make it compatible with with torch.compile"""
|
61
67
|
backup_ca_comm = None
|
62
68
|
|
63
69
|
try:
|
64
70
|
if enable_compile:
|
65
|
-
_to_torch(model)
|
71
|
+
_to_torch(model, reverse=False, batch_size=batch_size)
|
66
72
|
monkey_patch_vllm_all_gather()
|
67
73
|
backup_ca_comm = tp_group.ca_comm
|
68
74
|
# Use custom-allreduce here.
|
@@ -70,13 +76,15 @@ def patch_model(
|
|
70
76
|
# even with ENABLE_INTRA_NODE_COMM=1.
|
71
77
|
# tp_group.ca_comm = None
|
72
78
|
yield torch.compile(
|
73
|
-
torch.no_grad()(model.forward),
|
79
|
+
torch.no_grad()(model.forward),
|
80
|
+
mode="max-autotune-no-cudagraphs",
|
81
|
+
dynamic=False,
|
74
82
|
)
|
75
83
|
else:
|
76
84
|
yield model.forward
|
77
85
|
finally:
|
78
86
|
if enable_compile:
|
79
|
-
_to_torch(model, reverse=True)
|
87
|
+
_to_torch(model, reverse=True, batch_size=batch_size)
|
80
88
|
monkey_patch_vllm_all_gather(reverse=True)
|
81
89
|
tp_group.ca_comm = backup_ca_comm
|
82
90
|
|
@@ -237,6 +245,7 @@ class CudaGraphRunner:
|
|
237
245
|
with patch_model(
|
238
246
|
self.model_runner.model,
|
239
247
|
bs in self.compile_bs,
|
248
|
+
bs,
|
240
249
|
self.model_runner.tp_group,
|
241
250
|
) as forward:
|
242
251
|
(
|
@@ -256,10 +256,15 @@ class ForwardBatch:
|
|
256
256
|
ret.extend_prefix_lens = torch.tensor(
|
257
257
|
batch.extend_prefix_lens, dtype=torch.int32
|
258
258
|
).to(device, non_blocking=True)
|
259
|
-
|
260
|
-
|
261
|
-
ret.
|
262
|
-
|
259
|
+
if model_runner.server_args.attention_backend != "torch_native":
|
260
|
+
ret.extend_num_tokens = batch.extend_num_tokens
|
261
|
+
ret.positions, ret.extend_start_loc = compute_position_triton(
|
262
|
+
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
|
263
|
+
)
|
264
|
+
else:
|
265
|
+
ret.positions, ret.extend_start_loc = compute_position_torch(
|
266
|
+
ret.extend_prefix_lens, ret.extend_seq_lens
|
267
|
+
)
|
263
268
|
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
264
269
|
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
265
270
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|