sglang 0.2.14.post2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +2 -0
- sglang/bench_latency.py +39 -28
- sglang/lang/backend/runtime_endpoint.py +8 -4
- sglang/lang/interpreter.py +3 -0
- sglang/lang/ir.py +5 -0
- sglang/launch_server_llavavid.py +12 -12
- sglang/srt/configs/__init__.py +5 -0
- sglang/srt/configs/exaone.py +195 -0
- sglang/srt/constrained/fsm_cache.py +1 -1
- sglang/srt/conversation.py +24 -2
- sglang/srt/hf_transformers_utils.py +12 -12
- sglang/srt/layers/extend_attention.py +13 -8
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/sampler.py +94 -17
- sglang/srt/managers/controller_multi.py +5 -5
- sglang/srt/managers/controller_single.py +5 -5
- sglang/srt/managers/io_struct.py +6 -1
- sglang/srt/managers/schedule_batch.py +26 -11
- sglang/srt/managers/tokenizer_manager.py +9 -9
- sglang/srt/managers/tp_worker.py +38 -26
- sglang/srt/model_config.py +3 -3
- sglang/srt/model_executor/cuda_graph_runner.py +26 -9
- sglang/srt/model_executor/forward_batch_info.py +68 -23
- sglang/srt/model_executor/model_runner.py +15 -22
- sglang/srt/models/chatglm.py +9 -15
- sglang/srt/models/commandr.py +5 -1
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +5 -1
- sglang/srt/models/deepseek_v2.py +57 -25
- sglang/srt/models/exaone.py +368 -0
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +5 -1
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +5 -1
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/{llama2.py → llama.py} +25 -45
- sglang/srt/models/llama_classification.py +34 -41
- sglang/srt/models/llama_embedding.py +7 -6
- sglang/srt/models/llava.py +8 -11
- sglang/srt/models/llavavid.py +5 -6
- sglang/srt/models/minicpm.py +5 -1
- sglang/srt/models/mistral.py +2 -3
- sglang/srt/models/mixtral.py +6 -2
- sglang/srt/models/mixtral_quant.py +5 -1
- sglang/srt/models/qwen.py +5 -2
- sglang/srt/models/qwen2.py +6 -2
- sglang/srt/models/qwen2_moe.py +5 -14
- sglang/srt/models/stablelm.py +5 -1
- sglang/srt/openai_api/adapter.py +16 -1
- sglang/srt/openai_api/protocol.py +5 -5
- sglang/srt/sampling/sampling_batch_info.py +75 -6
- sglang/srt/server.py +6 -6
- sglang/srt/utils.py +0 -3
- sglang/test/runners.py +1 -1
- sglang/test/test_programs.py +68 -0
- sglang/test/test_utils.py +4 -0
- sglang/utils.py +39 -0
- sglang/version.py +1 -1
- {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/METADATA +9 -8
- sglang-0.3.0.dist-info/RECORD +118 -0
- {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/WHEEL +1 -1
- sglang-0.2.14.post2.dist-info/RECORD +0 -115
- {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/LICENSE +0 -0
- {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
|
|
29
29
|
|
30
30
|
|
31
31
|
@dataclasses.dataclass
|
32
|
-
class
|
32
|
+
class LogitsProcessorOutput:
|
33
33
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
34
34
|
next_token_logits: torch.Tensor
|
35
35
|
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
@@ -185,7 +185,7 @@ class LogitsProcessor(nn.Module):
|
|
185
185
|
|
186
186
|
# Return only last_logits if logprob is not requested
|
187
187
|
if not logits_metadata.return_logprob:
|
188
|
-
return
|
188
|
+
return LogitsProcessorOutput(
|
189
189
|
next_token_logits=last_logits,
|
190
190
|
next_token_logprobs=None,
|
191
191
|
normalized_prompt_logprobs=None,
|
@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module):
|
|
209
209
|
else:
|
210
210
|
output_top_logprobs = None
|
211
211
|
|
212
|
-
return
|
212
|
+
return LogitsProcessorOutput(
|
213
213
|
next_token_logits=last_logits,
|
214
214
|
next_token_logprobs=last_logprobs,
|
215
215
|
normalized_prompt_logprobs=None,
|
@@ -278,7 +278,7 @@ class LogitsProcessor(nn.Module):
|
|
278
278
|
# Remove the last token logprob for the prefill tokens.
|
279
279
|
input_token_logprobs = input_token_logprobs[:-1]
|
280
280
|
|
281
|
-
return
|
281
|
+
return LogitsProcessorOutput(
|
282
282
|
next_token_logits=last_logits,
|
283
283
|
next_token_logprobs=last_logprobs,
|
284
284
|
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
sglang/srt/layers/sampler.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1
|
+
import dataclasses
|
1
2
|
import logging
|
3
|
+
from typing import Tuple, Union
|
2
4
|
|
3
5
|
import torch
|
4
6
|
from flashinfer.sampling import (
|
@@ -7,8 +9,11 @@ from flashinfer.sampling import (
|
|
7
9
|
top_k_top_p_sampling_from_probs,
|
8
10
|
top_p_renorm_prob,
|
9
11
|
)
|
12
|
+
from torch.library import custom_op as torch_custom_op
|
10
13
|
from vllm.model_executor.custom_op import CustomOp
|
11
14
|
|
15
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
16
|
+
|
12
17
|
# TODO: move this dict to another place
|
13
18
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
14
19
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
@@ -16,37 +21,76 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
|
16
21
|
logger = logging.getLogger(__name__)
|
17
22
|
|
18
23
|
|
24
|
+
@dataclasses.dataclass
|
25
|
+
class SampleOutput:
|
26
|
+
success: torch.Tensor
|
27
|
+
probs: torch.Tensor
|
28
|
+
batch_next_token_ids: torch.Tensor
|
29
|
+
|
30
|
+
|
19
31
|
class Sampler(CustomOp):
|
20
32
|
def __init__(self):
|
21
33
|
super().__init__()
|
34
|
+
# FIXME: torch.multinomial has too many bugs
|
35
|
+
self.forward_native = self.forward_cuda
|
36
|
+
self.is_torch_compile = False
|
37
|
+
|
38
|
+
def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
39
|
+
# min-token, presence, frequency
|
40
|
+
if sampling_info.linear_penalties is not None:
|
41
|
+
logits += sampling_info.linear_penalties
|
42
|
+
|
43
|
+
# repetition
|
44
|
+
if sampling_info.scaling_penalties is not None:
|
45
|
+
logits = torch.where(
|
46
|
+
logits > 0,
|
47
|
+
logits / sampling_info.scaling_penalties,
|
48
|
+
logits * sampling_info.scaling_penalties,
|
49
|
+
)
|
50
|
+
|
51
|
+
return logits
|
22
52
|
|
23
|
-
def
|
53
|
+
def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
24
54
|
# Post process logits
|
25
55
|
logits = logits.contiguous()
|
26
56
|
logits.div_(sampling_info.temperatures)
|
57
|
+
if self.is_torch_compile:
|
58
|
+
# FIXME: Temporary workaround for unknown bugs in torch.compile
|
59
|
+
logits.add_(0)
|
60
|
+
|
27
61
|
if sampling_info.logit_bias is not None:
|
28
62
|
logits.add_(sampling_info.logit_bias)
|
29
63
|
|
30
64
|
if sampling_info.vocab_mask is not None:
|
31
|
-
logits = logits.masked_fill(
|
65
|
+
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
|
32
66
|
|
33
|
-
logits =
|
67
|
+
logits = self._apply_penalties(logits, sampling_info)
|
34
68
|
|
35
|
-
|
69
|
+
return torch.softmax(logits, dim=-1)
|
70
|
+
|
71
|
+
def forward_cuda(
|
72
|
+
self,
|
73
|
+
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
74
|
+
sampling_info: SamplingBatchInfo,
|
75
|
+
):
|
76
|
+
if isinstance(logits, LogitsProcessorOutput):
|
77
|
+
logits = logits.next_token_logits
|
78
|
+
|
79
|
+
probs = self._get_probs(logits, sampling_info)
|
36
80
|
|
37
81
|
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
38
82
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
39
83
|
uniform_samples = torch.rand(
|
40
84
|
(max_top_k_round, batch_size), device=probs.device
|
41
85
|
)
|
42
|
-
if sampling_info.
|
86
|
+
if sampling_info.need_min_p_sampling:
|
43
87
|
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
44
88
|
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
45
89
|
batch_next_token_ids, success = min_p_sampling_from_probs(
|
46
90
|
probs, uniform_samples, sampling_info.min_ps
|
47
91
|
)
|
48
92
|
else:
|
49
|
-
batch_next_token_ids, success =
|
93
|
+
batch_next_token_ids, success = flashinfer_top_k_top_p(
|
50
94
|
probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
|
51
95
|
)
|
52
96
|
else:
|
@@ -55,18 +99,48 @@ class Sampler(CustomOp):
|
|
55
99
|
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
56
100
|
)
|
57
101
|
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
102
|
+
return SampleOutput(success, probs, batch_next_token_ids)
|
103
|
+
|
104
|
+
def forward_native(
|
105
|
+
self,
|
106
|
+
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
107
|
+
sampling_info: SamplingBatchInfo,
|
108
|
+
):
|
109
|
+
if isinstance(logits, LogitsProcessorOutput):
|
110
|
+
logits = logits.next_token_logits
|
111
|
+
|
112
|
+
probs = self._get_probs(logits, sampling_info)
|
113
|
+
|
114
|
+
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
115
|
+
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
116
|
+
)
|
117
|
+
|
118
|
+
return SampleOutput(success, probs, batch_next_token_ids)
|
65
119
|
|
66
|
-
return batch_next_token_ids
|
67
120
|
|
68
|
-
|
69
|
-
|
121
|
+
@torch_custom_op("my_lib::flashinfer_top_k_top_p", mutates_args={})
|
122
|
+
def flashinfer_top_k_top_p(
|
123
|
+
probs: torch.Tensor,
|
124
|
+
uniform_samples: torch.Tensor,
|
125
|
+
top_ks: torch.Tensor,
|
126
|
+
top_ps: torch.Tensor,
|
127
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
128
|
+
# NOTE: we do not use min_p neither in CUDA nor in torch.compile
|
129
|
+
return top_k_top_p_sampling_from_probs(probs, uniform_samples, top_ks, top_ps)
|
130
|
+
|
131
|
+
|
132
|
+
@flashinfer_top_k_top_p.register_fake
|
133
|
+
def _(
|
134
|
+
probs: torch.Tensor,
|
135
|
+
uniform_samples: torch.Tensor,
|
136
|
+
top_ks: torch.Tensor,
|
137
|
+
top_ps: torch.Tensor,
|
138
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
139
|
+
bs = probs.shape[0]
|
140
|
+
return (
|
141
|
+
torch.ones(bs, dtype=torch.bool, device=probs.device),
|
142
|
+
torch.zeros(bs, dtype=torch.int32, device=probs.device),
|
143
|
+
)
|
70
144
|
|
71
145
|
|
72
146
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
@@ -87,7 +161,10 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|
87
161
|
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
88
162
|
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
89
163
|
try:
|
90
|
-
|
164
|
+
# FIXME: torch.multiomial does not support num_samples = 1
|
165
|
+
sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
|
166
|
+
:, :1
|
167
|
+
]
|
91
168
|
except RuntimeError as e:
|
92
169
|
logger.warning(f"Sampling error: {e}")
|
93
170
|
batch_next_token_ids = torch.zeros(
|
@@ -71,12 +71,12 @@ class ControllerMulti:
|
|
71
71
|
self,
|
72
72
|
server_args: ServerArgs,
|
73
73
|
port_args: PortArgs,
|
74
|
-
|
74
|
+
model_override_args,
|
75
75
|
):
|
76
76
|
# Parse args
|
77
77
|
self.server_args = server_args
|
78
78
|
self.port_args = port_args
|
79
|
-
self.
|
79
|
+
self.model_override_args = model_override_args
|
80
80
|
self.load_balance_method = LoadBalanceMethod.from_str(
|
81
81
|
server_args.load_balance_method
|
82
82
|
)
|
@@ -114,7 +114,7 @@ class ControllerMulti:
|
|
114
114
|
self.server_args,
|
115
115
|
self.port_args,
|
116
116
|
pipe_controller_writer,
|
117
|
-
self.
|
117
|
+
self.model_override_args,
|
118
118
|
True,
|
119
119
|
gpu_ids,
|
120
120
|
dp_worker_id,
|
@@ -189,14 +189,14 @@ def start_controller_process(
|
|
189
189
|
server_args: ServerArgs,
|
190
190
|
port_args: PortArgs,
|
191
191
|
pipe_writer,
|
192
|
-
|
192
|
+
model_override_args: dict,
|
193
193
|
):
|
194
194
|
"""Start a controller process."""
|
195
195
|
|
196
196
|
configure_logger(server_args)
|
197
197
|
|
198
198
|
try:
|
199
|
-
controller = ControllerMulti(server_args, port_args,
|
199
|
+
controller = ControllerMulti(server_args, port_args, model_override_args)
|
200
200
|
except Exception:
|
201
201
|
pipe_writer.send(get_exception_traceback())
|
202
202
|
raise
|
@@ -40,7 +40,7 @@ class ControllerSingle:
|
|
40
40
|
self,
|
41
41
|
server_args: ServerArgs,
|
42
42
|
port_args: PortArgs,
|
43
|
-
|
43
|
+
model_override_args: dict,
|
44
44
|
gpu_ids: List[int],
|
45
45
|
is_data_parallel_worker: bool,
|
46
46
|
dp_worker_id: int,
|
@@ -76,7 +76,7 @@ class ControllerSingle:
|
|
76
76
|
tp_rank_range,
|
77
77
|
server_args,
|
78
78
|
port_args.nccl_ports[dp_worker_id],
|
79
|
-
|
79
|
+
model_override_args,
|
80
80
|
)
|
81
81
|
|
82
82
|
# Launch tp rank 0
|
@@ -85,7 +85,7 @@ class ControllerSingle:
|
|
85
85
|
0,
|
86
86
|
server_args,
|
87
87
|
port_args.nccl_ports[dp_worker_id],
|
88
|
-
|
88
|
+
model_override_args,
|
89
89
|
)
|
90
90
|
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
|
91
91
|
|
@@ -126,7 +126,7 @@ def start_controller_process(
|
|
126
126
|
server_args: ServerArgs,
|
127
127
|
port_args: PortArgs,
|
128
128
|
pipe_writer: multiprocessing.connection.Connection,
|
129
|
-
|
129
|
+
model_override_args: dict,
|
130
130
|
is_data_parallel_worker: bool = False,
|
131
131
|
gpu_ids: List[int] = None,
|
132
132
|
dp_worker_id: int = None,
|
@@ -149,7 +149,7 @@ def start_controller_process(
|
|
149
149
|
controller = ControllerSingle(
|
150
150
|
server_args,
|
151
151
|
port_args,
|
152
|
-
|
152
|
+
model_override_args,
|
153
153
|
gpu_ids,
|
154
154
|
is_data_parallel_worker,
|
155
155
|
dp_worker_id,
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -18,8 +18,9 @@ The definition of objects transfered between different
|
|
18
18
|
processes (TokenizerManager, DetokenizerManager, Controller).
|
19
19
|
"""
|
20
20
|
|
21
|
+
import copy
|
21
22
|
import uuid
|
22
|
-
from dataclasses import dataclass
|
23
|
+
from dataclasses import dataclass, field
|
23
24
|
from typing import Dict, List, Optional, Union
|
24
25
|
|
25
26
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
@@ -249,6 +250,10 @@ class BatchTokenIDOut:
|
|
249
250
|
meta_info: List[Dict]
|
250
251
|
finished_reason: List[BaseFinishReason]
|
251
252
|
|
253
|
+
def __post_init__(self):
|
254
|
+
# deepcopy meta_info to avoid modification in place
|
255
|
+
self.meta_info = copy.deepcopy(self.meta_info)
|
256
|
+
|
252
257
|
|
253
258
|
@dataclass
|
254
259
|
class BatchStrOut:
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
"""
|
2
4
|
Copyright 2023-2024 SGLang Team
|
3
5
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -17,7 +19,7 @@ limitations under the License.
|
|
17
19
|
|
18
20
|
import logging
|
19
21
|
from dataclasses import dataclass
|
20
|
-
from typing import List, Optional, Union
|
22
|
+
from typing import TYPE_CHECKING, List, Optional, Union
|
21
23
|
|
22
24
|
import torch
|
23
25
|
|
@@ -29,6 +31,10 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|
29
31
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
30
32
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
31
33
|
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from sglang.srt.layers.sampler import SampleOutput
|
36
|
+
|
37
|
+
|
32
38
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
33
39
|
|
34
40
|
# Put some global args for easy access
|
@@ -172,19 +178,22 @@ class Req:
|
|
172
178
|
def adjust_max_prefix_ids(self):
|
173
179
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
174
180
|
input_len = len(self.fill_ids)
|
175
|
-
|
181
|
+
|
182
|
+
# FIXME: To work around some bugs in logprob computation, we need to ensure each
|
183
|
+
# request has at least one token. Later, we can relax this requirement and use `input_len`.
|
184
|
+
max_prefix_len = input_len - 1
|
176
185
|
|
177
186
|
if self.sampling_params.max_new_tokens > 0:
|
178
187
|
# Need at least one token to compute logits
|
179
188
|
max_prefix_len = min(max_prefix_len, input_len - 1)
|
180
189
|
|
181
190
|
if self.return_logprob:
|
182
|
-
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
183
|
-
|
184
191
|
if self.normalized_prompt_logprob is None:
|
185
192
|
# Need at least two tokens to compute normalized logprob
|
186
193
|
max_prefix_len = min(max_prefix_len, input_len - 2)
|
194
|
+
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
187
195
|
|
196
|
+
max_prefix_len = max(max_prefix_len, 0)
|
188
197
|
return self.fill_ids[:max_prefix_len]
|
189
198
|
|
190
199
|
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
@@ -678,11 +687,17 @@ class ScheduleBatch:
|
|
678
687
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
679
688
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
680
689
|
|
681
|
-
def
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
690
|
+
def check_sample_results(self, sample_output: SampleOutput):
|
691
|
+
if not torch.all(sample_output.success):
|
692
|
+
probs = sample_output.probs
|
693
|
+
batch_next_token_ids = sample_output.batch_next_token_ids
|
694
|
+
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
695
|
+
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
696
|
+
argmax_ids = torch.argmax(probs, dim=-1)
|
697
|
+
batch_next_token_ids = torch.where(
|
698
|
+
sample_output.success, batch_next_token_ids, argmax_ids
|
699
|
+
)
|
700
|
+
sample_output.probs = probs
|
701
|
+
sample_output.batch_next_token_ids = batch_next_token_ids
|
687
702
|
|
688
|
-
return batch_next_token_ids
|
703
|
+
return sample_output.batch_next_token_ids
|
@@ -77,7 +77,7 @@ class TokenizerManager:
|
|
77
77
|
self,
|
78
78
|
server_args: ServerArgs,
|
79
79
|
port_args: PortArgs,
|
80
|
-
|
80
|
+
model_override_args: dict = None,
|
81
81
|
):
|
82
82
|
self.server_args = server_args
|
83
83
|
|
@@ -86,8 +86,8 @@ class TokenizerManager:
|
|
86
86
|
self.recv_from_detokenizer = context.socket(zmq.PULL)
|
87
87
|
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
88
88
|
|
89
|
-
self.
|
90
|
-
self.
|
89
|
+
self.send_to_controller = context.socket(zmq.PUSH)
|
90
|
+
self.send_to_controller.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
|
91
91
|
|
92
92
|
# Read model args
|
93
93
|
self.model_path = server_args.model_path
|
@@ -95,7 +95,7 @@ class TokenizerManager:
|
|
95
95
|
self.hf_config = get_config(
|
96
96
|
self.model_path,
|
97
97
|
trust_remote_code=server_args.trust_remote_code,
|
98
|
-
|
98
|
+
model_override_args=model_override_args,
|
99
99
|
)
|
100
100
|
self.is_generation = is_generation_model(
|
101
101
|
self.hf_config.architectures, self.server_args.is_embedding
|
@@ -271,7 +271,7 @@ class TokenizerManager:
|
|
271
271
|
input_ids,
|
272
272
|
sampling_params,
|
273
273
|
)
|
274
|
-
self.
|
274
|
+
self.send_to_controller.send_pyobj(tokenized_obj)
|
275
275
|
|
276
276
|
# Recv results
|
277
277
|
event = asyncio.Event()
|
@@ -367,7 +367,7 @@ class TokenizerManager:
|
|
367
367
|
input_ids,
|
368
368
|
sampling_params,
|
369
369
|
)
|
370
|
-
self.
|
370
|
+
self.send_to_controller.send_pyobj(tokenized_obj)
|
371
371
|
|
372
372
|
event = asyncio.Event()
|
373
373
|
state = ReqState([], False, event)
|
@@ -500,14 +500,14 @@ class TokenizerManager:
|
|
500
500
|
|
501
501
|
def flush_cache(self):
|
502
502
|
req = FlushCacheReq()
|
503
|
-
self.
|
503
|
+
self.send_to_controller.send_pyobj(req)
|
504
504
|
|
505
505
|
def abort_request(self, rid: str):
|
506
506
|
if rid not in self.rid_to_state:
|
507
507
|
return
|
508
508
|
del self.rid_to_state[rid]
|
509
509
|
req = AbortReq(rid)
|
510
|
-
self.
|
510
|
+
self.send_to_controller.send_pyobj(req)
|
511
511
|
|
512
512
|
async def update_weights(
|
513
513
|
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
@@ -524,7 +524,7 @@ class TokenizerManager:
|
|
524
524
|
# wait for the previous generation requests to finish
|
525
525
|
while len(self.rid_to_state) > 0:
|
526
526
|
await asyncio.sleep(0)
|
527
|
-
self.
|
527
|
+
self.send_to_controller.send_pyobj(obj)
|
528
528
|
self.model_update_result = asyncio.Future()
|
529
529
|
result = await self.model_update_result
|
530
530
|
if result.success:
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -31,7 +31,7 @@ from sglang.global_config import global_config
|
|
31
31
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
32
32
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
33
33
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
34
|
-
from sglang.srt.layers.logits_processor import
|
34
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
35
35
|
from sglang.srt.managers.io_struct import (
|
36
36
|
AbortReq,
|
37
37
|
BatchEmbeddingOut,
|
@@ -76,7 +76,7 @@ class ModelTpServer:
|
|
76
76
|
tp_rank: int,
|
77
77
|
server_args: ServerArgs,
|
78
78
|
nccl_port: int,
|
79
|
-
|
79
|
+
model_override_args: dict,
|
80
80
|
):
|
81
81
|
suppress_other_loggers()
|
82
82
|
|
@@ -93,7 +93,7 @@ class ModelTpServer:
|
|
93
93
|
server_args.model_path,
|
94
94
|
server_args.trust_remote_code,
|
95
95
|
context_length=server_args.context_length,
|
96
|
-
|
96
|
+
model_override_args=model_override_args,
|
97
97
|
)
|
98
98
|
|
99
99
|
self.model_runner = ModelRunner(
|
@@ -504,21 +504,29 @@ class ModelTpServer:
|
|
504
504
|
if self.model_runner.is_generation:
|
505
505
|
# Forward and sample the next tokens
|
506
506
|
if batch.extend_num_tokens != 0:
|
507
|
-
|
508
|
-
|
507
|
+
sample_output, logits_output = self.model_runner.forward(
|
508
|
+
batch, ForwardMode.EXTEND
|
509
|
+
)
|
510
|
+
next_token_ids = batch.check_sample_results(sample_output)
|
509
511
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
510
512
|
next_token_ids
|
511
513
|
)
|
512
514
|
|
513
515
|
# Move logprobs to cpu
|
514
|
-
if
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
516
|
+
if logits_output.next_token_logprobs is not None:
|
517
|
+
logits_output.next_token_logprobs = (
|
518
|
+
logits_output.next_token_logprobs[
|
519
|
+
torch.arange(
|
520
|
+
len(next_token_ids), device=next_token_ids.device
|
521
|
+
),
|
522
|
+
next_token_ids,
|
523
|
+
].tolist()
|
524
|
+
)
|
525
|
+
logits_output.input_token_logprobs = (
|
526
|
+
logits_output.input_token_logprobs.tolist()
|
527
|
+
)
|
528
|
+
logits_output.normalized_prompt_logprobs = (
|
529
|
+
logits_output.normalized_prompt_logprobs.tolist()
|
522
530
|
)
|
523
531
|
|
524
532
|
next_token_ids = next_token_ids.tolist()
|
@@ -557,12 +565,14 @@ class ModelTpServer:
|
|
557
565
|
self.req_to_token_pool.free(req.req_pool_idx)
|
558
566
|
|
559
567
|
if req.return_logprob:
|
560
|
-
self.add_logprob_return_values(
|
568
|
+
self.add_logprob_return_values(
|
569
|
+
i, req, pt, next_token_ids, logits_output
|
570
|
+
)
|
561
571
|
pt += req.extend_input_len
|
562
572
|
else:
|
563
573
|
assert batch.extend_num_tokens != 0
|
564
|
-
|
565
|
-
embeddings =
|
574
|
+
logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
575
|
+
embeddings = logits_output.embeddings.tolist()
|
566
576
|
|
567
577
|
# Check finish conditions
|
568
578
|
for i, req in enumerate(batch.reqs):
|
@@ -590,7 +600,7 @@ class ModelTpServer:
|
|
590
600
|
req: Req,
|
591
601
|
pt: int,
|
592
602
|
next_token_ids: List[int],
|
593
|
-
output:
|
603
|
+
output: LogitsProcessorOutput,
|
594
604
|
):
|
595
605
|
if req.normalized_prompt_logprob is None:
|
596
606
|
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
@@ -672,15 +682,17 @@ class ModelTpServer:
|
|
672
682
|
batch.prepare_for_decode()
|
673
683
|
|
674
684
|
# Forward and sample the next tokens
|
675
|
-
|
676
|
-
|
685
|
+
sample_output, logits_output = self.model_runner.forward(
|
686
|
+
batch, ForwardMode.DECODE
|
687
|
+
)
|
688
|
+
next_token_ids = batch.check_sample_results(sample_output)
|
677
689
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
678
690
|
next_token_ids
|
679
691
|
)
|
680
692
|
|
681
693
|
# Move logprobs to cpu
|
682
|
-
if
|
683
|
-
next_token_logprobs =
|
694
|
+
if logits_output.next_token_logprobs is not None:
|
695
|
+
next_token_logprobs = logits_output.next_token_logprobs[
|
684
696
|
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
685
697
|
next_token_ids,
|
686
698
|
].tolist()
|
@@ -706,7 +718,7 @@ class ModelTpServer:
|
|
706
718
|
(next_token_logprobs[i], next_token_id)
|
707
719
|
)
|
708
720
|
if req.top_logprobs_num > 0:
|
709
|
-
req.output_top_logprobs.append(
|
721
|
+
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
710
722
|
|
711
723
|
self.handle_finished_requests(batch)
|
712
724
|
|
@@ -864,7 +876,7 @@ def run_tp_server(
|
|
864
876
|
tp_rank: int,
|
865
877
|
server_args: ServerArgs,
|
866
878
|
nccl_port: int,
|
867
|
-
|
879
|
+
model_override_args: dict,
|
868
880
|
):
|
869
881
|
"""Run a tensor parallel model server."""
|
870
882
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
@@ -875,7 +887,7 @@ def run_tp_server(
|
|
875
887
|
tp_rank,
|
876
888
|
server_args,
|
877
889
|
nccl_port,
|
878
|
-
|
890
|
+
model_override_args,
|
879
891
|
)
|
880
892
|
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
881
893
|
|
@@ -892,14 +904,14 @@ def launch_tp_servers(
|
|
892
904
|
tp_rank_range: List[int],
|
893
905
|
server_args: ServerArgs,
|
894
906
|
nccl_port: int,
|
895
|
-
|
907
|
+
model_override_args: dict,
|
896
908
|
):
|
897
909
|
"""Launch multiple tensor parallel servers."""
|
898
910
|
procs = []
|
899
911
|
for i in tp_rank_range:
|
900
912
|
proc = multiprocessing.Process(
|
901
913
|
target=run_tp_server,
|
902
|
-
args=(gpu_ids[i], i, server_args, nccl_port,
|
914
|
+
args=(gpu_ids[i], i, server_args, nccl_port, model_override_args),
|
903
915
|
)
|
904
916
|
proc.start()
|
905
917
|
procs.append(proc)
|
sglang/srt/model_config.py
CHANGED
@@ -33,17 +33,17 @@ class ModelConfig:
|
|
33
33
|
trust_remote_code: bool = True,
|
34
34
|
revision: Optional[str] = None,
|
35
35
|
context_length: Optional[int] = None,
|
36
|
-
|
36
|
+
model_override_args: Optional[dict] = None,
|
37
37
|
) -> None:
|
38
38
|
self.path = path
|
39
39
|
self.trust_remote_code = trust_remote_code
|
40
40
|
self.revision = revision
|
41
|
-
self.
|
41
|
+
self.model_override_args = model_override_args
|
42
42
|
self.hf_config = get_config(
|
43
43
|
self.path,
|
44
44
|
trust_remote_code,
|
45
45
|
revision,
|
46
|
-
|
46
|
+
model_override_args=model_override_args,
|
47
47
|
)
|
48
48
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
49
49
|
if context_length is not None:
|