sglang 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +6 -0
- sglang/bench_latency.py +7 -3
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +9 -0
- sglang/launch_server.py +8 -1
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -1
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/fused_moe/layer.py +2 -2
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +60 -23
- sglang/srt/layers/radix_attention.py +3 -4
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +52 -167
- sglang/srt/managers/tokenizer_manager.py +192 -83
- sglang/srt/managers/tp_worker.py +130 -43
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +49 -11
- sglang/srt/model_executor/forward_batch_info.py +59 -27
- sglang/srt/model_executor/model_runner.py +210 -61
- sglang/srt/models/chatglm.py +4 -12
- sglang/srt/models/commandr.py +5 -1
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +5 -1
- sglang/srt/models/deepseek_v2.py +5 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +15 -7
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +16 -2
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +5 -1
- sglang/srt/models/mixtral.py +5 -1
- sglang/srt/models/mixtral_quant.py +5 -1
- sglang/srt/models/qwen.py +5 -2
- sglang/srt/models/qwen2.py +13 -3
- sglang/srt/models/qwen2_moe.py +5 -14
- sglang/srt/models/stablelm.py +5 -1
- sglang/srt/openai_api/adapter.py +117 -37
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -0
- sglang/srt/server.py +84 -56
- sglang/srt/server_args.py +43 -15
- sglang/srt/utils.py +26 -16
- sglang/test/runners.py +23 -31
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_utils.py +36 -53
- sglang/version.py +1 -1
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/METADATA +92 -25
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang-0.2.13.dist-info/RECORD +0 -112
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
|
|
29
29
|
|
30
30
|
|
31
31
|
@dataclasses.dataclass
|
32
|
-
class
|
32
|
+
class LogitsProcessorOutput:
|
33
33
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
34
34
|
next_token_logits: torch.Tensor
|
35
35
|
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
@@ -55,6 +55,9 @@ class LogitsMetadata:
|
|
55
55
|
extend_start_loc: Optional[torch.Tensor] = None
|
56
56
|
top_logprobs_nums: Optional[List[int]] = None
|
57
57
|
|
58
|
+
extend_seq_lens_cpu: List[int] = None
|
59
|
+
logprob_start_lens_cpu: List[int] = None
|
60
|
+
|
58
61
|
@classmethod
|
59
62
|
def from_input_metadata(cls, input_metadata: InputMetadata):
|
60
63
|
return cls(
|
@@ -63,22 +66,30 @@ class LogitsMetadata:
|
|
63
66
|
extend_start_loc=input_metadata.extend_start_loc,
|
64
67
|
return_logprob=input_metadata.return_logprob,
|
65
68
|
top_logprobs_nums=input_metadata.top_logprobs_nums,
|
69
|
+
extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
|
70
|
+
logprob_start_lens_cpu=input_metadata.logprob_start_lens_cpu,
|
66
71
|
)
|
67
72
|
|
68
73
|
|
69
74
|
class LogitsProcessor(nn.Module):
|
70
|
-
def __init__(self, config):
|
75
|
+
def __init__(self, config, skip_all_gather: bool = False):
|
71
76
|
super().__init__()
|
72
77
|
self.config = config
|
73
|
-
self.
|
78
|
+
self.do_tensor_parallel_all_gather = (
|
79
|
+
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
80
|
+
)
|
74
81
|
|
75
82
|
def _get_normalized_prompt_logprobs(
|
76
|
-
self,
|
83
|
+
self,
|
84
|
+
input_token_logprobs: torch.Tensor,
|
85
|
+
cum_start_len0: torch.Tensor,
|
86
|
+
cum_start_len1: torch.Tensor,
|
87
|
+
logits_metadata: LogitsMetadata,
|
77
88
|
):
|
78
89
|
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
79
90
|
|
80
|
-
start = logits_metadata.extend_start_loc.clone()
|
81
|
-
end = start + logits_metadata.extend_seq_lens - 2
|
91
|
+
start = logits_metadata.extend_start_loc.clone() - cum_start_len0
|
92
|
+
end = start + logits_metadata.extend_seq_lens - 2 - cum_start_len1
|
82
93
|
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
83
94
|
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
84
95
|
sum_logp = (
|
@@ -91,7 +102,7 @@ class LogitsProcessor(nn.Module):
|
|
91
102
|
return normalized_prompt_logprobs
|
92
103
|
|
93
104
|
@staticmethod
|
94
|
-
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
|
105
|
+
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
95
106
|
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
96
107
|
output_top_logprobs = []
|
97
108
|
max_k = max(logits_metadata.top_logprobs_nums)
|
@@ -105,7 +116,7 @@ class LogitsProcessor(nn.Module):
|
|
105
116
|
# TODO: vectorize the code below
|
106
117
|
input_top_logprobs, output_top_logprobs = [], []
|
107
118
|
pt = 0
|
108
|
-
extend_seq_lens_cpu = logits_metadata.
|
119
|
+
extend_seq_lens_cpu = logits_metadata.extend_seq_lens_cpu
|
109
120
|
|
110
121
|
max_k = max(logits_metadata.top_logprobs_nums)
|
111
122
|
ret = all_logprobs.topk(max_k, dim=1)
|
@@ -113,26 +124,30 @@ class LogitsProcessor(nn.Module):
|
|
113
124
|
indices = ret.indices.tolist()
|
114
125
|
|
115
126
|
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
127
|
+
start_len = logits_metadata.logprob_start_lens_cpu[i]
|
128
|
+
pruned_len = extend_seq_len - start_len
|
129
|
+
|
116
130
|
if extend_seq_len == 0:
|
117
131
|
input_top_logprobs.append([])
|
118
132
|
output_top_logprobs.append([])
|
119
133
|
continue
|
134
|
+
|
120
135
|
k = logits_metadata.top_logprobs_nums[i]
|
121
136
|
input_top_logprobs.append(
|
122
137
|
[
|
123
138
|
list(zip(values[pt + j][:k], indices[pt + j][:k]))
|
124
|
-
for j in range(
|
139
|
+
for j in range(pruned_len - 1)
|
125
140
|
]
|
126
141
|
)
|
127
142
|
output_top_logprobs.append(
|
128
143
|
list(
|
129
144
|
zip(
|
130
|
-
values[pt +
|
131
|
-
indices[pt +
|
145
|
+
values[pt + pruned_len - 1][:k],
|
146
|
+
indices[pt + pruned_len - 1][:k],
|
132
147
|
)
|
133
148
|
)
|
134
149
|
)
|
135
|
-
pt +=
|
150
|
+
pt += pruned_len
|
136
151
|
|
137
152
|
return input_top_logprobs, output_top_logprobs
|
138
153
|
|
@@ -159,18 +174,18 @@ class LogitsProcessor(nn.Module):
|
|
159
174
|
last_hidden = hidden_states[last_index]
|
160
175
|
|
161
176
|
last_logits = torch.matmul(last_hidden, weight.T)
|
162
|
-
if self.
|
177
|
+
if self.do_tensor_parallel_all_gather:
|
163
178
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
164
179
|
last_logits = last_logits[:, : self.config.vocab_size].float()
|
165
180
|
|
166
181
|
if hasattr(self.config, "final_logit_softcapping"):
|
167
182
|
last_logits.div_(self.config.final_logit_softcapping)
|
168
|
-
|
183
|
+
torch.tanh(last_logits, out=last_logits)
|
169
184
|
last_logits.mul_(self.config.final_logit_softcapping)
|
170
185
|
|
171
186
|
# Return only last_logits if logprob is not requested
|
172
187
|
if not logits_metadata.return_logprob:
|
173
|
-
return
|
188
|
+
return LogitsProcessorOutput(
|
174
189
|
next_token_logits=last_logits,
|
175
190
|
next_token_logprobs=None,
|
176
191
|
normalized_prompt_logprobs=None,
|
@@ -194,7 +209,7 @@ class LogitsProcessor(nn.Module):
|
|
194
209
|
else:
|
195
210
|
output_top_logprobs = None
|
196
211
|
|
197
|
-
return
|
212
|
+
return LogitsProcessorOutput(
|
198
213
|
next_token_logits=last_logits,
|
199
214
|
next_token_logprobs=last_logprobs,
|
200
215
|
normalized_prompt_logprobs=None,
|
@@ -203,14 +218,30 @@ class LogitsProcessor(nn.Module):
|
|
203
218
|
output_top_logprobs=output_top_logprobs,
|
204
219
|
)
|
205
220
|
else:
|
206
|
-
|
207
|
-
|
221
|
+
pt, states, pruned_input_ids = 0, [], []
|
222
|
+
for i, extend_len in enumerate(logits_metadata.extend_seq_lens_cpu):
|
223
|
+
start_len = logits_metadata.logprob_start_lens_cpu[i]
|
224
|
+
states.append(hidden_states[pt + start_len : pt + extend_len])
|
225
|
+
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
226
|
+
pt += extend_len
|
227
|
+
|
228
|
+
states = torch.cat(states, dim=0)
|
229
|
+
pruned_input_ids = torch.cat(pruned_input_ids, dim=0)
|
230
|
+
|
231
|
+
cum_start_len1 = torch.tensor(
|
232
|
+
logits_metadata.logprob_start_lens_cpu, device="cuda"
|
233
|
+
).cumsum(0)
|
234
|
+
cum_start_len0 = torch.zeros_like(cum_start_len1)
|
235
|
+
cum_start_len0[1:] = cum_start_len1[:-1]
|
236
|
+
|
237
|
+
all_logits = torch.matmul(states, weight.T)
|
238
|
+
if self.do_tensor_parallel_all_gather:
|
208
239
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
209
240
|
all_logits = all_logits[:, : self.config.vocab_size].float()
|
210
241
|
|
211
242
|
if hasattr(self.config, "final_logit_softcapping"):
|
212
243
|
all_logits.div_(self.config.final_logit_softcapping)
|
213
|
-
|
244
|
+
torch.tanh(all_logits, out=all_logits)
|
214
245
|
all_logits.mul_(self.config.final_logit_softcapping)
|
215
246
|
|
216
247
|
all_logprobs = all_logits
|
@@ -228,20 +259,26 @@ class LogitsProcessor(nn.Module):
|
|
228
259
|
else:
|
229
260
|
input_top_logprobs = output_top_logprobs = None
|
230
261
|
|
231
|
-
last_logprobs = all_logprobs[last_index]
|
262
|
+
last_logprobs = all_logprobs[last_index - cum_start_len1]
|
232
263
|
|
233
264
|
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
234
265
|
# Note that we pad a zero at the end of each sequence for easy computation.
|
235
266
|
input_token_logprobs = all_logprobs[
|
236
267
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
237
|
-
torch.cat([
|
268
|
+
torch.cat([pruned_input_ids[1:], torch.tensor([0], device="cuda")]),
|
238
269
|
]
|
239
270
|
|
240
271
|
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
241
|
-
input_token_logprobs,
|
272
|
+
input_token_logprobs,
|
273
|
+
cum_start_len0,
|
274
|
+
cum_start_len1,
|
275
|
+
logits_metadata,
|
242
276
|
)
|
243
277
|
|
244
|
-
|
278
|
+
# Remove the last token logprob for the prefill tokens.
|
279
|
+
input_token_logprobs = input_token_logprobs[:-1]
|
280
|
+
|
281
|
+
return LogitsProcessorOutput(
|
245
282
|
next_token_logits=last_logits,
|
246
283
|
next_token_logprobs=last_logprobs,
|
247
284
|
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
@@ -203,7 +203,6 @@ class RadixAttention(nn.Module):
|
|
203
203
|
return self.decode_forward(q, k, v, input_metadata)
|
204
204
|
|
205
205
|
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
v_cache[input_metadata.out_cache_loc] = cache_v
|
206
|
+
input_metadata.token_to_kv_pool.set_kv_buffer(
|
207
|
+
self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
|
208
|
+
)
|
@@ -0,0 +1,154 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import logging
|
3
|
+
from typing import Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from flashinfer.sampling import (
|
7
|
+
min_p_sampling_from_probs,
|
8
|
+
top_k_renorm_prob,
|
9
|
+
top_k_top_p_sampling_from_probs,
|
10
|
+
top_p_renorm_prob,
|
11
|
+
)
|
12
|
+
from vllm.model_executor.custom_op import CustomOp
|
13
|
+
|
14
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
15
|
+
|
16
|
+
# TODO: move this dict to another place
|
17
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
18
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
@dataclasses.dataclass
|
24
|
+
class SampleOutput:
|
25
|
+
success: torch.Tensor
|
26
|
+
probs: torch.Tensor
|
27
|
+
batch_next_token_ids: torch.Tensor
|
28
|
+
|
29
|
+
|
30
|
+
class Sampler(CustomOp):
|
31
|
+
def __init__(self):
|
32
|
+
super().__init__()
|
33
|
+
|
34
|
+
def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
35
|
+
# min-token, presence, frequency
|
36
|
+
if sampling_info.linear_penalties is not None:
|
37
|
+
logits += sampling_info.linear_penalties
|
38
|
+
|
39
|
+
# repetition
|
40
|
+
if sampling_info.scaling_penalties is not None:
|
41
|
+
logits = torch.where(
|
42
|
+
logits > 0,
|
43
|
+
logits / sampling_info.scaling_penalties,
|
44
|
+
logits * sampling_info.scaling_penalties,
|
45
|
+
)
|
46
|
+
|
47
|
+
return logits
|
48
|
+
|
49
|
+
def _get_probs(
|
50
|
+
self,
|
51
|
+
logits: torch.Tensor,
|
52
|
+
sampling_info: SamplingBatchInfo,
|
53
|
+
is_torch_compile: bool = False,
|
54
|
+
):
|
55
|
+
# Post process logits
|
56
|
+
logits = logits.contiguous()
|
57
|
+
logits.div_(sampling_info.temperatures)
|
58
|
+
if is_torch_compile:
|
59
|
+
# FIXME: Temporary workaround for unknown bugs in torch.compile
|
60
|
+
logits.add_(0)
|
61
|
+
|
62
|
+
if sampling_info.logit_bias is not None:
|
63
|
+
logits.add_(sampling_info.logit_bias)
|
64
|
+
|
65
|
+
if sampling_info.vocab_mask is not None:
|
66
|
+
logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
|
67
|
+
|
68
|
+
logits = self._apply_penalties(logits, sampling_info)
|
69
|
+
|
70
|
+
return torch.softmax(logits, dim=-1)
|
71
|
+
|
72
|
+
def forward_cuda(
|
73
|
+
self,
|
74
|
+
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
75
|
+
sampling_info: SamplingBatchInfo,
|
76
|
+
):
|
77
|
+
if isinstance(logits, LogitsProcessorOutput):
|
78
|
+
logits = logits.next_token_logits
|
79
|
+
|
80
|
+
probs = self._get_probs(logits, sampling_info)
|
81
|
+
|
82
|
+
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
83
|
+
max_top_k_round, batch_size = 32, probs.shape[0]
|
84
|
+
uniform_samples = torch.rand(
|
85
|
+
(max_top_k_round, batch_size), device=probs.device
|
86
|
+
)
|
87
|
+
if sampling_info.need_min_p_sampling:
|
88
|
+
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
89
|
+
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
90
|
+
batch_next_token_ids, success = min_p_sampling_from_probs(
|
91
|
+
probs, uniform_samples, sampling_info.min_ps
|
92
|
+
)
|
93
|
+
else:
|
94
|
+
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
95
|
+
probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
|
96
|
+
)
|
97
|
+
else:
|
98
|
+
# Here we provide a slower fallback implementation.
|
99
|
+
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
100
|
+
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
101
|
+
)
|
102
|
+
|
103
|
+
return SampleOutput(success, probs, batch_next_token_ids)
|
104
|
+
|
105
|
+
def forward_native(
|
106
|
+
self,
|
107
|
+
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
108
|
+
sampling_info: SamplingBatchInfo,
|
109
|
+
):
|
110
|
+
if isinstance(logits, LogitsProcessorOutput):
|
111
|
+
logits = logits.next_token_logits
|
112
|
+
|
113
|
+
probs = self._get_probs(logits, sampling_info, is_torch_compile=True)
|
114
|
+
|
115
|
+
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
116
|
+
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
117
|
+
)
|
118
|
+
|
119
|
+
return SampleOutput(success, probs, batch_next_token_ids)
|
120
|
+
|
121
|
+
|
122
|
+
def top_k_top_p_min_p_sampling_from_probs_torch(
|
123
|
+
probs: torch.Tensor,
|
124
|
+
top_ks: torch.Tensor,
|
125
|
+
top_ps: torch.Tensor,
|
126
|
+
min_ps: torch.Tensor,
|
127
|
+
):
|
128
|
+
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
|
129
|
+
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
130
|
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
131
|
+
min_p_thresholds = probs_sort[:, 0] * min_ps
|
132
|
+
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
133
|
+
probs_sort[
|
134
|
+
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
135
|
+
>= top_ks.view(-1, 1)
|
136
|
+
] = 0.0
|
137
|
+
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
138
|
+
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
139
|
+
try:
|
140
|
+
# FIXME: torch.multiomial does not support num_samples = 1
|
141
|
+
sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
|
142
|
+
:, :1
|
143
|
+
]
|
144
|
+
except RuntimeError as e:
|
145
|
+
logger.warning(f"Sampling error: {e}")
|
146
|
+
batch_next_token_ids = torch.zeros(
|
147
|
+
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
|
148
|
+
)
|
149
|
+
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
150
|
+
return batch_next_token_ids, success
|
151
|
+
|
152
|
+
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
153
|
+
success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
|
154
|
+
return batch_next_token_ids, success
|
@@ -21,7 +21,6 @@ Each data parallel worker can manage multiple tensor parallel workers.
|
|
21
21
|
import dataclasses
|
22
22
|
import logging
|
23
23
|
import multiprocessing
|
24
|
-
import os
|
25
24
|
from enum import Enum, auto
|
26
25
|
|
27
26
|
import numpy as np
|
@@ -36,7 +35,7 @@ from sglang.srt.managers.io_struct import (
|
|
36
35
|
TokenizedGenerateReqInput,
|
37
36
|
)
|
38
37
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
39
|
-
from sglang.srt.utils import kill_parent_process
|
38
|
+
from sglang.srt.utils import configure_logger, kill_parent_process
|
40
39
|
from sglang.utils import get_exception_traceback
|
41
40
|
|
42
41
|
logger = logging.getLogger(__name__)
|
@@ -194,10 +193,7 @@ def start_controller_process(
|
|
194
193
|
):
|
195
194
|
"""Start a controller process."""
|
196
195
|
|
197
|
-
|
198
|
-
level=getattr(logging, server_args.log_level.upper()),
|
199
|
-
format="%(message)s",
|
200
|
-
)
|
196
|
+
configure_logger(server_args)
|
201
197
|
|
202
198
|
try:
|
203
199
|
controller = ControllerMulti(server_args, port_args, model_overide_args)
|
@@ -212,6 +208,4 @@ def start_controller_process(
|
|
212
208
|
except Exception:
|
213
209
|
logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
|
214
210
|
finally:
|
215
|
-
for w in controller.workers:
|
216
|
-
os.kill(w.proc.pid, 9)
|
217
211
|
kill_parent_process()
|
@@ -17,7 +17,6 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import logging
|
19
19
|
import multiprocessing
|
20
|
-
import os
|
21
20
|
from typing import List
|
22
21
|
|
23
22
|
import zmq
|
@@ -28,7 +27,7 @@ from sglang.srt.managers.tp_worker import (
|
|
28
27
|
launch_tp_servers,
|
29
28
|
)
|
30
29
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
31
|
-
from sglang.srt.utils import kill_parent_process
|
30
|
+
from sglang.srt.utils import configure_logger, kill_parent_process
|
32
31
|
from sglang.utils import get_exception_traceback
|
33
32
|
|
34
33
|
logger = logging.getLogger(__name__)
|
@@ -53,7 +52,7 @@ class ControllerSingle:
|
|
53
52
|
self.dp_worker_id = dp_worker_id
|
54
53
|
self.mp_queue = mp_queue
|
55
54
|
|
56
|
-
# Init communication
|
55
|
+
# Init inter-process communication
|
57
56
|
context = zmq.Context(2)
|
58
57
|
|
59
58
|
if not self.is_dp_worker:
|
@@ -134,11 +133,11 @@ def start_controller_process(
|
|
134
133
|
queue: multiprocessing.connection.Connection = None,
|
135
134
|
):
|
136
135
|
"""Start a controller process."""
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
)
|
136
|
+
if is_data_parallel_worker:
|
137
|
+
logger_prefix = f" DP{dp_worker_id} TP0"
|
138
|
+
else:
|
139
|
+
logger_prefix = " TP0"
|
140
|
+
configure_logger(server_args, prefix=logger_prefix)
|
142
141
|
|
143
142
|
if not is_data_parallel_worker:
|
144
143
|
tp_size_local = server_args.tp_size // server_args.nnodes
|
@@ -167,6 +166,4 @@ def start_controller_process(
|
|
167
166
|
except Exception:
|
168
167
|
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
169
168
|
finally:
|
170
|
-
for t in controller.tp_procs:
|
171
|
-
os.kill(t.pid, 9)
|
172
169
|
kill_parent_process()
|
@@ -17,7 +17,6 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import asyncio
|
19
19
|
import dataclasses
|
20
|
-
import inspect
|
21
20
|
from typing import List
|
22
21
|
|
23
22
|
import uvloop
|
@@ -29,6 +28,7 @@ from sglang.srt.managers.io_struct import (
|
|
29
28
|
BatchEmbeddingOut,
|
30
29
|
BatchStrOut,
|
31
30
|
BatchTokenIDOut,
|
31
|
+
UpdateWeightReqOutput,
|
32
32
|
)
|
33
33
|
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
|
34
34
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
@@ -39,6 +39,8 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
39
39
|
|
40
40
|
@dataclasses.dataclass
|
41
41
|
class DecodeStatus:
|
42
|
+
"""Store the status of incremental decoding."""
|
43
|
+
|
42
44
|
vid: int
|
43
45
|
decoded_text: str
|
44
46
|
decode_ids: List[int]
|
@@ -47,11 +49,14 @@ class DecodeStatus:
|
|
47
49
|
|
48
50
|
|
49
51
|
class DetokenizerManager:
|
52
|
+
"""DetokenizerManager is a process that detokenizes the token ids."""
|
53
|
+
|
50
54
|
def __init__(
|
51
55
|
self,
|
52
56
|
server_args: ServerArgs,
|
53
57
|
port_args: PortArgs,
|
54
58
|
):
|
59
|
+
# Init inter-process communication
|
55
60
|
context = zmq.asyncio.Context(2)
|
56
61
|
self.recv_from_router = context.socket(zmq.PULL)
|
57
62
|
self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
|
@@ -71,10 +76,13 @@ class DetokenizerManager:
|
|
71
76
|
self.decode_status = {}
|
72
77
|
|
73
78
|
async def handle_loop(self):
|
79
|
+
"""The event loop that handles requests"""
|
80
|
+
|
74
81
|
while True:
|
75
|
-
recv_obj
|
82
|
+
recv_obj = await self.recv_from_router.recv_pyobj()
|
76
83
|
|
77
84
|
if isinstance(recv_obj, BatchEmbeddingOut):
|
85
|
+
# If it is embedding model, no detokenization is needed.
|
78
86
|
self.send_to_tokenizer.send_pyobj(
|
79
87
|
BatchEmbeddingOut(
|
80
88
|
rids=recv_obj.rids,
|
@@ -84,15 +92,18 @@ class DetokenizerManager:
|
|
84
92
|
)
|
85
93
|
)
|
86
94
|
continue
|
95
|
+
elif isinstance(recv_obj, UpdateWeightReqOutput):
|
96
|
+
# If it is a weight update request, no detokenization is needed.
|
97
|
+
self.send_to_tokenizer.send_pyobj(recv_obj)
|
98
|
+
continue
|
99
|
+
elif self.tokenizer is None:
|
100
|
+
# If the tokenizer is skipped, no detokenization is needed
|
101
|
+
self.send_to_tokenizer.send_pyobj(recv_obj)
|
102
|
+
continue
|
87
103
|
|
88
104
|
assert isinstance(recv_obj, BatchTokenIDOut)
|
89
105
|
bs = len(recv_obj.rids)
|
90
106
|
|
91
|
-
if self.tokenizer is None:
|
92
|
-
# Send BatchTokenIDOut if no tokenizer init'ed.
|
93
|
-
self.send_to_tokenizer.send_pyobj(recv_obj)
|
94
|
-
continue
|
95
|
-
|
96
107
|
# Initialize decode status
|
97
108
|
read_ids, surr_ids = [], []
|
98
109
|
for i in range(bs):
|
@@ -126,8 +137,7 @@ class DetokenizerManager:
|
|
126
137
|
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
127
138
|
)
|
128
139
|
|
129
|
-
#
|
130
|
-
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
140
|
+
# Incremental decoding
|
131
141
|
output_strs = []
|
132
142
|
for i in range(bs):
|
133
143
|
s = self.decode_status[recv_obj.rids[i]]
|
@@ -144,6 +154,7 @@ class DetokenizerManager:
|
|
144
154
|
|
145
155
|
output_strs.append(s.decoded_text + new_text)
|
146
156
|
|
157
|
+
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
|
147
158
|
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
148
159
|
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
149
160
|
if pos != -1:
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -22,10 +22,8 @@ import uuid
|
|
22
22
|
from dataclasses import dataclass
|
23
23
|
from typing import Dict, List, Optional, Union
|
24
24
|
|
25
|
-
import torch
|
26
|
-
|
27
25
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
28
|
-
from sglang.srt.sampling_params import SamplingParams
|
26
|
+
from sglang.srt.sampling.sampling_params import SamplingParams
|
29
27
|
|
30
28
|
|
31
29
|
@dataclass
|
@@ -43,9 +41,9 @@ class GenerateReqInput:
|
|
43
41
|
rid: Optional[Union[List[str], str]] = None
|
44
42
|
# Whether to return logprobs.
|
45
43
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
46
|
-
#
|
44
|
+
# If return logprobs, the start location in the prompt for returning logprobs.
|
47
45
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
48
|
-
#
|
46
|
+
# If return logprobs, the number of top logprobs to return at each position.
|
49
47
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
50
48
|
# Whether to detokenize tokens in text in the returned logprobs.
|
51
49
|
return_text_in_logprobs: bool = False
|
@@ -77,7 +75,7 @@ class GenerateReqInput:
|
|
77
75
|
if self.return_logprob is None:
|
78
76
|
self.return_logprob = False
|
79
77
|
if self.logprob_start_len is None:
|
80
|
-
self.logprob_start_len =
|
78
|
+
self.logprob_start_len = -1
|
81
79
|
if self.top_logprobs_num is None:
|
82
80
|
self.top_logprobs_num = 0
|
83
81
|
else:
|
@@ -143,7 +141,7 @@ class GenerateReqInput:
|
|
143
141
|
self.return_logprob = [self.return_logprob] * num
|
144
142
|
|
145
143
|
if self.logprob_start_len is None:
|
146
|
-
self.logprob_start_len = [
|
144
|
+
self.logprob_start_len = [-1] * num
|
147
145
|
elif not isinstance(self.logprob_start_len, list):
|
148
146
|
self.logprob_start_len = [self.logprob_start_len] * num
|
149
147
|
|
@@ -155,16 +153,27 @@ class GenerateReqInput:
|
|
155
153
|
|
156
154
|
@dataclass
|
157
155
|
class TokenizedGenerateReqInput:
|
156
|
+
# The request id
|
158
157
|
rid: str
|
158
|
+
# The input text
|
159
159
|
input_text: str
|
160
|
+
# The input token ids
|
160
161
|
input_ids: List[int]
|
162
|
+
# The pixel values for input images
|
161
163
|
pixel_values: List[float]
|
164
|
+
# The hash of input images
|
162
165
|
image_hash: int
|
166
|
+
# The image size
|
163
167
|
image_size: List[int]
|
168
|
+
# The sampling parameters
|
164
169
|
sampling_params: SamplingParams
|
170
|
+
# Whether to return the logprobs
|
165
171
|
return_logprob: bool
|
172
|
+
# If return logprobs, the start location in the prompt for returning logprobs.
|
166
173
|
logprob_start_len: int
|
174
|
+
# If return logprobs, the number of top logprobs to return at each position.
|
167
175
|
top_logprobs_num: int
|
176
|
+
# Whether to stream output
|
168
177
|
stream: bool
|
169
178
|
|
170
179
|
|
@@ -215,15 +224,21 @@ class EmbeddingReqInput:
|
|
215
224
|
|
216
225
|
@dataclass
|
217
226
|
class TokenizedEmbeddingReqInput:
|
227
|
+
# The request id
|
218
228
|
rid: str
|
229
|
+
# The input text
|
219
230
|
input_text: str
|
231
|
+
# The input token ids
|
220
232
|
input_ids: List[int]
|
233
|
+
# Dummy sampling params for compatibility
|
221
234
|
sampling_params: SamplingParams
|
222
235
|
|
223
236
|
|
224
237
|
@dataclass
|
225
238
|
class BatchTokenIDOut:
|
239
|
+
# The request id
|
226
240
|
rids: List[str]
|
241
|
+
# The version id to sync decode status with in detokenizer_manager
|
227
242
|
vids: List[int]
|
228
243
|
decoded_texts: List[str]
|
229
244
|
decode_ids: List[int]
|
@@ -236,17 +251,25 @@ class BatchTokenIDOut:
|
|
236
251
|
|
237
252
|
@dataclass
|
238
253
|
class BatchStrOut:
|
254
|
+
# The request id
|
239
255
|
rids: List[str]
|
256
|
+
# The output decoded strings
|
240
257
|
output_strs: List[str]
|
258
|
+
# The meta info
|
241
259
|
meta_info: List[Dict]
|
260
|
+
# The finish reason
|
242
261
|
finished_reason: List[BaseFinishReason]
|
243
262
|
|
244
263
|
|
245
264
|
@dataclass
|
246
265
|
class BatchEmbeddingOut:
|
266
|
+
# The request id
|
247
267
|
rids: List[str]
|
268
|
+
# The output embedding
|
248
269
|
embeddings: List[List[float]]
|
270
|
+
# The meta info
|
249
271
|
meta_info: List[Dict]
|
272
|
+
# The finish reason
|
250
273
|
finished_reason: List[BaseFinishReason]
|
251
274
|
|
252
275
|
|
@@ -256,10 +279,20 @@ class FlushCacheReq:
|
|
256
279
|
|
257
280
|
|
258
281
|
@dataclass
|
259
|
-
class
|
260
|
-
|
282
|
+
class UpdateWeightReqInput:
|
283
|
+
# The model path with the new weights
|
284
|
+
model_path: str
|
285
|
+
# The format to load the weights
|
286
|
+
load_format: Optional[str] = None
|
261
287
|
|
262
288
|
|
263
289
|
@dataclass
|
264
|
-
class
|
265
|
-
|
290
|
+
class UpdateWeightReqOutput:
|
291
|
+
success: bool
|
292
|
+
message: str
|
293
|
+
|
294
|
+
|
295
|
+
@dataclass
|
296
|
+
class AbortReq:
|
297
|
+
# The request id
|
298
|
+
rid: str
|