sglang 0.3.6.post3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +4 -0
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +6 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/flashinfer_backend.py +13 -15
- sglang/srt/layers/attention/torch_native_backend.py +285 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +34 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/io_struct.py +48 -2
- sglang/srt/managers/schedule_batch.py +18 -14
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +76 -20
- sglang/srt/managers/tokenizer_manager.py +166 -68
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +21 -3
- sglang/srt/model_executor/cuda_graph_runner.py +16 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +136 -150
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +2 -3
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +3 -11
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +97 -27
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +1 -4
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +12 -5
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +0 -1
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -8
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -11
- sglang/srt/models/qwen2_vl.py +2 -6
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -12
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +9 -5
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/server.py +267 -170
- sglang/srt/server_args.py +65 -31
- sglang/srt/utils.py +245 -28
- sglang/test/test_utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/METADATA +1 -1
- sglang-0.4.0.dist-info/RECORD +184 -0
- sglang-0.3.6.post3.dist-info/RECORD +0 -162
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -45,13 +45,19 @@ from sglang.srt.managers.io_struct import (
|
|
45
45
|
EmbeddingReqInput,
|
46
46
|
FlushCacheReq,
|
47
47
|
GenerateReqInput,
|
48
|
+
GetWeightsByNameReqInput,
|
49
|
+
GetWeightsByNameReqOutput,
|
50
|
+
InitWeightsUpdateGroupReqInput,
|
51
|
+
InitWeightsUpdateGroupReqOutput,
|
48
52
|
OpenSessionReqInput,
|
49
53
|
OpenSessionReqOutput,
|
50
54
|
ProfileReq,
|
51
55
|
TokenizedEmbeddingReqInput,
|
52
56
|
TokenizedGenerateReqInput,
|
53
|
-
|
54
|
-
|
57
|
+
UpdateWeightFromDiskReqInput,
|
58
|
+
UpdateWeightFromDiskReqOutput,
|
59
|
+
UpdateWeightsFromDistributedReqInput,
|
60
|
+
UpdateWeightsFromDistributedReqOutput,
|
55
61
|
)
|
56
62
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
57
63
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -103,9 +109,12 @@ class TokenizerManager:
|
|
103
109
|
self.model_config = ModelConfig(
|
104
110
|
server_args.model_path,
|
105
111
|
trust_remote_code=server_args.trust_remote_code,
|
112
|
+
revision=server_args.revision,
|
106
113
|
context_length=server_args.context_length,
|
107
114
|
model_override_args=server_args.json_model_override_args,
|
108
115
|
is_embedding=server_args.is_embedding,
|
116
|
+
dtype=server_args.dtype,
|
117
|
+
quantization=server_args.quantization,
|
109
118
|
)
|
110
119
|
|
111
120
|
self.is_generation = self.model_config.is_generation
|
@@ -330,6 +339,12 @@ class TokenizerManager:
|
|
330
339
|
rids.append(tmp_obj.rid)
|
331
340
|
else:
|
332
341
|
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
342
|
+
if batch_size > 128:
|
343
|
+
logger.warning(
|
344
|
+
"Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
|
345
|
+
"The performance might be better if you just duplicate the requests n times or use "
|
346
|
+
"many threads to send them one by one with parallel sampling (n > 1)."
|
347
|
+
)
|
333
348
|
|
334
349
|
# Tokenize all requests
|
335
350
|
objs = [obj[i] for i in range(batch_size)]
|
@@ -405,8 +420,10 @@ class TokenizerManager:
|
|
405
420
|
req = ProfileReq.STOP_PROFILE
|
406
421
|
self.send_to_scheduler.send_pyobj(req)
|
407
422
|
|
408
|
-
async def
|
409
|
-
self,
|
423
|
+
async def update_weights_from_disk(
|
424
|
+
self,
|
425
|
+
obj: UpdateWeightFromDiskReqInput,
|
426
|
+
request: Optional[fastapi.Request] = None,
|
410
427
|
):
|
411
428
|
if self.to_create_loop:
|
412
429
|
self.create_handle_loop()
|
@@ -451,6 +468,63 @@ class TokenizerManager:
|
|
451
468
|
else:
|
452
469
|
return False, "Another update is in progress. Please try again later."
|
453
470
|
|
471
|
+
async def init_weights_update_group(
|
472
|
+
self,
|
473
|
+
obj: InitWeightsUpdateGroupReqInput,
|
474
|
+
request: Optional[fastapi.Request] = None,
|
475
|
+
) -> bool:
|
476
|
+
if self.to_create_loop:
|
477
|
+
self.create_handle_loop()
|
478
|
+
self.send_to_scheduler.send_pyobj(obj)
|
479
|
+
|
480
|
+
self.init_weights_update_group_result = asyncio.Future()
|
481
|
+
assert (
|
482
|
+
self.server_args.dp_size == 1
|
483
|
+
), "dp_size must be 1 for init parameter update group"
|
484
|
+
result = await self.init_weights_update_group_result
|
485
|
+
return result.success, result.message
|
486
|
+
|
487
|
+
async def update_weights_from_distributed(
|
488
|
+
self,
|
489
|
+
obj: UpdateWeightsFromDistributedReqInput,
|
490
|
+
request: Optional[fastapi.Request] = None,
|
491
|
+
):
|
492
|
+
if self.to_create_loop:
|
493
|
+
self.create_handle_loop()
|
494
|
+
|
495
|
+
if not self.model_update_lock.locked():
|
496
|
+
async with self.model_update_lock:
|
497
|
+
self.send_to_scheduler.send_pyobj(obj)
|
498
|
+
self.parameter_update_result = asyncio.Future()
|
499
|
+
assert (
|
500
|
+
self.server_args.dp_size == 1
|
501
|
+
), "dp_size must be for update weights from distributed"
|
502
|
+
result = await self.parameter_update_result
|
503
|
+
return result.success, result.message
|
504
|
+
else:
|
505
|
+
logger.error("Another parameter update is in progress in tokenizer manager")
|
506
|
+
return (
|
507
|
+
False,
|
508
|
+
"Another parameter update is in progress. Please try again later.",
|
509
|
+
)
|
510
|
+
|
511
|
+
async def get_weights_by_name(
|
512
|
+
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
513
|
+
):
|
514
|
+
if self.to_create_loop:
|
515
|
+
self.create_handle_loop()
|
516
|
+
|
517
|
+
self.send_to_scheduler.send_pyobj(obj)
|
518
|
+
self.get_weights_by_name_result = asyncio.Future()
|
519
|
+
if self.server_args.dp_size == 1:
|
520
|
+
result = await self.get_weights_by_name_result
|
521
|
+
return result.parameter
|
522
|
+
else:
|
523
|
+
self.get_weights_by_name_tmp = []
|
524
|
+
result = await self.get_weights_by_name_result
|
525
|
+
all_parameters = [r.parameter for r in result]
|
526
|
+
return all_parameters
|
527
|
+
|
454
528
|
async def open_session(
|
455
529
|
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
456
530
|
):
|
@@ -520,10 +594,77 @@ class TokenizerManager:
|
|
520
594
|
|
521
595
|
while True:
|
522
596
|
recv_obj: Union[
|
523
|
-
BatchStrOut,
|
597
|
+
BatchStrOut,
|
598
|
+
BatchEmbeddingOut,
|
599
|
+
BatchTokenIDOut,
|
600
|
+
UpdateWeightFromDiskReqOutput,
|
601
|
+
UpdateWeightsFromDistributedReqOutput,
|
602
|
+
GetWeightsByNameReqOutput,
|
603
|
+
InitWeightsUpdateGroupReqOutput,
|
524
604
|
] = await self.recv_from_detokenizer.recv_pyobj()
|
525
605
|
|
526
|
-
if isinstance(recv_obj,
|
606
|
+
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
|
607
|
+
for i, rid in enumerate(recv_obj.rids):
|
608
|
+
state = self.rid_to_state.get(rid, None)
|
609
|
+
if state is None:
|
610
|
+
continue
|
611
|
+
|
612
|
+
recv_obj.meta_info[i]["id"] = rid
|
613
|
+
if isinstance(recv_obj, BatchStrOut):
|
614
|
+
out_dict = {
|
615
|
+
"text": recv_obj.output_strs[i],
|
616
|
+
"meta_info": recv_obj.meta_info[i],
|
617
|
+
}
|
618
|
+
elif isinstance(recv_obj, BatchTokenIDOut):
|
619
|
+
out_dict = {
|
620
|
+
"token_ids": recv_obj.output_ids[i],
|
621
|
+
"meta_info": recv_obj.meta_info[i],
|
622
|
+
}
|
623
|
+
else:
|
624
|
+
assert isinstance(recv_obj, BatchEmbeddingOut)
|
625
|
+
out_dict = {
|
626
|
+
"embedding": recv_obj.embeddings[i],
|
627
|
+
"meta_info": recv_obj.meta_info[i],
|
628
|
+
}
|
629
|
+
state.out_list.append(out_dict)
|
630
|
+
state.finished = recv_obj.finished_reason[i] is not None
|
631
|
+
state.event.set()
|
632
|
+
|
633
|
+
if self.enable_metrics:
|
634
|
+
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
|
635
|
+
|
636
|
+
if state.first_token_time is None:
|
637
|
+
state.first_token_time = time.time()
|
638
|
+
self.metrics_collector.observe_time_to_first_token(
|
639
|
+
state.first_token_time - state.created_time
|
640
|
+
)
|
641
|
+
else:
|
642
|
+
if completion_tokens >= 2:
|
643
|
+
self.metrics_collector.observe_time_per_output_token(
|
644
|
+
(time.time() - state.first_token_time)
|
645
|
+
/ (completion_tokens - 1)
|
646
|
+
)
|
647
|
+
|
648
|
+
if state.finished:
|
649
|
+
self.metrics_collector.inc_prompt_tokens(
|
650
|
+
recv_obj.meta_info[i]["prompt_tokens"]
|
651
|
+
)
|
652
|
+
self.metrics_collector.inc_generation_tokens(
|
653
|
+
completion_tokens
|
654
|
+
)
|
655
|
+
self.metrics_collector.observe_e2e_request_latency(
|
656
|
+
time.time() - state.created_time
|
657
|
+
)
|
658
|
+
if completion_tokens >= 1:
|
659
|
+
self.metrics_collector.observe_time_per_output_token(
|
660
|
+
(time.time() - state.created_time)
|
661
|
+
/ completion_tokens
|
662
|
+
)
|
663
|
+
elif isinstance(recv_obj, OpenSessionReqOutput):
|
664
|
+
self.session_futures[recv_obj.session_id].set_result(
|
665
|
+
recv_obj.session_id
|
666
|
+
)
|
667
|
+
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
527
668
|
if self.server_args.dp_size == 1:
|
528
669
|
self.model_update_result.set_result(recv_obj)
|
529
670
|
else: # self.server_args.dp_size > 1
|
@@ -531,70 +672,27 @@ class TokenizerManager:
|
|
531
672
|
# set future if the all results are recevied
|
532
673
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
533
674
|
self.model_update_result.set_result(self.model_update_tmp)
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
)
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
if state is None:
|
548
|
-
continue
|
549
|
-
|
550
|
-
recv_obj.meta_info[i]["id"] = rid
|
551
|
-
if isinstance(recv_obj, BatchStrOut):
|
552
|
-
out_dict = {
|
553
|
-
"text": recv_obj.output_strs[i],
|
554
|
-
"meta_info": recv_obj.meta_info[i],
|
555
|
-
}
|
556
|
-
elif isinstance(recv_obj, BatchTokenIDOut):
|
557
|
-
out_dict = {
|
558
|
-
"token_ids": recv_obj.output_ids[i],
|
559
|
-
"meta_info": recv_obj.meta_info[i],
|
560
|
-
}
|
675
|
+
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
|
676
|
+
assert (
|
677
|
+
self.server_args.dp_size == 1
|
678
|
+
), "dp_size must be 1 for init parameter update group"
|
679
|
+
self.init_weights_update_group_result.set_result(recv_obj)
|
680
|
+
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
|
681
|
+
assert (
|
682
|
+
self.server_args.dp_size == 1
|
683
|
+
), "dp_size must be 1 for update weights from distributed"
|
684
|
+
self.parameter_update_result.set_result(recv_obj)
|
685
|
+
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
686
|
+
if self.server_args.dp_size == 1:
|
687
|
+
self.get_weights_by_name_result.set_result(recv_obj)
|
561
688
|
else:
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
}
|
567
|
-
state.out_list.append(out_dict)
|
568
|
-
state.finished = recv_obj.finished_reason[i] is not None
|
569
|
-
state.event.set()
|
570
|
-
|
571
|
-
if self.enable_metrics:
|
572
|
-
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
|
573
|
-
|
574
|
-
if state.first_token_time is None:
|
575
|
-
state.first_token_time = time.time()
|
576
|
-
self.metrics_collector.observe_time_to_first_token(
|
577
|
-
state.first_token_time - state.created_time
|
689
|
+
self.get_weights_by_name_tmp.append(recv_obj)
|
690
|
+
if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
|
691
|
+
self.get_weights_by_name_result.set_result(
|
692
|
+
self.get_weights_by_name_tmp
|
578
693
|
)
|
579
|
-
|
580
|
-
|
581
|
-
self.metrics_collector.observe_time_per_output_token(
|
582
|
-
(time.time() - state.first_token_time)
|
583
|
-
/ (completion_tokens - 1)
|
584
|
-
)
|
585
|
-
|
586
|
-
if state.finished:
|
587
|
-
self.metrics_collector.inc_prompt_tokens(
|
588
|
-
recv_obj.meta_info[i]["prompt_tokens"]
|
589
|
-
)
|
590
|
-
self.metrics_collector.inc_generation_tokens(completion_tokens)
|
591
|
-
self.metrics_collector.observe_e2e_request_latency(
|
592
|
-
time.time() - state.created_time
|
593
|
-
)
|
594
|
-
if completion_tokens >= 1:
|
595
|
-
self.metrics_collector.observe_time_per_output_token(
|
596
|
-
(time.time() - state.created_time) / completion_tokens
|
597
|
-
)
|
694
|
+
else:
|
695
|
+
raise ValueError(f"Invalid object: {recv_obj=}")
|
598
696
|
|
599
697
|
def convert_logprob_style(
|
600
698
|
self,
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -19,7 +19,12 @@ from typing import Optional
|
|
19
19
|
|
20
20
|
from sglang.srt.configs.model_config import ModelConfig
|
21
21
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
22
|
-
from sglang.srt.managers.io_struct import
|
22
|
+
from sglang.srt.managers.io_struct import (
|
23
|
+
GetWeightsByNameReqInput,
|
24
|
+
InitWeightsUpdateGroupReqInput,
|
25
|
+
UpdateWeightFromDiskReqInput,
|
26
|
+
UpdateWeightsFromDistributedReqInput,
|
27
|
+
)
|
23
28
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
24
29
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
25
30
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -47,9 +52,12 @@ class TpModelWorker:
|
|
47
52
|
self.model_config = ModelConfig(
|
48
53
|
server_args.model_path,
|
49
54
|
trust_remote_code=server_args.trust_remote_code,
|
55
|
+
revision=server_args.revision,
|
50
56
|
context_length=server_args.context_length,
|
51
57
|
model_override_args=server_args.json_model_override_args,
|
52
58
|
is_embedding=server_args.is_embedding,
|
59
|
+
dtype=server_args.dtype,
|
60
|
+
quantization=server_args.quantization,
|
53
61
|
)
|
54
62
|
self.model_runner = ModelRunner(
|
55
63
|
model_config=self.model_config,
|
@@ -155,8 +163,33 @@ class TpModelWorker:
|
|
155
163
|
embeddings = logits_output.embeddings
|
156
164
|
return embeddings
|
157
165
|
|
158
|
-
def
|
159
|
-
success, message = self.model_runner.
|
166
|
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
167
|
+
success, message = self.model_runner.update_weights_from_disk(
|
160
168
|
recv_req.model_path, recv_req.load_format
|
161
169
|
)
|
162
170
|
return success, message
|
171
|
+
|
172
|
+
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
173
|
+
success, message = self.model_runner.init_weights_update_group(
|
174
|
+
recv_req.master_address,
|
175
|
+
recv_req.master_port,
|
176
|
+
recv_req.rank_offset,
|
177
|
+
recv_req.world_size,
|
178
|
+
recv_req.group_name,
|
179
|
+
recv_req.backend,
|
180
|
+
)
|
181
|
+
return success, message
|
182
|
+
|
183
|
+
def update_weights_from_distributed(
|
184
|
+
self, recv_req: UpdateWeightsFromDistributedReqInput
|
185
|
+
):
|
186
|
+
success, message = self.model_runner.update_weights_from_distributed(
|
187
|
+
recv_req.name, recv_req.dtype, recv_req.shape
|
188
|
+
)
|
189
|
+
return success, message
|
190
|
+
|
191
|
+
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
192
|
+
parameter = self.model_runner.get_weights_by_name(
|
193
|
+
recv_req.name, recv_req.truncate_size
|
194
|
+
)
|
195
|
+
return parameter
|
@@ -23,7 +23,12 @@ from typing import Optional
|
|
23
23
|
import psutil
|
24
24
|
import torch
|
25
25
|
|
26
|
-
from sglang.srt.managers.io_struct import
|
26
|
+
from sglang.srt.managers.io_struct import (
|
27
|
+
GetWeightsByNameReqInput,
|
28
|
+
InitWeightsUpdateGroupReqInput,
|
29
|
+
UpdateWeightFromDiskReqInput,
|
30
|
+
UpdateWeightsFromDistributedReqInput,
|
31
|
+
)
|
27
32
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
28
33
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
29
34
|
from sglang.srt.server_args import ServerArgs
|
@@ -204,10 +209,23 @@ class TpModelWorkerClient:
|
|
204
209
|
) % self.future_token_ids_limit
|
205
210
|
return None, future_next_token_ids
|
206
211
|
|
207
|
-
def
|
208
|
-
success, message = self.worker.
|
212
|
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
213
|
+
success, message = self.worker.update_weights_from_disk(recv_req)
|
209
214
|
return success, message
|
210
215
|
|
216
|
+
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
217
|
+
success, message = self.worker.init_weights_update_group(recv_req)
|
218
|
+
return success, message
|
219
|
+
|
220
|
+
def update_weights_from_distributed(
|
221
|
+
self, recv_req: UpdateWeightsFromDistributedReqInput
|
222
|
+
):
|
223
|
+
success, message = self.worker.update_weights_from_distributed(recv_req)
|
224
|
+
return success, message
|
225
|
+
|
226
|
+
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
227
|
+
return self.worker.get_weights_by_name(recv_req)
|
228
|
+
|
211
229
|
def __delete__(self):
|
212
230
|
self.input_queue.put((None, None))
|
213
231
|
self.copy_queue.put((None, None, None))
|
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
|
|
36
36
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
37
37
|
|
38
38
|
|
39
|
-
def _to_torch(model: torch.nn.Module, reverse: bool
|
39
|
+
def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
40
40
|
for sub in model._modules.values():
|
41
41
|
if isinstance(sub, CustomOp):
|
42
42
|
if reverse:
|
@@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
|
45
45
|
else:
|
46
46
|
# NOTE: Temporarily workaround MoE
|
47
47
|
if "FusedMoE" in sub.__class__.__name__:
|
48
|
-
|
48
|
+
if batch_size == 1:
|
49
|
+
# The performance of torch.compile on this layer is not always good when bs > 1,
|
50
|
+
# so we decide to skip it for now.
|
51
|
+
sub._forward_method = fused_moe_forward_native
|
49
52
|
else:
|
50
53
|
sub._forward_method = sub.forward_native
|
51
54
|
setattr(sub, "is_torch_compile", True)
|
52
55
|
if isinstance(sub, torch.nn.Module):
|
53
|
-
_to_torch(sub, reverse)
|
56
|
+
_to_torch(sub, reverse, batch_size)
|
54
57
|
|
55
58
|
|
56
59
|
@contextmanager
|
57
60
|
def patch_model(
|
58
|
-
model: torch.nn.Module,
|
61
|
+
model: torch.nn.Module,
|
62
|
+
enable_compile: bool,
|
63
|
+
batch_size: int,
|
64
|
+
tp_group: "GroupCoordinator",
|
59
65
|
):
|
60
66
|
"""Patch the model to make it compatible with with torch.compile"""
|
61
67
|
backup_ca_comm = None
|
62
68
|
|
63
69
|
try:
|
64
70
|
if enable_compile:
|
65
|
-
_to_torch(model)
|
71
|
+
_to_torch(model, reverse=False, batch_size=batch_size)
|
66
72
|
monkey_patch_vllm_all_gather()
|
67
73
|
backup_ca_comm = tp_group.ca_comm
|
68
74
|
# Use custom-allreduce here.
|
@@ -70,13 +76,15 @@ def patch_model(
|
|
70
76
|
# even with ENABLE_INTRA_NODE_COMM=1.
|
71
77
|
# tp_group.ca_comm = None
|
72
78
|
yield torch.compile(
|
73
|
-
torch.no_grad()(model.forward),
|
79
|
+
torch.no_grad()(model.forward),
|
80
|
+
mode="max-autotune-no-cudagraphs",
|
81
|
+
dynamic=False,
|
74
82
|
)
|
75
83
|
else:
|
76
84
|
yield model.forward
|
77
85
|
finally:
|
78
86
|
if enable_compile:
|
79
|
-
_to_torch(model, reverse=True)
|
87
|
+
_to_torch(model, reverse=True, batch_size=batch_size)
|
80
88
|
monkey_patch_vllm_all_gather(reverse=True)
|
81
89
|
tp_group.ca_comm = backup_ca_comm
|
82
90
|
|
@@ -237,6 +245,7 @@ class CudaGraphRunner:
|
|
237
245
|
with patch_model(
|
238
246
|
self.model_runner.model,
|
239
247
|
bs in self.compile_bs,
|
248
|
+
bs,
|
240
249
|
self.model_runner.tp_group,
|
241
250
|
) as forward:
|
242
251
|
(
|
@@ -256,10 +256,15 @@ class ForwardBatch:
|
|
256
256
|
ret.extend_prefix_lens = torch.tensor(
|
257
257
|
batch.extend_prefix_lens, dtype=torch.int32
|
258
258
|
).to(device, non_blocking=True)
|
259
|
-
|
260
|
-
|
261
|
-
ret.
|
262
|
-
|
259
|
+
if model_runner.server_args.attention_backend != "torch_native":
|
260
|
+
ret.extend_num_tokens = batch.extend_num_tokens
|
261
|
+
ret.positions, ret.extend_start_loc = compute_position_triton(
|
262
|
+
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
|
263
|
+
)
|
264
|
+
else:
|
265
|
+
ret.positions, ret.extend_start_loc = compute_position_torch(
|
266
|
+
ret.extend_prefix_lens, ret.extend_seq_lens
|
267
|
+
)
|
263
268
|
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
264
269
|
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
265
270
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|