sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_one_batch.py +4 -0
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +11 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +5 -5
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +33 -20
- sglang/srt/layers/attention/torch_native_backend.py +299 -0
- sglang/srt/layers/attention/triton_backend.py +22 -8
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +661 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +36 -2
- sglang/srt/layers/quantization/fp8.py +559 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +4 -2
- sglang/srt/layers/sampler.py +2 -0
- sglang/srt/layers/torchao_utils.py +23 -45
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/io_struct.py +48 -2
- sglang/srt/managers/schedule_batch.py +19 -14
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +145 -85
- sglang/srt/managers/tokenizer_manager.py +166 -68
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/model_executor/cuda_graph_runner.py +30 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +146 -153
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/model_parallel.py +1 -5
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +4 -5
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +90 -18
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +3 -8
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +96 -31
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +1 -4
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +24 -14
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +0 -1
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -13
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -16
- sglang/srt/models/qwen2_vl.py +2 -6
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -17
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +9 -5
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +270 -173
- sglang/srt/server_args.py +102 -29
- sglang/srt/utils.py +295 -28
- sglang/test/test_utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
- sglang-0.4.0.post1.dist-info/RECORD +189 -0
- sglang-0.3.6.post3.dist-info/RECORD +0 -162
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -45,13 +45,19 @@ from sglang.srt.managers.io_struct import (
|
|
45
45
|
EmbeddingReqInput,
|
46
46
|
FlushCacheReq,
|
47
47
|
GenerateReqInput,
|
48
|
+
GetWeightsByNameReqInput,
|
49
|
+
GetWeightsByNameReqOutput,
|
50
|
+
InitWeightsUpdateGroupReqInput,
|
51
|
+
InitWeightsUpdateGroupReqOutput,
|
48
52
|
OpenSessionReqInput,
|
49
53
|
OpenSessionReqOutput,
|
50
54
|
ProfileReq,
|
51
55
|
TokenizedEmbeddingReqInput,
|
52
56
|
TokenizedGenerateReqInput,
|
53
|
-
|
54
|
-
|
57
|
+
UpdateWeightFromDiskReqInput,
|
58
|
+
UpdateWeightFromDiskReqOutput,
|
59
|
+
UpdateWeightsFromDistributedReqInput,
|
60
|
+
UpdateWeightsFromDistributedReqOutput,
|
55
61
|
)
|
56
62
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
57
63
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -103,9 +109,12 @@ class TokenizerManager:
|
|
103
109
|
self.model_config = ModelConfig(
|
104
110
|
server_args.model_path,
|
105
111
|
trust_remote_code=server_args.trust_remote_code,
|
112
|
+
revision=server_args.revision,
|
106
113
|
context_length=server_args.context_length,
|
107
114
|
model_override_args=server_args.json_model_override_args,
|
108
115
|
is_embedding=server_args.is_embedding,
|
116
|
+
dtype=server_args.dtype,
|
117
|
+
quantization=server_args.quantization,
|
109
118
|
)
|
110
119
|
|
111
120
|
self.is_generation = self.model_config.is_generation
|
@@ -330,6 +339,12 @@ class TokenizerManager:
|
|
330
339
|
rids.append(tmp_obj.rid)
|
331
340
|
else:
|
332
341
|
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
342
|
+
if batch_size > 128:
|
343
|
+
logger.warning(
|
344
|
+
"Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
|
345
|
+
"The performance might be better if you just duplicate the requests n times or use "
|
346
|
+
"many threads to send them one by one with parallel sampling (n > 1)."
|
347
|
+
)
|
333
348
|
|
334
349
|
# Tokenize all requests
|
335
350
|
objs = [obj[i] for i in range(batch_size)]
|
@@ -405,8 +420,10 @@ class TokenizerManager:
|
|
405
420
|
req = ProfileReq.STOP_PROFILE
|
406
421
|
self.send_to_scheduler.send_pyobj(req)
|
407
422
|
|
408
|
-
async def
|
409
|
-
self,
|
423
|
+
async def update_weights_from_disk(
|
424
|
+
self,
|
425
|
+
obj: UpdateWeightFromDiskReqInput,
|
426
|
+
request: Optional[fastapi.Request] = None,
|
410
427
|
):
|
411
428
|
if self.to_create_loop:
|
412
429
|
self.create_handle_loop()
|
@@ -451,6 +468,63 @@ class TokenizerManager:
|
|
451
468
|
else:
|
452
469
|
return False, "Another update is in progress. Please try again later."
|
453
470
|
|
471
|
+
async def init_weights_update_group(
|
472
|
+
self,
|
473
|
+
obj: InitWeightsUpdateGroupReqInput,
|
474
|
+
request: Optional[fastapi.Request] = None,
|
475
|
+
) -> bool:
|
476
|
+
if self.to_create_loop:
|
477
|
+
self.create_handle_loop()
|
478
|
+
self.send_to_scheduler.send_pyobj(obj)
|
479
|
+
|
480
|
+
self.init_weights_update_group_result = asyncio.Future()
|
481
|
+
assert (
|
482
|
+
self.server_args.dp_size == 1
|
483
|
+
), "dp_size must be 1 for init parameter update group"
|
484
|
+
result = await self.init_weights_update_group_result
|
485
|
+
return result.success, result.message
|
486
|
+
|
487
|
+
async def update_weights_from_distributed(
|
488
|
+
self,
|
489
|
+
obj: UpdateWeightsFromDistributedReqInput,
|
490
|
+
request: Optional[fastapi.Request] = None,
|
491
|
+
):
|
492
|
+
if self.to_create_loop:
|
493
|
+
self.create_handle_loop()
|
494
|
+
|
495
|
+
if not self.model_update_lock.locked():
|
496
|
+
async with self.model_update_lock:
|
497
|
+
self.send_to_scheduler.send_pyobj(obj)
|
498
|
+
self.parameter_update_result = asyncio.Future()
|
499
|
+
assert (
|
500
|
+
self.server_args.dp_size == 1
|
501
|
+
), "dp_size must be for update weights from distributed"
|
502
|
+
result = await self.parameter_update_result
|
503
|
+
return result.success, result.message
|
504
|
+
else:
|
505
|
+
logger.error("Another parameter update is in progress in tokenizer manager")
|
506
|
+
return (
|
507
|
+
False,
|
508
|
+
"Another parameter update is in progress. Please try again later.",
|
509
|
+
)
|
510
|
+
|
511
|
+
async def get_weights_by_name(
|
512
|
+
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
513
|
+
):
|
514
|
+
if self.to_create_loop:
|
515
|
+
self.create_handle_loop()
|
516
|
+
|
517
|
+
self.send_to_scheduler.send_pyobj(obj)
|
518
|
+
self.get_weights_by_name_result = asyncio.Future()
|
519
|
+
if self.server_args.dp_size == 1:
|
520
|
+
result = await self.get_weights_by_name_result
|
521
|
+
return result.parameter
|
522
|
+
else:
|
523
|
+
self.get_weights_by_name_tmp = []
|
524
|
+
result = await self.get_weights_by_name_result
|
525
|
+
all_parameters = [r.parameter for r in result]
|
526
|
+
return all_parameters
|
527
|
+
|
454
528
|
async def open_session(
|
455
529
|
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
456
530
|
):
|
@@ -520,10 +594,77 @@ class TokenizerManager:
|
|
520
594
|
|
521
595
|
while True:
|
522
596
|
recv_obj: Union[
|
523
|
-
BatchStrOut,
|
597
|
+
BatchStrOut,
|
598
|
+
BatchEmbeddingOut,
|
599
|
+
BatchTokenIDOut,
|
600
|
+
UpdateWeightFromDiskReqOutput,
|
601
|
+
UpdateWeightsFromDistributedReqOutput,
|
602
|
+
GetWeightsByNameReqOutput,
|
603
|
+
InitWeightsUpdateGroupReqOutput,
|
524
604
|
] = await self.recv_from_detokenizer.recv_pyobj()
|
525
605
|
|
526
|
-
if isinstance(recv_obj,
|
606
|
+
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
|
607
|
+
for i, rid in enumerate(recv_obj.rids):
|
608
|
+
state = self.rid_to_state.get(rid, None)
|
609
|
+
if state is None:
|
610
|
+
continue
|
611
|
+
|
612
|
+
recv_obj.meta_info[i]["id"] = rid
|
613
|
+
if isinstance(recv_obj, BatchStrOut):
|
614
|
+
out_dict = {
|
615
|
+
"text": recv_obj.output_strs[i],
|
616
|
+
"meta_info": recv_obj.meta_info[i],
|
617
|
+
}
|
618
|
+
elif isinstance(recv_obj, BatchTokenIDOut):
|
619
|
+
out_dict = {
|
620
|
+
"token_ids": recv_obj.output_ids[i],
|
621
|
+
"meta_info": recv_obj.meta_info[i],
|
622
|
+
}
|
623
|
+
else:
|
624
|
+
assert isinstance(recv_obj, BatchEmbeddingOut)
|
625
|
+
out_dict = {
|
626
|
+
"embedding": recv_obj.embeddings[i],
|
627
|
+
"meta_info": recv_obj.meta_info[i],
|
628
|
+
}
|
629
|
+
state.out_list.append(out_dict)
|
630
|
+
state.finished = recv_obj.finished_reason[i] is not None
|
631
|
+
state.event.set()
|
632
|
+
|
633
|
+
if self.enable_metrics:
|
634
|
+
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
|
635
|
+
|
636
|
+
if state.first_token_time is None:
|
637
|
+
state.first_token_time = time.time()
|
638
|
+
self.metrics_collector.observe_time_to_first_token(
|
639
|
+
state.first_token_time - state.created_time
|
640
|
+
)
|
641
|
+
else:
|
642
|
+
if completion_tokens >= 2:
|
643
|
+
self.metrics_collector.observe_time_per_output_token(
|
644
|
+
(time.time() - state.first_token_time)
|
645
|
+
/ (completion_tokens - 1)
|
646
|
+
)
|
647
|
+
|
648
|
+
if state.finished:
|
649
|
+
self.metrics_collector.inc_prompt_tokens(
|
650
|
+
recv_obj.meta_info[i]["prompt_tokens"]
|
651
|
+
)
|
652
|
+
self.metrics_collector.inc_generation_tokens(
|
653
|
+
completion_tokens
|
654
|
+
)
|
655
|
+
self.metrics_collector.observe_e2e_request_latency(
|
656
|
+
time.time() - state.created_time
|
657
|
+
)
|
658
|
+
if completion_tokens >= 1:
|
659
|
+
self.metrics_collector.observe_time_per_output_token(
|
660
|
+
(time.time() - state.created_time)
|
661
|
+
/ completion_tokens
|
662
|
+
)
|
663
|
+
elif isinstance(recv_obj, OpenSessionReqOutput):
|
664
|
+
self.session_futures[recv_obj.session_id].set_result(
|
665
|
+
recv_obj.session_id
|
666
|
+
)
|
667
|
+
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
527
668
|
if self.server_args.dp_size == 1:
|
528
669
|
self.model_update_result.set_result(recv_obj)
|
529
670
|
else: # self.server_args.dp_size > 1
|
@@ -531,70 +672,27 @@ class TokenizerManager:
|
|
531
672
|
# set future if the all results are recevied
|
532
673
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
533
674
|
self.model_update_result.set_result(self.model_update_tmp)
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
)
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
if state is None:
|
548
|
-
continue
|
549
|
-
|
550
|
-
recv_obj.meta_info[i]["id"] = rid
|
551
|
-
if isinstance(recv_obj, BatchStrOut):
|
552
|
-
out_dict = {
|
553
|
-
"text": recv_obj.output_strs[i],
|
554
|
-
"meta_info": recv_obj.meta_info[i],
|
555
|
-
}
|
556
|
-
elif isinstance(recv_obj, BatchTokenIDOut):
|
557
|
-
out_dict = {
|
558
|
-
"token_ids": recv_obj.output_ids[i],
|
559
|
-
"meta_info": recv_obj.meta_info[i],
|
560
|
-
}
|
675
|
+
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
|
676
|
+
assert (
|
677
|
+
self.server_args.dp_size == 1
|
678
|
+
), "dp_size must be 1 for init parameter update group"
|
679
|
+
self.init_weights_update_group_result.set_result(recv_obj)
|
680
|
+
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
|
681
|
+
assert (
|
682
|
+
self.server_args.dp_size == 1
|
683
|
+
), "dp_size must be 1 for update weights from distributed"
|
684
|
+
self.parameter_update_result.set_result(recv_obj)
|
685
|
+
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
686
|
+
if self.server_args.dp_size == 1:
|
687
|
+
self.get_weights_by_name_result.set_result(recv_obj)
|
561
688
|
else:
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
}
|
567
|
-
state.out_list.append(out_dict)
|
568
|
-
state.finished = recv_obj.finished_reason[i] is not None
|
569
|
-
state.event.set()
|
570
|
-
|
571
|
-
if self.enable_metrics:
|
572
|
-
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
|
573
|
-
|
574
|
-
if state.first_token_time is None:
|
575
|
-
state.first_token_time = time.time()
|
576
|
-
self.metrics_collector.observe_time_to_first_token(
|
577
|
-
state.first_token_time - state.created_time
|
689
|
+
self.get_weights_by_name_tmp.append(recv_obj)
|
690
|
+
if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
|
691
|
+
self.get_weights_by_name_result.set_result(
|
692
|
+
self.get_weights_by_name_tmp
|
578
693
|
)
|
579
|
-
|
580
|
-
|
581
|
-
self.metrics_collector.observe_time_per_output_token(
|
582
|
-
(time.time() - state.first_token_time)
|
583
|
-
/ (completion_tokens - 1)
|
584
|
-
)
|
585
|
-
|
586
|
-
if state.finished:
|
587
|
-
self.metrics_collector.inc_prompt_tokens(
|
588
|
-
recv_obj.meta_info[i]["prompt_tokens"]
|
589
|
-
)
|
590
|
-
self.metrics_collector.inc_generation_tokens(completion_tokens)
|
591
|
-
self.metrics_collector.observe_e2e_request_latency(
|
592
|
-
time.time() - state.created_time
|
593
|
-
)
|
594
|
-
if completion_tokens >= 1:
|
595
|
-
self.metrics_collector.observe_time_per_output_token(
|
596
|
-
(time.time() - state.created_time) / completion_tokens
|
597
|
-
)
|
694
|
+
else:
|
695
|
+
raise ValueError(f"Invalid object: {recv_obj=}")
|
598
696
|
|
599
697
|
def convert_logprob_style(
|
600
698
|
self,
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -19,7 +19,12 @@ from typing import Optional
|
|
19
19
|
|
20
20
|
from sglang.srt.configs.model_config import ModelConfig
|
21
21
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
22
|
-
from sglang.srt.managers.io_struct import
|
22
|
+
from sglang.srt.managers.io_struct import (
|
23
|
+
GetWeightsByNameReqInput,
|
24
|
+
InitWeightsUpdateGroupReqInput,
|
25
|
+
UpdateWeightFromDiskReqInput,
|
26
|
+
UpdateWeightsFromDistributedReqInput,
|
27
|
+
)
|
23
28
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
24
29
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
25
30
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -47,9 +52,12 @@ class TpModelWorker:
|
|
47
52
|
self.model_config = ModelConfig(
|
48
53
|
server_args.model_path,
|
49
54
|
trust_remote_code=server_args.trust_remote_code,
|
55
|
+
revision=server_args.revision,
|
50
56
|
context_length=server_args.context_length,
|
51
57
|
model_override_args=server_args.json_model_override_args,
|
52
58
|
is_embedding=server_args.is_embedding,
|
59
|
+
dtype=server_args.dtype,
|
60
|
+
quantization=server_args.quantization,
|
53
61
|
)
|
54
62
|
self.model_runner = ModelRunner(
|
55
63
|
model_config=self.model_config,
|
@@ -155,8 +163,33 @@ class TpModelWorker:
|
|
155
163
|
embeddings = logits_output.embeddings
|
156
164
|
return embeddings
|
157
165
|
|
158
|
-
def
|
159
|
-
success, message = self.model_runner.
|
166
|
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
167
|
+
success, message = self.model_runner.update_weights_from_disk(
|
160
168
|
recv_req.model_path, recv_req.load_format
|
161
169
|
)
|
162
170
|
return success, message
|
171
|
+
|
172
|
+
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
173
|
+
success, message = self.model_runner.init_weights_update_group(
|
174
|
+
recv_req.master_address,
|
175
|
+
recv_req.master_port,
|
176
|
+
recv_req.rank_offset,
|
177
|
+
recv_req.world_size,
|
178
|
+
recv_req.group_name,
|
179
|
+
recv_req.backend,
|
180
|
+
)
|
181
|
+
return success, message
|
182
|
+
|
183
|
+
def update_weights_from_distributed(
|
184
|
+
self, recv_req: UpdateWeightsFromDistributedReqInput
|
185
|
+
):
|
186
|
+
success, message = self.model_runner.update_weights_from_distributed(
|
187
|
+
recv_req.name, recv_req.dtype, recv_req.shape
|
188
|
+
)
|
189
|
+
return success, message
|
190
|
+
|
191
|
+
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
192
|
+
parameter = self.model_runner.get_weights_by_name(
|
193
|
+
recv_req.name, recv_req.truncate_size
|
194
|
+
)
|
195
|
+
return parameter
|
@@ -23,16 +23,22 @@ from typing import Optional
|
|
23
23
|
import psutil
|
24
24
|
import torch
|
25
25
|
|
26
|
-
from sglang.srt.managers.io_struct import
|
26
|
+
from sglang.srt.managers.io_struct import (
|
27
|
+
GetWeightsByNameReqInput,
|
28
|
+
InitWeightsUpdateGroupReqInput,
|
29
|
+
UpdateWeightFromDiskReqInput,
|
30
|
+
UpdateWeightsFromDistributedReqInput,
|
31
|
+
)
|
27
32
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
28
33
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
29
34
|
from sglang.srt.server_args import ServerArgs
|
35
|
+
from sglang.srt.utils import get_compiler_backend
|
30
36
|
from sglang.utils import get_exception_traceback
|
31
37
|
|
32
38
|
logger = logging.getLogger(__name__)
|
33
39
|
|
34
40
|
|
35
|
-
@torch.compile(dynamic=True)
|
41
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
36
42
|
def resolve_future_token_ids(input_ids, future_token_ids_map):
|
37
43
|
input_ids[:] = torch.where(
|
38
44
|
input_ids < 0,
|
@@ -68,12 +74,13 @@ class TpModelWorkerClient:
|
|
68
74
|
# Launch threads
|
69
75
|
self.input_queue = Queue()
|
70
76
|
self.output_queue = Queue()
|
71
|
-
self.forward_stream = torch.
|
77
|
+
self.forward_stream = torch.get_device_module(self.device).Stream()
|
72
78
|
self.forward_thread = threading.Thread(
|
73
79
|
target=self.forward_thread_func,
|
74
80
|
)
|
75
81
|
self.forward_thread.start()
|
76
82
|
self.parent_process = psutil.Process().parent()
|
83
|
+
self.scheduler_stream = torch.get_device_module(self.device).current_stream()
|
77
84
|
|
78
85
|
def get_worker_info(self):
|
79
86
|
return self.worker.get_worker_info()
|
@@ -92,7 +99,7 @@ class TpModelWorkerClient:
|
|
92
99
|
|
93
100
|
def forward_thread_func(self):
|
94
101
|
try:
|
95
|
-
with torch.
|
102
|
+
with torch.get_device_module(self.device).stream(self.forward_stream):
|
96
103
|
self.forward_thread_func_()
|
97
104
|
except Exception:
|
98
105
|
traceback = get_exception_traceback()
|
@@ -117,7 +124,7 @@ class TpModelWorkerClient:
|
|
117
124
|
|
118
125
|
# Create event
|
119
126
|
self.launch_done = threading.Event()
|
120
|
-
copy_done = torch.
|
127
|
+
copy_done = torch.get_device_module(self.device).Event()
|
121
128
|
|
122
129
|
# Resolve future tokens in the input
|
123
130
|
input_ids = model_worker_batch.input_ids
|
@@ -185,7 +192,7 @@ class TpModelWorkerClient:
|
|
185
192
|
)
|
186
193
|
|
187
194
|
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
188
|
-
|
195
|
+
self.scheduler_stream.synchronize()
|
189
196
|
|
190
197
|
# Push a new batch to the queue
|
191
198
|
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
@@ -204,10 +211,23 @@ class TpModelWorkerClient:
|
|
204
211
|
) % self.future_token_ids_limit
|
205
212
|
return None, future_next_token_ids
|
206
213
|
|
207
|
-
def
|
208
|
-
success, message = self.worker.
|
214
|
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
215
|
+
success, message = self.worker.update_weights_from_disk(recv_req)
|
209
216
|
return success, message
|
210
217
|
|
218
|
+
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
219
|
+
success, message = self.worker.init_weights_update_group(recv_req)
|
220
|
+
return success, message
|
221
|
+
|
222
|
+
def update_weights_from_distributed(
|
223
|
+
self, recv_req: UpdateWeightsFromDistributedReqInput
|
224
|
+
):
|
225
|
+
success, message = self.worker.update_weights_from_distributed(recv_req)
|
226
|
+
return success, message
|
227
|
+
|
228
|
+
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
229
|
+
return self.worker.get_weights_by_name(recv_req)
|
230
|
+
|
211
231
|
def __delete__(self):
|
212
232
|
self.input_queue.put((None, None))
|
213
233
|
self.copy_queue.put((None, None, None))
|
@@ -27,6 +27,7 @@ from typing import List, Tuple, Union
|
|
27
27
|
import torch
|
28
28
|
|
29
29
|
from sglang.srt.layers.radix_attention import RadixAttention
|
30
|
+
from sglang.srt.utils import get_compiler_backend
|
30
31
|
|
31
32
|
logger = logging.getLogger(__name__)
|
32
33
|
|
@@ -129,6 +130,9 @@ class BaseTokenToKVPool:
|
|
129
130
|
return select_index.to(self.device, non_blocking=True)
|
130
131
|
|
131
132
|
def free(self, free_index: torch.Tensor):
|
133
|
+
if free_index.numel() == 0:
|
134
|
+
return
|
135
|
+
|
132
136
|
if self.is_not_in_free_group:
|
133
137
|
self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
|
134
138
|
else:
|
@@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
234
238
|
|
235
239
|
# This compiled version is slower in the unit test
|
236
240
|
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
237
|
-
@torch.compile(dynamic=True)
|
241
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
238
242
|
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
239
243
|
dst_1[loc] = src_1.to(dtype).view(store_dtype)
|
240
244
|
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
|
|
36
36
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
37
37
|
|
38
38
|
|
39
|
-
def _to_torch(model: torch.nn.Module, reverse: bool
|
39
|
+
def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
40
40
|
for sub in model._modules.values():
|
41
41
|
if isinstance(sub, CustomOp):
|
42
42
|
if reverse:
|
@@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
|
45
45
|
else:
|
46
46
|
# NOTE: Temporarily workaround MoE
|
47
47
|
if "FusedMoE" in sub.__class__.__name__:
|
48
|
-
|
48
|
+
if batch_size == 1:
|
49
|
+
# The performance of torch.compile on this layer is not always good when bs > 1,
|
50
|
+
# so we decide to only use torch.compile when bs =1
|
51
|
+
sub._forward_method = fused_moe_forward_native
|
49
52
|
else:
|
50
53
|
sub._forward_method = sub.forward_native
|
51
54
|
setattr(sub, "is_torch_compile", True)
|
52
55
|
if isinstance(sub, torch.nn.Module):
|
53
|
-
_to_torch(sub, reverse)
|
56
|
+
_to_torch(sub, reverse, batch_size)
|
54
57
|
|
55
58
|
|
56
59
|
@contextmanager
|
57
60
|
def patch_model(
|
58
|
-
model: torch.nn.Module,
|
61
|
+
model: torch.nn.Module,
|
62
|
+
enable_compile: bool,
|
63
|
+
batch_size: int,
|
64
|
+
tp_group: "GroupCoordinator",
|
59
65
|
):
|
60
66
|
"""Patch the model to make it compatible with with torch.compile"""
|
61
67
|
backup_ca_comm = None
|
62
68
|
|
63
69
|
try:
|
64
70
|
if enable_compile:
|
65
|
-
_to_torch(model)
|
71
|
+
_to_torch(model, reverse=False, batch_size=batch_size)
|
66
72
|
monkey_patch_vllm_all_gather()
|
67
73
|
backup_ca_comm = tp_group.ca_comm
|
68
74
|
# Use custom-allreduce here.
|
@@ -70,13 +76,15 @@ def patch_model(
|
|
70
76
|
# even with ENABLE_INTRA_NODE_COMM=1.
|
71
77
|
# tp_group.ca_comm = None
|
72
78
|
yield torch.compile(
|
73
|
-
torch.no_grad()(model.forward),
|
79
|
+
torch.no_grad()(model.forward),
|
80
|
+
mode="max-autotune-no-cudagraphs",
|
81
|
+
dynamic=False,
|
74
82
|
)
|
75
83
|
else:
|
76
84
|
yield model.forward
|
77
85
|
finally:
|
78
86
|
if enable_compile:
|
79
|
-
_to_torch(model, reverse=True)
|
87
|
+
_to_torch(model, reverse=True, batch_size=batch_size)
|
80
88
|
monkey_patch_vllm_all_gather(reverse=True)
|
81
89
|
tp_group.ca_comm = backup_ca_comm
|
82
90
|
|
@@ -122,6 +130,20 @@ class CudaGraphRunner:
|
|
122
130
|
self.capture_bs = list(range(1, 32)) + [64, 128]
|
123
131
|
else:
|
124
132
|
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
133
|
+
|
134
|
+
if max(self.capture_bs) > model_runner.req_to_token_pool.size:
|
135
|
+
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
136
|
+
# is very samll. We add more values here to make sure we capture the maximum bs.
|
137
|
+
self.capture_bs = list(
|
138
|
+
sorted(
|
139
|
+
set(
|
140
|
+
self.capture_bs
|
141
|
+
+ [model_runner.req_to_token_pool.size - 1]
|
142
|
+
+ [model_runner.req_to_token_pool.size]
|
143
|
+
)
|
144
|
+
)
|
145
|
+
)
|
146
|
+
|
125
147
|
self.capture_bs = [
|
126
148
|
bs
|
127
149
|
for bs in self.capture_bs
|
@@ -237,6 +259,7 @@ class CudaGraphRunner:
|
|
237
259
|
with patch_model(
|
238
260
|
self.model_runner.model,
|
239
261
|
bs in self.compile_bs,
|
262
|
+
bs,
|
240
263
|
self.model_runner.tp_group,
|
241
264
|
) as forward:
|
242
265
|
(
|
@@ -256,10 +256,15 @@ class ForwardBatch:
|
|
256
256
|
ret.extend_prefix_lens = torch.tensor(
|
257
257
|
batch.extend_prefix_lens, dtype=torch.int32
|
258
258
|
).to(device, non_blocking=True)
|
259
|
-
|
260
|
-
|
261
|
-
ret.
|
262
|
-
|
259
|
+
if model_runner.server_args.attention_backend != "torch_native":
|
260
|
+
ret.extend_num_tokens = batch.extend_num_tokens
|
261
|
+
ret.positions, ret.extend_start_loc = compute_position_triton(
|
262
|
+
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
|
263
|
+
)
|
264
|
+
else:
|
265
|
+
ret.positions, ret.extend_start_loc = compute_position_torch(
|
266
|
+
ret.extend_prefix_lens, ret.extend_seq_lens
|
267
|
+
)
|
263
268
|
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
264
269
|
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
265
270
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|