sglang 0.2.14.post2__py3-none-any.whl → 0.2.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +2 -0
- sglang/bench_latency.py +39 -28
- sglang/lang/interpreter.py +3 -0
- sglang/lang/ir.py +5 -0
- sglang/launch_server_llavavid.py +12 -12
- sglang/srt/configs/__init__.py +5 -0
- sglang/srt/configs/exaone.py +195 -0
- sglang/srt/constrained/fsm_cache.py +1 -1
- sglang/srt/conversation.py +24 -2
- sglang/srt/hf_transformers_utils.py +11 -11
- sglang/srt/layers/extend_attention.py +13 -8
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/sampler.py +69 -16
- sglang/srt/managers/controller_multi.py +5 -5
- sglang/srt/managers/controller_single.py +5 -5
- sglang/srt/managers/io_struct.py +6 -1
- sglang/srt/managers/schedule_batch.py +20 -8
- sglang/srt/managers/tokenizer_manager.py +2 -2
- sglang/srt/managers/tp_worker.py +38 -26
- sglang/srt/model_config.py +3 -3
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +68 -23
- sglang/srt/model_executor/model_runner.py +14 -12
- sglang/srt/models/chatglm.py +4 -12
- sglang/srt/models/commandr.py +5 -1
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +5 -1
- sglang/srt/models/deepseek_v2.py +57 -25
- sglang/srt/models/exaone.py +399 -0
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +5 -1
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +5 -1
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/minicpm.py +5 -1
- sglang/srt/models/mixtral.py +6 -2
- sglang/srt/models/mixtral_quant.py +5 -1
- sglang/srt/models/qwen.py +5 -2
- sglang/srt/models/qwen2.py +6 -2
- sglang/srt/models/qwen2_moe.py +5 -14
- sglang/srt/models/stablelm.py +5 -1
- sglang/srt/openai_api/adapter.py +16 -1
- sglang/srt/openai_api/protocol.py +5 -5
- sglang/srt/sampling/sampling_batch_info.py +79 -6
- sglang/srt/server.py +6 -6
- sglang/srt/utils.py +0 -3
- sglang/test/runners.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/METADATA +7 -7
- {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/RECORD +55 -52
- {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/LICENSE +0 -0
- {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/WHEEL +0 -0
- {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/top_level.txt +0 -0
sglang/srt/layers/sampler.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1
|
+
import dataclasses
|
1
2
|
import logging
|
3
|
+
from typing import Union
|
2
4
|
|
3
5
|
import torch
|
4
6
|
from flashinfer.sampling import (
|
@@ -9,6 +11,8 @@ from flashinfer.sampling import (
|
|
9
11
|
)
|
10
12
|
from vllm.model_executor.custom_op import CustomOp
|
11
13
|
|
14
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
15
|
+
|
12
16
|
# TODO: move this dict to another place
|
13
17
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
14
18
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
@@ -16,30 +20,71 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
|
16
20
|
logger = logging.getLogger(__name__)
|
17
21
|
|
18
22
|
|
23
|
+
@dataclasses.dataclass
|
24
|
+
class SampleOutput:
|
25
|
+
success: torch.Tensor
|
26
|
+
probs: torch.Tensor
|
27
|
+
batch_next_token_ids: torch.Tensor
|
28
|
+
|
29
|
+
|
19
30
|
class Sampler(CustomOp):
|
20
31
|
def __init__(self):
|
21
32
|
super().__init__()
|
22
33
|
|
23
|
-
def
|
34
|
+
def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
35
|
+
# min-token, presence, frequency
|
36
|
+
if sampling_info.linear_penalties is not None:
|
37
|
+
logits += sampling_info.linear_penalties
|
38
|
+
|
39
|
+
# repetition
|
40
|
+
if sampling_info.scaling_penalties is not None:
|
41
|
+
logits = torch.where(
|
42
|
+
logits > 0,
|
43
|
+
logits / sampling_info.scaling_penalties,
|
44
|
+
logits * sampling_info.scaling_penalties,
|
45
|
+
)
|
46
|
+
|
47
|
+
return logits
|
48
|
+
|
49
|
+
def _get_probs(
|
50
|
+
self,
|
51
|
+
logits: torch.Tensor,
|
52
|
+
sampling_info: SamplingBatchInfo,
|
53
|
+
is_torch_compile: bool = False,
|
54
|
+
):
|
24
55
|
# Post process logits
|
25
56
|
logits = logits.contiguous()
|
26
57
|
logits.div_(sampling_info.temperatures)
|
58
|
+
if is_torch_compile:
|
59
|
+
# FIXME: Temporary workaround for unknown bugs in torch.compile
|
60
|
+
logits.add_(0)
|
61
|
+
|
27
62
|
if sampling_info.logit_bias is not None:
|
28
63
|
logits.add_(sampling_info.logit_bias)
|
29
64
|
|
30
65
|
if sampling_info.vocab_mask is not None:
|
31
|
-
logits = logits.masked_fill(
|
66
|
+
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
|
32
67
|
|
33
|
-
logits =
|
68
|
+
logits = self._apply_penalties(logits, sampling_info)
|
34
69
|
|
35
|
-
|
70
|
+
return torch.softmax(logits, dim=-1)
|
71
|
+
|
72
|
+
def forward_cuda(
|
73
|
+
self,
|
74
|
+
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
75
|
+
sampling_info: SamplingBatchInfo,
|
76
|
+
):
|
77
|
+
if isinstance(logits, LogitsProcessorOutput):
|
78
|
+
logits = logits.next_token_logits
|
79
|
+
|
80
|
+
probs = self._get_probs(logits, sampling_info)
|
36
81
|
|
37
82
|
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
38
83
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
39
84
|
uniform_samples = torch.rand(
|
40
85
|
(max_top_k_round, batch_size), device=probs.device
|
41
86
|
)
|
42
|
-
if sampling_info.
|
87
|
+
if sampling_info.need_min_p_sampling:
|
43
88
|
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
44
89
|
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
45
90
|
batch_next_token_ids, success = min_p_sampling_from_probs(
|
@@ -55,18 +100,23 @@ class Sampler(CustomOp):
|
|
55
100
|
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
56
101
|
)
|
57
102
|
|
58
|
-
|
59
|
-
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
60
|
-
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
61
|
-
argmax_ids = torch.argmax(probs, dim=-1)
|
62
|
-
batch_next_token_ids = torch.where(
|
63
|
-
success, batch_next_token_ids, argmax_ids
|
64
|
-
)
|
103
|
+
return SampleOutput(success, probs, batch_next_token_ids)
|
65
104
|
|
66
|
-
|
105
|
+
def forward_native(
|
106
|
+
self,
|
107
|
+
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
108
|
+
sampling_info: SamplingBatchInfo,
|
109
|
+
):
|
110
|
+
if isinstance(logits, LogitsProcessorOutput):
|
111
|
+
logits = logits.next_token_logits
|
112
|
+
|
113
|
+
probs = self._get_probs(logits, sampling_info, is_torch_compile=True)
|
114
|
+
|
115
|
+
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
116
|
+
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
117
|
+
)
|
67
118
|
|
68
|
-
|
69
|
-
raise NotImplementedError("Native forward is not implemented yet.")
|
119
|
+
return SampleOutput(success, probs, batch_next_token_ids)
|
70
120
|
|
71
121
|
|
72
122
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
@@ -87,7 +137,10 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|
87
137
|
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
88
138
|
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
89
139
|
try:
|
90
|
-
|
140
|
+
# FIXME: torch.multiomial does not support num_samples = 1
|
141
|
+
sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
|
142
|
+
:, :1
|
143
|
+
]
|
91
144
|
except RuntimeError as e:
|
92
145
|
logger.warning(f"Sampling error: {e}")
|
93
146
|
batch_next_token_ids = torch.zeros(
|
@@ -71,12 +71,12 @@ class ControllerMulti:
|
|
71
71
|
self,
|
72
72
|
server_args: ServerArgs,
|
73
73
|
port_args: PortArgs,
|
74
|
-
|
74
|
+
model_override_args,
|
75
75
|
):
|
76
76
|
# Parse args
|
77
77
|
self.server_args = server_args
|
78
78
|
self.port_args = port_args
|
79
|
-
self.
|
79
|
+
self.model_override_args = model_override_args
|
80
80
|
self.load_balance_method = LoadBalanceMethod.from_str(
|
81
81
|
server_args.load_balance_method
|
82
82
|
)
|
@@ -114,7 +114,7 @@ class ControllerMulti:
|
|
114
114
|
self.server_args,
|
115
115
|
self.port_args,
|
116
116
|
pipe_controller_writer,
|
117
|
-
self.
|
117
|
+
self.model_override_args,
|
118
118
|
True,
|
119
119
|
gpu_ids,
|
120
120
|
dp_worker_id,
|
@@ -189,14 +189,14 @@ def start_controller_process(
|
|
189
189
|
server_args: ServerArgs,
|
190
190
|
port_args: PortArgs,
|
191
191
|
pipe_writer,
|
192
|
-
|
192
|
+
model_override_args: dict,
|
193
193
|
):
|
194
194
|
"""Start a controller process."""
|
195
195
|
|
196
196
|
configure_logger(server_args)
|
197
197
|
|
198
198
|
try:
|
199
|
-
controller = ControllerMulti(server_args, port_args,
|
199
|
+
controller = ControllerMulti(server_args, port_args, model_override_args)
|
200
200
|
except Exception:
|
201
201
|
pipe_writer.send(get_exception_traceback())
|
202
202
|
raise
|
@@ -40,7 +40,7 @@ class ControllerSingle:
|
|
40
40
|
self,
|
41
41
|
server_args: ServerArgs,
|
42
42
|
port_args: PortArgs,
|
43
|
-
|
43
|
+
model_override_args: dict,
|
44
44
|
gpu_ids: List[int],
|
45
45
|
is_data_parallel_worker: bool,
|
46
46
|
dp_worker_id: int,
|
@@ -76,7 +76,7 @@ class ControllerSingle:
|
|
76
76
|
tp_rank_range,
|
77
77
|
server_args,
|
78
78
|
port_args.nccl_ports[dp_worker_id],
|
79
|
-
|
79
|
+
model_override_args,
|
80
80
|
)
|
81
81
|
|
82
82
|
# Launch tp rank 0
|
@@ -85,7 +85,7 @@ class ControllerSingle:
|
|
85
85
|
0,
|
86
86
|
server_args,
|
87
87
|
port_args.nccl_ports[dp_worker_id],
|
88
|
-
|
88
|
+
model_override_args,
|
89
89
|
)
|
90
90
|
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
|
91
91
|
|
@@ -126,7 +126,7 @@ def start_controller_process(
|
|
126
126
|
server_args: ServerArgs,
|
127
127
|
port_args: PortArgs,
|
128
128
|
pipe_writer: multiprocessing.connection.Connection,
|
129
|
-
|
129
|
+
model_override_args: dict,
|
130
130
|
is_data_parallel_worker: bool = False,
|
131
131
|
gpu_ids: List[int] = None,
|
132
132
|
dp_worker_id: int = None,
|
@@ -149,7 +149,7 @@ def start_controller_process(
|
|
149
149
|
controller = ControllerSingle(
|
150
150
|
server_args,
|
151
151
|
port_args,
|
152
|
-
|
152
|
+
model_override_args,
|
153
153
|
gpu_ids,
|
154
154
|
is_data_parallel_worker,
|
155
155
|
dp_worker_id,
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -18,8 +18,9 @@ The definition of objects transfered between different
|
|
18
18
|
processes (TokenizerManager, DetokenizerManager, Controller).
|
19
19
|
"""
|
20
20
|
|
21
|
+
import copy
|
21
22
|
import uuid
|
22
|
-
from dataclasses import dataclass
|
23
|
+
from dataclasses import dataclass, field
|
23
24
|
from typing import Dict, List, Optional, Union
|
24
25
|
|
25
26
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
@@ -249,6 +250,10 @@ class BatchTokenIDOut:
|
|
249
250
|
meta_info: List[Dict]
|
250
251
|
finished_reason: List[BaseFinishReason]
|
251
252
|
|
253
|
+
def __post_init__(self):
|
254
|
+
# deepcopy meta_info to avoid modification in place
|
255
|
+
self.meta_info = copy.deepcopy(self.meta_info)
|
256
|
+
|
252
257
|
|
253
258
|
@dataclass
|
254
259
|
class BatchStrOut:
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
"""
|
2
4
|
Copyright 2023-2024 SGLang Team
|
3
5
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -17,7 +19,7 @@ limitations under the License.
|
|
17
19
|
|
18
20
|
import logging
|
19
21
|
from dataclasses import dataclass
|
20
|
-
from typing import List, Optional, Union
|
22
|
+
from typing import TYPE_CHECKING, List, Optional, Union
|
21
23
|
|
22
24
|
import torch
|
23
25
|
|
@@ -29,6 +31,10 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|
29
31
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
30
32
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
31
33
|
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from sglang.srt.layers.sampler import SampleOutput
|
36
|
+
|
37
|
+
|
32
38
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
33
39
|
|
34
40
|
# Put some global args for easy access
|
@@ -678,11 +684,17 @@ class ScheduleBatch:
|
|
678
684
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
679
685
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
680
686
|
|
681
|
-
def
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
+
def check_sample_results(self, sample_output: SampleOutput):
|
688
|
+
if not torch.all(sample_output.success):
|
689
|
+
probs = sample_output.probs
|
690
|
+
batch_next_token_ids = sample_output.batch_next_token_ids
|
691
|
+
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
692
|
+
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
693
|
+
argmax_ids = torch.argmax(probs, dim=-1)
|
694
|
+
batch_next_token_ids = torch.where(
|
695
|
+
sample_output.success, batch_next_token_ids, argmax_ids
|
696
|
+
)
|
697
|
+
sample_output.probs = probs
|
698
|
+
sample_output.batch_next_token_ids = batch_next_token_ids
|
687
699
|
|
688
|
-
return batch_next_token_ids
|
700
|
+
return sample_output.batch_next_token_ids
|
@@ -77,7 +77,7 @@ class TokenizerManager:
|
|
77
77
|
self,
|
78
78
|
server_args: ServerArgs,
|
79
79
|
port_args: PortArgs,
|
80
|
-
|
80
|
+
model_override_args: dict = None,
|
81
81
|
):
|
82
82
|
self.server_args = server_args
|
83
83
|
|
@@ -95,7 +95,7 @@ class TokenizerManager:
|
|
95
95
|
self.hf_config = get_config(
|
96
96
|
self.model_path,
|
97
97
|
trust_remote_code=server_args.trust_remote_code,
|
98
|
-
|
98
|
+
model_override_args=model_override_args,
|
99
99
|
)
|
100
100
|
self.is_generation = is_generation_model(
|
101
101
|
self.hf_config.architectures, self.server_args.is_embedding
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -31,7 +31,7 @@ from sglang.global_config import global_config
|
|
31
31
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
32
32
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
33
33
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
34
|
-
from sglang.srt.layers.logits_processor import
|
34
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
35
35
|
from sglang.srt.managers.io_struct import (
|
36
36
|
AbortReq,
|
37
37
|
BatchEmbeddingOut,
|
@@ -76,7 +76,7 @@ class ModelTpServer:
|
|
76
76
|
tp_rank: int,
|
77
77
|
server_args: ServerArgs,
|
78
78
|
nccl_port: int,
|
79
|
-
|
79
|
+
model_override_args: dict,
|
80
80
|
):
|
81
81
|
suppress_other_loggers()
|
82
82
|
|
@@ -93,7 +93,7 @@ class ModelTpServer:
|
|
93
93
|
server_args.model_path,
|
94
94
|
server_args.trust_remote_code,
|
95
95
|
context_length=server_args.context_length,
|
96
|
-
|
96
|
+
model_override_args=model_override_args,
|
97
97
|
)
|
98
98
|
|
99
99
|
self.model_runner = ModelRunner(
|
@@ -504,21 +504,29 @@ class ModelTpServer:
|
|
504
504
|
if self.model_runner.is_generation:
|
505
505
|
# Forward and sample the next tokens
|
506
506
|
if batch.extend_num_tokens != 0:
|
507
|
-
|
508
|
-
|
507
|
+
sample_output, logits_output = self.model_runner.forward(
|
508
|
+
batch, ForwardMode.EXTEND
|
509
|
+
)
|
510
|
+
next_token_ids = batch.check_sample_results(sample_output)
|
509
511
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
510
512
|
next_token_ids
|
511
513
|
)
|
512
514
|
|
513
515
|
# Move logprobs to cpu
|
514
|
-
if
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
516
|
+
if logits_output.next_token_logprobs is not None:
|
517
|
+
logits_output.next_token_logprobs = (
|
518
|
+
logits_output.next_token_logprobs[
|
519
|
+
torch.arange(
|
520
|
+
len(next_token_ids), device=next_token_ids.device
|
521
|
+
),
|
522
|
+
next_token_ids,
|
523
|
+
].tolist()
|
524
|
+
)
|
525
|
+
logits_output.input_token_logprobs = (
|
526
|
+
logits_output.input_token_logprobs.tolist()
|
527
|
+
)
|
528
|
+
logits_output.normalized_prompt_logprobs = (
|
529
|
+
logits_output.normalized_prompt_logprobs.tolist()
|
522
530
|
)
|
523
531
|
|
524
532
|
next_token_ids = next_token_ids.tolist()
|
@@ -557,12 +565,14 @@ class ModelTpServer:
|
|
557
565
|
self.req_to_token_pool.free(req.req_pool_idx)
|
558
566
|
|
559
567
|
if req.return_logprob:
|
560
|
-
self.add_logprob_return_values(
|
568
|
+
self.add_logprob_return_values(
|
569
|
+
i, req, pt, next_token_ids, logits_output
|
570
|
+
)
|
561
571
|
pt += req.extend_input_len
|
562
572
|
else:
|
563
573
|
assert batch.extend_num_tokens != 0
|
564
|
-
|
565
|
-
embeddings =
|
574
|
+
logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
575
|
+
embeddings = logits_output.embeddings.tolist()
|
566
576
|
|
567
577
|
# Check finish conditions
|
568
578
|
for i, req in enumerate(batch.reqs):
|
@@ -590,7 +600,7 @@ class ModelTpServer:
|
|
590
600
|
req: Req,
|
591
601
|
pt: int,
|
592
602
|
next_token_ids: List[int],
|
593
|
-
output:
|
603
|
+
output: LogitsProcessorOutput,
|
594
604
|
):
|
595
605
|
if req.normalized_prompt_logprob is None:
|
596
606
|
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
@@ -672,15 +682,17 @@ class ModelTpServer:
|
|
672
682
|
batch.prepare_for_decode()
|
673
683
|
|
674
684
|
# Forward and sample the next tokens
|
675
|
-
|
676
|
-
|
685
|
+
sample_output, logits_output = self.model_runner.forward(
|
686
|
+
batch, ForwardMode.DECODE
|
687
|
+
)
|
688
|
+
next_token_ids = batch.check_sample_results(sample_output)
|
677
689
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
678
690
|
next_token_ids
|
679
691
|
)
|
680
692
|
|
681
693
|
# Move logprobs to cpu
|
682
|
-
if
|
683
|
-
next_token_logprobs =
|
694
|
+
if logits_output.next_token_logprobs is not None:
|
695
|
+
next_token_logprobs = logits_output.next_token_logprobs[
|
684
696
|
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
685
697
|
next_token_ids,
|
686
698
|
].tolist()
|
@@ -706,7 +718,7 @@ class ModelTpServer:
|
|
706
718
|
(next_token_logprobs[i], next_token_id)
|
707
719
|
)
|
708
720
|
if req.top_logprobs_num > 0:
|
709
|
-
req.output_top_logprobs.append(
|
721
|
+
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
710
722
|
|
711
723
|
self.handle_finished_requests(batch)
|
712
724
|
|
@@ -864,7 +876,7 @@ def run_tp_server(
|
|
864
876
|
tp_rank: int,
|
865
877
|
server_args: ServerArgs,
|
866
878
|
nccl_port: int,
|
867
|
-
|
879
|
+
model_override_args: dict,
|
868
880
|
):
|
869
881
|
"""Run a tensor parallel model server."""
|
870
882
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
@@ -875,7 +887,7 @@ def run_tp_server(
|
|
875
887
|
tp_rank,
|
876
888
|
server_args,
|
877
889
|
nccl_port,
|
878
|
-
|
890
|
+
model_override_args,
|
879
891
|
)
|
880
892
|
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
881
893
|
|
@@ -892,14 +904,14 @@ def launch_tp_servers(
|
|
892
904
|
tp_rank_range: List[int],
|
893
905
|
server_args: ServerArgs,
|
894
906
|
nccl_port: int,
|
895
|
-
|
907
|
+
model_override_args: dict,
|
896
908
|
):
|
897
909
|
"""Launch multiple tensor parallel servers."""
|
898
910
|
procs = []
|
899
911
|
for i in tp_rank_range:
|
900
912
|
proc = multiprocessing.Process(
|
901
913
|
target=run_tp_server,
|
902
|
-
args=(gpu_ids[i], i, server_args, nccl_port,
|
914
|
+
args=(gpu_ids[i], i, server_args, nccl_port, model_override_args),
|
903
915
|
)
|
904
916
|
proc.start()
|
905
917
|
procs.append(proc)
|
sglang/srt/model_config.py
CHANGED
@@ -33,17 +33,17 @@ class ModelConfig:
|
|
33
33
|
trust_remote_code: bool = True,
|
34
34
|
revision: Optional[str] = None,
|
35
35
|
context_length: Optional[int] = None,
|
36
|
-
|
36
|
+
model_override_args: Optional[dict] = None,
|
37
37
|
) -> None:
|
38
38
|
self.path = path
|
39
39
|
self.trust_remote_code = trust_remote_code
|
40
40
|
self.revision = revision
|
41
|
-
self.
|
41
|
+
self.model_override_args = model_override_args
|
42
42
|
self.hf_config = get_config(
|
43
43
|
self.path,
|
44
44
|
trust_remote_code,
|
45
45
|
revision,
|
46
|
-
|
46
|
+
model_override_args=model_override_args,
|
47
47
|
)
|
48
48
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
49
49
|
if context_length is not None:
|
@@ -26,16 +26,18 @@ from vllm.distributed.parallel_state import graph_capture
|
|
26
26
|
from vllm.model_executor.custom_op import CustomOp
|
27
27
|
|
28
28
|
from sglang.srt.layers.logits_processor import (
|
29
|
-
LogitProcessorOutput,
|
30
29
|
LogitsMetadata,
|
31
30
|
LogitsProcessor,
|
31
|
+
LogitsProcessorOutput,
|
32
32
|
)
|
33
|
+
from sglang.srt.layers.sampler import SampleOutput
|
33
34
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
34
35
|
from sglang.srt.model_executor.forward_batch_info import (
|
35
36
|
ForwardMode,
|
36
37
|
InputMetadata,
|
37
38
|
update_flashinfer_indices,
|
38
39
|
)
|
40
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
39
41
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
40
42
|
|
41
43
|
|
@@ -144,6 +146,10 @@ class CudaGraphRunner:
|
|
144
146
|
self.flashinfer_kv_indices.clone(),
|
145
147
|
]
|
146
148
|
|
149
|
+
# Sampling inputs
|
150
|
+
vocab_size = model_runner.model_config.vocab_size
|
151
|
+
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
|
152
|
+
|
147
153
|
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
148
154
|
|
149
155
|
if use_torch_compile:
|
@@ -235,6 +241,7 @@ class CudaGraphRunner:
|
|
235
241
|
def run_once():
|
236
242
|
input_metadata = InputMetadata(
|
237
243
|
forward_mode=ForwardMode.DECODE,
|
244
|
+
sampling_info=self.sampling_info[:bs],
|
238
245
|
batch_size=bs,
|
239
246
|
req_pool_indices=req_pool_indices,
|
240
247
|
seq_lens=seq_lens,
|
@@ -299,27 +306,35 @@ class CudaGraphRunner:
|
|
299
306
|
self.flashinfer_handlers[bs],
|
300
307
|
)
|
301
308
|
|
309
|
+
# Sampling inputs
|
310
|
+
self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
|
311
|
+
|
302
312
|
# Replay
|
303
313
|
torch.cuda.synchronize()
|
304
314
|
self.graphs[bs].replay()
|
305
315
|
torch.cuda.synchronize()
|
306
|
-
|
316
|
+
sample_output, logits_output = self.output_buffers[bs]
|
307
317
|
|
308
318
|
# Unpad
|
309
319
|
if bs != raw_bs:
|
310
|
-
|
311
|
-
next_token_logits=
|
320
|
+
logits_output = LogitsProcessorOutput(
|
321
|
+
next_token_logits=logits_output.next_token_logits[:raw_bs],
|
312
322
|
next_token_logprobs=None,
|
313
323
|
normalized_prompt_logprobs=None,
|
314
324
|
input_token_logprobs=None,
|
315
325
|
input_top_logprobs=None,
|
316
326
|
output_top_logprobs=None,
|
317
327
|
)
|
328
|
+
sample_output = SampleOutput(
|
329
|
+
sample_output.success[:raw_bs],
|
330
|
+
sample_output.probs[:raw_bs],
|
331
|
+
sample_output.batch_next_token_ids[:raw_bs],
|
332
|
+
)
|
318
333
|
|
319
334
|
# Extract logprobs
|
320
335
|
if batch.return_logprob:
|
321
|
-
|
322
|
-
|
336
|
+
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
|
337
|
+
logits_output.next_token_logits, dim=-1
|
323
338
|
)
|
324
339
|
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
|
325
340
|
if return_top_logprob:
|
@@ -327,8 +342,8 @@ class CudaGraphRunner:
|
|
327
342
|
forward_mode=ForwardMode.DECODE,
|
328
343
|
top_logprobs_nums=batch.top_logprobs_nums,
|
329
344
|
)
|
330
|
-
|
331
|
-
|
345
|
+
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
346
|
+
logits_output.next_token_logprobs, logits_metadata
|
332
347
|
)[1]
|
333
348
|
|
334
|
-
return
|
349
|
+
return sample_output, logits_output
|