sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_offline_throughput.py +18 -6
- sglang/bench_one_batch.py +13 -0
- sglang/bench_serving.py +8 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/constrained/outlines_backend.py +5 -0
- sglang/srt/constrained/xgrammar_backend.py +9 -6
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +22 -5
- sglang/srt/layers/attention/torch_native_backend.py +22 -8
- sglang/srt/layers/attention/triton_backend.py +38 -33
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +665 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
- sglang/srt/layers/fused_moe_triton/layer.py +1 -1
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/quantization/__init__.py +2 -47
- sglang/srt/layers/quantization/fp8.py +607 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +11 -2
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/torchao_utils.py +58 -45
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +39 -24
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +236 -197
- sglang/srt/managers/tokenizer_manager.py +99 -58
- sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -11
- sglang/srt/model_executor/model_runner.py +24 -9
- sglang/srt/model_parallel.py +67 -10
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/deepseek_v2.py +87 -7
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +72 -13
- sglang/srt/models/llama.py +22 -5
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +12 -9
- sglang/srt/models/phi3_small.py +0 -5
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +0 -5
- sglang/srt/models/torch_native_llama.py +0 -5
- sglang/srt/openai_api/adapter.py +4 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +4 -4
- sglang/srt/server_args.py +62 -13
- sglang/srt/utils.py +57 -10
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ import signal
|
|
22
22
|
import sys
|
23
23
|
import time
|
24
24
|
import uuid
|
25
|
-
from typing import Dict, List, Optional,
|
25
|
+
from typing import Any, Dict, List, Optional, Union
|
26
26
|
|
27
27
|
import fastapi
|
28
28
|
import uvloop
|
@@ -76,6 +76,7 @@ class ReqState:
|
|
76
76
|
out_list: List
|
77
77
|
finished: bool
|
78
78
|
event: asyncio.Event
|
79
|
+
obj: Any
|
79
80
|
|
80
81
|
# For metrics
|
81
82
|
created_time: float
|
@@ -283,7 +284,7 @@ class TokenizerManager:
|
|
283
284
|
):
|
284
285
|
"""Wait for the response of one request."""
|
285
286
|
event = asyncio.Event()
|
286
|
-
state = ReqState([], False, event, created_time=created_time)
|
287
|
+
state = ReqState([], False, event, obj, created_time=created_time)
|
287
288
|
self.rid_to_state[obj.rid] = state
|
288
289
|
|
289
290
|
while True:
|
@@ -295,15 +296,7 @@ class TokenizerManager:
|
|
295
296
|
raise ValueError(f"Abort request {obj.rid}")
|
296
297
|
continue
|
297
298
|
|
298
|
-
|
299
|
-
out = self.convert_logprob_style(
|
300
|
-
state.out_list[-1],
|
301
|
-
obj.return_logprob,
|
302
|
-
obj.top_logprobs_num,
|
303
|
-
obj.return_text_in_logprobs,
|
304
|
-
)
|
305
|
-
else: # isinstance(obj, (EmbeddingReqInput,))
|
306
|
-
out = state.out_list[-1]
|
299
|
+
out = state.out_list[-1]
|
307
300
|
|
308
301
|
state.out_list = []
|
309
302
|
if state.finished:
|
@@ -315,7 +308,13 @@ class TokenizerManager:
|
|
315
308
|
break
|
316
309
|
|
317
310
|
state.event.clear()
|
318
|
-
|
311
|
+
|
312
|
+
if obj.stream:
|
313
|
+
yield out
|
314
|
+
else:
|
315
|
+
if request is not None and await request.is_disconnected():
|
316
|
+
self.abort_request(obj.rid)
|
317
|
+
raise ValueError(f"Abort request {obj.rid}")
|
319
318
|
|
320
319
|
async def _handle_batch_request(
|
321
320
|
self,
|
@@ -573,7 +572,7 @@ class TokenizerManager:
|
|
573
572
|
|
574
573
|
async def sigterm_watchdog(self):
|
575
574
|
while not self.gracefully_exit:
|
576
|
-
await asyncio.sleep(
|
575
|
+
await asyncio.sleep(5)
|
577
576
|
|
578
577
|
# drain requests
|
579
578
|
while True:
|
@@ -609,29 +608,55 @@ class TokenizerManager:
|
|
609
608
|
if state is None:
|
610
609
|
continue
|
611
610
|
|
612
|
-
|
611
|
+
meta_info = {
|
612
|
+
"id": rid,
|
613
|
+
"finish_reason": recv_obj.finished_reasons[i],
|
614
|
+
"prompt_tokens": recv_obj.prompt_tokens[i],
|
615
|
+
}
|
616
|
+
|
617
|
+
if getattr(state.obj, "return_logprob", False):
|
618
|
+
self.convert_logprob_style(
|
619
|
+
meta_info,
|
620
|
+
state.obj.top_logprobs_num,
|
621
|
+
state.obj.return_text_in_logprobs,
|
622
|
+
recv_obj,
|
623
|
+
i,
|
624
|
+
)
|
625
|
+
|
626
|
+
if not isinstance(recv_obj, BatchEmbeddingOut):
|
627
|
+
meta_info.update(
|
628
|
+
{
|
629
|
+
"completion_tokens": recv_obj.completion_tokens[i],
|
630
|
+
"cached_tokens": recv_obj.cached_tokens[i],
|
631
|
+
}
|
632
|
+
)
|
633
|
+
|
613
634
|
if isinstance(recv_obj, BatchStrOut):
|
614
635
|
out_dict = {
|
615
636
|
"text": recv_obj.output_strs[i],
|
616
|
-
"meta_info":
|
637
|
+
"meta_info": meta_info,
|
617
638
|
}
|
618
639
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
619
640
|
out_dict = {
|
620
641
|
"token_ids": recv_obj.output_ids[i],
|
621
|
-
"meta_info":
|
642
|
+
"meta_info": meta_info,
|
622
643
|
}
|
623
644
|
else:
|
624
645
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
625
646
|
out_dict = {
|
626
647
|
"embedding": recv_obj.embeddings[i],
|
627
|
-
"meta_info":
|
648
|
+
"meta_info": meta_info,
|
628
649
|
}
|
629
650
|
state.out_list.append(out_dict)
|
630
|
-
state.finished = recv_obj.
|
651
|
+
state.finished = recv_obj.finished_reasons[i] is not None
|
631
652
|
state.event.set()
|
632
653
|
|
633
654
|
if self.enable_metrics:
|
634
|
-
completion_tokens =
|
655
|
+
completion_tokens = (
|
656
|
+
recv_obj.completion_tokens[i]
|
657
|
+
if recv_obj.completion_tokens
|
658
|
+
else 0
|
659
|
+
)
|
635
660
|
|
636
661
|
if state.first_token_time is None:
|
637
662
|
state.first_token_time = time.time()
|
@@ -647,7 +672,7 @@ class TokenizerManager:
|
|
647
672
|
|
648
673
|
if state.finished:
|
649
674
|
self.metrics_collector.inc_prompt_tokens(
|
650
|
-
recv_obj.
|
675
|
+
recv_obj.prompt_tokens[i]
|
651
676
|
)
|
652
677
|
self.metrics_collector.inc_generation_tokens(
|
653
678
|
completion_tokens
|
@@ -696,57 +721,73 @@ class TokenizerManager:
|
|
696
721
|
|
697
722
|
def convert_logprob_style(
|
698
723
|
self,
|
699
|
-
|
700
|
-
return_logprob: bool,
|
724
|
+
meta_info: dict,
|
701
725
|
top_logprobs_num: int,
|
702
726
|
return_text_in_logprobs: bool,
|
727
|
+
recv_obj: BatchStrOut,
|
728
|
+
recv_obj_index: int,
|
703
729
|
):
|
704
|
-
|
705
|
-
|
706
|
-
|
730
|
+
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
731
|
+
recv_obj.input_token_logprobs_val[recv_obj_index],
|
732
|
+
recv_obj.input_token_logprobs_idx[recv_obj_index],
|
733
|
+
return_text_in_logprobs,
|
734
|
+
)
|
735
|
+
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
736
|
+
recv_obj.output_token_logprobs_val[recv_obj_index],
|
737
|
+
recv_obj.output_token_logprobs_idx[recv_obj_index],
|
738
|
+
return_text_in_logprobs,
|
739
|
+
)
|
740
|
+
meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
|
741
|
+
recv_obj_index
|
742
|
+
]
|
743
|
+
|
744
|
+
if top_logprobs_num > 0:
|
745
|
+
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
746
|
+
recv_obj.input_top_logprobs_val[recv_obj_index],
|
747
|
+
recv_obj.input_top_logprobs_idx[recv_obj_index],
|
748
|
+
return_text_in_logprobs,
|
707
749
|
)
|
708
|
-
|
709
|
-
|
750
|
+
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
751
|
+
recv_obj.output_top_logprobs_val[recv_obj_index],
|
752
|
+
recv_obj.output_top_logprobs_idx[recv_obj_index],
|
753
|
+
return_text_in_logprobs,
|
710
754
|
)
|
711
755
|
|
712
|
-
if top_logprobs_num > 0:
|
713
|
-
ret["meta_info"]["input_top_logprobs"] = (
|
714
|
-
self.detokenize_top_logprobs_tokens(
|
715
|
-
ret["meta_info"]["input_top_logprobs"],
|
716
|
-
return_text_in_logprobs,
|
717
|
-
)
|
718
|
-
)
|
719
|
-
ret["meta_info"]["output_top_logprobs"] = (
|
720
|
-
self.detokenize_top_logprobs_tokens(
|
721
|
-
ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
|
722
|
-
)
|
723
|
-
)
|
724
|
-
return ret
|
725
|
-
|
726
756
|
def detokenize_logprob_tokens(
|
727
|
-
self,
|
757
|
+
self,
|
758
|
+
token_logprobs_val: List[float],
|
759
|
+
token_logprobs_idx: List[int],
|
760
|
+
decode_to_text: bool,
|
728
761
|
):
|
729
|
-
# TODO(lianmin): This should run on DetokenizerManager
|
730
762
|
if not decode_to_text:
|
731
|
-
return [
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
]
|
763
|
+
return [
|
764
|
+
(logprob, token_id, None)
|
765
|
+
for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
|
766
|
+
]
|
767
|
+
else:
|
768
|
+
assert self.tokenizer is not None
|
769
|
+
token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
|
770
|
+
return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
|
740
771
|
|
741
|
-
def detokenize_top_logprobs_tokens(
|
772
|
+
def detokenize_top_logprobs_tokens(
|
773
|
+
self,
|
774
|
+
token_logprobs_val: List[float],
|
775
|
+
token_logprobs_idx: List[int],
|
776
|
+
decode_to_text: bool,
|
777
|
+
):
|
742
778
|
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
|
743
779
|
# We should batch all top-k tokens in all positions.
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
780
|
+
ret = []
|
781
|
+
for i in range(len(token_logprobs_val)):
|
782
|
+
if token_logprobs_val[i]:
|
783
|
+
ret.append(
|
784
|
+
self.detokenize_logprob_tokens(
|
785
|
+
token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
|
786
|
+
)
|
748
787
|
)
|
749
|
-
|
788
|
+
else:
|
789
|
+
ret.append(None)
|
790
|
+
return ret
|
750
791
|
|
751
792
|
|
752
793
|
class SignalHandler:
|
@@ -32,12 +32,13 @@ from sglang.srt.managers.io_struct import (
|
|
32
32
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
33
33
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
34
34
|
from sglang.srt.server_args import ServerArgs
|
35
|
+
from sglang.srt.utils import get_compiler_backend
|
35
36
|
from sglang.utils import get_exception_traceback
|
36
37
|
|
37
38
|
logger = logging.getLogger(__name__)
|
38
39
|
|
39
40
|
|
40
|
-
@torch.compile(dynamic=True)
|
41
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
41
42
|
def resolve_future_token_ids(input_ids, future_token_ids_map):
|
42
43
|
input_ids[:] = torch.where(
|
43
44
|
input_ids < 0,
|
@@ -73,12 +74,13 @@ class TpModelWorkerClient:
|
|
73
74
|
# Launch threads
|
74
75
|
self.input_queue = Queue()
|
75
76
|
self.output_queue = Queue()
|
76
|
-
self.forward_stream = torch.
|
77
|
+
self.forward_stream = torch.get_device_module(self.device).Stream()
|
77
78
|
self.forward_thread = threading.Thread(
|
78
79
|
target=self.forward_thread_func,
|
79
80
|
)
|
80
81
|
self.forward_thread.start()
|
81
82
|
self.parent_process = psutil.Process().parent()
|
83
|
+
self.scheduler_stream = torch.get_device_module(self.device).current_stream()
|
82
84
|
|
83
85
|
def get_worker_info(self):
|
84
86
|
return self.worker.get_worker_info()
|
@@ -97,7 +99,7 @@ class TpModelWorkerClient:
|
|
97
99
|
|
98
100
|
def forward_thread_func(self):
|
99
101
|
try:
|
100
|
-
with torch.
|
102
|
+
with torch.get_device_module(self.device).stream(self.forward_stream):
|
101
103
|
self.forward_thread_func_()
|
102
104
|
except Exception:
|
103
105
|
traceback = get_exception_traceback()
|
@@ -122,7 +124,7 @@ class TpModelWorkerClient:
|
|
122
124
|
|
123
125
|
# Create event
|
124
126
|
self.launch_done = threading.Event()
|
125
|
-
copy_done = torch.
|
127
|
+
copy_done = torch.get_device_module(self.device).Event()
|
126
128
|
|
127
129
|
# Resolve future tokens in the input
|
128
130
|
input_ids = model_worker_batch.input_ids
|
@@ -190,7 +192,7 @@ class TpModelWorkerClient:
|
|
190
192
|
)
|
191
193
|
|
192
194
|
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
193
|
-
|
195
|
+
self.scheduler_stream.synchronize()
|
194
196
|
|
195
197
|
# Push a new batch to the queue
|
196
198
|
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from typing import Callable
|
2
|
+
from typing import Callable, List, Tuple
|
3
3
|
|
4
4
|
|
5
5
|
class BasePrefixCache(ABC):
|
@@ -10,7 +10,7 @@ class BasePrefixCache(ABC):
|
|
10
10
|
pass
|
11
11
|
|
12
12
|
@abstractmethod
|
13
|
-
def match_prefix(self, **kwargs):
|
13
|
+
def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
|
14
14
|
pass
|
15
15
|
|
16
16
|
@abstractmethod
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
4
4
|
|
5
|
-
from typing import TYPE_CHECKING, Callable, List, Optional
|
5
|
+
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
|
6
6
|
|
7
7
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
8
8
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
@@ -30,7 +30,7 @@ class ChunkCache(BasePrefixCache):
|
|
30
30
|
def reset(self):
|
31
31
|
self.entries = {}
|
32
32
|
|
33
|
-
def match_prefix(self, rid: int, key: List[int]):
|
33
|
+
def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
|
34
34
|
if rid not in self.entries:
|
35
35
|
return [], None
|
36
36
|
|
@@ -27,6 +27,7 @@ from typing import List, Tuple, Union
|
|
27
27
|
import torch
|
28
28
|
|
29
29
|
from sglang.srt.layers.radix_attention import RadixAttention
|
30
|
+
from sglang.srt.utils import get_compiler_backend
|
30
31
|
|
31
32
|
logger = logging.getLogger(__name__)
|
32
33
|
|
@@ -129,6 +130,9 @@ class BaseTokenToKVPool:
|
|
129
130
|
return select_index.to(self.device, non_blocking=True)
|
130
131
|
|
131
132
|
def free(self, free_index: torch.Tensor):
|
133
|
+
if free_index.numel() == 0:
|
134
|
+
return
|
135
|
+
|
132
136
|
if self.is_not_in_free_group:
|
133
137
|
self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
|
134
138
|
else:
|
@@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
234
238
|
|
235
239
|
# This compiled version is slower in the unit test
|
236
240
|
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
237
|
-
@torch.compile(dynamic=True)
|
241
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
238
242
|
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
239
243
|
dst_1[loc] = src_1.to(dtype).view(store_dtype)
|
240
244
|
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
|
|
22
22
|
import heapq
|
23
23
|
import time
|
24
24
|
from collections import defaultdict
|
25
|
-
from typing import TYPE_CHECKING, Callable, List, Optional
|
25
|
+
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
|
26
26
|
|
27
27
|
import torch
|
28
28
|
|
@@ -76,7 +76,17 @@ class RadixCache(BasePrefixCache):
|
|
76
76
|
self.root_node.lock_ref = 1
|
77
77
|
self.evictable_size_ = 0
|
78
78
|
|
79
|
-
def match_prefix(self, key: List, **kwargs):
|
79
|
+
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
|
80
|
+
"""Find the matching prefix from the radix tree.
|
81
|
+
Args:
|
82
|
+
key: A list of token IDs to find a matching prefix.
|
83
|
+
Returns:
|
84
|
+
A tuple of a tensor of matching prefix token IDs and
|
85
|
+
the last node that contains the prefix values. Note that
|
86
|
+
this API can modify the internal state of the Radix tree.
|
87
|
+
The last node create a new child if the prefix is shorter
|
88
|
+
than the last node's value.
|
89
|
+
"""
|
80
90
|
if self.disable:
|
81
91
|
return [], self.root_node
|
82
92
|
|
@@ -20,6 +20,8 @@ from contextlib import contextmanager
|
|
20
20
|
from typing import TYPE_CHECKING, Callable
|
21
21
|
|
22
22
|
import torch
|
23
|
+
import tqdm
|
24
|
+
from vllm.distributed import get_tensor_model_parallel_rank
|
23
25
|
from vllm.distributed.parallel_state import graph_capture
|
24
26
|
from vllm.model_executor.custom_op import CustomOp
|
25
27
|
|
@@ -47,7 +49,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
|
47
49
|
if "FusedMoE" in sub.__class__.__name__:
|
48
50
|
if batch_size == 1:
|
49
51
|
# The performance of torch.compile on this layer is not always good when bs > 1,
|
50
|
-
# so we decide to
|
52
|
+
# so we decide to only use torch.compile when bs =1
|
51
53
|
sub._forward_method = fused_moe_forward_native
|
52
54
|
else:
|
53
55
|
sub._forward_method = sub.forward_native
|
@@ -127,9 +129,23 @@ class CudaGraphRunner:
|
|
127
129
|
|
128
130
|
# Batch sizes to capture
|
129
131
|
if model_runner.server_args.disable_cuda_graph_padding:
|
130
|
-
self.capture_bs = list(range(1,
|
132
|
+
self.capture_bs = list(range(1, 33)) + [64, 128]
|
131
133
|
else:
|
132
134
|
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
135
|
+
|
136
|
+
if max(self.capture_bs) > model_runner.req_to_token_pool.size:
|
137
|
+
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
138
|
+
# is very samll. We add more values here to make sure we capture the maximum bs.
|
139
|
+
self.capture_bs = list(
|
140
|
+
sorted(
|
141
|
+
set(
|
142
|
+
self.capture_bs
|
143
|
+
+ [model_runner.req_to_token_pool.size - 1]
|
144
|
+
+ [model_runner.req_to_token_pool.size]
|
145
|
+
)
|
146
|
+
)
|
147
|
+
)
|
148
|
+
|
133
149
|
self.capture_bs = [
|
134
150
|
bs
|
135
151
|
for bs in self.capture_bs
|
@@ -241,7 +257,12 @@ class CudaGraphRunner:
|
|
241
257
|
def capture(self):
|
242
258
|
with graph_capture() as graph_capture_context:
|
243
259
|
self.stream = graph_capture_context.stream
|
244
|
-
|
260
|
+
capture_bs = (
|
261
|
+
tqdm.tqdm(self.capture_bs)
|
262
|
+
if get_tensor_model_parallel_rank() == 0
|
263
|
+
else self.capture_bs
|
264
|
+
)
|
265
|
+
for bs in capture_bs:
|
245
266
|
with patch_model(
|
246
267
|
self.model_runner.model,
|
247
268
|
bs in self.compile_bs,
|
@@ -373,8 +394,14 @@ class CudaGraphRunner:
|
|
373
394
|
|
374
395
|
# Extract logprobs
|
375
396
|
if forward_batch.return_logprob:
|
376
|
-
|
377
|
-
|
397
|
+
logits_metadata = LogitsMetadata(
|
398
|
+
forward_mode=ForwardMode.DECODE,
|
399
|
+
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
400
|
+
)
|
401
|
+
next_token_logprobs = (
|
402
|
+
LogitsProcessor.compute_temp_top_p_normalized_logprobs(
|
403
|
+
next_token_logits, logits_metadata
|
404
|
+
)
|
378
405
|
)
|
379
406
|
logits_output = LogitsProcessorOutput(
|
380
407
|
next_token_logits=next_token_logits,
|
@@ -382,13 +409,14 @@ class CudaGraphRunner:
|
|
382
409
|
)
|
383
410
|
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
384
411
|
if return_top_logprob:
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
)
|
389
|
-
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
412
|
+
(
|
413
|
+
logits_output.output_top_logprobs_val,
|
414
|
+
logits_output.output_top_logprobs_idx,
|
415
|
+
) = LogitsProcessor.get_top_logprobs(
|
390
416
|
next_token_logprobs, logits_metadata
|
391
|
-
)[
|
417
|
+
)[
|
418
|
+
2:4
|
419
|
+
]
|
392
420
|
else:
|
393
421
|
logits_output = LogitsProcessorOutput(
|
394
422
|
next_token_logits=next_token_logits,
|
@@ -27,7 +27,6 @@ from vllm.distributed import (
|
|
27
27
|
initialize_model_parallel,
|
28
28
|
set_custom_all_reduce,
|
29
29
|
)
|
30
|
-
from vllm.distributed.parallel_state import in_the_same_node_as
|
31
30
|
|
32
31
|
from sglang.srt.configs.device_config import DeviceConfig
|
33
32
|
from sglang.srt.configs.load_config import LoadConfig
|
@@ -38,6 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack
|
|
38
37
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
39
38
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
40
39
|
from sglang.srt.layers.sampler import Sampler
|
40
|
+
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
41
41
|
from sglang.srt.lora.lora_manager import LoRAManager
|
42
42
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
43
43
|
from sglang.srt.mem_cache.memory_pool import (
|
@@ -111,15 +111,20 @@ class ModelRunner:
|
|
111
111
|
)
|
112
112
|
|
113
113
|
if self.is_multimodal:
|
114
|
-
logger.info(
|
115
|
-
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
116
|
-
)
|
117
|
-
server_args.chunked_prefill_size = -1
|
118
114
|
self.mem_fraction_static *= 0.95
|
115
|
+
if self.model_config.hf_config.architectures == [
|
116
|
+
"MllamaForConditionalGeneration"
|
117
|
+
]:
|
118
|
+
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
|
119
|
+
server_args.chunked_prefill_size = -1
|
119
120
|
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
120
121
|
if self.model_config.hf_config.architectures == [
|
121
122
|
"Qwen2VLForConditionalGeneration"
|
122
123
|
]:
|
124
|
+
logger.info(
|
125
|
+
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
|
126
|
+
)
|
127
|
+
server_args.chunked_prefill_size = -1
|
123
128
|
server_args.disable_radix_cache = True
|
124
129
|
|
125
130
|
# Global vars
|
@@ -139,6 +144,7 @@ class ModelRunner:
|
|
139
144
|
"torchao_config": server_args.torchao_config,
|
140
145
|
"enable_nan_detection": server_args.enable_nan_detection,
|
141
146
|
"enable_dp_attention": server_args.enable_dp_attention,
|
147
|
+
"enable_ep_moe": server_args.enable_ep_moe,
|
142
148
|
}
|
143
149
|
)
|
144
150
|
|
@@ -151,6 +157,11 @@ class ModelRunner:
|
|
151
157
|
self.sampler = Sampler()
|
152
158
|
self.load_model()
|
153
159
|
|
160
|
+
# Apply torchao quantization
|
161
|
+
apply_torchao_config_to_model(
|
162
|
+
self.model, global_server_args_dict["torchao_config"]
|
163
|
+
)
|
164
|
+
|
154
165
|
# Apply torch TP if the model supports it
|
155
166
|
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
156
167
|
if self.tp_size > 1 and supports_torch_tp:
|
@@ -235,20 +246,22 @@ class ModelRunner:
|
|
235
246
|
if torch.cuda.get_device_capability()[1] < 5:
|
236
247
|
raise RuntimeError("SGLang only supports sm75 and above.")
|
237
248
|
|
238
|
-
# Prepare the
|
249
|
+
# Prepare the model config
|
239
250
|
self.load_config = LoadConfig(
|
240
251
|
load_format=self.server_args.load_format,
|
241
252
|
download_dir=self.server_args.download_dir,
|
242
253
|
)
|
243
|
-
|
244
254
|
if self.server_args.load_format == "gguf":
|
245
255
|
monkey_patch_vllm_gguf_config()
|
256
|
+
|
257
|
+
# Load the model
|
246
258
|
self.model = get_model(
|
247
259
|
model_config=self.model_config,
|
248
260
|
load_config=self.load_config,
|
249
261
|
device_config=DeviceConfig(self.device),
|
250
262
|
)
|
251
263
|
|
264
|
+
# Parse other args
|
252
265
|
self.sliding_window_size = (
|
253
266
|
self.model.get_attention_sliding_window_size()
|
254
267
|
if hasattr(self.model, "get_attention_sliding_window_size")
|
@@ -263,8 +276,10 @@ class ModelRunner:
|
|
263
276
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
264
277
|
)
|
265
278
|
|
266
|
-
def update_weights_from_disk(
|
267
|
-
|
279
|
+
def update_weights_from_disk(
|
280
|
+
self, model_path: str, load_format: str
|
281
|
+
) -> tuple[bool, str]:
|
282
|
+
"""Update engine weights in-place from the disk."""
|
268
283
|
from sglang.srt.model_loader.loader import (
|
269
284
|
DefaultModelLoader,
|
270
285
|
device_loading_context,
|