sglang 0.4.0.post1__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +18 -6
- sglang/bench_one_batch.py +13 -0
- sglang/bench_serving.py +8 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/constrained/xgrammar_backend.py +4 -1
- sglang/srt/layers/attention/flashinfer_backend.py +2 -0
- sglang/srt/layers/attention/triton_backend.py +16 -25
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/ep_moe/layer.py +4 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
- sglang/srt/layers/fused_moe_triton/layer.py +1 -1
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/quantization/__init__.py +2 -47
- sglang/srt/layers/quantization/fp8.py +58 -10
- sglang/srt/layers/radix_attention.py +8 -1
- sglang/srt/layers/sampler.py +27 -5
- sglang/srt/layers/torchao_utils.py +35 -0
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +38 -24
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +169 -134
- sglang/srt/managers/tokenizer_manager.py +99 -58
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +24 -10
- sglang/srt/model_executor/model_runner.py +22 -14
- sglang/srt/model_parallel.py +66 -5
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +72 -8
- sglang/srt/models/llama.py +22 -0
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/openai_api/adapter.py +4 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/server.py +1 -1
- sglang/srt/server_args.py +19 -9
- sglang/srt/utils.py +7 -10
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +11 -6
- {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +54 -52
- {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ import signal
|
|
22
22
|
import sys
|
23
23
|
import time
|
24
24
|
import uuid
|
25
|
-
from typing import Dict, List, Optional,
|
25
|
+
from typing import Any, Dict, List, Optional, Union
|
26
26
|
|
27
27
|
import fastapi
|
28
28
|
import uvloop
|
@@ -76,6 +76,7 @@ class ReqState:
|
|
76
76
|
out_list: List
|
77
77
|
finished: bool
|
78
78
|
event: asyncio.Event
|
79
|
+
obj: Any
|
79
80
|
|
80
81
|
# For metrics
|
81
82
|
created_time: float
|
@@ -283,7 +284,7 @@ class TokenizerManager:
|
|
283
284
|
):
|
284
285
|
"""Wait for the response of one request."""
|
285
286
|
event = asyncio.Event()
|
286
|
-
state = ReqState([], False, event, created_time=created_time)
|
287
|
+
state = ReqState([], False, event, obj, created_time=created_time)
|
287
288
|
self.rid_to_state[obj.rid] = state
|
288
289
|
|
289
290
|
while True:
|
@@ -295,15 +296,7 @@ class TokenizerManager:
|
|
295
296
|
raise ValueError(f"Abort request {obj.rid}")
|
296
297
|
continue
|
297
298
|
|
298
|
-
|
299
|
-
out = self.convert_logprob_style(
|
300
|
-
state.out_list[-1],
|
301
|
-
obj.return_logprob,
|
302
|
-
obj.top_logprobs_num,
|
303
|
-
obj.return_text_in_logprobs,
|
304
|
-
)
|
305
|
-
else: # isinstance(obj, (EmbeddingReqInput,))
|
306
|
-
out = state.out_list[-1]
|
299
|
+
out = state.out_list[-1]
|
307
300
|
|
308
301
|
state.out_list = []
|
309
302
|
if state.finished:
|
@@ -315,7 +308,13 @@ class TokenizerManager:
|
|
315
308
|
break
|
316
309
|
|
317
310
|
state.event.clear()
|
318
|
-
|
311
|
+
|
312
|
+
if obj.stream:
|
313
|
+
yield out
|
314
|
+
else:
|
315
|
+
if request is not None and await request.is_disconnected():
|
316
|
+
self.abort_request(obj.rid)
|
317
|
+
raise ValueError(f"Abort request {obj.rid}")
|
319
318
|
|
320
319
|
async def _handle_batch_request(
|
321
320
|
self,
|
@@ -573,7 +572,7 @@ class TokenizerManager:
|
|
573
572
|
|
574
573
|
async def sigterm_watchdog(self):
|
575
574
|
while not self.gracefully_exit:
|
576
|
-
await asyncio.sleep(
|
575
|
+
await asyncio.sleep(5)
|
577
576
|
|
578
577
|
# drain requests
|
579
578
|
while True:
|
@@ -609,29 +608,55 @@ class TokenizerManager:
|
|
609
608
|
if state is None:
|
610
609
|
continue
|
611
610
|
|
612
|
-
|
611
|
+
meta_info = {
|
612
|
+
"id": rid,
|
613
|
+
"finish_reason": recv_obj.finished_reasons[i],
|
614
|
+
"prompt_tokens": recv_obj.prompt_tokens[i],
|
615
|
+
}
|
616
|
+
|
617
|
+
if getattr(state.obj, "return_logprob", False):
|
618
|
+
self.convert_logprob_style(
|
619
|
+
meta_info,
|
620
|
+
state.obj.top_logprobs_num,
|
621
|
+
state.obj.return_text_in_logprobs,
|
622
|
+
recv_obj,
|
623
|
+
i,
|
624
|
+
)
|
625
|
+
|
626
|
+
if not isinstance(recv_obj, BatchEmbeddingOut):
|
627
|
+
meta_info.update(
|
628
|
+
{
|
629
|
+
"completion_tokens": recv_obj.completion_tokens[i],
|
630
|
+
"cached_tokens": recv_obj.cached_tokens[i],
|
631
|
+
}
|
632
|
+
)
|
633
|
+
|
613
634
|
if isinstance(recv_obj, BatchStrOut):
|
614
635
|
out_dict = {
|
615
636
|
"text": recv_obj.output_strs[i],
|
616
|
-
"meta_info":
|
637
|
+
"meta_info": meta_info,
|
617
638
|
}
|
618
639
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
619
640
|
out_dict = {
|
620
641
|
"token_ids": recv_obj.output_ids[i],
|
621
|
-
"meta_info":
|
642
|
+
"meta_info": meta_info,
|
622
643
|
}
|
623
644
|
else:
|
624
645
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
625
646
|
out_dict = {
|
626
647
|
"embedding": recv_obj.embeddings[i],
|
627
|
-
"meta_info":
|
648
|
+
"meta_info": meta_info,
|
628
649
|
}
|
629
650
|
state.out_list.append(out_dict)
|
630
|
-
state.finished = recv_obj.
|
651
|
+
state.finished = recv_obj.finished_reasons[i] is not None
|
631
652
|
state.event.set()
|
632
653
|
|
633
654
|
if self.enable_metrics:
|
634
|
-
completion_tokens =
|
655
|
+
completion_tokens = (
|
656
|
+
recv_obj.completion_tokens[i]
|
657
|
+
if recv_obj.completion_tokens
|
658
|
+
else 0
|
659
|
+
)
|
635
660
|
|
636
661
|
if state.first_token_time is None:
|
637
662
|
state.first_token_time = time.time()
|
@@ -647,7 +672,7 @@ class TokenizerManager:
|
|
647
672
|
|
648
673
|
if state.finished:
|
649
674
|
self.metrics_collector.inc_prompt_tokens(
|
650
|
-
recv_obj.
|
675
|
+
recv_obj.prompt_tokens[i]
|
651
676
|
)
|
652
677
|
self.metrics_collector.inc_generation_tokens(
|
653
678
|
completion_tokens
|
@@ -696,57 +721,73 @@ class TokenizerManager:
|
|
696
721
|
|
697
722
|
def convert_logprob_style(
|
698
723
|
self,
|
699
|
-
|
700
|
-
return_logprob: bool,
|
724
|
+
meta_info: dict,
|
701
725
|
top_logprobs_num: int,
|
702
726
|
return_text_in_logprobs: bool,
|
727
|
+
recv_obj: BatchStrOut,
|
728
|
+
recv_obj_index: int,
|
703
729
|
):
|
704
|
-
|
705
|
-
|
706
|
-
|
730
|
+
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
731
|
+
recv_obj.input_token_logprobs_val[recv_obj_index],
|
732
|
+
recv_obj.input_token_logprobs_idx[recv_obj_index],
|
733
|
+
return_text_in_logprobs,
|
734
|
+
)
|
735
|
+
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
736
|
+
recv_obj.output_token_logprobs_val[recv_obj_index],
|
737
|
+
recv_obj.output_token_logprobs_idx[recv_obj_index],
|
738
|
+
return_text_in_logprobs,
|
739
|
+
)
|
740
|
+
meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
|
741
|
+
recv_obj_index
|
742
|
+
]
|
743
|
+
|
744
|
+
if top_logprobs_num > 0:
|
745
|
+
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
746
|
+
recv_obj.input_top_logprobs_val[recv_obj_index],
|
747
|
+
recv_obj.input_top_logprobs_idx[recv_obj_index],
|
748
|
+
return_text_in_logprobs,
|
707
749
|
)
|
708
|
-
|
709
|
-
|
750
|
+
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
751
|
+
recv_obj.output_top_logprobs_val[recv_obj_index],
|
752
|
+
recv_obj.output_top_logprobs_idx[recv_obj_index],
|
753
|
+
return_text_in_logprobs,
|
710
754
|
)
|
711
755
|
|
712
|
-
if top_logprobs_num > 0:
|
713
|
-
ret["meta_info"]["input_top_logprobs"] = (
|
714
|
-
self.detokenize_top_logprobs_tokens(
|
715
|
-
ret["meta_info"]["input_top_logprobs"],
|
716
|
-
return_text_in_logprobs,
|
717
|
-
)
|
718
|
-
)
|
719
|
-
ret["meta_info"]["output_top_logprobs"] = (
|
720
|
-
self.detokenize_top_logprobs_tokens(
|
721
|
-
ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
|
722
|
-
)
|
723
|
-
)
|
724
|
-
return ret
|
725
|
-
|
726
756
|
def detokenize_logprob_tokens(
|
727
|
-
self,
|
757
|
+
self,
|
758
|
+
token_logprobs_val: List[float],
|
759
|
+
token_logprobs_idx: List[int],
|
760
|
+
decode_to_text: bool,
|
728
761
|
):
|
729
|
-
# TODO(lianmin): This should run on DetokenizerManager
|
730
762
|
if not decode_to_text:
|
731
|
-
return [
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
]
|
763
|
+
return [
|
764
|
+
(logprob, token_id, None)
|
765
|
+
for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
|
766
|
+
]
|
767
|
+
else:
|
768
|
+
assert self.tokenizer is not None
|
769
|
+
token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
|
770
|
+
return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
|
740
771
|
|
741
|
-
def detokenize_top_logprobs_tokens(
|
772
|
+
def detokenize_top_logprobs_tokens(
|
773
|
+
self,
|
774
|
+
token_logprobs_val: List[float],
|
775
|
+
token_logprobs_idx: List[int],
|
776
|
+
decode_to_text: bool,
|
777
|
+
):
|
742
778
|
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
|
743
779
|
# We should batch all top-k tokens in all positions.
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
780
|
+
ret = []
|
781
|
+
for i in range(len(token_logprobs_val)):
|
782
|
+
if token_logprobs_val[i]:
|
783
|
+
ret.append(
|
784
|
+
self.detokenize_logprob_tokens(
|
785
|
+
token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
|
786
|
+
)
|
748
787
|
)
|
749
|
-
|
788
|
+
else:
|
789
|
+
ret.append(None)
|
790
|
+
return ret
|
750
791
|
|
751
792
|
|
752
793
|
class SignalHandler:
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from typing import Callable
|
2
|
+
from typing import Callable, List, Tuple
|
3
3
|
|
4
4
|
|
5
5
|
class BasePrefixCache(ABC):
|
@@ -10,7 +10,7 @@ class BasePrefixCache(ABC):
|
|
10
10
|
pass
|
11
11
|
|
12
12
|
@abstractmethod
|
13
|
-
def match_prefix(self, **kwargs):
|
13
|
+
def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
|
14
14
|
pass
|
15
15
|
|
16
16
|
@abstractmethod
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
4
4
|
|
5
|
-
from typing import TYPE_CHECKING, Callable, List, Optional
|
5
|
+
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
|
6
6
|
|
7
7
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
8
8
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
@@ -30,7 +30,7 @@ class ChunkCache(BasePrefixCache):
|
|
30
30
|
def reset(self):
|
31
31
|
self.entries = {}
|
32
32
|
|
33
|
-
def match_prefix(self, rid: int, key: List[int]):
|
33
|
+
def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
|
34
34
|
if rid not in self.entries:
|
35
35
|
return [], None
|
36
36
|
|
@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
|
|
22
22
|
import heapq
|
23
23
|
import time
|
24
24
|
from collections import defaultdict
|
25
|
-
from typing import TYPE_CHECKING, Callable, List, Optional
|
25
|
+
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
|
26
26
|
|
27
27
|
import torch
|
28
28
|
|
@@ -76,7 +76,17 @@ class RadixCache(BasePrefixCache):
|
|
76
76
|
self.root_node.lock_ref = 1
|
77
77
|
self.evictable_size_ = 0
|
78
78
|
|
79
|
-
def match_prefix(self, key: List, **kwargs):
|
79
|
+
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
|
80
|
+
"""Find the matching prefix from the radix tree.
|
81
|
+
Args:
|
82
|
+
key: A list of token IDs to find a matching prefix.
|
83
|
+
Returns:
|
84
|
+
A tuple of a tensor of matching prefix token IDs and
|
85
|
+
the last node that contains the prefix values. Note that
|
86
|
+
this API can modify the internal state of the Radix tree.
|
87
|
+
The last node create a new child if the prefix is shorter
|
88
|
+
than the last node's value.
|
89
|
+
"""
|
80
90
|
if self.disable:
|
81
91
|
return [], self.root_node
|
82
92
|
|
@@ -20,6 +20,8 @@ from contextlib import contextmanager
|
|
20
20
|
from typing import TYPE_CHECKING, Callable
|
21
21
|
|
22
22
|
import torch
|
23
|
+
import tqdm
|
24
|
+
from vllm.distributed import get_tensor_model_parallel_rank
|
23
25
|
from vllm.distributed.parallel_state import graph_capture
|
24
26
|
from vllm.model_executor.custom_op import CustomOp
|
25
27
|
|
@@ -127,7 +129,7 @@ class CudaGraphRunner:
|
|
127
129
|
|
128
130
|
# Batch sizes to capture
|
129
131
|
if model_runner.server_args.disable_cuda_graph_padding:
|
130
|
-
self.capture_bs = list(range(1,
|
132
|
+
self.capture_bs = list(range(1, 33)) + [64, 128]
|
131
133
|
else:
|
132
134
|
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
133
135
|
|
@@ -255,7 +257,12 @@ class CudaGraphRunner:
|
|
255
257
|
def capture(self):
|
256
258
|
with graph_capture() as graph_capture_context:
|
257
259
|
self.stream = graph_capture_context.stream
|
258
|
-
|
260
|
+
capture_bs = (
|
261
|
+
tqdm.tqdm(self.capture_bs)
|
262
|
+
if get_tensor_model_parallel_rank() == 0
|
263
|
+
else self.capture_bs
|
264
|
+
)
|
265
|
+
for bs in capture_bs:
|
259
266
|
with patch_model(
|
260
267
|
self.model_runner.model,
|
261
268
|
bs in self.compile_bs,
|
@@ -387,8 +394,14 @@ class CudaGraphRunner:
|
|
387
394
|
|
388
395
|
# Extract logprobs
|
389
396
|
if forward_batch.return_logprob:
|
390
|
-
|
391
|
-
|
397
|
+
logits_metadata = LogitsMetadata(
|
398
|
+
forward_mode=ForwardMode.DECODE,
|
399
|
+
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
400
|
+
)
|
401
|
+
next_token_logprobs = (
|
402
|
+
LogitsProcessor.compute_temp_top_p_normalized_logprobs(
|
403
|
+
next_token_logits, logits_metadata
|
404
|
+
)
|
392
405
|
)
|
393
406
|
logits_output = LogitsProcessorOutput(
|
394
407
|
next_token_logits=next_token_logits,
|
@@ -396,13 +409,14 @@ class CudaGraphRunner:
|
|
396
409
|
)
|
397
410
|
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
398
411
|
if return_top_logprob:
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
)
|
403
|
-
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
412
|
+
(
|
413
|
+
logits_output.output_top_logprobs_val,
|
414
|
+
logits_output.output_top_logprobs_idx,
|
415
|
+
) = LogitsProcessor.get_top_logprobs(
|
404
416
|
next_token_logprobs, logits_metadata
|
405
|
-
)[
|
417
|
+
)[
|
418
|
+
2:4
|
419
|
+
]
|
406
420
|
else:
|
407
421
|
logits_output = LogitsProcessorOutput(
|
408
422
|
next_token_logits=next_token_logits,
|
@@ -111,17 +111,20 @@ class ModelRunner:
|
|
111
111
|
)
|
112
112
|
|
113
113
|
if self.is_multimodal:
|
114
|
-
server_args.chunked_prefill_size = -1
|
115
114
|
self.mem_fraction_static *= 0.95
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
115
|
+
if self.model_config.hf_config.architectures == [
|
116
|
+
"MllamaForConditionalGeneration"
|
117
|
+
]:
|
118
|
+
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
|
119
|
+
server_args.chunked_prefill_size = -1
|
121
120
|
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
122
121
|
if self.model_config.hf_config.architectures == [
|
123
122
|
"Qwen2VLForConditionalGeneration"
|
124
123
|
]:
|
124
|
+
logger.info(
|
125
|
+
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
|
126
|
+
)
|
127
|
+
server_args.chunked_prefill_size = -1
|
125
128
|
server_args.disable_radix_cache = True
|
126
129
|
|
127
130
|
# Global vars
|
@@ -154,6 +157,11 @@ class ModelRunner:
|
|
154
157
|
self.sampler = Sampler()
|
155
158
|
self.load_model()
|
156
159
|
|
160
|
+
# Apply torchao quantization
|
161
|
+
apply_torchao_config_to_model(
|
162
|
+
self.model, global_server_args_dict["torchao_config"]
|
163
|
+
)
|
164
|
+
|
157
165
|
# Apply torch TP if the model supports it
|
158
166
|
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
159
167
|
if self.tp_size > 1 and supports_torch_tp:
|
@@ -162,10 +170,6 @@ class ModelRunner:
|
|
162
170
|
else:
|
163
171
|
self.torch_tp_applied = False
|
164
172
|
|
165
|
-
apply_torchao_config_to_model(
|
166
|
-
self.model, global_server_args_dict["torchao_config"]
|
167
|
-
)
|
168
|
-
|
169
173
|
# Init memory pool and attention backends
|
170
174
|
if server_args.lora_paths is not None:
|
171
175
|
self.init_lora_manager()
|
@@ -242,20 +246,22 @@ class ModelRunner:
|
|
242
246
|
if torch.cuda.get_device_capability()[1] < 5:
|
243
247
|
raise RuntimeError("SGLang only supports sm75 and above.")
|
244
248
|
|
245
|
-
# Prepare the
|
249
|
+
# Prepare the model config
|
246
250
|
self.load_config = LoadConfig(
|
247
251
|
load_format=self.server_args.load_format,
|
248
252
|
download_dir=self.server_args.download_dir,
|
249
253
|
)
|
250
|
-
|
251
254
|
if self.server_args.load_format == "gguf":
|
252
255
|
monkey_patch_vllm_gguf_config()
|
256
|
+
|
257
|
+
# Load the model
|
253
258
|
self.model = get_model(
|
254
259
|
model_config=self.model_config,
|
255
260
|
load_config=self.load_config,
|
256
261
|
device_config=DeviceConfig(self.device),
|
257
262
|
)
|
258
263
|
|
264
|
+
# Parse other args
|
259
265
|
self.sliding_window_size = (
|
260
266
|
self.model.get_attention_sliding_window_size()
|
261
267
|
if hasattr(self.model, "get_attention_sliding_window_size")
|
@@ -270,8 +276,10 @@ class ModelRunner:
|
|
270
276
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
271
277
|
)
|
272
278
|
|
273
|
-
def update_weights_from_disk(
|
274
|
-
|
279
|
+
def update_weights_from_disk(
|
280
|
+
self, model_path: str, load_format: str
|
281
|
+
) -> tuple[bool, str]:
|
282
|
+
"""Update engine weights in-place from the disk."""
|
275
283
|
from sglang.srt.model_loader.loader import (
|
276
284
|
DefaultModelLoader,
|
277
285
|
device_loading_context,
|
sglang/srt/model_parallel.py
CHANGED
@@ -2,18 +2,18 @@
|
|
2
2
|
Common utilities for torch model parallelism.
|
3
3
|
"""
|
4
4
|
|
5
|
-
from typing import Optional
|
5
|
+
from typing import Optional, Sequence
|
6
6
|
|
7
7
|
import torch
|
8
|
+
import torch.nn as nn
|
8
9
|
from torch.distributed.device_mesh import DeviceMesh
|
9
10
|
|
10
11
|
try:
|
11
|
-
|
12
|
+
import torch.distributed.tensor as dt
|
12
13
|
except ImportError:
|
13
14
|
# torch 2.4 or older
|
14
|
-
|
15
|
+
import torch.distributed._tensor as dt
|
15
16
|
|
16
|
-
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
17
17
|
from torch.distributed.tensor.parallel import (
|
18
18
|
ColwiseParallel,
|
19
19
|
RowwiseParallel,
|
@@ -21,6 +21,50 @@ from torch.distributed.tensor.parallel import (
|
|
21
21
|
)
|
22
22
|
|
23
23
|
|
24
|
+
def _shard_tensor(
|
25
|
+
full_tensor: torch.Tensor,
|
26
|
+
device_mesh: DeviceMesh,
|
27
|
+
placements: Sequence[dt.Shard],
|
28
|
+
) -> "dt.DTensor":
|
29
|
+
"""
|
30
|
+
Locally shards a full tensor based on indicated sharding arrangement, and
|
31
|
+
returns a DTensor containing the local shard.
|
32
|
+
|
33
|
+
.. warning:: This is a private API that is subject to change. It skips the
|
34
|
+
communication otherwise required by `distribute_tensor`. It is only
|
35
|
+
applicable to cases where all ranks have the same `full_tensor`. For
|
36
|
+
example, in distributed inference all ranks load from the same
|
37
|
+
checkpoint. This API will not check for data equality between ranks, it
|
38
|
+
is thus user's responsibility to ensure the `full_tensor` is the same
|
39
|
+
across ranks.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
full_tensor (torch.Tensor): the full tensor to be sharded.
|
43
|
+
device_mesh (:class:`DeviceMesh`): DeviceMesh to place the
|
44
|
+
DTensor. Must have same dimension as the number of placements.
|
45
|
+
placements (Sequence[:class:`Shard`]): the placements that
|
46
|
+
describes how to place the local tensor on DeviceMesh.
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
A :class:`DTensor` object with the shard as its local tensor.
|
50
|
+
|
51
|
+
Examples:
|
52
|
+
>>> # xdoctest: +SKIP("need world_size and rank")
|
53
|
+
>>> device_mesh = dist.init_device_mesh("cuda", (world_size,))
|
54
|
+
>>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}")
|
55
|
+
>>> dtensor = _shard_tensor(full_tensor, device_mesh, [Shard(1)])
|
56
|
+
"""
|
57
|
+
shape, offset = dt._utils.compute_local_shape_and_global_offset(
|
58
|
+
full_tensor.shape, device_mesh, placements
|
59
|
+
)
|
60
|
+
slices = [
|
61
|
+
slice(cur_offset, cur_offset + cur_shape)
|
62
|
+
for cur_shape, cur_offset in zip(shape, offset)
|
63
|
+
]
|
64
|
+
local_tensor = full_tensor[slices]
|
65
|
+
return dt.DTensor.from_local(local_tensor, device_mesh, placements)
|
66
|
+
|
67
|
+
|
24
68
|
class ColwiseParallelSharded(ColwiseParallel):
|
25
69
|
"""
|
26
70
|
A version of ColwiseParallel where the local weight has been already
|
@@ -34,7 +78,7 @@ class ColwiseParallelSharded(ColwiseParallel):
|
|
34
78
|
# means Colwise as Linear is input * weight^T + bias, where
|
35
79
|
# weight would become Shard(1)
|
36
80
|
for name, param in module.named_parameters():
|
37
|
-
dtensor = DTensor.from_local(param, device_mesh, [Shard(0)])
|
81
|
+
dtensor = dt.DTensor.from_local(param, device_mesh, [dt.Shard(0)])
|
38
82
|
dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
|
39
83
|
module.register_parameter(name, dist_param)
|
40
84
|
|
@@ -47,6 +91,23 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
|
|
47
91
|
AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
|
48
92
|
"""
|
49
93
|
|
94
|
+
def _partition_linear_fn(self, name, module, device_mesh):
|
95
|
+
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
|
96
|
+
# means Rowwise as nn.Linear is input * weight^T + bias, where
|
97
|
+
# weight would become Shard(0)
|
98
|
+
module.register_parameter(
|
99
|
+
"weight",
|
100
|
+
nn.Parameter(_shard_tensor(module.weight, device_mesh, [dt.Shard(1)])),
|
101
|
+
)
|
102
|
+
if getattr(module, "bias", None) is not None:
|
103
|
+
# The Linear module has bias
|
104
|
+
module.register_parameter(
|
105
|
+
"bias",
|
106
|
+
nn.Parameter(
|
107
|
+
dt.distribute_tensor(module.bias, device_mesh, [dt.Replicate()])
|
108
|
+
),
|
109
|
+
)
|
110
|
+
|
50
111
|
@staticmethod
|
51
112
|
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
52
113
|
outputs = super(
|
sglang/srt/models/gemma2.py
CHANGED
@@ -355,6 +355,40 @@ class Gemma2ForCausalLM(nn.Module):
|
|
355
355
|
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
356
356
|
)
|
357
357
|
|
358
|
+
def get_hidden_dim(self, module_name):
|
359
|
+
# return input_dim, output_dim
|
360
|
+
if module_name in ["q_proj", "qkv_proj"]:
|
361
|
+
return (
|
362
|
+
self.config.hidden_size,
|
363
|
+
self.config.head_dim * self.config.num_attention_heads,
|
364
|
+
)
|
365
|
+
elif module_name in ["o_proj"]:
|
366
|
+
return (
|
367
|
+
self.config.head_dim * self.config.num_attention_heads,
|
368
|
+
self.config.hidden_size,
|
369
|
+
)
|
370
|
+
elif module_name in ["kv_proj"]:
|
371
|
+
return (
|
372
|
+
self.config.hidden_size,
|
373
|
+
self.config.head_dim * self.config.num_key_value_heads,
|
374
|
+
)
|
375
|
+
elif module_name == "gate_up_proj":
|
376
|
+
return self.config.hidden_size, self.config.intermediate_size
|
377
|
+
elif module_name == "down_proj":
|
378
|
+
return self.config.intermediate_size, self.config.hidden_size
|
379
|
+
else:
|
380
|
+
raise NotImplementedError()
|
381
|
+
|
382
|
+
def get_module_name(self, name):
|
383
|
+
params_mapping = {
|
384
|
+
"q_proj": "qkv_proj",
|
385
|
+
"k_proj": "qkv_proj",
|
386
|
+
"v_proj": "qkv_proj",
|
387
|
+
"gate_proj": "gate_up_proj",
|
388
|
+
"up_proj": "gate_up_proj",
|
389
|
+
}
|
390
|
+
return params_mapping.get(name, name)
|
391
|
+
|
358
392
|
def get_attention_sliding_window_size(self):
|
359
393
|
return get_attention_sliding_window_size(self.config)
|
360
394
|
|
@@ -32,7 +32,6 @@ class Gemma2ForSequenceClassification(nn.Module):
|
|
32
32
|
) -> None:
|
33
33
|
super().__init__()
|
34
34
|
self.config = config
|
35
|
-
self.torchao_config = None
|
36
35
|
self.quant_config = quant_config
|
37
36
|
self.num_labels = config.num_labels
|
38
37
|
self.model = Gemma2Model(config, quant_config=quant_config)
|