sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +13 -1
- sglang/bench_latency.py +10 -5
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +5 -2
- sglang/lang/ir.py +22 -4
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -2
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +64 -27
- sglang/srt/layers/radix_attention.py +41 -18
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +59 -179
- sglang/srt/managers/tokenizer_manager.py +193 -84
- sglang/srt/managers/tp_worker.py +131 -50
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +97 -28
- sglang/srt/model_executor/forward_batch_info.py +188 -82
- sglang/srt/model_executor/model_runner.py +269 -87
- sglang/srt/models/chatglm.py +6 -14
- sglang/srt/models/commandr.py +6 -2
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +7 -3
- sglang/srt/models/deepseek_v2.py +12 -7
- sglang/srt/models/gemma.py +6 -2
- sglang/srt/models/gemma2.py +22 -8
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +66 -398
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +7 -3
- sglang/srt/models/mixtral.py +61 -255
- sglang/srt/models/mixtral_quant.py +6 -5
- sglang/srt/models/qwen.py +7 -4
- sglang/srt/models/qwen2.py +15 -5
- sglang/srt/models/qwen2_moe.py +7 -16
- sglang/srt/models/stablelm.py +6 -2
- sglang/srt/openai_api/adapter.py +149 -58
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
- sglang/srt/server.py +107 -71
- sglang/srt/server_args.py +49 -15
- sglang/srt/utils.py +27 -18
- sglang/test/runners.py +38 -38
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_programs.py +32 -5
- sglang/test/test_utils.py +37 -50
- sglang/version.py +1 -1
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.12.dist-info/RECORD +0 -112
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
|
|
29
29
|
|
30
30
|
|
31
31
|
@dataclasses.dataclass
|
32
|
-
class
|
32
|
+
class LogitsProcessorOutput:
|
33
33
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
34
34
|
next_token_logits: torch.Tensor
|
35
35
|
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
@@ -55,6 +55,9 @@ class LogitsMetadata:
|
|
55
55
|
extend_start_loc: Optional[torch.Tensor] = None
|
56
56
|
top_logprobs_nums: Optional[List[int]] = None
|
57
57
|
|
58
|
+
extend_seq_lens_cpu: List[int] = None
|
59
|
+
logprob_start_lens_cpu: List[int] = None
|
60
|
+
|
58
61
|
@classmethod
|
59
62
|
def from_input_metadata(cls, input_metadata: InputMetadata):
|
60
63
|
return cls(
|
@@ -63,22 +66,30 @@ class LogitsMetadata:
|
|
63
66
|
extend_start_loc=input_metadata.extend_start_loc,
|
64
67
|
return_logprob=input_metadata.return_logprob,
|
65
68
|
top_logprobs_nums=input_metadata.top_logprobs_nums,
|
69
|
+
extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
|
70
|
+
logprob_start_lens_cpu=input_metadata.logprob_start_lens_cpu,
|
66
71
|
)
|
67
72
|
|
68
73
|
|
69
74
|
class LogitsProcessor(nn.Module):
|
70
|
-
def __init__(self, config):
|
75
|
+
def __init__(self, config, skip_all_gather: bool = False):
|
71
76
|
super().__init__()
|
72
77
|
self.config = config
|
73
|
-
self.
|
78
|
+
self.do_tensor_parallel_all_gather = (
|
79
|
+
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
80
|
+
)
|
74
81
|
|
75
82
|
def _get_normalized_prompt_logprobs(
|
76
|
-
self,
|
83
|
+
self,
|
84
|
+
input_token_logprobs: torch.Tensor,
|
85
|
+
cum_start_len0: torch.Tensor,
|
86
|
+
cum_start_len1: torch.Tensor,
|
87
|
+
logits_metadata: LogitsMetadata,
|
77
88
|
):
|
78
89
|
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
79
90
|
|
80
|
-
start = logits_metadata.extend_start_loc.clone()
|
81
|
-
end = start + logits_metadata.extend_seq_lens - 2
|
91
|
+
start = logits_metadata.extend_start_loc.clone() - cum_start_len0
|
92
|
+
end = start + logits_metadata.extend_seq_lens - 2 - cum_start_len1
|
82
93
|
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
83
94
|
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
84
95
|
sum_logp = (
|
@@ -91,7 +102,7 @@ class LogitsProcessor(nn.Module):
|
|
91
102
|
return normalized_prompt_logprobs
|
92
103
|
|
93
104
|
@staticmethod
|
94
|
-
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
|
105
|
+
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
95
106
|
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
96
107
|
output_top_logprobs = []
|
97
108
|
max_k = max(logits_metadata.top_logprobs_nums)
|
@@ -105,7 +116,7 @@ class LogitsProcessor(nn.Module):
|
|
105
116
|
# TODO: vectorize the code below
|
106
117
|
input_top_logprobs, output_top_logprobs = [], []
|
107
118
|
pt = 0
|
108
|
-
extend_seq_lens_cpu = logits_metadata.
|
119
|
+
extend_seq_lens_cpu = logits_metadata.extend_seq_lens_cpu
|
109
120
|
|
110
121
|
max_k = max(logits_metadata.top_logprobs_nums)
|
111
122
|
ret = all_logprobs.topk(max_k, dim=1)
|
@@ -113,26 +124,30 @@ class LogitsProcessor(nn.Module):
|
|
113
124
|
indices = ret.indices.tolist()
|
114
125
|
|
115
126
|
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
127
|
+
start_len = logits_metadata.logprob_start_lens_cpu[i]
|
128
|
+
pruned_len = extend_seq_len - start_len
|
129
|
+
|
116
130
|
if extend_seq_len == 0:
|
117
131
|
input_top_logprobs.append([])
|
118
132
|
output_top_logprobs.append([])
|
119
133
|
continue
|
134
|
+
|
120
135
|
k = logits_metadata.top_logprobs_nums[i]
|
121
136
|
input_top_logprobs.append(
|
122
137
|
[
|
123
138
|
list(zip(values[pt + j][:k], indices[pt + j][:k]))
|
124
|
-
for j in range(
|
139
|
+
for j in range(pruned_len - 1)
|
125
140
|
]
|
126
141
|
)
|
127
142
|
output_top_logprobs.append(
|
128
143
|
list(
|
129
144
|
zip(
|
130
|
-
values[pt +
|
131
|
-
indices[pt +
|
145
|
+
values[pt + pruned_len - 1][:k],
|
146
|
+
indices[pt + pruned_len - 1][:k],
|
132
147
|
)
|
133
148
|
)
|
134
149
|
)
|
135
|
-
pt +=
|
150
|
+
pt += pruned_len
|
136
151
|
|
137
152
|
return input_top_logprobs, output_top_logprobs
|
138
153
|
|
@@ -159,18 +174,18 @@ class LogitsProcessor(nn.Module):
|
|
159
174
|
last_hidden = hidden_states[last_index]
|
160
175
|
|
161
176
|
last_logits = torch.matmul(last_hidden, weight.T)
|
162
|
-
if self.
|
177
|
+
if self.do_tensor_parallel_all_gather:
|
163
178
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
164
179
|
last_logits = last_logits[:, : self.config.vocab_size].float()
|
165
180
|
|
166
181
|
if hasattr(self.config, "final_logit_softcapping"):
|
167
|
-
last_logits
|
168
|
-
|
169
|
-
last_logits
|
182
|
+
last_logits.div_(self.config.final_logit_softcapping)
|
183
|
+
torch.tanh(last_logits, out=last_logits)
|
184
|
+
last_logits.mul_(self.config.final_logit_softcapping)
|
170
185
|
|
171
186
|
# Return only last_logits if logprob is not requested
|
172
187
|
if not logits_metadata.return_logprob:
|
173
|
-
return
|
188
|
+
return LogitsProcessorOutput(
|
174
189
|
next_token_logits=last_logits,
|
175
190
|
next_token_logprobs=None,
|
176
191
|
normalized_prompt_logprobs=None,
|
@@ -194,7 +209,7 @@ class LogitsProcessor(nn.Module):
|
|
194
209
|
else:
|
195
210
|
output_top_logprobs = None
|
196
211
|
|
197
|
-
return
|
212
|
+
return LogitsProcessorOutput(
|
198
213
|
next_token_logits=last_logits,
|
199
214
|
next_token_logprobs=last_logprobs,
|
200
215
|
normalized_prompt_logprobs=None,
|
@@ -203,15 +218,31 @@ class LogitsProcessor(nn.Module):
|
|
203
218
|
output_top_logprobs=output_top_logprobs,
|
204
219
|
)
|
205
220
|
else:
|
206
|
-
|
207
|
-
|
221
|
+
pt, states, pruned_input_ids = 0, [], []
|
222
|
+
for i, extend_len in enumerate(logits_metadata.extend_seq_lens_cpu):
|
223
|
+
start_len = logits_metadata.logprob_start_lens_cpu[i]
|
224
|
+
states.append(hidden_states[pt + start_len : pt + extend_len])
|
225
|
+
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
226
|
+
pt += extend_len
|
227
|
+
|
228
|
+
states = torch.cat(states, dim=0)
|
229
|
+
pruned_input_ids = torch.cat(pruned_input_ids, dim=0)
|
230
|
+
|
231
|
+
cum_start_len1 = torch.tensor(
|
232
|
+
logits_metadata.logprob_start_lens_cpu, device="cuda"
|
233
|
+
).cumsum(0)
|
234
|
+
cum_start_len0 = torch.zeros_like(cum_start_len1)
|
235
|
+
cum_start_len0[1:] = cum_start_len1[:-1]
|
236
|
+
|
237
|
+
all_logits = torch.matmul(states, weight.T)
|
238
|
+
if self.do_tensor_parallel_all_gather:
|
208
239
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
209
240
|
all_logits = all_logits[:, : self.config.vocab_size].float()
|
210
241
|
|
211
242
|
if hasattr(self.config, "final_logit_softcapping"):
|
212
|
-
all_logits
|
213
|
-
|
214
|
-
all_logits
|
243
|
+
all_logits.div_(self.config.final_logit_softcapping)
|
244
|
+
torch.tanh(all_logits, out=all_logits)
|
245
|
+
all_logits.mul_(self.config.final_logit_softcapping)
|
215
246
|
|
216
247
|
all_logprobs = all_logits
|
217
248
|
del all_logits, hidden_states
|
@@ -228,20 +259,26 @@ class LogitsProcessor(nn.Module):
|
|
228
259
|
else:
|
229
260
|
input_top_logprobs = output_top_logprobs = None
|
230
261
|
|
231
|
-
last_logprobs = all_logprobs[last_index]
|
262
|
+
last_logprobs = all_logprobs[last_index - cum_start_len1]
|
232
263
|
|
233
264
|
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
234
265
|
# Note that we pad a zero at the end of each sequence for easy computation.
|
235
266
|
input_token_logprobs = all_logprobs[
|
236
267
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
237
|
-
torch.cat([
|
268
|
+
torch.cat([pruned_input_ids[1:], torch.tensor([0], device="cuda")]),
|
238
269
|
]
|
239
270
|
|
240
271
|
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
241
|
-
input_token_logprobs,
|
272
|
+
input_token_logprobs,
|
273
|
+
cum_start_len0,
|
274
|
+
cum_start_len1,
|
275
|
+
logits_metadata,
|
242
276
|
)
|
243
277
|
|
244
|
-
|
278
|
+
# Remove the last token logprob for the prefill tokens.
|
279
|
+
input_token_logprobs = input_token_logprobs[:-1]
|
280
|
+
|
281
|
+
return LogitsProcessorOutput(
|
245
282
|
next_token_logits=last_logits,
|
246
283
|
next_token_logprobs=last_logprobs,
|
247
284
|
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
@@ -15,6 +15,8 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""Radix attention."""
|
17
17
|
|
18
|
+
from typing import Optional
|
19
|
+
|
18
20
|
import torch
|
19
21
|
from flashinfer.cascade import merge_state
|
20
22
|
from torch import nn
|
@@ -34,6 +36,7 @@ class RadixAttention(nn.Module):
|
|
34
36
|
scaling: float,
|
35
37
|
num_kv_heads: int,
|
36
38
|
layer_id: int,
|
39
|
+
sliding_window_size: Optional[int] = None,
|
37
40
|
logit_cap: int = -1,
|
38
41
|
v_head_dim: int = -1,
|
39
42
|
):
|
@@ -46,6 +49,7 @@ class RadixAttention(nn.Module):
|
|
46
49
|
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
47
50
|
self.scaling = scaling
|
48
51
|
self.layer_id = layer_id
|
52
|
+
self.sliding_window_size = sliding_window_size if sliding_window_size else -1
|
49
53
|
|
50
54
|
if (
|
51
55
|
not global_server_args_dict.get("disable_flashinfer", False)
|
@@ -113,14 +117,25 @@ class RadixAttention(nn.Module):
|
|
113
117
|
return o
|
114
118
|
|
115
119
|
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
120
|
+
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
121
|
+
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
|
122
|
+
if self.sliding_window_size != -1:
|
123
|
+
prefill_wrapper_paged = prefill_wrapper_paged[0]
|
124
|
+
else:
|
125
|
+
if isinstance(prefill_wrapper_paged, list):
|
126
|
+
prefill_wrapper_paged = prefill_wrapper_paged[1]
|
127
|
+
|
116
128
|
if not input_metadata.flashinfer_use_ragged:
|
117
|
-
|
129
|
+
if k is not None:
|
130
|
+
assert v is not None
|
131
|
+
self.store_kv_cache(k, v, input_metadata)
|
118
132
|
|
119
|
-
o =
|
133
|
+
o = prefill_wrapper_paged.forward(
|
120
134
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
121
135
|
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
122
136
|
causal=True,
|
123
137
|
sm_scale=self.scaling,
|
138
|
+
window_left=self.sliding_window_size,
|
124
139
|
logits_soft_cap=self.logit_cap,
|
125
140
|
)
|
126
141
|
else:
|
@@ -138,14 +153,12 @@ class RadixAttention(nn.Module):
|
|
138
153
|
if input_metadata.extend_no_prefix:
|
139
154
|
o = o1
|
140
155
|
else:
|
141
|
-
o2, s2 = (
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
logits_soft_cap=self.logit_cap,
|
148
|
-
)
|
156
|
+
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
157
|
+
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
158
|
+
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
159
|
+
causal=False,
|
160
|
+
sm_scale=self.scaling,
|
161
|
+
logits_soft_cap=self.logit_cap,
|
149
162
|
)
|
150
163
|
|
151
164
|
o, _ = merge_state(o1, s1, o2, s2)
|
@@ -158,9 +171,18 @@ class RadixAttention(nn.Module):
|
|
158
171
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
159
172
|
|
160
173
|
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
161
|
-
|
174
|
+
decode_wrapper = input_metadata.flashinfer_decode_wrapper
|
175
|
+
if self.sliding_window_size != -1:
|
176
|
+
decode_wrapper = decode_wrapper[0]
|
177
|
+
else:
|
178
|
+
if isinstance(decode_wrapper, list):
|
179
|
+
decode_wrapper = decode_wrapper[1]
|
162
180
|
|
163
|
-
|
181
|
+
if k is not None:
|
182
|
+
assert v is not None
|
183
|
+
self.store_kv_cache(k, v, input_metadata)
|
184
|
+
|
185
|
+
o = decode_wrapper.forward(
|
164
186
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
165
187
|
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
166
188
|
sm_scale=self.scaling,
|
@@ -170,8 +192,10 @@ class RadixAttention(nn.Module):
|
|
170
192
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
171
193
|
|
172
194
|
def forward(self, q, k, v, input_metadata: InputMetadata):
|
173
|
-
k
|
174
|
-
|
195
|
+
if k is not None:
|
196
|
+
assert v is not None
|
197
|
+
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
198
|
+
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
175
199
|
|
176
200
|
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
177
201
|
return self.extend_forward(q, k, v, input_metadata)
|
@@ -179,7 +203,6 @@ class RadixAttention(nn.Module):
|
|
179
203
|
return self.decode_forward(q, k, v, input_metadata)
|
180
204
|
|
181
205
|
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
v_cache[input_metadata.out_cache_loc] = cache_v
|
206
|
+
input_metadata.token_to_kv_pool.set_kv_buffer(
|
207
|
+
self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
|
208
|
+
)
|
@@ -0,0 +1,154 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import logging
|
3
|
+
from typing import Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from flashinfer.sampling import (
|
7
|
+
min_p_sampling_from_probs,
|
8
|
+
top_k_renorm_prob,
|
9
|
+
top_k_top_p_sampling_from_probs,
|
10
|
+
top_p_renorm_prob,
|
11
|
+
)
|
12
|
+
from vllm.model_executor.custom_op import CustomOp
|
13
|
+
|
14
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
15
|
+
|
16
|
+
# TODO: move this dict to another place
|
17
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
18
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
@dataclasses.dataclass
|
24
|
+
class SampleOutput:
|
25
|
+
success: torch.Tensor
|
26
|
+
probs: torch.Tensor
|
27
|
+
batch_next_token_ids: torch.Tensor
|
28
|
+
|
29
|
+
|
30
|
+
class Sampler(CustomOp):
|
31
|
+
def __init__(self):
|
32
|
+
super().__init__()
|
33
|
+
|
34
|
+
def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
35
|
+
# min-token, presence, frequency
|
36
|
+
if sampling_info.linear_penalties is not None:
|
37
|
+
logits += sampling_info.linear_penalties
|
38
|
+
|
39
|
+
# repetition
|
40
|
+
if sampling_info.scaling_penalties is not None:
|
41
|
+
logits = torch.where(
|
42
|
+
logits > 0,
|
43
|
+
logits / sampling_info.scaling_penalties,
|
44
|
+
logits * sampling_info.scaling_penalties,
|
45
|
+
)
|
46
|
+
|
47
|
+
return logits
|
48
|
+
|
49
|
+
def _get_probs(
|
50
|
+
self,
|
51
|
+
logits: torch.Tensor,
|
52
|
+
sampling_info: SamplingBatchInfo,
|
53
|
+
is_torch_compile: bool = False,
|
54
|
+
):
|
55
|
+
# Post process logits
|
56
|
+
logits = logits.contiguous()
|
57
|
+
logits.div_(sampling_info.temperatures)
|
58
|
+
if is_torch_compile:
|
59
|
+
# FIXME: Temporary workaround for unknown bugs in torch.compile
|
60
|
+
logits.add_(0)
|
61
|
+
|
62
|
+
if sampling_info.logit_bias is not None:
|
63
|
+
logits.add_(sampling_info.logit_bias)
|
64
|
+
|
65
|
+
if sampling_info.vocab_mask is not None:
|
66
|
+
logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
|
67
|
+
|
68
|
+
logits = self._apply_penalties(logits, sampling_info)
|
69
|
+
|
70
|
+
return torch.softmax(logits, dim=-1)
|
71
|
+
|
72
|
+
def forward_cuda(
|
73
|
+
self,
|
74
|
+
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
75
|
+
sampling_info: SamplingBatchInfo,
|
76
|
+
):
|
77
|
+
if isinstance(logits, LogitsProcessorOutput):
|
78
|
+
logits = logits.next_token_logits
|
79
|
+
|
80
|
+
probs = self._get_probs(logits, sampling_info)
|
81
|
+
|
82
|
+
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
83
|
+
max_top_k_round, batch_size = 32, probs.shape[0]
|
84
|
+
uniform_samples = torch.rand(
|
85
|
+
(max_top_k_round, batch_size), device=probs.device
|
86
|
+
)
|
87
|
+
if sampling_info.need_min_p_sampling:
|
88
|
+
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
89
|
+
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
90
|
+
batch_next_token_ids, success = min_p_sampling_from_probs(
|
91
|
+
probs, uniform_samples, sampling_info.min_ps
|
92
|
+
)
|
93
|
+
else:
|
94
|
+
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
95
|
+
probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
|
96
|
+
)
|
97
|
+
else:
|
98
|
+
# Here we provide a slower fallback implementation.
|
99
|
+
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
100
|
+
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
101
|
+
)
|
102
|
+
|
103
|
+
return SampleOutput(success, probs, batch_next_token_ids)
|
104
|
+
|
105
|
+
def forward_native(
|
106
|
+
self,
|
107
|
+
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
108
|
+
sampling_info: SamplingBatchInfo,
|
109
|
+
):
|
110
|
+
if isinstance(logits, LogitsProcessorOutput):
|
111
|
+
logits = logits.next_token_logits
|
112
|
+
|
113
|
+
probs = self._get_probs(logits, sampling_info, is_torch_compile=True)
|
114
|
+
|
115
|
+
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
116
|
+
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
117
|
+
)
|
118
|
+
|
119
|
+
return SampleOutput(success, probs, batch_next_token_ids)
|
120
|
+
|
121
|
+
|
122
|
+
def top_k_top_p_min_p_sampling_from_probs_torch(
|
123
|
+
probs: torch.Tensor,
|
124
|
+
top_ks: torch.Tensor,
|
125
|
+
top_ps: torch.Tensor,
|
126
|
+
min_ps: torch.Tensor,
|
127
|
+
):
|
128
|
+
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
|
129
|
+
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
130
|
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
131
|
+
min_p_thresholds = probs_sort[:, 0] * min_ps
|
132
|
+
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
133
|
+
probs_sort[
|
134
|
+
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
135
|
+
>= top_ks.view(-1, 1)
|
136
|
+
] = 0.0
|
137
|
+
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
138
|
+
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
139
|
+
try:
|
140
|
+
# FIXME: torch.multiomial does not support num_samples = 1
|
141
|
+
sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
|
142
|
+
:, :1
|
143
|
+
]
|
144
|
+
except RuntimeError as e:
|
145
|
+
logger.warning(f"Sampling error: {e}")
|
146
|
+
batch_next_token_ids = torch.zeros(
|
147
|
+
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
|
148
|
+
)
|
149
|
+
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
150
|
+
return batch_next_token_ids, success
|
151
|
+
|
152
|
+
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
153
|
+
success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
|
154
|
+
return batch_next_token_ids, success
|
@@ -21,7 +21,6 @@ Each data parallel worker can manage multiple tensor parallel workers.
|
|
21
21
|
import dataclasses
|
22
22
|
import logging
|
23
23
|
import multiprocessing
|
24
|
-
import os
|
25
24
|
from enum import Enum, auto
|
26
25
|
|
27
26
|
import numpy as np
|
@@ -36,7 +35,7 @@ from sglang.srt.managers.io_struct import (
|
|
36
35
|
TokenizedGenerateReqInput,
|
37
36
|
)
|
38
37
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
39
|
-
from sglang.srt.utils import kill_parent_process
|
38
|
+
from sglang.srt.utils import configure_logger, kill_parent_process
|
40
39
|
from sglang.utils import get_exception_traceback
|
41
40
|
|
42
41
|
logger = logging.getLogger(__name__)
|
@@ -194,10 +193,7 @@ def start_controller_process(
|
|
194
193
|
):
|
195
194
|
"""Start a controller process."""
|
196
195
|
|
197
|
-
|
198
|
-
level=getattr(logging, server_args.log_level.upper()),
|
199
|
-
format="%(message)s",
|
200
|
-
)
|
196
|
+
configure_logger(server_args)
|
201
197
|
|
202
198
|
try:
|
203
199
|
controller = ControllerMulti(server_args, port_args, model_overide_args)
|
@@ -212,6 +208,4 @@ def start_controller_process(
|
|
212
208
|
except Exception:
|
213
209
|
logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
|
214
210
|
finally:
|
215
|
-
for w in controller.workers:
|
216
|
-
os.kill(w.proc.pid, 9)
|
217
211
|
kill_parent_process()
|
@@ -17,7 +17,6 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import logging
|
19
19
|
import multiprocessing
|
20
|
-
import os
|
21
20
|
from typing import List
|
22
21
|
|
23
22
|
import zmq
|
@@ -28,7 +27,7 @@ from sglang.srt.managers.tp_worker import (
|
|
28
27
|
launch_tp_servers,
|
29
28
|
)
|
30
29
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
31
|
-
from sglang.srt.utils import kill_parent_process
|
30
|
+
from sglang.srt.utils import configure_logger, kill_parent_process
|
32
31
|
from sglang.utils import get_exception_traceback
|
33
32
|
|
34
33
|
logger = logging.getLogger(__name__)
|
@@ -53,7 +52,7 @@ class ControllerSingle:
|
|
53
52
|
self.dp_worker_id = dp_worker_id
|
54
53
|
self.mp_queue = mp_queue
|
55
54
|
|
56
|
-
# Init communication
|
55
|
+
# Init inter-process communication
|
57
56
|
context = zmq.Context(2)
|
58
57
|
|
59
58
|
if not self.is_dp_worker:
|
@@ -134,11 +133,11 @@ def start_controller_process(
|
|
134
133
|
queue: multiprocessing.connection.Connection = None,
|
135
134
|
):
|
136
135
|
"""Start a controller process."""
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
)
|
136
|
+
if is_data_parallel_worker:
|
137
|
+
logger_prefix = f" DP{dp_worker_id} TP0"
|
138
|
+
else:
|
139
|
+
logger_prefix = " TP0"
|
140
|
+
configure_logger(server_args, prefix=logger_prefix)
|
142
141
|
|
143
142
|
if not is_data_parallel_worker:
|
144
143
|
tp_size_local = server_args.tp_size // server_args.nnodes
|
@@ -167,6 +166,4 @@ def start_controller_process(
|
|
167
166
|
except Exception:
|
168
167
|
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
169
168
|
finally:
|
170
|
-
for t in controller.tp_procs:
|
171
|
-
os.kill(t.pid, 9)
|
172
169
|
kill_parent_process()
|
@@ -17,7 +17,6 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import asyncio
|
19
19
|
import dataclasses
|
20
|
-
import inspect
|
21
20
|
from typing import List
|
22
21
|
|
23
22
|
import uvloop
|
@@ -29,6 +28,7 @@ from sglang.srt.managers.io_struct import (
|
|
29
28
|
BatchEmbeddingOut,
|
30
29
|
BatchStrOut,
|
31
30
|
BatchTokenIDOut,
|
31
|
+
UpdateWeightReqOutput,
|
32
32
|
)
|
33
33
|
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
|
34
34
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
@@ -39,6 +39,8 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
39
39
|
|
40
40
|
@dataclasses.dataclass
|
41
41
|
class DecodeStatus:
|
42
|
+
"""Store the status of incremental decoding."""
|
43
|
+
|
42
44
|
vid: int
|
43
45
|
decoded_text: str
|
44
46
|
decode_ids: List[int]
|
@@ -47,11 +49,14 @@ class DecodeStatus:
|
|
47
49
|
|
48
50
|
|
49
51
|
class DetokenizerManager:
|
52
|
+
"""DetokenizerManager is a process that detokenizes the token ids."""
|
53
|
+
|
50
54
|
def __init__(
|
51
55
|
self,
|
52
56
|
server_args: ServerArgs,
|
53
57
|
port_args: PortArgs,
|
54
58
|
):
|
59
|
+
# Init inter-process communication
|
55
60
|
context = zmq.asyncio.Context(2)
|
56
61
|
self.recv_from_router = context.socket(zmq.PULL)
|
57
62
|
self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
|
@@ -71,10 +76,13 @@ class DetokenizerManager:
|
|
71
76
|
self.decode_status = {}
|
72
77
|
|
73
78
|
async def handle_loop(self):
|
79
|
+
"""The event loop that handles requests"""
|
80
|
+
|
74
81
|
while True:
|
75
|
-
recv_obj
|
82
|
+
recv_obj = await self.recv_from_router.recv_pyobj()
|
76
83
|
|
77
84
|
if isinstance(recv_obj, BatchEmbeddingOut):
|
85
|
+
# If it is embedding model, no detokenization is needed.
|
78
86
|
self.send_to_tokenizer.send_pyobj(
|
79
87
|
BatchEmbeddingOut(
|
80
88
|
rids=recv_obj.rids,
|
@@ -84,15 +92,18 @@ class DetokenizerManager:
|
|
84
92
|
)
|
85
93
|
)
|
86
94
|
continue
|
95
|
+
elif isinstance(recv_obj, UpdateWeightReqOutput):
|
96
|
+
# If it is a weight update request, no detokenization is needed.
|
97
|
+
self.send_to_tokenizer.send_pyobj(recv_obj)
|
98
|
+
continue
|
99
|
+
elif self.tokenizer is None:
|
100
|
+
# If the tokenizer is skipped, no detokenization is needed
|
101
|
+
self.send_to_tokenizer.send_pyobj(recv_obj)
|
102
|
+
continue
|
87
103
|
|
88
104
|
assert isinstance(recv_obj, BatchTokenIDOut)
|
89
105
|
bs = len(recv_obj.rids)
|
90
106
|
|
91
|
-
if self.tokenizer is None:
|
92
|
-
# Send BatchTokenIDOut if no tokenizer init'ed.
|
93
|
-
self.send_to_tokenizer.send_pyobj(recv_obj)
|
94
|
-
continue
|
95
|
-
|
96
107
|
# Initialize decode status
|
97
108
|
read_ids, surr_ids = [], []
|
98
109
|
for i in range(bs):
|
@@ -126,8 +137,7 @@ class DetokenizerManager:
|
|
126
137
|
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
127
138
|
)
|
128
139
|
|
129
|
-
#
|
130
|
-
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
140
|
+
# Incremental decoding
|
131
141
|
output_strs = []
|
132
142
|
for i in range(bs):
|
133
143
|
s = self.decode_status[recv_obj.rids[i]]
|
@@ -144,6 +154,7 @@ class DetokenizerManager:
|
|
144
154
|
|
145
155
|
output_strs.append(s.decoded_text + new_text)
|
146
156
|
|
157
|
+
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
|
147
158
|
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
148
159
|
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
149
160
|
if pos != -1:
|