sglang 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/backend/runtime_endpoint.py +4 -4
- sglang/lang/interpreter.py +4 -4
- sglang/srt/constrained/fsm_cache.py +21 -1
- sglang/srt/hf_transformers_utils.py +3 -1
- sglang/srt/layers/logits_processor.py +70 -61
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/token_attention.py +1 -1
- sglang/srt/managers/controller/cuda_graph_runner.py +26 -17
- sglang/srt/managers/controller/infer_batch.py +54 -13
- sglang/srt/managers/controller/model_runner.py +22 -7
- sglang/srt/managers/controller/tp_worker.py +47 -41
- sglang/srt/managers/io_struct.py +2 -2
- sglang/srt/managers/tokenizer_manager.py +62 -43
- sglang/srt/model_config.py +5 -0
- sglang/srt/models/deepseek_v2.py +517 -0
- sglang/srt/models/llama_classification.py +3 -3
- sglang/srt/openai_api/adapter.py +33 -33
- sglang/srt/openai_api/protocol.py +1 -1
- sglang/srt/sampling_params.py +5 -4
- sglang/srt/server.py +2 -15
- sglang/srt/server_args.py +28 -7
- sglang/test/test_programs.py +5 -1
- sglang/version.py +1 -1
- {sglang-0.2.5.dist-info → sglang-0.2.6.dist-info}/METADATA +9 -7
- {sglang-0.2.5.dist-info → sglang-0.2.6.dist-info}/RECORD +28 -27
- {sglang-0.2.5.dist-info → sglang-0.2.6.dist-info}/LICENSE +0 -0
- {sglang-0.2.5.dist-info → sglang-0.2.6.dist-info}/WHEEL +0 -0
- {sglang-0.2.5.dist-info → sglang-0.2.6.dist-info}/top_level.txt +0 -0
@@ -253,14 +253,14 @@ class RuntimeEndpoint(BaseBackend):
|
|
253
253
|
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
254
254
|
]
|
255
255
|
decision = choices[np.argmax(normalized_prompt_logprobs)]
|
256
|
-
|
257
|
-
|
256
|
+
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
257
|
+
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
258
258
|
|
259
259
|
return (
|
260
260
|
decision,
|
261
261
|
normalized_prompt_logprobs,
|
262
|
-
|
263
|
-
|
262
|
+
input_token_logprobs,
|
263
|
+
output_token_logprobs,
|
264
264
|
)
|
265
265
|
|
266
266
|
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
sglang/lang/interpreter.py
CHANGED
@@ -541,16 +541,16 @@ class StreamExecutor:
|
|
541
541
|
(
|
542
542
|
decision,
|
543
543
|
normalized_prompt_logprobs,
|
544
|
-
|
545
|
-
|
544
|
+
input_token_logprobs,
|
545
|
+
output_token_logprobs,
|
546
546
|
) = self.backend.select(self, expr.choices, expr.temperature)
|
547
547
|
if expr.name is not None:
|
548
548
|
name = expr.name
|
549
549
|
self.variables[name] = decision
|
550
550
|
self.meta_info[name] = {
|
551
551
|
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
552
|
-
"
|
553
|
-
"
|
552
|
+
"input_token_logprobs": input_token_logprobs,
|
553
|
+
"output_token_logprobs": output_token_logprobs,
|
554
554
|
}
|
555
555
|
self.variable_event[name].set()
|
556
556
|
self.text_ += decision
|
@@ -21,7 +21,27 @@ class FSMCache(BaseCache):
|
|
21
21
|
tokenizer = AutoTokenizer.from_pretrained(
|
22
22
|
tokenizer_path, **tokenizer_args_dict
|
23
23
|
)
|
24
|
-
|
24
|
+
try:
|
25
|
+
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
26
|
+
except AttributeError:
|
27
|
+
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
|
28
|
+
origin_pad_token_id = tokenizer.pad_token_id
|
29
|
+
|
30
|
+
def fset(self, value):
|
31
|
+
self._value = value
|
32
|
+
|
33
|
+
type(tokenizer).pad_token_id = property(
|
34
|
+
fget=type(tokenizer).pad_token_id.fget, fset=fset
|
35
|
+
)
|
36
|
+
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
37
|
+
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
|
38
|
+
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
|
39
|
+
self.outlines_tokenizer.pad_token = (
|
40
|
+
self.outlines_tokenizer.tokenizer.pad_token
|
41
|
+
)
|
42
|
+
self.outlines_tokenizer.vocabulary = (
|
43
|
+
self.outlines_tokenizer.tokenizer.get_vocab()
|
44
|
+
)
|
25
45
|
else:
|
26
46
|
self.outlines_tokenizer = TransformerTokenizer(
|
27
47
|
tokenizer_path, **tokenizer_args_dict
|
@@ -73,7 +73,9 @@ def get_context_length(config):
|
|
73
73
|
rope_scaling = getattr(config, "rope_scaling", None)
|
74
74
|
if rope_scaling:
|
75
75
|
rope_scaling_factor = config.rope_scaling["factor"]
|
76
|
-
if
|
76
|
+
if "original_max_position_embeddings" in rope_scaling:
|
77
|
+
rope_scaling_factor = 1
|
78
|
+
if config.rope_scaling.get("rope_type", None) == "llama3":
|
77
79
|
rope_scaling_factor = 1
|
78
80
|
else:
|
79
81
|
rope_scaling_factor = 1
|
@@ -1,7 +1,7 @@
|
|
1
1
|
"""Logits processing."""
|
2
2
|
|
3
3
|
import dataclasses
|
4
|
-
from typing import List, Union
|
4
|
+
from typing import List, Optional, Union
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch import nn
|
@@ -22,23 +22,23 @@ class LogitProcessorOutput:
|
|
22
22
|
|
23
23
|
# The normlaized logprobs of prompts. shape: [#seq]
|
24
24
|
normalized_prompt_logprobs: torch.Tensor
|
25
|
-
# The logprobs of
|
26
|
-
|
25
|
+
# The logprobs of input tokens. shape: [#token, vocab_size]
|
26
|
+
input_token_logprobs: torch.Tensor
|
27
27
|
|
28
|
-
# The logprob and id of the top-k tokens in
|
29
|
-
|
30
|
-
# The logprob and id of the top-k tokens in
|
31
|
-
|
28
|
+
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
29
|
+
input_top_logprobs: List
|
30
|
+
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
31
|
+
output_top_logprobs: List
|
32
32
|
|
33
33
|
|
34
34
|
@dataclasses.dataclass
|
35
35
|
class LogitsMetadata:
|
36
36
|
forward_mode: ForwardMode
|
37
|
-
return_logprob: bool
|
37
|
+
return_logprob: bool = False
|
38
38
|
|
39
|
-
extend_seq_lens: torch.Tensor = None
|
40
|
-
extend_start_loc: torch.Tensor = None
|
41
|
-
top_logprobs_nums: List[int] = None
|
39
|
+
extend_seq_lens: Optional[torch.Tensor] = None
|
40
|
+
extend_start_loc: Optional[torch.Tensor] = None
|
41
|
+
top_logprobs_nums: Optional[List[int]] = None
|
42
42
|
|
43
43
|
@classmethod
|
44
44
|
def from_input_metadata(cls, input_metadata: InputMetadata):
|
@@ -58,20 +58,16 @@ class LogitsProcessor(nn.Module):
|
|
58
58
|
self.tp_size = get_tensor_model_parallel_world_size()
|
59
59
|
|
60
60
|
def _get_normalized_prompt_logprobs(
|
61
|
-
self,
|
61
|
+
self, input_token_logprobs, logits_metadata: LogitsMetadata
|
62
62
|
):
|
63
|
-
logprobs_cumsum = torch.cumsum(
|
64
|
-
prefill_token_logprobs, dim=0, dtype=torch.float32
|
65
|
-
)
|
63
|
+
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
66
64
|
|
67
65
|
start = logits_metadata.extend_start_loc.clone()
|
68
66
|
end = start + logits_metadata.extend_seq_lens - 2
|
69
|
-
start.clamp_(min=0, max=
|
70
|
-
end.clamp_(min=0, max=
|
67
|
+
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
68
|
+
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
71
69
|
sum_logp = (
|
72
|
-
logprobs_cumsum[end]
|
73
|
-
- logprobs_cumsum[start]
|
74
|
-
+ prefill_token_logprobs[start]
|
70
|
+
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
75
71
|
)
|
76
72
|
normalized_prompt_logprobs = sum_logp / (
|
77
73
|
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
|
@@ -79,37 +75,38 @@ class LogitsProcessor(nn.Module):
|
|
79
75
|
|
80
76
|
return normalized_prompt_logprobs
|
81
77
|
|
82
|
-
|
78
|
+
@staticmethod
|
79
|
+
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
|
83
80
|
# TODO: vectorize the code below
|
84
81
|
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
85
|
-
|
82
|
+
output_top_logprobs = []
|
86
83
|
for i in range(all_logprobs.shape[0]):
|
87
84
|
k = logits_metadata.top_logprobs_nums[i]
|
88
85
|
t = all_logprobs[i].topk(k)
|
89
86
|
v_cpu = t.values.tolist()
|
90
87
|
p_cpu = t.indices.tolist()
|
91
|
-
|
92
|
-
return None,
|
88
|
+
output_top_logprobs.append(list(zip(v_cpu, p_cpu)))
|
89
|
+
return None, output_top_logprobs
|
93
90
|
else:
|
94
|
-
|
91
|
+
input_top_logprobs, output_top_logprobs = [], []
|
95
92
|
pt = 0
|
96
93
|
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
|
97
94
|
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
98
95
|
if extend_seq_len == 0:
|
99
|
-
|
100
|
-
|
96
|
+
input_top_logprobs.append([])
|
97
|
+
output_top_logprobs.append([])
|
101
98
|
continue
|
102
99
|
k = logits_metadata.top_logprobs_nums[i]
|
103
100
|
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
|
104
101
|
vs_cpu = t.values.tolist()
|
105
102
|
ps_cpu = t.indices.tolist()
|
106
|
-
|
103
|
+
input_top_logprobs.append(
|
107
104
|
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
108
105
|
)
|
109
|
-
|
106
|
+
output_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
110
107
|
pt += extend_seq_len
|
111
108
|
|
112
|
-
return
|
109
|
+
return input_top_logprobs, output_top_logprobs
|
113
110
|
|
114
111
|
def forward(
|
115
112
|
self,
|
@@ -136,7 +133,7 @@ class LogitsProcessor(nn.Module):
|
|
136
133
|
last_logits = torch.matmul(last_hidden, weight.T)
|
137
134
|
if self.tp_size > 1:
|
138
135
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
139
|
-
last_logits = last_logits[:, : self.config.vocab_size]
|
136
|
+
last_logits = last_logits[:, : self.config.vocab_size].float()
|
140
137
|
|
141
138
|
if hasattr(self.config, "final_logit_softcapping"):
|
142
139
|
last_logits /= self.config.final_logit_softcapping
|
@@ -149,63 +146,75 @@ class LogitsProcessor(nn.Module):
|
|
149
146
|
next_token_logits=last_logits,
|
150
147
|
next_token_logprobs=None,
|
151
148
|
normalized_prompt_logprobs=None,
|
152
|
-
|
153
|
-
|
154
|
-
|
149
|
+
input_token_logprobs=None,
|
150
|
+
input_top_logprobs=None,
|
151
|
+
output_top_logprobs=None,
|
155
152
|
)
|
156
153
|
else:
|
157
154
|
# When logprob is requested, compute the logits for all tokens.
|
158
155
|
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
159
|
-
|
160
|
-
else:
|
161
|
-
all_logits = torch.matmul(hidden_states, weight.T)
|
162
|
-
if self.tp_size > 1:
|
163
|
-
all_logits = tensor_model_parallel_all_gather(all_logits)
|
164
|
-
all_logits = all_logits[:, : self.config.vocab_size]
|
165
|
-
|
166
|
-
all_logprobs = all_logits.float()
|
167
|
-
del all_logits
|
168
|
-
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
156
|
+
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
|
169
157
|
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
174
|
-
all_logprobs, logits_metadata
|
158
|
+
# Get the logprob of top-k tokens
|
159
|
+
return_top_logprob = any(
|
160
|
+
x > 0 for x in logits_metadata.top_logprobs_nums
|
175
161
|
)
|
176
|
-
|
177
|
-
|
162
|
+
if return_top_logprob:
|
163
|
+
output_top_logprobs = self.get_top_logprobs(
|
164
|
+
last_logprobs, logits_metadata
|
165
|
+
)[1]
|
166
|
+
else:
|
167
|
+
output_top_logprobs = None
|
178
168
|
|
179
|
-
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
180
169
|
return LogitProcessorOutput(
|
181
170
|
next_token_logits=last_logits,
|
182
|
-
next_token_logprobs=
|
171
|
+
next_token_logprobs=last_logprobs,
|
183
172
|
normalized_prompt_logprobs=None,
|
184
|
-
|
185
|
-
|
186
|
-
|
173
|
+
input_token_logprobs=None,
|
174
|
+
input_top_logprobs=None,
|
175
|
+
output_top_logprobs=output_top_logprobs,
|
187
176
|
)
|
188
177
|
else:
|
178
|
+
all_logits = torch.matmul(hidden_states, weight.T)
|
179
|
+
if self.tp_size > 1:
|
180
|
+
all_logits = tensor_model_parallel_all_gather(all_logits)
|
181
|
+
all_logits = all_logits[:, : self.config.vocab_size].float()
|
182
|
+
|
183
|
+
all_logprobs = all_logits
|
184
|
+
del all_logits
|
185
|
+
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
186
|
+
|
187
|
+
# Get the logprob of top-k tokens
|
188
|
+
return_top_logprob = any(
|
189
|
+
x > 0 for x in logits_metadata.top_logprobs_nums
|
190
|
+
)
|
191
|
+
if return_top_logprob:
|
192
|
+
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
|
193
|
+
all_logprobs, logits_metadata
|
194
|
+
)
|
195
|
+
else:
|
196
|
+
input_top_logprobs = output_top_logprobs = None
|
197
|
+
|
189
198
|
last_logprobs = all_logprobs[last_index]
|
190
199
|
|
191
200
|
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
192
201
|
# Note that we pad a zero at the end of each sequence for easy computation.
|
193
|
-
|
202
|
+
input_token_logprobs = all_logprobs[
|
194
203
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
195
204
|
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
196
205
|
]
|
197
206
|
|
198
207
|
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
199
|
-
|
208
|
+
input_token_logprobs, logits_metadata
|
200
209
|
)
|
201
210
|
|
202
211
|
return LogitProcessorOutput(
|
203
212
|
next_token_logits=last_logits,
|
204
213
|
next_token_logprobs=last_logprobs,
|
205
214
|
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
206
|
-
|
207
|
-
|
208
|
-
|
215
|
+
input_token_logprobs=input_token_logprobs,
|
216
|
+
input_top_logprobs=input_top_logprobs,
|
217
|
+
output_top_logprobs=output_top_logprobs,
|
209
218
|
)
|
210
219
|
|
211
220
|
|
@@ -7,8 +7,11 @@ from torch import nn
|
|
7
7
|
from sglang.global_config import global_config
|
8
8
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
9
9
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
10
|
-
from sglang.srt.managers.controller.model_runner import
|
11
|
-
|
10
|
+
from sglang.srt.managers.controller.model_runner import (
|
11
|
+
ForwardMode,
|
12
|
+
InputMetadata,
|
13
|
+
global_server_args_dict,
|
14
|
+
)
|
12
15
|
|
13
16
|
|
14
17
|
class RadixAttention(nn.Module):
|
@@ -5,7 +5,7 @@ import torch
|
|
5
5
|
import triton
|
6
6
|
import triton.language as tl
|
7
7
|
|
8
|
-
from sglang.srt.
|
8
|
+
from sglang.srt.managers.controller.infer_batch import global_server_args_dict
|
9
9
|
|
10
10
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
11
11
|
REDUCE_TRITON_TYPE = tl.float32
|
@@ -9,7 +9,11 @@ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
|
9
9
|
from vllm.distributed.parallel_state import graph_capture
|
10
10
|
from vllm.model_executor.custom_op import CustomOp
|
11
11
|
|
12
|
-
from sglang.srt.layers.logits_processor import
|
12
|
+
from sglang.srt.layers.logits_processor import (
|
13
|
+
LogitProcessorOutput,
|
14
|
+
LogitsMetadata,
|
15
|
+
LogitsProcessor,
|
16
|
+
)
|
13
17
|
from sglang.srt.managers.controller.infer_batch import (
|
14
18
|
Batch,
|
15
19
|
ForwardMode,
|
@@ -185,7 +189,6 @@ class CudaGraphRunner:
|
|
185
189
|
|
186
190
|
def replay(self, batch: Batch):
|
187
191
|
assert batch.out_cache_loc is not None
|
188
|
-
assert not batch.return_logprob
|
189
192
|
raw_bs = len(batch.reqs)
|
190
193
|
|
191
194
|
# Pad
|
@@ -218,23 +221,29 @@ class CudaGraphRunner:
|
|
218
221
|
output = self.output_buffers[bs]
|
219
222
|
|
220
223
|
# Unpad
|
221
|
-
if bs
|
222
|
-
return output
|
223
|
-
else:
|
224
|
+
if bs != raw_bs:
|
224
225
|
output = LogitProcessorOutput(
|
225
226
|
next_token_logits=output.next_token_logits[:raw_bs],
|
226
|
-
next_token_logprobs=
|
227
|
-
output.next_token_logprobs[:raw_bs]
|
228
|
-
if output.next_token_logprobs is not None
|
229
|
-
else None
|
230
|
-
),
|
227
|
+
next_token_logprobs=None,
|
231
228
|
normalized_prompt_logprobs=None,
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
output.decode_top_logprobs[:raw_bs]
|
236
|
-
if output.decode_top_logprobs is not None
|
237
|
-
else None
|
238
|
-
),
|
229
|
+
input_token_logprobs=None,
|
230
|
+
input_top_logprobs=None,
|
231
|
+
output_top_logprobs=None,
|
239
232
|
)
|
233
|
+
|
234
|
+
# Extract logprobs
|
235
|
+
if batch.return_logprob:
|
236
|
+
output.next_token_logprobs = torch.nn.functional.log_softmax(
|
237
|
+
output.next_token_logits, dim=-1
|
238
|
+
)
|
239
|
+
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
|
240
|
+
if return_top_logprob:
|
241
|
+
logits_metadata = LogitsMetadata(
|
242
|
+
forward_mode=ForwardMode.DECODE,
|
243
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
244
|
+
)
|
245
|
+
output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
246
|
+
output.next_token_logprobs, logits_metadata
|
247
|
+
)[1]
|
248
|
+
|
240
249
|
return output
|
@@ -17,6 +17,13 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
|
17
17
|
|
18
18
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
19
19
|
|
20
|
+
# Put some global args for easy access
|
21
|
+
global_server_args_dict = {
|
22
|
+
"disable_flashinfer": False,
|
23
|
+
"disable_flashinfer_sampling": False,
|
24
|
+
"attention_reduce_in_fp32": False,
|
25
|
+
}
|
26
|
+
|
20
27
|
|
21
28
|
class ForwardMode(IntEnum):
|
22
29
|
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
@@ -124,10 +131,10 @@ class Req:
|
|
124
131
|
self.logprob_start_len = 0
|
125
132
|
self.top_logprobs_num = 0
|
126
133
|
self.normalized_prompt_logprob = None
|
127
|
-
self.
|
128
|
-
self.
|
129
|
-
self.
|
130
|
-
self.
|
134
|
+
self.input_token_logprobs = None
|
135
|
+
self.input_top_logprobs = None
|
136
|
+
self.output_token_logprobs = []
|
137
|
+
self.output_top_logprobs = []
|
131
138
|
# The tokens is prefilled but need to be considered as decode tokens
|
132
139
|
# and should be updated for the decode logprobs
|
133
140
|
self.last_update_decode_tokens = 0
|
@@ -244,8 +251,8 @@ class Req:
|
|
244
251
|
k = k + 1
|
245
252
|
else:
|
246
253
|
break
|
247
|
-
self.
|
248
|
-
self.
|
254
|
+
self.output_token_logprobs = self.output_token_logprobs[:k]
|
255
|
+
self.output_top_logprobs = self.output_top_logprobs[:k]
|
249
256
|
self.logprob_start_len = prompt_tokens + k
|
250
257
|
self.last_update_decode_tokens = len(self.output_ids) - k
|
251
258
|
|
@@ -376,7 +383,7 @@ class Batch:
|
|
376
383
|
logit_bias = torch.zeros(
|
377
384
|
(bs, vocab_size), dtype=torch.float32, device=device
|
378
385
|
)
|
379
|
-
logit_bias[i] = int_token_logit_bias
|
386
|
+
logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
380
387
|
|
381
388
|
# Set fields
|
382
389
|
self.input_ids = torch.tensor(
|
@@ -687,13 +694,21 @@ class Batch:
|
|
687
694
|
# TODO(lmzheng): apply penalty
|
688
695
|
probs = torch.softmax(logits, dim=-1)
|
689
696
|
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
697
|
+
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
698
|
+
max_top_k_round, batch_size = 32, probs.shape[0]
|
699
|
+
uniform_samples = torch.rand(
|
700
|
+
(max_top_k_round, batch_size), device=probs.device
|
701
|
+
)
|
702
|
+
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
703
|
+
probs, uniform_samples, self.top_ks, self.top_ps
|
704
|
+
)
|
705
|
+
else:
|
706
|
+
# Here we provide a slower fallback implementation.
|
707
|
+
batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
|
708
|
+
probs, self.top_ks, self.top_ps
|
709
|
+
)
|
695
710
|
|
696
|
-
if torch.
|
711
|
+
if not torch.all(success):
|
697
712
|
warnings.warn("Sampling failed, fallback to top_k=1 strategy")
|
698
713
|
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
699
714
|
argmax_ids = torch.argmax(probs, dim=-1)
|
@@ -933,3 +948,29 @@ def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
|
933
948
|
max_extend_len = int(torch.max(extend_seq_lens))
|
934
949
|
|
935
950
|
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
951
|
+
|
952
|
+
|
953
|
+
def top_k_top_p_sampling_from_probs_torch(
|
954
|
+
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
|
955
|
+
):
|
956
|
+
"""A top-k and top-k sampling implementation with native pytorch operations."""
|
957
|
+
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
958
|
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
959
|
+
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
960
|
+
probs_sort[
|
961
|
+
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
962
|
+
>= top_ks.view(-1, 1)
|
963
|
+
] = 0.0
|
964
|
+
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
965
|
+
try:
|
966
|
+
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
967
|
+
except RuntimeError:
|
968
|
+
batch_next_token_ids = torch.zeros(
|
969
|
+
(probs_sort.shape[0],), dtype=torch.int64, device=probs.device
|
970
|
+
)
|
971
|
+
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
972
|
+
return batch_next_token_ids, success
|
973
|
+
|
974
|
+
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
975
|
+
success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
|
976
|
+
return batch_next_token_ids, success
|
@@ -25,7 +25,12 @@ from vllm.distributed import (
|
|
25
25
|
from vllm.model_executor.models import ModelRegistry
|
26
26
|
|
27
27
|
from sglang.global_config import global_config
|
28
|
-
from sglang.srt.managers.controller.infer_batch import
|
28
|
+
from sglang.srt.managers.controller.infer_batch import (
|
29
|
+
Batch,
|
30
|
+
ForwardMode,
|
31
|
+
InputMetadata,
|
32
|
+
global_server_args_dict,
|
33
|
+
)
|
29
34
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
30
35
|
from sglang.srt.server_args import ServerArgs
|
31
36
|
from sglang.srt.utils import (
|
@@ -60,7 +65,13 @@ class ModelRunner:
|
|
60
65
|
self.nccl_port = nccl_port
|
61
66
|
self.server_args = server_args
|
62
67
|
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
63
|
-
|
68
|
+
global_server_args_dict.update(
|
69
|
+
{
|
70
|
+
"disable_flashinfer": server_args.disable_flashinfer,
|
71
|
+
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
|
72
|
+
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
73
|
+
}
|
74
|
+
)
|
64
75
|
|
65
76
|
# Init torch distributed
|
66
77
|
torch.cuda.set_device(self.gpu_id)
|
@@ -95,7 +106,7 @@ class ModelRunner:
|
|
95
106
|
|
96
107
|
# Load the model and create memory pool
|
97
108
|
self.load_model()
|
98
|
-
self.init_memory_pool(total_gpu_memory)
|
109
|
+
self.init_memory_pool(total_gpu_memory, server_args.max_num_reqs)
|
99
110
|
self.init_cublas()
|
100
111
|
self.init_flash_infer()
|
101
112
|
|
@@ -108,6 +119,7 @@ class ModelRunner:
|
|
108
119
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
109
120
|
)
|
110
121
|
|
122
|
+
monkey_patch_vllm_dummy_weight_loader()
|
111
123
|
device_config = DeviceConfig()
|
112
124
|
load_config = LoadConfig(load_format=self.server_args.load_format)
|
113
125
|
vllm_model_config = VllmModelConfig(
|
@@ -176,7 +188,7 @@ class ModelRunner:
|
|
176
188
|
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
177
189
|
return max_num_token
|
178
190
|
|
179
|
-
def init_memory_pool(self, total_gpu_memory):
|
191
|
+
def init_memory_pool(self, total_gpu_memory, max_num_reqs=None):
|
180
192
|
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
181
193
|
|
182
194
|
if self.max_total_num_tokens <= 0:
|
@@ -184,11 +196,14 @@ class ModelRunner:
|
|
184
196
|
"Not enough memory. Please try to increase --mem-fraction-static."
|
185
197
|
)
|
186
198
|
|
187
|
-
|
188
|
-
max(
|
199
|
+
if max_num_reqs is None:
|
200
|
+
max_num_reqs = max(
|
189
201
|
int(self.max_total_num_tokens / self.model_config.context_len * 512),
|
190
202
|
2048,
|
191
|
-
)
|
203
|
+
)
|
204
|
+
|
205
|
+
self.req_to_token_pool = ReqToTokenPool(
|
206
|
+
max_num_reqs,
|
192
207
|
self.model_config.context_len + 8,
|
193
208
|
)
|
194
209
|
self.token_to_kv_pool = TokenToKVPool(
|