sglang 0.4.1.post7__py3-none-any.whl → 0.4.2.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +17 -11
- sglang/bench_one_batch.py +14 -6
- sglang/bench_serving.py +47 -44
- sglang/lang/chat_template.py +31 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +5 -2
- sglang/srt/entrypoints/engine.py +5 -2
- sglang/srt/entrypoints/http_server.py +24 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
- sglang/srt/layers/attention/vision.py +243 -40
- sglang/srt/layers/dp_attention.py +3 -1
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +24 -9
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +20 -12
- sglang/srt/layers/moe/fused_moe_native.py +17 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +9 -0
- sglang/srt/layers/parameter.py +16 -7
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +11 -1
- sglang/srt/layers/rotary_embedding.py +34 -13
- sglang/srt/layers/sampler.py +33 -10
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/image_processor.py +77 -38
- sglang/srt/managers/io_struct.py +36 -5
- sglang/srt/managers/schedule_batch.py +31 -25
- sglang/srt/managers/scheduler.py +78 -38
- sglang/srt/managers/tokenizer_manager.py +4 -0
- sglang/srt/mem_cache/base_prefix_cache.py +4 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +30 -1
- sglang/srt/model_executor/cuda_graph_runner.py +23 -25
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +7 -4
- sglang/srt/model_loader/loader.py +75 -0
- sglang/srt/model_loader/weight_utils.py +91 -5
- sglang/srt/models/commandr.py +14 -2
- sglang/srt/models/dbrx.py +9 -1
- sglang/srt/models/deepseek_v2.py +3 -3
- sglang/srt/models/gemma2.py +9 -1
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/minicpm3.py +3 -3
- sglang/srt/models/minicpmv.py +129 -76
- sglang/srt/models/mllama.py +16 -56
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_vl.py +18 -8
- sglang/srt/models/torch_native_llama.py +17 -4
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +5 -4
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +4 -14
- sglang/srt/server.py +2 -2
- sglang/srt/server_args.py +26 -1
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +62 -67
- sglang/test/test_programs.py +1 -0
- sglang/test/test_utils.py +81 -22
- sglang/utils.py +42 -0
- sglang/version.py +1 -1
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/METADATA +8 -8
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/RECORD +78 -67
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -149,6 +149,7 @@ class Scheduler:
|
|
149
149
|
if not self.spec_algorithm.is_none()
|
150
150
|
else 1
|
151
151
|
)
|
152
|
+
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
152
153
|
|
153
154
|
# Distributed rank info
|
154
155
|
self.dp_size = server_args.dp_size
|
@@ -281,6 +282,7 @@ class Scheduler:
|
|
281
282
|
# Print debug info
|
282
283
|
logger.info(
|
283
284
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
285
|
+
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
|
284
286
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
285
287
|
f"max_running_requests={self.max_running_requests}, "
|
286
288
|
f"context_len={self.model_config.context_len}"
|
@@ -408,6 +410,11 @@ class Scheduler:
|
|
408
410
|
},
|
409
411
|
)
|
410
412
|
|
413
|
+
# The largest prefill length of a single request
|
414
|
+
self._largest_prefill_len: int = 0
|
415
|
+
# The largest context length (prefill + generation) of a single request
|
416
|
+
self._largest_prefill_decode_len: int = 0
|
417
|
+
|
411
418
|
# Init request dispatcher
|
412
419
|
self._request_dispatcher = TypeBasedDispatcher(
|
413
420
|
[
|
@@ -480,7 +487,7 @@ class Scheduler:
|
|
480
487
|
@torch.no_grad()
|
481
488
|
def event_loop_overlap(self):
|
482
489
|
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
483
|
-
result_queue = deque()
|
490
|
+
self.result_queue = deque()
|
484
491
|
|
485
492
|
while True:
|
486
493
|
recv_reqs = self.recv_requests()
|
@@ -491,7 +498,7 @@ class Scheduler:
|
|
491
498
|
|
492
499
|
if batch:
|
493
500
|
result = self.run_batch(batch)
|
494
|
-
result_queue.append((batch.copy(), result))
|
501
|
+
self.result_queue.append((batch.copy(), result))
|
495
502
|
|
496
503
|
if self.last_batch is None:
|
497
504
|
# Create a dummy first batch to start the pipeline for overlap schedule.
|
@@ -505,7 +512,7 @@ class Scheduler:
|
|
505
512
|
|
506
513
|
if self.last_batch:
|
507
514
|
# Process the results of the last batch
|
508
|
-
tmp_batch, tmp_result = result_queue.popleft()
|
515
|
+
tmp_batch, tmp_result = self.result_queue.popleft()
|
509
516
|
tmp_batch.next_batch_sampling_info = (
|
510
517
|
self.tp_worker.cur_sampling_info if batch else None
|
511
518
|
)
|
@@ -636,7 +643,7 @@ class Scheduler:
|
|
636
643
|
self.waiting_queue.append(req)
|
637
644
|
return
|
638
645
|
|
639
|
-
# Handle
|
646
|
+
# Handle multimodal inputs
|
640
647
|
if recv_req.image_inputs is not None:
|
641
648
|
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
|
642
649
|
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
@@ -660,24 +667,23 @@ class Scheduler:
|
|
660
667
|
self.waiting_queue.append(req)
|
661
668
|
return
|
662
669
|
|
663
|
-
# Copy more attributes
|
664
|
-
req.logprob_start_len = recv_req.logprob_start_len
|
665
|
-
|
666
|
-
if req.logprob_start_len == -1:
|
667
|
-
# By default, only return the logprobs for output tokens
|
668
|
-
req.logprob_start_len = len(req.origin_input_ids) - 1
|
669
|
-
|
670
670
|
# Validate prompts length
|
671
671
|
error_msg = validate_input_length(
|
672
672
|
req,
|
673
673
|
self.max_req_input_len,
|
674
674
|
self.server_args.allow_auto_truncate,
|
675
675
|
)
|
676
|
-
|
677
676
|
if error_msg:
|
678
677
|
self.waiting_queue.append(req)
|
679
678
|
return
|
680
679
|
|
680
|
+
# Copy more attributes
|
681
|
+
if recv_req.logprob_start_len == -1:
|
682
|
+
# By default, only return the logprobs for output tokens
|
683
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
684
|
+
else:
|
685
|
+
req.logprob_start_len = recv_req.logprob_start_len
|
686
|
+
|
681
687
|
req.sampling_params.max_new_tokens = min(
|
682
688
|
(
|
683
689
|
req.sampling_params.max_new_tokens
|
@@ -725,15 +731,26 @@ class Scheduler:
|
|
725
731
|
req.tokenizer = self.tokenizer
|
726
732
|
|
727
733
|
# Validate prompts length
|
728
|
-
validate_input_length(
|
734
|
+
error_msg = validate_input_length(
|
729
735
|
req,
|
730
736
|
self.max_req_input_len,
|
731
737
|
self.server_args.allow_auto_truncate,
|
732
738
|
)
|
739
|
+
if error_msg:
|
740
|
+
self.waiting_queue.append(req)
|
741
|
+
return
|
733
742
|
|
743
|
+
# Copy more attributes
|
744
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
734
745
|
self.waiting_queue.append(req)
|
735
746
|
|
736
|
-
def log_prefill_stats(
|
747
|
+
def log_prefill_stats(
|
748
|
+
self,
|
749
|
+
adder: PrefillAdder,
|
750
|
+
can_run_list: List[Req],
|
751
|
+
running_bs: ScheduleBatch,
|
752
|
+
has_being_chunked: bool,
|
753
|
+
):
|
737
754
|
self.tree_cache_metrics["total"] += (
|
738
755
|
adder.log_input_tokens + adder.log_hit_tokens
|
739
756
|
) / 10**9
|
@@ -815,10 +832,16 @@ class Scheduler:
|
|
815
832
|
available_size = (
|
816
833
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
817
834
|
)
|
818
|
-
|
835
|
+
protected_size = self.tree_cache.protected_size()
|
836
|
+
memory_leak = available_size != (
|
837
|
+
self.max_total_num_tokens
|
838
|
+
if not self.enable_hierarchical_cache
|
839
|
+
else self.max_total_num_tokens - protected_size
|
840
|
+
)
|
841
|
+
if memory_leak:
|
819
842
|
msg = (
|
820
843
|
"KV cache pool leak detected!"
|
821
|
-
f"{available_size=}, {self.max_total_num_tokens=}\n"
|
844
|
+
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
822
845
|
)
|
823
846
|
warnings.warn(msg)
|
824
847
|
if crash_on_warnings():
|
@@ -933,7 +956,14 @@ class Scheduler:
|
|
933
956
|
res = adder.add_one_req(req)
|
934
957
|
if res != AddReqResult.CONTINUE:
|
935
958
|
if res == AddReqResult.NO_TOKEN:
|
936
|
-
self.
|
959
|
+
if self.enable_hierarchical_cache:
|
960
|
+
# Set batch_is_full after making sure there are requests that can be served
|
961
|
+
self.batch_is_full = len(adder.can_run_list) > 0 or (
|
962
|
+
self.running_batch is not None
|
963
|
+
and not self.running_batch.is_empty()
|
964
|
+
)
|
965
|
+
else:
|
966
|
+
self.batch_is_full = True
|
937
967
|
break
|
938
968
|
if self.server_args.prefill_only_one_req:
|
939
969
|
break
|
@@ -1023,7 +1053,7 @@ class Scheduler:
|
|
1023
1053
|
)
|
1024
1054
|
|
1025
1055
|
# Check for jump-forward
|
1026
|
-
if not self.disable_jump_forward:
|
1056
|
+
if not self.disable_jump_forward and batch.has_grammar:
|
1027
1057
|
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
1028
1058
|
self.waiting_queue.extend(jump_forward_reqs)
|
1029
1059
|
if batch.is_empty():
|
@@ -1044,26 +1074,23 @@ class Scheduler:
|
|
1044
1074
|
self.forward_ct += 1
|
1045
1075
|
|
1046
1076
|
if self.is_generation:
|
1047
|
-
if
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
)
|
1053
|
-
else:
|
1054
|
-
(
|
1055
|
-
logits_output,
|
1056
|
-
next_token_ids,
|
1057
|
-
model_worker_batch,
|
1058
|
-
num_accepted_tokens,
|
1059
|
-
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
1060
|
-
self.spec_num_total_accepted_tokens += (
|
1061
|
-
num_accepted_tokens + batch.batch_size()
|
1062
|
-
)
|
1063
|
-
self.spec_num_total_forward_ct += batch.batch_size()
|
1064
|
-
self.num_generated_tokens += num_accepted_tokens
|
1077
|
+
if self.spec_algorithm.is_none():
|
1078
|
+
model_worker_batch = batch.get_model_worker_batch()
|
1079
|
+
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
1080
|
+
model_worker_batch
|
1081
|
+
)
|
1065
1082
|
else:
|
1066
|
-
|
1083
|
+
(
|
1084
|
+
logits_output,
|
1085
|
+
next_token_ids,
|
1086
|
+
model_worker_batch,
|
1087
|
+
num_accepted_tokens,
|
1088
|
+
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
1089
|
+
self.spec_num_total_accepted_tokens += (
|
1090
|
+
num_accepted_tokens + batch.batch_size()
|
1091
|
+
)
|
1092
|
+
self.spec_num_total_forward_ct += batch.batch_size()
|
1093
|
+
self.num_generated_tokens += num_accepted_tokens
|
1067
1094
|
batch.output_ids = next_token_ids
|
1068
1095
|
|
1069
1096
|
ret = GenerationBatchResult(
|
@@ -1072,7 +1099,6 @@ class Scheduler:
|
|
1072
1099
|
bid=model_worker_batch.bid,
|
1073
1100
|
)
|
1074
1101
|
else: # embedding or reward model
|
1075
|
-
assert batch.extend_num_tokens != 0
|
1076
1102
|
model_worker_batch = batch.get_model_worker_batch()
|
1077
1103
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
1078
1104
|
ret = EmbeddingBatchResult(
|
@@ -1371,6 +1397,7 @@ class Scheduler:
|
|
1371
1397
|
prompt_tokens = []
|
1372
1398
|
completion_tokens = []
|
1373
1399
|
cached_tokens = []
|
1400
|
+
spec_verify_ct = []
|
1374
1401
|
|
1375
1402
|
if return_logprob:
|
1376
1403
|
input_token_logprobs_val = []
|
@@ -1424,6 +1451,9 @@ class Scheduler:
|
|
1424
1451
|
completion_tokens.append(len(req.output_ids))
|
1425
1452
|
cached_tokens.append(req.cached_tokens)
|
1426
1453
|
|
1454
|
+
if not self.spec_algorithm.is_none():
|
1455
|
+
spec_verify_ct.append(req.spec_verify_ct)
|
1456
|
+
|
1427
1457
|
if return_logprob:
|
1428
1458
|
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
1429
1459
|
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
@@ -1451,6 +1481,7 @@ class Scheduler:
|
|
1451
1481
|
prompt_tokens,
|
1452
1482
|
completion_tokens,
|
1453
1483
|
cached_tokens,
|
1484
|
+
spec_verify_ct,
|
1454
1485
|
input_token_logprobs_val,
|
1455
1486
|
input_token_logprobs_idx,
|
1456
1487
|
output_token_logprobs_val,
|
@@ -1564,6 +1595,15 @@ class Scheduler:
|
|
1564
1595
|
self.grammar_backend.reset()
|
1565
1596
|
self.req_to_token_pool.clear()
|
1566
1597
|
self.token_to_kv_pool.clear()
|
1598
|
+
|
1599
|
+
if not self.spec_algorithm.is_none():
|
1600
|
+
self.draft_worker.model_runner.req_to_token_pool.clear()
|
1601
|
+
self.draft_worker.model_runner.token_to_kv_pool.clear()
|
1602
|
+
|
1603
|
+
self.num_generated_tokens = 0
|
1604
|
+
self.forward_ct_decode = 0
|
1605
|
+
self.spec_num_total_accepted_tokens = 0
|
1606
|
+
self.spec_num_total_forward_ct = 0
|
1567
1607
|
torch.cuda.empty_cache()
|
1568
1608
|
logger.info("Cache flushed successfully!")
|
1569
1609
|
if_success = True
|
@@ -785,6 +785,9 @@ class TokenizerManager:
|
|
785
785
|
i,
|
786
786
|
)
|
787
787
|
|
788
|
+
if self.server_args.speculative_algorithm:
|
789
|
+
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
790
|
+
|
788
791
|
if not isinstance(recv_obj, BatchEmbeddingOut):
|
789
792
|
meta_info.update(
|
790
793
|
{
|
@@ -809,6 +812,7 @@ class TokenizerManager:
|
|
809
812
|
"embedding": recv_obj.embeddings[i],
|
810
813
|
"meta_info": meta_info,
|
811
814
|
}
|
815
|
+
|
812
816
|
state.out_list.append(out_dict)
|
813
817
|
state.finished = recv_obj.finished_reasons[i] is not None
|
814
818
|
state.event.set()
|
@@ -34,7 +34,10 @@ if TYPE_CHECKING:
|
|
34
34
|
|
35
35
|
|
36
36
|
class TreeNode:
|
37
|
-
|
37
|
+
|
38
|
+
counter = 0
|
39
|
+
|
40
|
+
def __init__(self, id: Optional[int] = None):
|
38
41
|
self.children = defaultdict(TreeNode)
|
39
42
|
self.parent = None
|
40
43
|
self.key = None
|
@@ -42,6 +45,23 @@ class TreeNode:
|
|
42
45
|
self.lock_ref = 0
|
43
46
|
self.last_access_time = time.time()
|
44
47
|
|
48
|
+
self.hit_count = 0
|
49
|
+
# indicating the node is loading KV cache from host
|
50
|
+
self.loading = False
|
51
|
+
# store the host indices of KV cache
|
52
|
+
self.host_value = None
|
53
|
+
|
54
|
+
self.id = TreeNode.counter if id is None else id
|
55
|
+
TreeNode.counter += 1
|
56
|
+
|
57
|
+
@property
|
58
|
+
def evicted(self):
|
59
|
+
return self.value is None
|
60
|
+
|
61
|
+
@property
|
62
|
+
def backuped(self):
|
63
|
+
return self.host_value is not None
|
64
|
+
|
45
65
|
def __lt__(self, other: "TreeNode"):
|
46
66
|
return self.last_access_time < other.last_access_time
|
47
67
|
|
@@ -75,6 +95,7 @@ class RadixCache(BasePrefixCache):
|
|
75
95
|
self.root_node.value = []
|
76
96
|
self.root_node.lock_ref = 1
|
77
97
|
self.evictable_size_ = 0
|
98
|
+
self.protected_size_ = 0
|
78
99
|
|
79
100
|
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
|
80
101
|
"""Find the matching prefix from the radix tree.
|
@@ -203,6 +224,7 @@ class RadixCache(BasePrefixCache):
|
|
203
224
|
while node != self.root_node:
|
204
225
|
if node.lock_ref == 0:
|
205
226
|
self.evictable_size_ -= len(node.value)
|
227
|
+
self.protected_size_ += len(node.value)
|
206
228
|
delta -= len(node.value)
|
207
229
|
node.lock_ref += 1
|
208
230
|
node = node.parent
|
@@ -216,6 +238,7 @@ class RadixCache(BasePrefixCache):
|
|
216
238
|
while node != self.root_node:
|
217
239
|
if node.lock_ref == 1:
|
218
240
|
self.evictable_size_ += len(node.value)
|
241
|
+
self.protected_size_ -= len(node.value)
|
219
242
|
delta += len(node.value)
|
220
243
|
node.lock_ref -= 1
|
221
244
|
node = node.parent
|
@@ -224,6 +247,10 @@ class RadixCache(BasePrefixCache):
|
|
224
247
|
def evictable_size(self):
|
225
248
|
return self.evictable_size_
|
226
249
|
|
250
|
+
def protected_size(self):
|
251
|
+
# protected size refers to the size of the cache that is locked
|
252
|
+
return self.protected_size_
|
253
|
+
|
227
254
|
##### Internal Helper Functions #####
|
228
255
|
|
229
256
|
def _match_prefix_helper(
|
@@ -303,6 +330,8 @@ class RadixCache(BasePrefixCache):
|
|
303
330
|
self.evictable_size_ -= len(node.key)
|
304
331
|
|
305
332
|
def _total_size_helper(self, node: TreeNode):
|
333
|
+
if node.evicted:
|
334
|
+
return 0
|
306
335
|
x = len(node.value)
|
307
336
|
for child in node.children.values():
|
308
337
|
x += self._total_size_helper(child)
|
@@ -24,7 +24,7 @@ import tqdm
|
|
24
24
|
from vllm.model_executor.custom_op import CustomOp
|
25
25
|
|
26
26
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
27
|
-
from sglang.srt.distributed.parallel_state import graph_capture
|
27
|
+
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
28
28
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
29
29
|
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
30
30
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
|
|
38
38
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
39
39
|
|
40
40
|
|
41
|
-
def _to_torch(model: torch.nn.Module, reverse: bool,
|
41
|
+
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
42
42
|
for sub in model._modules.values():
|
43
43
|
if isinstance(sub, CustomOp):
|
44
44
|
if reverse:
|
@@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
|
47
47
|
else:
|
48
48
|
# NOTE: Temporarily workaround MoE
|
49
49
|
if "FusedMoE" in sub.__class__.__name__:
|
50
|
-
if
|
50
|
+
if num_tokens == 1:
|
51
51
|
# The performance of torch.compile on this layer is not always good when bs > 1,
|
52
52
|
# so we decide to only use torch.compile when bs =1
|
53
53
|
sub._forward_method = fused_moe_forward_native
|
@@ -55,22 +55,22 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
|
55
55
|
sub._forward_method = sub.forward_native
|
56
56
|
setattr(sub, "is_torch_compile", True)
|
57
57
|
if isinstance(sub, torch.nn.Module):
|
58
|
-
_to_torch(sub, reverse,
|
58
|
+
_to_torch(sub, reverse, num_tokens)
|
59
59
|
|
60
60
|
|
61
61
|
@contextmanager
|
62
62
|
def patch_model(
|
63
63
|
model: torch.nn.Module,
|
64
64
|
enable_compile: bool,
|
65
|
-
|
66
|
-
tp_group:
|
65
|
+
num_tokens: int,
|
66
|
+
tp_group: GroupCoordinator,
|
67
67
|
):
|
68
68
|
"""Patch the model to make it compatible with with torch.compile"""
|
69
69
|
backup_ca_comm = None
|
70
70
|
|
71
71
|
try:
|
72
72
|
if enable_compile:
|
73
|
-
_to_torch(model, reverse=False,
|
73
|
+
_to_torch(model, reverse=False, num_tokens=num_tokens)
|
74
74
|
backup_ca_comm = tp_group.ca_comm
|
75
75
|
# Use custom-allreduce here.
|
76
76
|
# We found the custom allreduce is much faster than the built-in allreduce in torch,
|
@@ -85,7 +85,7 @@ def patch_model(
|
|
85
85
|
yield model.forward
|
86
86
|
finally:
|
87
87
|
if enable_compile:
|
88
|
-
_to_torch(model, reverse=True,
|
88
|
+
_to_torch(model, reverse=True, num_tokens=num_tokens)
|
89
89
|
tp_group.ca_comm = backup_ca_comm
|
90
90
|
|
91
91
|
|
@@ -149,9 +149,18 @@ class CudaGraphRunner:
|
|
149
149
|
and bs <= model_runner.server_args.cuda_graph_max_bs
|
150
150
|
]
|
151
151
|
|
152
|
+
self.compile_bs = (
|
153
|
+
[
|
154
|
+
bs
|
155
|
+
for bs in self.capture_bs
|
156
|
+
if bs <= self.model_runner.server_args.torch_compile_max_bs
|
157
|
+
]
|
158
|
+
if self.use_torch_compile
|
159
|
+
else []
|
160
|
+
)
|
161
|
+
|
152
162
|
self.capture_forward_mode = ForwardMode.DECODE
|
153
163
|
self.num_tokens_per_bs = 1
|
154
|
-
|
155
164
|
if model_runner.spec_algorithm.is_eagle():
|
156
165
|
if self.model_runner.is_draft_worker:
|
157
166
|
self.num_tokens_per_bs = (
|
@@ -163,16 +172,6 @@ class CudaGraphRunner:
|
|
163
172
|
self.model_runner.server_args.speculative_num_draft_tokens
|
164
173
|
)
|
165
174
|
|
166
|
-
self.compile_bs = (
|
167
|
-
[
|
168
|
-
bs
|
169
|
-
for bs in self.capture_bs
|
170
|
-
if bs <= self.model_runner.server_args.torch_compile_max_bs
|
171
|
-
]
|
172
|
-
if self.use_torch_compile
|
173
|
-
else []
|
174
|
-
)
|
175
|
-
|
176
175
|
# Attention backend
|
177
176
|
self.max_bs = max(self.capture_bs)
|
178
177
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
@@ -180,7 +179,6 @@ class CudaGraphRunner:
|
|
180
179
|
self.seq_len_fill_value = (
|
181
180
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
182
181
|
)
|
183
|
-
|
184
182
|
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
185
183
|
self.encoder_len_fill_value = 0
|
186
184
|
|
@@ -189,14 +187,14 @@ class CudaGraphRunner:
|
|
189
187
|
|
190
188
|
# Common inputs
|
191
189
|
with torch.device("cuda"):
|
192
|
-
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.
|
190
|
+
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
193
191
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
194
192
|
self.seq_lens = torch.full(
|
195
193
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
196
194
|
)
|
197
|
-
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.
|
195
|
+
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
198
196
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
199
|
-
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.
|
197
|
+
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
200
198
|
|
201
199
|
# Speculative_inference
|
202
200
|
if model_runner.spec_algorithm.is_eagle():
|
@@ -285,8 +283,8 @@ class CudaGraphRunner:
|
|
285
283
|
with patch_model(
|
286
284
|
self.model_runner.model,
|
287
285
|
bs in self.compile_bs,
|
288
|
-
bs,
|
289
|
-
self.model_runner.tp_group,
|
286
|
+
num_tokens=bs * self.num_tokens_per_bs,
|
287
|
+
tp_group=self.model_runner.tp_group,
|
290
288
|
) as forward:
|
291
289
|
(
|
292
290
|
graph,
|
@@ -38,7 +38,7 @@ import triton
|
|
38
38
|
import triton.language as tl
|
39
39
|
|
40
40
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
41
|
-
from sglang.srt.utils import
|
41
|
+
from sglang.srt.utils import get_compiler_backend
|
42
42
|
|
43
43
|
if TYPE_CHECKING:
|
44
44
|
from sglang.srt.layers.attention import AttentionBackend
|
@@ -282,6 +282,9 @@ class ForwardBatch:
|
|
282
282
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
283
283
|
lora_paths=batch.lora_paths,
|
284
284
|
sampling_info=batch.sampling_info,
|
285
|
+
req_to_token_pool=model_runner.req_to_token_pool,
|
286
|
+
token_to_kv_pool=model_runner.token_to_kv_pool,
|
287
|
+
attn_backend=model_runner.attn_backend,
|
285
288
|
spec_algorithm=batch.spec_algorithm,
|
286
289
|
spec_info=batch.spec_info,
|
287
290
|
capture_hidden_mode=batch.capture_hidden_mode,
|
@@ -336,11 +339,6 @@ class ForwardBatch:
|
|
336
339
|
if model_runner.model_is_mrope:
|
337
340
|
ret.compute_mrope_positions(model_runner, batch)
|
338
341
|
|
339
|
-
# Init attention information
|
340
|
-
ret.req_to_token_pool = model_runner.req_to_token_pool
|
341
|
-
ret.token_to_kv_pool = model_runner.token_to_kv_pool
|
342
|
-
ret.attn_backend = model_runner.attn_backend
|
343
|
-
|
344
342
|
# Init lora information
|
345
343
|
if model_runner.server_args.lora_paths is not None:
|
346
344
|
model_runner.lora_manager.prepare_lora_batch(ret)
|
@@ -417,6 +415,6 @@ def compute_position_torch(
|
|
417
415
|
return positions.to(torch.int64), extend_start_loc
|
418
416
|
|
419
417
|
|
420
|
-
@
|
418
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
421
419
|
def clamp_position(seq_lens):
|
422
420
|
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
@@ -185,9 +185,12 @@ class ModelRunner:
|
|
185
185
|
self.load_model()
|
186
186
|
|
187
187
|
# Apply torchao quantization
|
188
|
-
|
189
|
-
|
190
|
-
|
188
|
+
torchao_applied = getattr(self.model, "torchao_applied", False)
|
189
|
+
# In layered loading, torchao may have been applied
|
190
|
+
if not torchao_applied:
|
191
|
+
apply_torchao_config_to_model(
|
192
|
+
self.model, global_server_args_dict["torchao_config"]
|
193
|
+
)
|
191
194
|
|
192
195
|
# Apply torch TP if the model supports it
|
193
196
|
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
@@ -215,7 +218,7 @@ class ModelRunner:
|
|
215
218
|
|
216
219
|
def init_torch_distributed(self):
|
217
220
|
logger.info("Init torch distributed begin.")
|
218
|
-
|
221
|
+
|
219
222
|
torch.get_device_module(self.device).set_device(self.gpu_id)
|
220
223
|
if self.device == "cuda":
|
221
224
|
backend = "nccl"
|
@@ -374,6 +374,78 @@ class DefaultModelLoader(BaseModelLoader):
|
|
374
374
|
return model.eval()
|
375
375
|
|
376
376
|
|
377
|
+
class LayeredModelLoader(DefaultModelLoader):
|
378
|
+
"""Model loader that loads weights layer by layer so that one can quantize a
|
379
|
+
layer before loading another to make the peak memory envelope smaller."""
|
380
|
+
|
381
|
+
def __init__(self, load_config: LoadConfig):
|
382
|
+
# Back to the default load format
|
383
|
+
load_config.load_format = LoadFormat.AUTO
|
384
|
+
super().__init__(load_config)
|
385
|
+
|
386
|
+
def load_model(
|
387
|
+
self,
|
388
|
+
*,
|
389
|
+
model_config: ModelConfig,
|
390
|
+
device_config: DeviceConfig,
|
391
|
+
) -> nn.Module:
|
392
|
+
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
393
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
394
|
+
|
395
|
+
torchao_config = global_server_args_dict.get("torchao_config")
|
396
|
+
target_device = torch.device(device_config.device)
|
397
|
+
|
398
|
+
with set_default_torch_dtype(model_config.dtype):
|
399
|
+
# Create model on meta device
|
400
|
+
with torch.device("meta"):
|
401
|
+
model = _initialize_model(
|
402
|
+
model_config,
|
403
|
+
self.load_config,
|
404
|
+
)
|
405
|
+
|
406
|
+
# Check model's layered load support
|
407
|
+
if not hasattr(model, "load_weights_to_module"):
|
408
|
+
raise ValueError(
|
409
|
+
"LayeredModelLoader requires the model to have a "
|
410
|
+
"`load_weights_to_module` method. "
|
411
|
+
f"{model_config.model_path} does not support it."
|
412
|
+
)
|
413
|
+
|
414
|
+
# Get all weights from disk
|
415
|
+
weights = self._get_all_weights(model_config, model)
|
416
|
+
|
417
|
+
# Helper function to recursively fill the weights of a module
|
418
|
+
def fill_module(module, fqn: List[str], weights):
|
419
|
+
"""
|
420
|
+
fqn: list of strings representing the fully qualified name of `module`.
|
421
|
+
"""
|
422
|
+
# Layer by layer
|
423
|
+
for name, submod in module.named_children():
|
424
|
+
fill_module(submod, fqn + [name], weights)
|
425
|
+
|
426
|
+
# First materialize on target device
|
427
|
+
module.to_empty(device=target_device, recurse=False)
|
428
|
+
fqn_path = ".".join(fqn)
|
429
|
+
# Fill weights
|
430
|
+
model.load_weights_to_module(
|
431
|
+
fqn_path,
|
432
|
+
weights,
|
433
|
+
)
|
434
|
+
# Quantize weights if applicable
|
435
|
+
if torchao_config and "proj" in fqn_path:
|
436
|
+
# Note: `None` here is needed to indicate no filter, see
|
437
|
+
# `apply_torchao_config_to_model` for details.
|
438
|
+
apply_torchao_config_to_model(module, torchao_config, None)
|
439
|
+
|
440
|
+
# Start calling on root module
|
441
|
+
fill_module(model, [], weights)
|
442
|
+
|
443
|
+
if torchao_config:
|
444
|
+
model.torchao_applied = True
|
445
|
+
|
446
|
+
return model.eval()
|
447
|
+
|
448
|
+
|
377
449
|
class DummyModelLoader(BaseModelLoader):
|
378
450
|
"""Model loader that will set model weights to random values."""
|
379
451
|
|
@@ -1149,4 +1221,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
|
1149
1221
|
if load_config.load_format == LoadFormat.GGUF:
|
1150
1222
|
return GGUFModelLoader(load_config)
|
1151
1223
|
|
1224
|
+
if load_config.load_format == LoadFormat.LAYERED:
|
1225
|
+
return LayeredModelLoader(load_config)
|
1226
|
+
|
1152
1227
|
return DefaultModelLoader(load_config)
|