sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +30 -4
- sglang/backend/litellm.py +2 -2
- sglang/backend/openai.py +26 -15
- sglang/backend/runtime_endpoint.py +18 -14
- sglang/bench_latency.py +317 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +41 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -2
- sglang/lang/ir.py +74 -28
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +6 -3
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +2 -0
- sglang/srt/hf_transformers_utils.py +68 -9
- sglang/srt/layers/extend_attention.py +2 -1
- sglang/srt/layers/fused_moe.py +280 -169
- sglang/srt/layers/logits_processor.py +106 -42
- sglang/srt/layers/radix_attention.py +53 -29
- sglang/srt/layers/token_attention.py +4 -1
- sglang/srt/managers/controller/dp_worker.py +6 -3
- sglang/srt/managers/controller/infer_batch.py +144 -69
- sglang/srt/managers/controller/manager_multi.py +5 -5
- sglang/srt/managers/controller/manager_single.py +9 -4
- sglang/srt/managers/controller/model_runner.py +167 -55
- sglang/srt/managers/controller/radix_cache.py +4 -0
- sglang/srt/managers/controller/schedule_heuristic.py +2 -0
- sglang/srt/managers/controller/tp_worker.py +156 -134
- sglang/srt/managers/detokenizer_manager.py +19 -21
- sglang/srt/managers/io_struct.py +11 -5
- sglang/srt/managers/tokenizer_manager.py +16 -14
- sglang/srt/model_config.py +89 -4
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +204 -137
- sglang/srt/models/llama2.py +12 -5
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +11 -8
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +373 -0
- sglang/srt/models/mixtral.py +164 -115
- sglang/srt/models/mixtral_quant.py +0 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +454 -0
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/models/yivl.py +2 -2
- sglang/srt/openai_api_adapter.py +35 -25
- sglang/srt/openai_protocol.py +2 -2
- sglang/srt/server.py +69 -19
- sglang/srt/server_args.py +76 -43
- sglang/srt/utils.py +177 -35
- sglang/test/test_programs.py +28 -10
- sglang/utils.py +4 -3
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
- sglang-0.1.19.dist-info/RECORD +81 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
- sglang/srt/managers/router/infer_batch.py +0 -596
- sglang/srt/managers/router/manager.py +0 -82
- sglang/srt/managers/router/model_rpc.py +0 -818
- sglang/srt/managers/router/model_runner.py +0 -445
- sglang/srt/managers/router/radix_cache.py +0 -267
- sglang/srt/managers/router/scheduler.py +0 -59
- sglang-0.1.17.dist-info/RECORD +0 -81
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
sglang/srt/models/mixtral.py
CHANGED
@@ -33,13 +33,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
33
33
|
from vllm.model_executor.utils import set_weight_attrs
|
34
34
|
from vllm.utils import print_warning_once
|
35
35
|
|
36
|
-
|
37
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
38
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
38
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
40
39
|
|
41
40
|
|
42
|
-
|
43
41
|
class MixtralMoE(nn.Module):
|
44
42
|
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
|
45
43
|
across all ranks.
|
@@ -76,32 +74,46 @@ class MixtralMoE(nn.Module):
|
|
76
74
|
self.params_dtype = params_dtype
|
77
75
|
|
78
76
|
# Gate always runs at half / full precision for now.
|
79
|
-
self.gate = ReplicatedLinear(
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
77
|
+
self.gate = ReplicatedLinear(
|
78
|
+
self.hidden_size,
|
79
|
+
self.num_total_experts,
|
80
|
+
bias=False,
|
81
|
+
params_dtype=self.params_dtype,
|
82
|
+
quant_config=None,
|
83
|
+
)
|
84
84
|
|
85
85
|
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
86
86
|
params_dtype = torch.float8_e4m3fn
|
87
87
|
|
88
88
|
self.w13_weight = nn.Parameter(
|
89
|
-
torch.empty(
|
90
|
-
|
91
|
-
|
92
|
-
|
89
|
+
torch.empty(
|
90
|
+
self.num_total_experts,
|
91
|
+
2 * self.intermediate_size,
|
92
|
+
self.hidden_size,
|
93
|
+
dtype=params_dtype,
|
94
|
+
)
|
95
|
+
)
|
93
96
|
self.w2_weight = nn.Parameter(
|
94
|
-
torch.empty(
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
set_weight_attrs(
|
103
|
-
|
104
|
-
|
97
|
+
torch.empty(
|
98
|
+
self.num_total_experts,
|
99
|
+
self.hidden_size,
|
100
|
+
self.intermediate_size,
|
101
|
+
dtype=params_dtype,
|
102
|
+
)
|
103
|
+
)
|
104
|
+
|
105
|
+
set_weight_attrs(
|
106
|
+
self.w13_weight,
|
107
|
+
{
|
108
|
+
"weight_loader": self.weight_loader,
|
109
|
+
},
|
110
|
+
)
|
111
|
+
set_weight_attrs(
|
112
|
+
self.w2_weight,
|
113
|
+
{
|
114
|
+
"weight_loader": self.weight_loader,
|
115
|
+
},
|
116
|
+
)
|
105
117
|
|
106
118
|
# Used for fp8.
|
107
119
|
self.w13_scale = None
|
@@ -111,46 +123,68 @@ class MixtralMoE(nn.Module):
|
|
111
123
|
|
112
124
|
if self.use_fp8:
|
113
125
|
# WEIGHT_SCALE (for fp8)
|
114
|
-
self.w13_scale = nn.Parameter(
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
126
|
+
self.w13_scale = nn.Parameter(
|
127
|
+
torch.ones(self.num_total_experts, dtype=torch.float32),
|
128
|
+
requires_grad=False,
|
129
|
+
)
|
130
|
+
self.w2_scale = nn.Parameter(
|
131
|
+
torch.ones(self.num_total_experts, dtype=torch.float32),
|
132
|
+
requires_grad=False,
|
133
|
+
)
|
120
134
|
|
121
135
|
# If loading fp8 checkpoint, pass the weight loaders.
|
122
136
|
# If loading an fp16 checkpoint, do not (we will quantize in
|
123
137
|
# process_weights_after_loading()
|
124
138
|
if quant_config.is_checkpoint_fp8_serialized:
|
125
|
-
set_weight_attrs(
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
139
|
+
set_weight_attrs(
|
140
|
+
self.w13_scale,
|
141
|
+
{
|
142
|
+
"weight_loader": self.weight_loader,
|
143
|
+
},
|
144
|
+
)
|
145
|
+
set_weight_attrs(
|
146
|
+
self.w2_scale,
|
147
|
+
{
|
148
|
+
"weight_loader": self.weight_loader,
|
149
|
+
},
|
150
|
+
)
|
131
151
|
|
132
152
|
# ACT_SCALE (for fp8)
|
133
153
|
if quant_config.activation_scheme == "static":
|
134
154
|
if not quant_config.is_checkpoint_fp8_serialized:
|
135
155
|
raise ValueError(
|
136
156
|
"Found static activation scheme for checkpoint that "
|
137
|
-
"was not serialized fp8."
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
set_weight_attrs(
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
157
|
+
"was not serialized fp8."
|
158
|
+
)
|
159
|
+
self.a13_scale = nn.Parameter(
|
160
|
+
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
161
|
+
requires_grad=False,
|
162
|
+
)
|
163
|
+
self.a2_scale = nn.Parameter(
|
164
|
+
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
165
|
+
requires_grad=False,
|
166
|
+
)
|
167
|
+
|
168
|
+
set_weight_attrs(
|
169
|
+
self.a13_scale,
|
170
|
+
{
|
171
|
+
"weight_loader": self.weight_loader,
|
172
|
+
},
|
173
|
+
)
|
174
|
+
set_weight_attrs(
|
175
|
+
self.a2_scale,
|
176
|
+
{
|
177
|
+
"weight_loader": self.weight_loader,
|
178
|
+
},
|
179
|
+
)
|
180
|
+
|
181
|
+
def weight_loader(
|
182
|
+
self,
|
183
|
+
param: nn.Parameter,
|
184
|
+
loaded_weight: torch.Tensor,
|
185
|
+
weight_name: str,
|
186
|
+
expert_id: int,
|
187
|
+
):
|
154
188
|
tp_rank = get_tensor_model_parallel_rank()
|
155
189
|
param_data = param.data
|
156
190
|
shard_size = self.intermediate_size
|
@@ -158,8 +192,9 @@ class MixtralMoE(nn.Module):
|
|
158
192
|
if weight_name.endswith("w1.weight"):
|
159
193
|
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
160
194
|
if weight_name.endswith("w3.weight"):
|
161
|
-
param_data[expert_id,
|
162
|
-
|
195
|
+
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
|
196
|
+
shard, :
|
197
|
+
]
|
163
198
|
if weight_name.endswith("w2.weight"):
|
164
199
|
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
165
200
|
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
@@ -172,17 +207,17 @@ class MixtralMoE(nn.Module):
|
|
172
207
|
|
173
208
|
# If checkpoint is fp16, quantize here.
|
174
209
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
175
|
-
w13_weight = torch.empty_like(
|
176
|
-
|
177
|
-
|
178
|
-
|
210
|
+
w13_weight = torch.empty_like(
|
211
|
+
self.w13_weight.data, dtype=torch.float8_e4m3fn
|
212
|
+
)
|
213
|
+
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
|
179
214
|
for expert in range(self.num_total_experts):
|
180
|
-
w13_weight[expert, :, :], self.w13_scale[
|
181
|
-
expert
|
182
|
-
|
183
|
-
w2_weight[expert, :, :], self.w2_scale[
|
184
|
-
expert
|
185
|
-
|
215
|
+
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
|
216
|
+
self.w13_weight.data[expert, :, :]
|
217
|
+
)
|
218
|
+
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
|
219
|
+
self.w2_weight.data[expert, :, :]
|
220
|
+
)
|
186
221
|
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
187
222
|
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
188
223
|
|
@@ -193,40 +228,40 @@ class MixtralMoE(nn.Module):
|
|
193
228
|
if self.a13_scale is None or self.a2_scale is None:
|
194
229
|
raise ValueError(
|
195
230
|
"QuantConfig has static quantization, but found "
|
196
|
-
"activation scales are None."
|
231
|
+
"activation scales are None."
|
232
|
+
)
|
197
233
|
|
198
|
-
if
|
199
|
-
or not all_close_1d(self.a2_scale)):
|
234
|
+
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
|
200
235
|
print_warning_once(
|
201
236
|
"Found act_scales that are not equal for fp8 MoE layer. "
|
202
|
-
"Using the maximum across experts for each layer. "
|
237
|
+
"Using the maximum across experts for each layer. "
|
238
|
+
)
|
203
239
|
|
204
|
-
self.a13_scale = nn.Parameter(self.a13_scale.max(),
|
205
|
-
|
206
|
-
self.a2_scale = nn.Parameter(self.a2_scale.max(),
|
207
|
-
requires_grad=False)
|
240
|
+
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
|
241
|
+
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
|
208
242
|
|
209
243
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
210
244
|
num_tokens, hidden_size = hidden_states.shape
|
211
245
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
212
246
|
# router_logits: (num_tokens, n_experts)
|
213
247
|
router_logits, _ = self.gate(hidden_states)
|
214
|
-
final_hidden_states = fused_moe(
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
248
|
+
final_hidden_states = fused_moe(
|
249
|
+
hidden_states,
|
250
|
+
self.w13_weight,
|
251
|
+
self.w2_weight,
|
252
|
+
router_logits,
|
253
|
+
self.top_k,
|
254
|
+
renormalize=True,
|
255
|
+
inplace=True,
|
256
|
+
use_fp8=self.use_fp8,
|
257
|
+
w1_scale=self.w13_scale,
|
258
|
+
w2_scale=self.w2_scale,
|
259
|
+
a1_scale=self.a13_scale,
|
260
|
+
a2_scale=self.a2_scale,
|
261
|
+
)
|
226
262
|
|
227
263
|
if self.tp_size > 1:
|
228
|
-
final_hidden_states = tensor_model_parallel_all_reduce(
|
229
|
-
final_hidden_states)
|
264
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
230
265
|
|
231
266
|
return final_hidden_states.view(num_tokens, hidden_size)
|
232
267
|
|
@@ -335,7 +370,8 @@ class MixtralDecoderLayer(nn.Module):
|
|
335
370
|
top_k=config.num_experts_per_tok,
|
336
371
|
hidden_size=config.hidden_size,
|
337
372
|
intermediate_size=config.intermediate_size,
|
338
|
-
quant_config=quant_config
|
373
|
+
quant_config=quant_config,
|
374
|
+
)
|
339
375
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
340
376
|
self.post_attention_layernorm = RMSNorm(
|
341
377
|
config.hidden_size, eps=config.rms_norm_eps
|
@@ -444,35 +480,48 @@ class MixtralForCausalLM(nn.Module):
|
|
444
480
|
("qkv_proj", "v_proj", "v"),
|
445
481
|
]
|
446
482
|
|
447
|
-
expert_params_mapping =
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
483
|
+
expert_params_mapping = (
|
484
|
+
[
|
485
|
+
# These are the weight scales for the experts
|
486
|
+
# (param_name, weight_name, expert_id)
|
487
|
+
(
|
488
|
+
"w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
489
|
+
f"experts.{expert_id}.{weight_name}.weight_scale",
|
490
|
+
expert_id,
|
491
|
+
)
|
492
|
+
for expert_id in range(self.config.num_local_experts)
|
493
|
+
for weight_name in ["w1", "w2", "w3"]
|
494
|
+
]
|
495
|
+
+ [
|
496
|
+
# These are the weights for the experts
|
497
|
+
# (param_name, weight_name, expert_id)
|
498
|
+
(
|
499
|
+
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
500
|
+
f"experts.{expert_id}.{weight_name}.weight",
|
501
|
+
expert_id,
|
502
|
+
)
|
503
|
+
for expert_id in range(self.config.num_local_experts)
|
504
|
+
for weight_name in ["w1", "w2", "w3"]
|
505
|
+
]
|
506
|
+
+ [
|
507
|
+
# These are the activation scales for the experts
|
508
|
+
# (param_name, weight_name, expert_id)
|
509
|
+
(
|
510
|
+
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
511
|
+
f"experts.{expert_id}.{weight_name}.act_scale",
|
512
|
+
expert_id,
|
513
|
+
)
|
514
|
+
for expert_id in range(self.config.num_local_experts)
|
515
|
+
for weight_name in ["w1", "w2", "w3"]
|
516
|
+
]
|
517
|
+
)
|
469
518
|
|
470
519
|
params_dict = dict(self.named_parameters())
|
471
520
|
for name, loaded_weight in weights:
|
472
521
|
if "rotary_emb.inv_freq" in name:
|
473
522
|
continue
|
474
523
|
|
475
|
-
for
|
524
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
476
525
|
if weight_name not in name:
|
477
526
|
continue
|
478
527
|
name = name.replace(weight_name, param_name)
|
@@ -490,18 +539,18 @@ class MixtralForCausalLM(nn.Module):
|
|
490
539
|
name = name.replace(weight_name, param_name)
|
491
540
|
param = params_dict[name]
|
492
541
|
weight_loader = param.weight_loader
|
493
|
-
weight_loader(
|
494
|
-
|
495
|
-
|
496
|
-
expert_id=expert_id)
|
542
|
+
weight_loader(
|
543
|
+
param, loaded_weight, weight_name, expert_id=expert_id
|
544
|
+
)
|
497
545
|
break
|
498
546
|
else:
|
499
547
|
# Skip loading extra bias for GPTQ models.
|
500
548
|
if name.endswith(".bias") and name not in params_dict:
|
501
549
|
continue
|
502
550
|
param = params_dict[name]
|
503
|
-
weight_loader = getattr(
|
504
|
-
|
551
|
+
weight_loader = getattr(
|
552
|
+
param, "weight_loader", default_weight_loader
|
553
|
+
)
|
505
554
|
weight_loader(param, loaded_weight)
|
506
555
|
|
507
556
|
|
@@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
28
28
|
)
|
29
29
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
30
30
|
|
31
|
-
|
32
31
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
33
32
|
from sglang.srt.layers.radix_attention import RadixAttention
|
34
33
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
sglang/srt/models/qwen.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
# Adapted from
|
2
2
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
|
3
|
-
from typing import Any, Dict,
|
3
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
4
4
|
|
5
5
|
import torch
|
6
6
|
from torch import nn
|
sglang/srt/models/qwen2.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# Adapted from llama2.py
|
2
2
|
# Modify details for the adaptation of Qwen2 model.
|
3
3
|
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
4
|
-
from typing import Any, Dict, Optional, Tuple
|
4
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch import nn
|