sglang 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +33 -26
- sglang/api.py +9 -1
- sglang/bench_latency.py +2 -2
- sglang/bench_serving.py +10 -1
- sglang/check_env.py +1 -1
- sglang/lang/backend/litellm.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +4 -4
- sglang/lang/interpreter.py +24 -9
- sglang/lang/ir.py +1 -1
- sglang/srt/constrained/__init__.py +15 -0
- sglang/srt/constrained/base_cache.py +15 -0
- sglang/srt/constrained/fsm_cache.py +36 -1
- sglang/srt/constrained/jump_forward.py +15 -0
- sglang/srt/conversation.py +26 -0
- sglang/srt/hf_transformers_utils.py +18 -1
- sglang/srt/layers/context_flashattention_nopad.py +15 -0
- sglang/srt/layers/extend_attention.py +15 -0
- sglang/srt/layers/fused_moe.py +15 -0
- sglang/srt/layers/linear.py +15 -0
- sglang/srt/layers/logits_processor.py +109 -72
- sglang/srt/layers/quantization/__init__.py +15 -0
- sglang/srt/layers/quantization/fp8.py +15 -0
- sglang/srt/layers/radix_attention.py +21 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/{controller/manager_multi.py → controller_multi.py} +17 -2
- sglang/srt/managers/{controller/manager_single.py → controller_single.py} +17 -2
- sglang/srt/managers/detokenizer_manager.py +16 -1
- sglang/srt/managers/io_struct.py +38 -5
- sglang/srt/managers/{controller/schedule_heuristic.py → policy_scheduler.py} +37 -22
- sglang/srt/managers/{controller/infer_batch.py → schedule_batch.py} +85 -25
- sglang/srt/managers/tokenizer_manager.py +99 -57
- sglang/srt/managers/{controller/tp_worker.py → tp_worker.py} +177 -81
- sglang/srt/mem_cache/flush_cache.py +33 -0
- sglang/srt/{memory_pool.py → mem_cache/memory_pool.py} +16 -1
- sglang/srt/{managers/controller → mem_cache}/radix_cache.py +15 -0
- sglang/srt/mm_utils.py +15 -0
- sglang/srt/model_config.py +20 -0
- sglang/srt/{managers/controller → model_executor}/cuda_graph_runner.py +42 -18
- sglang/srt/{managers/controller → model_executor}/model_runner.py +51 -16
- sglang/srt/model_loader/model_loader.py +15 -0
- sglang/srt/model_loader/utils.py +16 -1
- sglang/srt/models/chatglm.py +16 -1
- sglang/srt/models/commandr.py +16 -1
- sglang/srt/models/dbrx.py +16 -1
- sglang/srt/models/deepseek.py +16 -1
- sglang/srt/models/deepseek_v2.py +532 -0
- sglang/srt/models/gemma.py +16 -1
- sglang/srt/models/gemma2.py +16 -1
- sglang/srt/models/gpt_bigcode.py +16 -1
- sglang/srt/models/grok.py +16 -1
- sglang/srt/models/internlm2.py +16 -1
- sglang/srt/models/llama2.py +16 -1
- sglang/srt/models/llama_classification.py +19 -4
- sglang/srt/models/llava.py +17 -2
- sglang/srt/models/llavavid.py +17 -2
- sglang/srt/models/minicpm.py +16 -1
- sglang/srt/models/mistral.py +15 -0
- sglang/srt/models/mixtral.py +16 -1
- sglang/srt/models/mixtral_quant.py +16 -1
- sglang/srt/models/qwen.py +16 -1
- sglang/srt/models/qwen2.py +16 -1
- sglang/srt/models/qwen2_moe.py +16 -1
- sglang/srt/models/stablelm.py +16 -1
- sglang/srt/models/yivl.py +15 -0
- sglang/srt/openai_api/adapter.py +545 -160
- sglang/srt/openai_api/protocol.py +65 -1
- sglang/srt/sampling_params.py +20 -4
- sglang/srt/server.py +90 -37
- sglang/srt/server_args.py +76 -17
- sglang/srt/utils.py +15 -0
- sglang/test/test_programs.py +5 -1
- sglang/utils.py +22 -0
- sglang/version.py +1 -1
- {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/METADATA +40 -12
- sglang-0.2.7.dist-info/RECORD +93 -0
- {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/WHEEL +1 -1
- sglang/srt/flush_cache.py +0 -18
- sglang-0.2.5.dist-info/RECORD +0 -92
- {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/LICENSE +0 -0
- {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,22 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""Logits processing."""
|
2
17
|
|
3
18
|
import dataclasses
|
4
|
-
from typing import List, Union
|
19
|
+
from typing import List, Optional, Union
|
5
20
|
|
6
21
|
import torch
|
7
22
|
from torch import nn
|
@@ -10,7 +25,7 @@ from vllm.distributed import (
|
|
10
25
|
tensor_model_parallel_all_gather,
|
11
26
|
)
|
12
27
|
|
13
|
-
from sglang.srt.
|
28
|
+
from sglang.srt.model_executor.model_runner import ForwardMode, InputMetadata
|
14
29
|
|
15
30
|
|
16
31
|
@dataclasses.dataclass
|
@@ -22,23 +37,23 @@ class LogitProcessorOutput:
|
|
22
37
|
|
23
38
|
# The normlaized logprobs of prompts. shape: [#seq]
|
24
39
|
normalized_prompt_logprobs: torch.Tensor
|
25
|
-
# The logprobs of
|
26
|
-
|
40
|
+
# The logprobs of input tokens. shape: [#token, vocab_size]
|
41
|
+
input_token_logprobs: torch.Tensor
|
27
42
|
|
28
|
-
# The logprob and id of the top-k tokens in
|
29
|
-
|
30
|
-
# The logprob and id of the top-k tokens in
|
31
|
-
|
43
|
+
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
44
|
+
input_top_logprobs: List
|
45
|
+
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
46
|
+
output_top_logprobs: List
|
32
47
|
|
33
48
|
|
34
49
|
@dataclasses.dataclass
|
35
50
|
class LogitsMetadata:
|
36
51
|
forward_mode: ForwardMode
|
37
|
-
return_logprob: bool
|
52
|
+
return_logprob: bool = False
|
38
53
|
|
39
|
-
extend_seq_lens: torch.Tensor = None
|
40
|
-
extend_start_loc: torch.Tensor = None
|
41
|
-
top_logprobs_nums: List[int] = None
|
54
|
+
extend_seq_lens: Optional[torch.Tensor] = None
|
55
|
+
extend_start_loc: Optional[torch.Tensor] = None
|
56
|
+
top_logprobs_nums: Optional[List[int]] = None
|
42
57
|
|
43
58
|
@classmethod
|
44
59
|
def from_input_metadata(cls, input_metadata: InputMetadata):
|
@@ -58,20 +73,16 @@ class LogitsProcessor(nn.Module):
|
|
58
73
|
self.tp_size = get_tensor_model_parallel_world_size()
|
59
74
|
|
60
75
|
def _get_normalized_prompt_logprobs(
|
61
|
-
self,
|
76
|
+
self, input_token_logprobs, logits_metadata: LogitsMetadata
|
62
77
|
):
|
63
|
-
logprobs_cumsum = torch.cumsum(
|
64
|
-
prefill_token_logprobs, dim=0, dtype=torch.float32
|
65
|
-
)
|
78
|
+
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
66
79
|
|
67
80
|
start = logits_metadata.extend_start_loc.clone()
|
68
81
|
end = start + logits_metadata.extend_seq_lens - 2
|
69
|
-
start.clamp_(min=0, max=
|
70
|
-
end.clamp_(min=0, max=
|
82
|
+
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
83
|
+
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
71
84
|
sum_logp = (
|
72
|
-
logprobs_cumsum[end]
|
73
|
-
- logprobs_cumsum[start]
|
74
|
-
+ prefill_token_logprobs[start]
|
85
|
+
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
75
86
|
)
|
76
87
|
normalized_prompt_logprobs = sum_logp / (
|
77
88
|
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
|
@@ -79,37 +90,51 @@ class LogitsProcessor(nn.Module):
|
|
79
90
|
|
80
91
|
return normalized_prompt_logprobs
|
81
92
|
|
82
|
-
|
83
|
-
|
93
|
+
@staticmethod
|
94
|
+
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
|
84
95
|
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
return None,
|
96
|
+
output_top_logprobs = []
|
97
|
+
max_k = max(logits_metadata.top_logprobs_nums)
|
98
|
+
ret = all_logprobs.topk(max_k, dim=1)
|
99
|
+
values = ret.values.tolist()
|
100
|
+
indices = ret.indices.tolist()
|
101
|
+
for i, k in enumerate(logits_metadata.top_logprobs_nums):
|
102
|
+
output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
|
103
|
+
return None, output_top_logprobs
|
93
104
|
else:
|
94
|
-
|
105
|
+
# TODO: vectorize the code below
|
106
|
+
input_top_logprobs, output_top_logprobs = [], []
|
95
107
|
pt = 0
|
96
108
|
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
|
109
|
+
|
110
|
+
max_k = max(logits_metadata.top_logprobs_nums)
|
111
|
+
ret = all_logprobs.topk(max_k, dim=1)
|
112
|
+
values = ret.values.tolist()
|
113
|
+
indices = ret.indices.tolist()
|
114
|
+
|
97
115
|
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
98
116
|
if extend_seq_len == 0:
|
99
|
-
|
100
|
-
|
117
|
+
input_top_logprobs.append([])
|
118
|
+
output_top_logprobs.append([])
|
101
119
|
continue
|
102
120
|
k = logits_metadata.top_logprobs_nums[i]
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
121
|
+
input_top_logprobs.append(
|
122
|
+
[
|
123
|
+
list(zip(values[pt + j][:k], indices[pt + j][:k]))
|
124
|
+
for j in range(extend_seq_len - 1)
|
125
|
+
]
|
126
|
+
)
|
127
|
+
output_top_logprobs.append(
|
128
|
+
list(
|
129
|
+
zip(
|
130
|
+
values[pt + extend_seq_len - 1][:k],
|
131
|
+
indices[pt + extend_seq_len - 1][:k],
|
132
|
+
)
|
133
|
+
)
|
108
134
|
)
|
109
|
-
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
110
135
|
pt += extend_seq_len
|
111
136
|
|
112
|
-
return
|
137
|
+
return input_top_logprobs, output_top_logprobs
|
113
138
|
|
114
139
|
def forward(
|
115
140
|
self,
|
@@ -136,7 +161,7 @@ class LogitsProcessor(nn.Module):
|
|
136
161
|
last_logits = torch.matmul(last_hidden, weight.T)
|
137
162
|
if self.tp_size > 1:
|
138
163
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
139
|
-
last_logits = last_logits[:, : self.config.vocab_size]
|
164
|
+
last_logits = last_logits[:, : self.config.vocab_size].float()
|
140
165
|
|
141
166
|
if hasattr(self.config, "final_logit_softcapping"):
|
142
167
|
last_logits /= self.config.final_logit_softcapping
|
@@ -149,63 +174,75 @@ class LogitsProcessor(nn.Module):
|
|
149
174
|
next_token_logits=last_logits,
|
150
175
|
next_token_logprobs=None,
|
151
176
|
normalized_prompt_logprobs=None,
|
152
|
-
|
153
|
-
|
154
|
-
|
177
|
+
input_token_logprobs=None,
|
178
|
+
input_top_logprobs=None,
|
179
|
+
output_top_logprobs=None,
|
155
180
|
)
|
156
181
|
else:
|
157
182
|
# When logprob is requested, compute the logits for all tokens.
|
158
183
|
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
159
|
-
|
160
|
-
else:
|
161
|
-
all_logits = torch.matmul(hidden_states, weight.T)
|
162
|
-
if self.tp_size > 1:
|
163
|
-
all_logits = tensor_model_parallel_all_gather(all_logits)
|
164
|
-
all_logits = all_logits[:, : self.config.vocab_size]
|
184
|
+
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
|
165
185
|
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
# Get the logprob of top-k tokens
|
171
|
-
return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums)
|
172
|
-
if return_top_logprob:
|
173
|
-
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
174
|
-
all_logprobs, logits_metadata
|
186
|
+
# Get the logprob of top-k tokens
|
187
|
+
return_top_logprob = any(
|
188
|
+
x > 0 for x in logits_metadata.top_logprobs_nums
|
175
189
|
)
|
176
|
-
|
177
|
-
|
190
|
+
if return_top_logprob:
|
191
|
+
output_top_logprobs = self.get_top_logprobs(
|
192
|
+
last_logprobs, logits_metadata
|
193
|
+
)[1]
|
194
|
+
else:
|
195
|
+
output_top_logprobs = None
|
178
196
|
|
179
|
-
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
180
197
|
return LogitProcessorOutput(
|
181
198
|
next_token_logits=last_logits,
|
182
|
-
next_token_logprobs=
|
199
|
+
next_token_logprobs=last_logprobs,
|
183
200
|
normalized_prompt_logprobs=None,
|
184
|
-
|
185
|
-
|
186
|
-
|
201
|
+
input_token_logprobs=None,
|
202
|
+
input_top_logprobs=None,
|
203
|
+
output_top_logprobs=output_top_logprobs,
|
187
204
|
)
|
188
205
|
else:
|
206
|
+
all_logits = torch.matmul(hidden_states, weight.T)
|
207
|
+
if self.tp_size > 1:
|
208
|
+
all_logits = tensor_model_parallel_all_gather(all_logits)
|
209
|
+
all_logits = all_logits[:, : self.config.vocab_size].float()
|
210
|
+
|
211
|
+
all_logprobs = all_logits
|
212
|
+
del all_logits
|
213
|
+
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
214
|
+
|
215
|
+
# Get the logprob of top-k tokens
|
216
|
+
return_top_logprob = any(
|
217
|
+
x > 0 for x in logits_metadata.top_logprobs_nums
|
218
|
+
)
|
219
|
+
if return_top_logprob:
|
220
|
+
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
|
221
|
+
all_logprobs, logits_metadata
|
222
|
+
)
|
223
|
+
else:
|
224
|
+
input_top_logprobs = output_top_logprobs = None
|
225
|
+
|
189
226
|
last_logprobs = all_logprobs[last_index]
|
190
227
|
|
191
228
|
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
192
229
|
# Note that we pad a zero at the end of each sequence for easy computation.
|
193
|
-
|
230
|
+
input_token_logprobs = all_logprobs[
|
194
231
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
195
232
|
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
196
233
|
]
|
197
234
|
|
198
235
|
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
199
|
-
|
236
|
+
input_token_logprobs, logits_metadata
|
200
237
|
)
|
201
238
|
|
202
239
|
return LogitProcessorOutput(
|
203
240
|
next_token_logits=last_logits,
|
204
241
|
next_token_logprobs=last_logprobs,
|
205
242
|
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
206
|
-
|
207
|
-
|
208
|
-
|
243
|
+
input_token_logprobs=input_token_logprobs,
|
244
|
+
input_top_logprobs=input_top_logprobs,
|
245
|
+
output_top_logprobs=output_top_logprobs,
|
209
246
|
)
|
210
247
|
|
211
248
|
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
# temporarily adapted from vLLM
|
2
17
|
# FIXME: in progress of refactoring the model loader
|
3
18
|
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
# adapted from https://github.com/vllm-project/vllm/blob/e76466dde2bc9525d55165ceaa600d298c7bf773/vllm/model_executor/layers/quantization/fp8.py
|
2
17
|
# FIXME refactor in progress
|
3
18
|
from typing import Any, Dict, List, Optional, Union
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""Radix attention."""
|
2
17
|
|
3
18
|
import torch
|
@@ -7,8 +22,11 @@ from torch import nn
|
|
7
22
|
from sglang.global_config import global_config
|
8
23
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
9
24
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
10
|
-
from sglang.srt.
|
11
|
-
|
25
|
+
from sglang.srt.model_executor.model_runner import (
|
26
|
+
ForwardMode,
|
27
|
+
InputMetadata,
|
28
|
+
global_server_args_dict,
|
29
|
+
)
|
12
30
|
|
13
31
|
|
14
32
|
class RadixAttention(nn.Module):
|
@@ -85,7 +103,7 @@ class RadixAttention(nn.Module):
|
|
85
103
|
return o
|
86
104
|
|
87
105
|
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
88
|
-
if not input_metadata.
|
106
|
+
if not input_metadata.flashinfer_use_ragged:
|
89
107
|
self.store_kv_cache(k, v, input_metadata)
|
90
108
|
|
91
109
|
o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
# Adapted from
|
2
17
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
|
3
18
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
|
@@ -5,7 +20,7 @@ import torch
|
|
5
20
|
import triton
|
6
21
|
import triton.language as tl
|
7
22
|
|
8
|
-
from sglang.srt.
|
23
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
24
|
|
10
25
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
11
26
|
REDUCE_TRITON_TYPE = tl.float32
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""
|
2
17
|
A controller that manages multiple data parallel workers.
|
3
18
|
Each data parallel worker can manage multiple tensor parallel workers.
|
@@ -12,7 +27,7 @@ from enum import Enum, auto
|
|
12
27
|
import numpy as np
|
13
28
|
import zmq
|
14
29
|
|
15
|
-
from sglang.srt.managers.
|
30
|
+
from sglang.srt.managers.controller_single import (
|
16
31
|
start_controller_process as start_controller_process_single,
|
17
32
|
)
|
18
33
|
from sglang.srt.managers.io_struct import (
|
@@ -24,7 +39,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
|
24
39
|
from sglang.srt.utils import kill_parent_process
|
25
40
|
from sglang.utils import get_exception_traceback
|
26
41
|
|
27
|
-
logger = logging.getLogger(
|
42
|
+
logger = logging.getLogger(__name__)
|
28
43
|
|
29
44
|
|
30
45
|
class LoadBalanceMethod(Enum):
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""A controller that manages a group of tensor parallel workers."""
|
2
17
|
|
3
18
|
import logging
|
@@ -7,7 +22,7 @@ from typing import List
|
|
7
22
|
|
8
23
|
import zmq
|
9
24
|
|
10
|
-
from sglang.srt.managers.
|
25
|
+
from sglang.srt.managers.tp_worker import (
|
11
26
|
ModelTpServer,
|
12
27
|
broadcast_recv_input,
|
13
28
|
launch_tp_servers,
|
@@ -16,7 +31,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
|
16
31
|
from sglang.srt.utils import kill_parent_process
|
17
32
|
from sglang.utils import get_exception_traceback
|
18
33
|
|
19
|
-
logger = logging.getLogger(
|
34
|
+
logger = logging.getLogger(__name__)
|
20
35
|
|
21
36
|
|
22
37
|
class ControllerSingle:
|
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""DetokenizerManager is a process that detokenizes the token ids."""
|
2
17
|
|
3
18
|
import asyncio
|
@@ -10,8 +25,8 @@ import zmq
|
|
10
25
|
import zmq.asyncio
|
11
26
|
|
12
27
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
13
|
-
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
|
14
28
|
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
29
|
+
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
|
15
30
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
16
31
|
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
|
17
32
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -1,3 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
1
16
|
"""
|
2
17
|
The definition of objects transfered between different
|
3
18
|
processes (TokenizerManager, DetokenizerManager, Controller).
|
@@ -7,7 +22,7 @@ import uuid
|
|
7
22
|
from dataclasses import dataclass
|
8
23
|
from typing import Dict, List, Optional, Union
|
9
24
|
|
10
|
-
from sglang.srt.managers.
|
25
|
+
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
11
26
|
from sglang.srt.sampling_params import SamplingParams
|
12
27
|
|
13
28
|
|
@@ -20,7 +35,7 @@ class GenerateReqInput:
|
|
20
35
|
# The image input. It can be a file name, a url, or base64 encoded string.
|
21
36
|
# See also python/sglang/srt/utils.py:load_image.
|
22
37
|
image_data: Optional[Union[List[str], str]] = None
|
23
|
-
# The sampling_params.
|
38
|
+
# The sampling_params. See descriptions below.
|
24
39
|
sampling_params: Union[List[Dict], Dict] = None
|
25
40
|
# The request id.
|
26
41
|
rid: Optional[Union[List[str], str]] = None
|
@@ -30,7 +45,7 @@ class GenerateReqInput:
|
|
30
45
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
31
46
|
# The number of top logprobs to return.
|
32
47
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
33
|
-
# Whether to detokenize tokens in logprobs.
|
48
|
+
# Whether to detokenize tokens in text in the returned logprobs.
|
34
49
|
return_text_in_logprobs: bool = False
|
35
50
|
# Whether to stream output.
|
36
51
|
stream: bool = False
|
@@ -64,8 +79,26 @@ class GenerateReqInput:
|
|
64
79
|
if self.top_logprobs_num is None:
|
65
80
|
self.top_logprobs_num = 0
|
66
81
|
else:
|
67
|
-
|
68
|
-
|
82
|
+
parallel_sample_num_list = []
|
83
|
+
if isinstance(self.sampling_params, dict):
|
84
|
+
parallel_sample_num = self.sampling_params.get("n", 1)
|
85
|
+
elif isinstance(self.sampling_params, list):
|
86
|
+
for sp in self.sampling_params:
|
87
|
+
parallel_sample_num = sp.get("n", 1)
|
88
|
+
parallel_sample_num_list.append(parallel_sample_num)
|
89
|
+
parallel_sample_num = max(parallel_sample_num_list)
|
90
|
+
all_equal = all(
|
91
|
+
element == parallel_sample_num
|
92
|
+
for element in parallel_sample_num_list
|
93
|
+
)
|
94
|
+
if parallel_sample_num > 1 and (not all_equal):
|
95
|
+
## TODO cope with the case that the parallel_sample_num is different for different samples
|
96
|
+
raise ValueError(
|
97
|
+
"The parallel_sample_num should be the same for all samples in sample params."
|
98
|
+
)
|
99
|
+
else:
|
100
|
+
parallel_sample_num = 1
|
101
|
+
self.parallel_sample_num = parallel_sample_num
|
69
102
|
|
70
103
|
if parallel_sample_num != 1:
|
71
104
|
# parallel sampling +1 represents the original prefill stage
|
@@ -1,46 +1,61 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""Request policy scheduler"""
|
2
17
|
|
3
18
|
import random
|
4
19
|
from collections import defaultdict
|
5
20
|
|
6
21
|
|
7
|
-
class
|
22
|
+
class PolicyScheduler:
|
8
23
|
def __init__(
|
9
24
|
self,
|
10
|
-
|
25
|
+
policy,
|
11
26
|
max_running_seqs,
|
12
27
|
max_prefill_num_tokens,
|
13
28
|
max_total_num_tokens,
|
14
29
|
tree_cache,
|
15
30
|
):
|
16
|
-
if tree_cache.disable and
|
31
|
+
if tree_cache.disable and policy == "lpm":
|
17
32
|
# LMP is meaningless when the tree cache is disabled.
|
18
|
-
|
33
|
+
policy = "fcfs"
|
19
34
|
|
20
|
-
self.
|
35
|
+
self.policy = policy
|
21
36
|
self.max_running_seqs = max_running_seqs
|
22
37
|
self.max_prefill_num_tokens = max_prefill_num_tokens
|
23
38
|
self.max_total_num_tokens = max_total_num_tokens
|
24
39
|
self.tree_cache = tree_cache
|
25
40
|
|
26
|
-
def get_priority_queue(self,
|
27
|
-
if self.
|
41
|
+
def get_priority_queue(self, waiting_queue):
|
42
|
+
if self.policy == "lpm":
|
28
43
|
# longest prefix match
|
29
|
-
|
30
|
-
return
|
31
|
-
elif self.
|
44
|
+
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
|
45
|
+
return waiting_queue
|
46
|
+
elif self.policy == "fcfs":
|
32
47
|
# first come first serve
|
33
|
-
return
|
34
|
-
elif self.
|
48
|
+
return waiting_queue
|
49
|
+
elif self.policy == "lof":
|
35
50
|
# longest output first
|
36
|
-
|
37
|
-
return
|
38
|
-
elif self.
|
39
|
-
random.shuffle(
|
40
|
-
return
|
41
|
-
elif self.
|
51
|
+
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
52
|
+
return waiting_queue
|
53
|
+
elif self.policy == "random":
|
54
|
+
random.shuffle(waiting_queue)
|
55
|
+
return waiting_queue
|
56
|
+
elif self.policy == "dfs-weight":
|
42
57
|
last_node_to_reqs = defaultdict(list)
|
43
|
-
for req in
|
58
|
+
for req in waiting_queue:
|
44
59
|
last_node_to_reqs[req.last_node].append(req)
|
45
60
|
|
46
61
|
node_to_weight = defaultdict(int)
|
@@ -52,10 +67,10 @@ class ScheduleHeuristic:
|
|
52
67
|
self.get_dfs_priority(
|
53
68
|
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
|
54
69
|
)
|
55
|
-
assert len(q) == len(
|
70
|
+
assert len(q) == len(waiting_queue)
|
56
71
|
return q
|
57
72
|
else:
|
58
|
-
raise ValueError(f"Unknown
|
73
|
+
raise ValueError(f"Unknown schedule_policy: {self.policy}")
|
59
74
|
|
60
75
|
def calc_weight(self, cur_node, node_to_weight):
|
61
76
|
for child in cur_node.children.values():
|