sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +1 -0
- sglang/bench_serving.py +9 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +4 -1
- sglang/srt/layers/attention/flashinfer_backend.py +51 -5
- sglang/srt/layers/attention/triton_backend.py +16 -25
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
- sglang/srt/layers/moe/topk.py +191 -0
- sglang/srt/layers/quantization/__init__.py +5 -50
- sglang/srt/layers/quantization/fp8.py +221 -36
- sglang/srt/layers/quantization/fp8_kernel.py +278 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/radix_attention.py +8 -1
- sglang/srt/layers/sampler.py +27 -5
- sglang/srt/layers/torchao_utils.py +31 -0
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +54 -34
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +171 -136
- sglang/srt/managers/tokenizer_manager.py +184 -133
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +25 -11
- sglang/srt/model_executor/model_runner.py +28 -14
- sglang/srt/model_parallel.py +66 -5
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +73 -9
- sglang/srt/models/llama.py +22 -0
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +8 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/server.py +2 -1
- sglang/srt/server_args.py +19 -9
- sglang/srt/utils.py +40 -54
- sglang/test/test_block_fp8.py +341 -0
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
sglang/srt/layers/linear.py
CHANGED
@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
30
30
|
QuantizationConfig,
|
31
31
|
QuantizeMethodBase,
|
32
32
|
)
|
33
|
+
from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
|
33
34
|
from sglang.srt.utils import set_weight_attrs
|
34
35
|
|
35
36
|
logger = logging.getLogger(__name__)
|
@@ -628,8 +629,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
628
629
|
assert loaded_shard_id < len(self.output_sizes)
|
629
630
|
|
630
631
|
tp_size = get_tensor_model_parallel_world_size()
|
631
|
-
|
632
|
-
|
632
|
+
|
633
|
+
if isinstance(param, BlockQuantScaleParameter):
|
634
|
+
weight_block_size = self.quant_method.quant_config.weight_block_size
|
635
|
+
block_n, _ = weight_block_size[0], weight_block_size[1]
|
636
|
+
shard_offset = (
|
637
|
+
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
|
638
|
+
) // tp_size
|
639
|
+
shard_size = (
|
640
|
+
(self.output_sizes[loaded_shard_id] + block_n - 1) // block_n // tp_size
|
641
|
+
)
|
642
|
+
else:
|
643
|
+
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
644
|
+
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
633
645
|
|
634
646
|
param.load_merged_column_weight(
|
635
647
|
loaded_weight=loaded_weight,
|
@@ -795,6 +807,12 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
795
807
|
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
|
796
808
|
shard_size = self._get_shard_size_mapping(loaded_shard_id)
|
797
809
|
|
810
|
+
if isinstance(param, BlockQuantScaleParameter):
|
811
|
+
weight_block_size = self.quant_method.quant_config.weight_block_size
|
812
|
+
block_n, _ = weight_block_size[0], weight_block_size[1]
|
813
|
+
shard_offset = (shard_offset + block_n - 1) // block_n
|
814
|
+
shard_size = (shard_size + block_n - 1) // block_n
|
815
|
+
|
798
816
|
param.load_qkv_weight(
|
799
817
|
loaded_weight=loaded_weight,
|
800
818
|
num_heads=self.num_kv_head_replicas,
|
@@ -39,10 +39,12 @@ class LogitsProcessorOutput:
|
|
39
39
|
# The logprobs of input tokens. shape: [#token, vocab_size]
|
40
40
|
input_token_logprobs: torch.Tensor = None
|
41
41
|
|
42
|
-
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k]
|
43
|
-
|
44
|
-
|
45
|
-
|
42
|
+
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k]
|
43
|
+
input_top_logprobs_val: List = None
|
44
|
+
input_top_logprobs_idx: List = None
|
45
|
+
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k]
|
46
|
+
output_top_logprobs_val: List = None
|
47
|
+
output_top_logprobs_idx: List = None
|
46
48
|
|
47
49
|
|
48
50
|
@dataclasses.dataclass
|
@@ -89,76 +91,18 @@ class LogitsMetadata:
|
|
89
91
|
|
90
92
|
|
91
93
|
class LogitsProcessor(nn.Module):
|
92
|
-
def __init__(
|
94
|
+
def __init__(
|
95
|
+
self, config, skip_all_gather: bool = False, logit_scale: Optional[float] = None
|
96
|
+
):
|
93
97
|
super().__init__()
|
94
98
|
self.config = config
|
99
|
+
self.logit_scale = logit_scale
|
95
100
|
self.do_tensor_parallel_all_gather = (
|
96
101
|
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
97
102
|
)
|
98
|
-
|
99
|
-
|
100
|
-
self,
|
101
|
-
input_token_logprobs: torch.Tensor,
|
102
|
-
logits_metadata: LogitsMetadata,
|
103
|
-
):
|
104
|
-
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
105
|
-
pruned_lens = torch.tensor(
|
106
|
-
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
|
107
|
-
)
|
108
|
-
|
109
|
-
start = torch.zeros_like(pruned_lens)
|
110
|
-
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
|
111
|
-
end = torch.clamp(
|
112
|
-
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
|
113
|
-
)
|
114
|
-
sum_logp = (
|
115
|
-
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
103
|
+
self.final_logit_softcapping = getattr(
|
104
|
+
self.config, "final_logit_softcapping", None
|
116
105
|
)
|
117
|
-
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
|
118
|
-
return normalized_prompt_logprobs
|
119
|
-
|
120
|
-
@staticmethod
|
121
|
-
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
122
|
-
max_k = max(logits_metadata.top_logprobs_nums)
|
123
|
-
ret = all_logprobs.topk(max_k, dim=1)
|
124
|
-
values = ret.values.tolist()
|
125
|
-
indices = ret.indices.tolist()
|
126
|
-
|
127
|
-
if logits_metadata.forward_mode.is_decode():
|
128
|
-
output_top_logprobs = []
|
129
|
-
for i, k in enumerate(logits_metadata.top_logprobs_nums):
|
130
|
-
output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
|
131
|
-
return None, output_top_logprobs
|
132
|
-
else:
|
133
|
-
input_top_logprobs, output_top_logprobs = [], []
|
134
|
-
|
135
|
-
pt = 0
|
136
|
-
for k, pruned_len in zip(
|
137
|
-
logits_metadata.top_logprobs_nums,
|
138
|
-
logits_metadata.extend_logprob_pruned_lens_cpu,
|
139
|
-
):
|
140
|
-
if pruned_len <= 0:
|
141
|
-
input_top_logprobs.append([])
|
142
|
-
output_top_logprobs.append([])
|
143
|
-
continue
|
144
|
-
|
145
|
-
input_top_logprobs.append(
|
146
|
-
[
|
147
|
-
list(zip(values[pt + j][:k], indices[pt + j][:k]))
|
148
|
-
for j in range(pruned_len - 1)
|
149
|
-
]
|
150
|
-
)
|
151
|
-
output_top_logprobs.append(
|
152
|
-
list(
|
153
|
-
zip(
|
154
|
-
values[pt + pruned_len - 1][:k],
|
155
|
-
indices[pt + pruned_len - 1][:k],
|
156
|
-
)
|
157
|
-
)
|
158
|
-
)
|
159
|
-
pt += pruned_len
|
160
|
-
|
161
|
-
return input_top_logprobs, output_top_logprobs
|
162
106
|
|
163
107
|
def forward(
|
164
108
|
self,
|
@@ -184,38 +128,33 @@ class LogitsProcessor(nn.Module):
|
|
184
128
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
185
129
|
last_logits = last_logits[:, : self.config.vocab_size].float()
|
186
130
|
|
187
|
-
if
|
188
|
-
last_logits.div_(self.
|
131
|
+
if self.final_logit_softcapping:
|
132
|
+
last_logits.div_(self.final_logit_softcapping)
|
189
133
|
torch.tanh(last_logits, out=last_logits)
|
190
|
-
last_logits.mul_(self.
|
134
|
+
last_logits.mul_(self.final_logit_softcapping)
|
191
135
|
|
192
136
|
# Return only last_logits if logprob is not requested
|
193
137
|
if not logits_metadata.return_logprob:
|
194
138
|
return LogitsProcessorOutput(
|
195
139
|
next_token_logits=last_logits,
|
196
|
-
next_token_logprobs=None,
|
197
|
-
normalized_prompt_logprobs=None,
|
198
|
-
input_token_logprobs=None,
|
199
|
-
input_top_logprobs=None,
|
200
|
-
output_top_logprobs=None,
|
201
140
|
)
|
202
141
|
else:
|
203
|
-
last_logprobs =
|
142
|
+
last_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
143
|
+
last_logits, logits_metadata
|
144
|
+
)
|
204
145
|
|
205
146
|
if logits_metadata.forward_mode.is_decode():
|
206
147
|
if logits_metadata.return_top_logprob:
|
207
|
-
|
208
|
-
last_logprobs, logits_metadata
|
209
|
-
)
|
148
|
+
output_top_logprobs_val, output_top_logprobs_idx = (
|
149
|
+
self.get_top_logprobs(last_logprobs, logits_metadata)[2:4]
|
150
|
+
)
|
210
151
|
else:
|
211
|
-
|
152
|
+
output_top_logprobs_val = output_top_logprobs_idx = None
|
212
153
|
return LogitsProcessorOutput(
|
213
154
|
next_token_logits=last_logits,
|
214
155
|
next_token_logprobs=last_logprobs,
|
215
|
-
|
216
|
-
|
217
|
-
input_top_logprobs=None,
|
218
|
-
output_top_logprobs=output_top_logprobs,
|
156
|
+
output_top_logprobs_val=output_top_logprobs_val,
|
157
|
+
output_top_logprobs_idx=output_top_logprobs_idx,
|
219
158
|
)
|
220
159
|
else:
|
221
160
|
# Slice the requested tokens to compute logprob
|
@@ -233,24 +172,35 @@ class LogitsProcessor(nn.Module):
|
|
233
172
|
all_logits = self._get_logits(states, lm_head)
|
234
173
|
if self.do_tensor_parallel_all_gather:
|
235
174
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
175
|
+
|
176
|
+
# The LM head's weights may be zero-padded for parallelism. Remove any
|
177
|
+
# extra logits that this padding may have produced.
|
236
178
|
all_logits = all_logits[:, : self.config.vocab_size].float()
|
237
179
|
|
238
|
-
if
|
239
|
-
all_logits.div_(self.
|
180
|
+
if self.final_logit_softcapping:
|
181
|
+
all_logits.div_(self.final_logit_softcapping)
|
240
182
|
torch.tanh(all_logits, out=all_logits)
|
241
|
-
all_logits.mul_(self.
|
183
|
+
all_logits.mul_(self.final_logit_softcapping)
|
242
184
|
|
243
185
|
all_logprobs = all_logits
|
244
186
|
del all_logits, hidden_states
|
245
|
-
|
187
|
+
|
188
|
+
all_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
189
|
+
all_logprobs, logits_metadata
|
190
|
+
)
|
246
191
|
|
247
192
|
# Get the logprob of top-k tokens
|
248
193
|
if logits_metadata.return_top_logprob:
|
249
|
-
|
250
|
-
|
251
|
-
|
194
|
+
(
|
195
|
+
input_top_logprobs_val,
|
196
|
+
input_top_logprobs_idx,
|
197
|
+
output_top_logprobs_val,
|
198
|
+
output_top_logprobs_idx,
|
199
|
+
) = self.get_top_logprobs(all_logprobs, logits_metadata)
|
252
200
|
else:
|
253
|
-
|
201
|
+
input_top_logprobs_val = input_top_logprobs_idx = (
|
202
|
+
output_top_logprobs_val
|
203
|
+
) = output_top_logprobs_idx = None
|
254
204
|
|
255
205
|
# Compute the normalized logprobs for the requested tokens.
|
256
206
|
# Note that we pad a zero at the end for easy batching.
|
@@ -273,8 +223,10 @@ class LogitsProcessor(nn.Module):
|
|
273
223
|
next_token_logprobs=last_logprobs,
|
274
224
|
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
275
225
|
input_token_logprobs=input_token_logprobs,
|
276
|
-
|
277
|
-
|
226
|
+
input_top_logprobs_val=input_top_logprobs_val,
|
227
|
+
input_top_logprobs_idx=input_top_logprobs_idx,
|
228
|
+
output_top_logprobs_val=output_top_logprobs_val,
|
229
|
+
output_top_logprobs_idx=output_top_logprobs_idx,
|
278
230
|
)
|
279
231
|
|
280
232
|
def _get_logits(
|
@@ -288,8 +240,94 @@ class LogitsProcessor(nn.Module):
|
|
288
240
|
else:
|
289
241
|
# GGUF models
|
290
242
|
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
|
243
|
+
|
244
|
+
# Optional scaling factor
|
245
|
+
if self.logit_scale is not None:
|
246
|
+
logits.mul_(self.logit_scale) # In-place multiply
|
291
247
|
return logits
|
292
248
|
|
249
|
+
@staticmethod
|
250
|
+
def _get_normalized_prompt_logprobs(
|
251
|
+
input_token_logprobs: torch.Tensor,
|
252
|
+
logits_metadata: LogitsMetadata,
|
253
|
+
):
|
254
|
+
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
255
|
+
pruned_lens = torch.tensor(
|
256
|
+
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
|
257
|
+
)
|
258
|
+
|
259
|
+
start = torch.zeros_like(pruned_lens)
|
260
|
+
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
|
261
|
+
end = torch.clamp(
|
262
|
+
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
|
263
|
+
)
|
264
|
+
sum_logp = (
|
265
|
+
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
266
|
+
)
|
267
|
+
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
|
268
|
+
return normalized_prompt_logprobs
|
269
|
+
|
270
|
+
@staticmethod
|
271
|
+
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
272
|
+
max_k = max(logits_metadata.top_logprobs_nums)
|
273
|
+
ret = all_logprobs.topk(max_k, dim=1)
|
274
|
+
values = ret.values.tolist()
|
275
|
+
indices = ret.indices.tolist()
|
276
|
+
|
277
|
+
if logits_metadata.forward_mode.is_decode():
|
278
|
+
output_top_logprobs_val = []
|
279
|
+
output_top_logprobs_idx = []
|
280
|
+
for i, k in enumerate(logits_metadata.top_logprobs_nums):
|
281
|
+
output_top_logprobs_val.append(values[i][:k])
|
282
|
+
output_top_logprobs_idx.append(indices[i][:k])
|
283
|
+
return None, None, output_top_logprobs_val, output_top_logprobs_idx
|
284
|
+
else:
|
285
|
+
input_top_logprobs_val, input_top_logprobs_idx = [], []
|
286
|
+
output_top_logprobs_val, output_top_logprobs_idx = [], []
|
287
|
+
|
288
|
+
pt = 0
|
289
|
+
for k, pruned_len in zip(
|
290
|
+
logits_metadata.top_logprobs_nums,
|
291
|
+
logits_metadata.extend_logprob_pruned_lens_cpu,
|
292
|
+
):
|
293
|
+
if pruned_len <= 0:
|
294
|
+
input_top_logprobs_val.append([])
|
295
|
+
input_top_logprobs_idx.append([])
|
296
|
+
output_top_logprobs_val.append([])
|
297
|
+
output_top_logprobs_idx.append([])
|
298
|
+
continue
|
299
|
+
|
300
|
+
input_top_logprobs_val.append(
|
301
|
+
[values[pt + j][:k] for j in range(pruned_len - 1)]
|
302
|
+
)
|
303
|
+
input_top_logprobs_idx.append(
|
304
|
+
[indices[pt + j][:k] for j in range(pruned_len - 1)]
|
305
|
+
)
|
306
|
+
output_top_logprobs_val.append(
|
307
|
+
list(
|
308
|
+
values[pt + pruned_len - 1][:k],
|
309
|
+
)
|
310
|
+
)
|
311
|
+
output_top_logprobs_idx.append(
|
312
|
+
list(
|
313
|
+
indices[pt + pruned_len - 1][:k],
|
314
|
+
)
|
315
|
+
)
|
316
|
+
pt += pruned_len
|
317
|
+
|
318
|
+
return (
|
319
|
+
input_top_logprobs_val,
|
320
|
+
input_top_logprobs_idx,
|
321
|
+
output_top_logprobs_val,
|
322
|
+
output_top_logprobs_idx,
|
323
|
+
)
|
324
|
+
|
325
|
+
@staticmethod
|
326
|
+
def compute_temp_top_p_normalized_logprobs(
|
327
|
+
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
|
328
|
+
) -> torch.Tensor:
|
329
|
+
return torch.nn.functional.log_softmax(last_logits, dim=-1)
|
330
|
+
|
293
331
|
|
294
332
|
def test():
|
295
333
|
all_logprobs = torch.tensor(
|
@@ -12,15 +12,15 @@ from vllm.model_executor.custom_op import CustomOp
|
|
12
12
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
13
13
|
|
14
14
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
15
|
-
from sglang.srt.layers.ep_moe.kernels import (
|
15
|
+
from sglang.srt.layers.moe.ep_moe.kernels import (
|
16
16
|
grouped_gemm_triton,
|
17
17
|
post_reorder_triton_kernel,
|
18
18
|
pre_reorder_triton_kernel,
|
19
19
|
run_moe_ep_preproess,
|
20
20
|
silu_and_mul_triton_kernel,
|
21
21
|
)
|
22
|
-
from sglang.srt.layers.fused_moe_triton.
|
23
|
-
from sglang.srt.layers.
|
22
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
|
23
|
+
from sglang.srt.layers.moe.topk import select_experts
|
24
24
|
from sglang.srt.layers.quantization.base_config import (
|
25
25
|
QuantizationConfig,
|
26
26
|
QuantizeMethodBase,
|
@@ -113,6 +113,7 @@ class EPMoE(torch.nn.Module):
|
|
113
113
|
quant_config: Optional[QuantizationConfig] = None,
|
114
114
|
tp_size: Optional[int] = None,
|
115
115
|
prefix: str = "",
|
116
|
+
correction_bias: Optional[torch.Tensor] = None,
|
116
117
|
):
|
117
118
|
super().__init__()
|
118
119
|
|
@@ -138,6 +139,7 @@ class EPMoE(torch.nn.Module):
|
|
138
139
|
assert num_expert_group is not None and topk_group is not None
|
139
140
|
self.num_expert_group = num_expert_group
|
140
141
|
self.topk_group = topk_group
|
142
|
+
self.correction_bias = correction_bias
|
141
143
|
|
142
144
|
if quant_config is None:
|
143
145
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
@@ -170,13 +172,15 @@ class EPMoE(torch.nn.Module):
|
|
170
172
|
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
171
173
|
)
|
172
174
|
|
173
|
-
topk_weights, topk_ids =
|
174
|
-
hidden_states,
|
175
|
-
router_logits,
|
176
|
-
self.top_k,
|
177
|
-
self.
|
178
|
-
self.
|
179
|
-
self.
|
175
|
+
topk_weights, topk_ids = select_experts(
|
176
|
+
hidden_states=hidden_states,
|
177
|
+
router_logits=router_logits,
|
178
|
+
top_k=self.top_k,
|
179
|
+
use_grouped_topk=self.use_grouped_topk,
|
180
|
+
renormalize=self.renormalize,
|
181
|
+
topk_group=self.topk_group,
|
182
|
+
num_expert_group=self.num_expert_group,
|
183
|
+
correction_bias=self.correction_bias,
|
180
184
|
)
|
181
185
|
|
182
186
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
@@ -297,35 +301,6 @@ class EPMoE(torch.nn.Module):
|
|
297
301
|
)
|
298
302
|
return output
|
299
303
|
|
300
|
-
def select_experts(
|
301
|
-
self,
|
302
|
-
hidden_states: torch.Tensor,
|
303
|
-
router_logits: torch.Tensor,
|
304
|
-
top_k: int,
|
305
|
-
renormalize: bool,
|
306
|
-
topk_group: Optional[int] = None,
|
307
|
-
num_expert_group: Optional[int] = None,
|
308
|
-
):
|
309
|
-
if self.use_grouped_topk:
|
310
|
-
assert topk_group is not None
|
311
|
-
assert num_expert_group is not None
|
312
|
-
topk_weights, topk_ids = grouped_topk(
|
313
|
-
hidden_states=hidden_states,
|
314
|
-
gating_output=router_logits,
|
315
|
-
topk=top_k,
|
316
|
-
renormalize=renormalize,
|
317
|
-
num_expert_group=num_expert_group,
|
318
|
-
topk_group=topk_group,
|
319
|
-
)
|
320
|
-
else:
|
321
|
-
topk_weights, topk_ids = fused_topk(
|
322
|
-
hidden_states=hidden_states,
|
323
|
-
gating_output=router_logits,
|
324
|
-
topk=top_k,
|
325
|
-
renormalize=renormalize,
|
326
|
-
)
|
327
|
-
return topk_weights, topk_ids.to(torch.int32)
|
328
|
-
|
329
304
|
@classmethod
|
330
305
|
def make_expert_params_mapping(
|
331
306
|
cls,
|
@@ -644,6 +619,10 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
644
619
|
"QuantConfig has static quantization, but found "
|
645
620
|
"activation scales are None."
|
646
621
|
)
|
622
|
+
layer.w13_weight_scale = torch.nn.Parameter(
|
623
|
+
torch.max(layer.w13_weight_scale, dim=1).values,
|
624
|
+
requires_grad=False,
|
625
|
+
)
|
647
626
|
return
|
648
627
|
|
649
628
|
def apply(
|
@@ -0,0 +1,46 @@
|
|
1
|
+
"""
|
2
|
+
Torch-native implementation for FusedMoE. This is used for torch.compile.
|
3
|
+
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
|
4
|
+
"""
|
5
|
+
|
6
|
+
from typing import Callable, Optional
|
7
|
+
|
8
|
+
import torch
|
9
|
+
from torch.nn import functional as F
|
10
|
+
|
11
|
+
from sglang.srt.layers.moe.topk import select_experts
|
12
|
+
|
13
|
+
|
14
|
+
def fused_moe_forward_native(
|
15
|
+
layer: torch.nn.Module,
|
16
|
+
x: torch.Tensor,
|
17
|
+
use_grouped_topk: bool,
|
18
|
+
top_k: int,
|
19
|
+
router_logits: torch.Tensor,
|
20
|
+
renormalize: bool,
|
21
|
+
topk_group: Optional[int] = None,
|
22
|
+
num_expert_group: Optional[int] = None,
|
23
|
+
custom_routing_function: Optional[Callable] = None,
|
24
|
+
correction_bias: Optional[torch.Tensor] = None,
|
25
|
+
) -> torch.Tensor:
|
26
|
+
topk_weights, topk_ids = select_experts(
|
27
|
+
hidden_states=x,
|
28
|
+
router_logits=router_logits,
|
29
|
+
use_grouped_topk=use_grouped_topk,
|
30
|
+
top_k=top_k,
|
31
|
+
renormalize=renormalize,
|
32
|
+
topk_group=topk_group,
|
33
|
+
num_expert_group=num_expert_group,
|
34
|
+
custom_routing_function=custom_routing_function,
|
35
|
+
correction_bias=correction_bias,
|
36
|
+
torch_native=True,
|
37
|
+
)
|
38
|
+
|
39
|
+
w13_weights = layer.w13_weight[topk_ids]
|
40
|
+
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
41
|
+
w2_weights = layer.w2_weight[topk_ids]
|
42
|
+
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
43
|
+
x1 = F.silu(x1)
|
44
|
+
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
45
|
+
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
46
|
+
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
@@ -1,14 +1,12 @@
|
|
1
1
|
from contextlib import contextmanager
|
2
2
|
from typing import Any, Dict, Optional
|
3
3
|
|
4
|
-
import sglang.srt.layers.fused_moe_triton.fused_moe # noqa
|
5
|
-
from sglang.srt.layers.fused_moe_triton.fused_moe import (
|
4
|
+
import sglang.srt.layers.moe.fused_moe_triton.fused_moe # noqa
|
5
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
6
6
|
fused_experts,
|
7
|
-
fused_topk,
|
8
7
|
get_config_file_name,
|
9
|
-
grouped_topk,
|
10
8
|
)
|
11
|
-
from sglang.srt.layers.fused_moe_triton.layer import (
|
9
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
12
10
|
FusedMoE,
|
13
11
|
FusedMoEMethodBase,
|
14
12
|
FusedMoeWeightScaleSupported,
|
@@ -37,8 +35,6 @@ __all__ = [
|
|
37
35
|
"override_config",
|
38
36
|
"get_config",
|
39
37
|
"fused_moe",
|
40
|
-
"fused_topk",
|
41
38
|
"fused_experts",
|
42
39
|
"get_config_file_name",
|
43
|
-
"grouped_topk",
|
44
40
|
]
|