sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_offline_throughput.py +18 -6
- sglang/bench_one_batch.py +13 -0
- sglang/bench_serving.py +8 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/constrained/outlines_backend.py +5 -0
- sglang/srt/constrained/xgrammar_backend.py +9 -6
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +22 -5
- sglang/srt/layers/attention/torch_native_backend.py +22 -8
- sglang/srt/layers/attention/triton_backend.py +38 -33
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +665 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
- sglang/srt/layers/fused_moe_triton/layer.py +1 -1
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/quantization/__init__.py +2 -47
- sglang/srt/layers/quantization/fp8.py +607 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +11 -2
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/torchao_utils.py +58 -45
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +39 -24
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +236 -197
- sglang/srt/managers/tokenizer_manager.py +99 -58
- sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -11
- sglang/srt/model_executor/model_runner.py +24 -9
- sglang/srt/model_parallel.py +67 -10
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/deepseek_v2.py +87 -7
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +72 -13
- sglang/srt/models/llama.py +22 -5
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +12 -9
- sglang/srt/models/phi3_small.py +0 -5
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +0 -5
- sglang/srt/models/torch_native_llama.py +0 -5
- sglang/srt/openai_api/adapter.py +4 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +4 -4
- sglang/srt/server_args.py +62 -13
- sglang/srt/utils.py +57 -10
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,665 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Callable, List, Optional, Tuple
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from torch.nn import Module
|
6
|
+
from vllm import _custom_ops as ops
|
7
|
+
from vllm.distributed import (
|
8
|
+
get_tensor_model_parallel_rank,
|
9
|
+
get_tensor_model_parallel_world_size,
|
10
|
+
)
|
11
|
+
from vllm.model_executor.custom_op import CustomOp
|
12
|
+
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
13
|
+
|
14
|
+
from sglang.srt.layers.custom_op_util import register_custom_op
|
15
|
+
from sglang.srt.layers.ep_moe.kernels import (
|
16
|
+
grouped_gemm_triton,
|
17
|
+
post_reorder_triton_kernel,
|
18
|
+
pre_reorder_triton_kernel,
|
19
|
+
run_moe_ep_preproess,
|
20
|
+
silu_and_mul_triton_kernel,
|
21
|
+
)
|
22
|
+
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_topk, grouped_topk
|
23
|
+
from sglang.srt.layers.fused_moe_triton.layer import FusedMoEMethodBase
|
24
|
+
from sglang.srt.layers.quantization.base_config import (
|
25
|
+
QuantizationConfig,
|
26
|
+
QuantizeMethodBase,
|
27
|
+
)
|
28
|
+
from sglang.srt.utils import is_hip, set_weight_attrs
|
29
|
+
|
30
|
+
logger = logging.getLogger(__name__)
|
31
|
+
|
32
|
+
|
33
|
+
class GroupedGemmRunner(torch.nn.Module):
|
34
|
+
flashinfer_gemm_warpper = None
|
35
|
+
|
36
|
+
def __init__(self, device, use_flashinfer: bool = False):
|
37
|
+
super().__init__()
|
38
|
+
self.device = device
|
39
|
+
self.use_flashinfer = use_flashinfer
|
40
|
+
if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None:
|
41
|
+
GroupedGemmRunner._init_flashinfer_wrapper(device)
|
42
|
+
|
43
|
+
@classmethod
|
44
|
+
def _init_flashinfer_wrapper(cls, device):
|
45
|
+
from flashinfer import SegmentGEMMWrapper
|
46
|
+
|
47
|
+
workspace_buffer = torch.empty(
|
48
|
+
128 * 1024 * 1024, dtype=torch.int8, device=device
|
49
|
+
)
|
50
|
+
cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer)
|
51
|
+
|
52
|
+
# c = a * b
|
53
|
+
def forward(
|
54
|
+
self,
|
55
|
+
a: torch.Tensor,
|
56
|
+
b: torch.Tensor,
|
57
|
+
c: torch.Tensor,
|
58
|
+
batch_size: int,
|
59
|
+
weight_column_major: bool,
|
60
|
+
seg_indptr: Optional[torch.Tensor] = None,
|
61
|
+
weight_indices: Optional[torch.Tensor] = None,
|
62
|
+
use_fp8_w8a8: bool = False,
|
63
|
+
scale_a: torch.Tensor = None,
|
64
|
+
scale_b: torch.Tensor = None,
|
65
|
+
):
|
66
|
+
if self.use_flashinfer:
|
67
|
+
# TODO: flashinfer
|
68
|
+
assert False
|
69
|
+
assert GroupedGemmRunner.flashinfer_gemm_warpper is not None
|
70
|
+
c = GroupedGemmRunner.flashinfer_gemm_warpper.run(
|
71
|
+
x=a,
|
72
|
+
weights=b,
|
73
|
+
batch_size=batch_size,
|
74
|
+
weight_column_major=weight_column_major,
|
75
|
+
seg_indptr=seg_indptr,
|
76
|
+
weight_indices=weight_indices,
|
77
|
+
)
|
78
|
+
else:
|
79
|
+
assert weight_column_major == True
|
80
|
+
c = grouped_gemm_triton(
|
81
|
+
a,
|
82
|
+
b,
|
83
|
+
c,
|
84
|
+
batch_size,
|
85
|
+
weight_column_major,
|
86
|
+
seg_indptr,
|
87
|
+
weight_indices,
|
88
|
+
use_fp8_w8a8,
|
89
|
+
scale_a,
|
90
|
+
scale_b,
|
91
|
+
)
|
92
|
+
return c
|
93
|
+
|
94
|
+
|
95
|
+
class EPMoE(torch.nn.Module):
|
96
|
+
"""
|
97
|
+
MoE Expert Parallel Impl
|
98
|
+
|
99
|
+
|
100
|
+
"""
|
101
|
+
|
102
|
+
def __init__(
|
103
|
+
self,
|
104
|
+
num_experts: int,
|
105
|
+
top_k: int,
|
106
|
+
hidden_size: int,
|
107
|
+
intermediate_size: int,
|
108
|
+
params_dtype: Optional[torch.dtype] = None,
|
109
|
+
renormalize: bool = True,
|
110
|
+
use_grouped_topk: bool = False,
|
111
|
+
num_expert_group: Optional[int] = None,
|
112
|
+
topk_group: Optional[int] = None,
|
113
|
+
quant_config: Optional[QuantizationConfig] = None,
|
114
|
+
tp_size: Optional[int] = None,
|
115
|
+
prefix: str = "",
|
116
|
+
):
|
117
|
+
super().__init__()
|
118
|
+
|
119
|
+
if params_dtype is None:
|
120
|
+
params_dtype = torch.get_default_dtype()
|
121
|
+
|
122
|
+
self.tp_size = (
|
123
|
+
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
124
|
+
)
|
125
|
+
self.tp_rank = get_tensor_model_parallel_rank()
|
126
|
+
|
127
|
+
self.num_experts = num_experts
|
128
|
+
assert self.num_experts % self.tp_size == 0
|
129
|
+
self.num_experts_per_partition = self.num_experts // self.tp_size
|
130
|
+
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
131
|
+
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
132
|
+
|
133
|
+
self.top_k = top_k
|
134
|
+
self.intermediate_size = intermediate_size
|
135
|
+
self.renormalize = renormalize
|
136
|
+
self.use_grouped_topk = use_grouped_topk
|
137
|
+
if self.use_grouped_topk:
|
138
|
+
assert num_expert_group is not None and topk_group is not None
|
139
|
+
self.num_expert_group = num_expert_group
|
140
|
+
self.topk_group = topk_group
|
141
|
+
|
142
|
+
if quant_config is None:
|
143
|
+
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
144
|
+
self.use_fp8_w8a8 = False
|
145
|
+
self.activation_scheme = None
|
146
|
+
else:
|
147
|
+
self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
|
148
|
+
quant_config
|
149
|
+
)
|
150
|
+
self.use_fp8_w8a8 = True
|
151
|
+
self.fp8_dtype = torch.float8_e4m3fn
|
152
|
+
self.activation_scheme = quant_config.activation_scheme
|
153
|
+
|
154
|
+
self.quant_method.create_weights(
|
155
|
+
layer=self,
|
156
|
+
num_experts_per_partition=self.num_experts_per_partition,
|
157
|
+
hidden_size=hidden_size,
|
158
|
+
intermediate_size=self.intermediate_size,
|
159
|
+
params_dtype=params_dtype,
|
160
|
+
weight_loader=self.weight_loader,
|
161
|
+
)
|
162
|
+
|
163
|
+
self.grouped_gemm_runner = None
|
164
|
+
|
165
|
+
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
166
|
+
assert self.quant_method is not None
|
167
|
+
|
168
|
+
if self.grouped_gemm_runner is None:
|
169
|
+
self.grouped_gemm_runner = GroupedGemmRunner(
|
170
|
+
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
171
|
+
)
|
172
|
+
|
173
|
+
topk_weights, topk_ids = self.select_experts(
|
174
|
+
hidden_states,
|
175
|
+
router_logits,
|
176
|
+
self.top_k,
|
177
|
+
self.renormalize,
|
178
|
+
self.topk_group,
|
179
|
+
self.num_expert_group,
|
180
|
+
)
|
181
|
+
|
182
|
+
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
183
|
+
topk_ids, self.num_experts
|
184
|
+
)
|
185
|
+
|
186
|
+
gateup_input = torch.empty(
|
187
|
+
(int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
|
188
|
+
device=hidden_states.device,
|
189
|
+
dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype,
|
190
|
+
)
|
191
|
+
if self.activation_scheme == "dynamic":
|
192
|
+
max_value = (
|
193
|
+
torch.max(hidden_states)
|
194
|
+
.repeat(self.num_experts_per_partition)
|
195
|
+
.to(torch.float32)
|
196
|
+
)
|
197
|
+
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
198
|
+
|
199
|
+
# PreReorder
|
200
|
+
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
|
201
|
+
hidden_states,
|
202
|
+
gateup_input,
|
203
|
+
src2dst,
|
204
|
+
topk_ids,
|
205
|
+
self.w13_input_scale,
|
206
|
+
self.start_expert_id,
|
207
|
+
self.end_expert_id,
|
208
|
+
self.top_k,
|
209
|
+
hidden_states.shape[1],
|
210
|
+
BLOCK_SIZE=512,
|
211
|
+
)
|
212
|
+
|
213
|
+
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
214
|
+
weight_indices_cur_rank = torch.arange(
|
215
|
+
0,
|
216
|
+
self.num_experts_per_partition,
|
217
|
+
device=hidden_states.device,
|
218
|
+
dtype=torch.int64,
|
219
|
+
)
|
220
|
+
# GroupGemm-0
|
221
|
+
gateup_output = torch.empty(
|
222
|
+
gateup_input.shape[0],
|
223
|
+
self.w13_weight.shape[1],
|
224
|
+
device=hidden_states.device,
|
225
|
+
dtype=hidden_states.dtype,
|
226
|
+
)
|
227
|
+
gateup_output = self.grouped_gemm_runner(
|
228
|
+
a=gateup_input,
|
229
|
+
b=self.w13_weight,
|
230
|
+
c=gateup_output,
|
231
|
+
batch_size=self.num_experts_per_partition,
|
232
|
+
weight_column_major=True,
|
233
|
+
seg_indptr=seg_indptr_cur_rank,
|
234
|
+
weight_indices=weight_indices_cur_rank,
|
235
|
+
use_fp8_w8a8=self.use_fp8_w8a8,
|
236
|
+
scale_a=self.w13_input_scale,
|
237
|
+
scale_b=self.w13_weight_scale,
|
238
|
+
)
|
239
|
+
|
240
|
+
# Act
|
241
|
+
down_input = torch.empty(
|
242
|
+
gateup_output.shape[0],
|
243
|
+
gateup_output.shape[1] // 2,
|
244
|
+
device=gateup_output.device,
|
245
|
+
dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype,
|
246
|
+
)
|
247
|
+
if self.w2_input_scale is None:
|
248
|
+
self.w2_input_scale = torch.ones(
|
249
|
+
self.num_experts_per_partition,
|
250
|
+
dtype=torch.float32,
|
251
|
+
device=hidden_states.device,
|
252
|
+
)
|
253
|
+
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
254
|
+
gateup_output,
|
255
|
+
down_input,
|
256
|
+
gateup_output.shape[1],
|
257
|
+
reorder_topk_ids,
|
258
|
+
self.w2_input_scale,
|
259
|
+
self.start_expert_id,
|
260
|
+
self.end_expert_id,
|
261
|
+
BLOCK_SIZE=512,
|
262
|
+
)
|
263
|
+
|
264
|
+
# GroupGemm-1
|
265
|
+
down_output = torch.empty(
|
266
|
+
down_input.shape[0],
|
267
|
+
self.w2_weight.shape[1],
|
268
|
+
device=hidden_states.device,
|
269
|
+
dtype=hidden_states.dtype,
|
270
|
+
)
|
271
|
+
down_output = self.grouped_gemm_runner(
|
272
|
+
a=down_input,
|
273
|
+
b=self.w2_weight,
|
274
|
+
c=down_output,
|
275
|
+
batch_size=self.num_experts_per_partition,
|
276
|
+
weight_column_major=True,
|
277
|
+
seg_indptr=seg_indptr_cur_rank,
|
278
|
+
weight_indices=weight_indices_cur_rank,
|
279
|
+
use_fp8_w8a8=self.use_fp8_w8a8,
|
280
|
+
scale_a=self.w2_input_scale,
|
281
|
+
scale_b=self.w2_weight_scale,
|
282
|
+
)
|
283
|
+
|
284
|
+
# PostReorder
|
285
|
+
output = torch.empty_like(hidden_states)
|
286
|
+
post_reorder_triton_kernel[(hidden_states.size(0),)](
|
287
|
+
down_output,
|
288
|
+
output,
|
289
|
+
src2dst,
|
290
|
+
topk_ids,
|
291
|
+
topk_weights,
|
292
|
+
self.start_expert_id,
|
293
|
+
self.end_expert_id,
|
294
|
+
self.top_k,
|
295
|
+
hidden_states.size(1),
|
296
|
+
BLOCK_SIZE=512,
|
297
|
+
)
|
298
|
+
return output
|
299
|
+
|
300
|
+
def select_experts(
|
301
|
+
self,
|
302
|
+
hidden_states: torch.Tensor,
|
303
|
+
router_logits: torch.Tensor,
|
304
|
+
top_k: int,
|
305
|
+
renormalize: bool,
|
306
|
+
topk_group: Optional[int] = None,
|
307
|
+
num_expert_group: Optional[int] = None,
|
308
|
+
):
|
309
|
+
if self.use_grouped_topk:
|
310
|
+
assert topk_group is not None
|
311
|
+
assert num_expert_group is not None
|
312
|
+
topk_weights, topk_ids = grouped_topk(
|
313
|
+
hidden_states=hidden_states,
|
314
|
+
gating_output=router_logits,
|
315
|
+
topk=top_k,
|
316
|
+
renormalize=renormalize,
|
317
|
+
num_expert_group=num_expert_group,
|
318
|
+
topk_group=topk_group,
|
319
|
+
)
|
320
|
+
else:
|
321
|
+
topk_weights, topk_ids = fused_topk(
|
322
|
+
hidden_states=hidden_states,
|
323
|
+
gating_output=router_logits,
|
324
|
+
topk=top_k,
|
325
|
+
renormalize=renormalize,
|
326
|
+
)
|
327
|
+
return topk_weights, topk_ids.to(torch.int32)
|
328
|
+
|
329
|
+
@classmethod
|
330
|
+
def make_expert_params_mapping(
|
331
|
+
cls,
|
332
|
+
ckpt_gate_proj_name: str,
|
333
|
+
ckpt_down_proj_name: str,
|
334
|
+
ckpt_up_proj_name: str,
|
335
|
+
num_experts: int,
|
336
|
+
) -> List[Tuple[str, str, int, str]]:
|
337
|
+
|
338
|
+
return [
|
339
|
+
# (param_name, weight_name, expert_id, shard_id)
|
340
|
+
(
|
341
|
+
(
|
342
|
+
"experts.w13_"
|
343
|
+
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
|
344
|
+
else "experts.w2_"
|
345
|
+
),
|
346
|
+
f"experts.{expert_id}.{weight_name}.",
|
347
|
+
expert_id,
|
348
|
+
shard_id,
|
349
|
+
)
|
350
|
+
for expert_id in range(num_experts)
|
351
|
+
for shard_id, weight_name in [
|
352
|
+
("w1", ckpt_gate_proj_name),
|
353
|
+
("w2", ckpt_down_proj_name),
|
354
|
+
("w3", ckpt_up_proj_name),
|
355
|
+
]
|
356
|
+
]
|
357
|
+
|
358
|
+
def weight_loader(
|
359
|
+
self,
|
360
|
+
param: torch.nn.Parameter,
|
361
|
+
loaded_weight: torch.Tensor,
|
362
|
+
weight_name: str,
|
363
|
+
shard_id: str,
|
364
|
+
expert_id: int,
|
365
|
+
) -> None:
|
366
|
+
if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
|
367
|
+
return
|
368
|
+
expert_id = expert_id - self.start_expert_id
|
369
|
+
|
370
|
+
if shard_id not in ("w1", "w2", "w3"):
|
371
|
+
raise ValueError(
|
372
|
+
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
|
373
|
+
)
|
374
|
+
|
375
|
+
# Special case for fp8 scales.
|
376
|
+
if "scale" in weight_name:
|
377
|
+
self._load_fp8_scale(
|
378
|
+
param.data, loaded_weight, weight_name, shard_id, expert_id
|
379
|
+
)
|
380
|
+
return
|
381
|
+
|
382
|
+
expert_data = param.data[expert_id]
|
383
|
+
if shard_id == "w2":
|
384
|
+
param.data[expert_id] = loaded_weight
|
385
|
+
elif shard_id == "w1":
|
386
|
+
param.data[expert_id][: self.intermediate_size, :] = loaded_weight
|
387
|
+
elif shard_id == "w3":
|
388
|
+
param.data[expert_id][self.intermediate_size :, :] = loaded_weight
|
389
|
+
else:
|
390
|
+
raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}")
|
391
|
+
|
392
|
+
def _load_fp8_scale(
|
393
|
+
self,
|
394
|
+
param: torch.nn.Parameter,
|
395
|
+
loaded_weight: torch.Tensor,
|
396
|
+
weight_name: str,
|
397
|
+
shard_id: str,
|
398
|
+
expert_id: int,
|
399
|
+
) -> None:
|
400
|
+
param_data = param.data
|
401
|
+
|
402
|
+
# Input scales can be loaded directly and should be equal.
|
403
|
+
if "input_scale" in weight_name:
|
404
|
+
if (
|
405
|
+
param_data[expert_id] != 1
|
406
|
+
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
407
|
+
):
|
408
|
+
raise ValueError(
|
409
|
+
"input_scales of w1 and w3 of a layer "
|
410
|
+
f"must be equal. But got {param_data[expert_id]} "
|
411
|
+
f"vs. {loaded_weight}"
|
412
|
+
)
|
413
|
+
param_data[expert_id] = loaded_weight
|
414
|
+
# Weight scales
|
415
|
+
elif "weight_scale" in weight_name:
|
416
|
+
# If we are in merged column case (gate_up_proj)
|
417
|
+
if shard_id in ("w1", "w3"):
|
418
|
+
# We have to keep the weight scales of w1 and w3 because
|
419
|
+
# we need to re-quantize w1/w3 weights after weight loading.
|
420
|
+
idx = 0 if shard_id == "w1" else 1
|
421
|
+
param_data[expert_id][idx] = loaded_weight
|
422
|
+
# If we are in the row parallel case (down_proj)
|
423
|
+
else:
|
424
|
+
param_data[expert_id] = loaded_weight
|
425
|
+
|
426
|
+
|
427
|
+
@register_custom_op("sglang_unquantized_ep_moe")
|
428
|
+
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
429
|
+
def create_weights(
|
430
|
+
self,
|
431
|
+
layer: torch.nn.Module,
|
432
|
+
num_experts_per_partition: int,
|
433
|
+
hidden_size: int,
|
434
|
+
intermediate_size: int,
|
435
|
+
params_dtype: torch.dtype,
|
436
|
+
**extra_weight_attrs,
|
437
|
+
):
|
438
|
+
# Fused gate_up_proj (column parallel)
|
439
|
+
w13_weight = torch.nn.Parameter(
|
440
|
+
torch.empty(
|
441
|
+
num_experts_per_partition,
|
442
|
+
2 * intermediate_size,
|
443
|
+
hidden_size,
|
444
|
+
dtype=params_dtype,
|
445
|
+
),
|
446
|
+
requires_grad=False,
|
447
|
+
)
|
448
|
+
layer.register_parameter("w13_weight", w13_weight)
|
449
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
450
|
+
|
451
|
+
# down_proj (row parallel)
|
452
|
+
w2_weight = torch.nn.Parameter(
|
453
|
+
torch.empty(
|
454
|
+
num_experts_per_partition,
|
455
|
+
hidden_size,
|
456
|
+
intermediate_size,
|
457
|
+
dtype=params_dtype,
|
458
|
+
),
|
459
|
+
requires_grad=False,
|
460
|
+
)
|
461
|
+
layer.register_parameter("w2_weight", w2_weight)
|
462
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
463
|
+
|
464
|
+
# scale
|
465
|
+
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
466
|
+
w13_input_scale = torch.nn.Parameter(
|
467
|
+
ones_tensor,
|
468
|
+
requires_grad=False,
|
469
|
+
)
|
470
|
+
layer.register_parameter("w13_input_scale", w13_input_scale)
|
471
|
+
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
472
|
+
|
473
|
+
w2_input_scale = torch.nn.Parameter(
|
474
|
+
ones_tensor,
|
475
|
+
requires_grad=False,
|
476
|
+
)
|
477
|
+
layer.register_parameter("w2_input_scale", w2_input_scale)
|
478
|
+
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
479
|
+
|
480
|
+
w13_weight_scale = torch.nn.Parameter(
|
481
|
+
ones_tensor,
|
482
|
+
requires_grad=False,
|
483
|
+
)
|
484
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
485
|
+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
486
|
+
|
487
|
+
w2_weight_scale = torch.nn.Parameter(
|
488
|
+
ones_tensor,
|
489
|
+
requires_grad=False,
|
490
|
+
)
|
491
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
492
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
493
|
+
|
494
|
+
def apply(
|
495
|
+
self,
|
496
|
+
layer: torch.nn.Module,
|
497
|
+
x: torch.Tensor,
|
498
|
+
router_logits: torch.Tensor,
|
499
|
+
top_k: int,
|
500
|
+
renormalize: bool,
|
501
|
+
use_grouped_topk: bool,
|
502
|
+
topk_group: Optional[int] = None,
|
503
|
+
num_expert_group: Optional[int] = None,
|
504
|
+
custom_routing_function: Optional[Callable] = None,
|
505
|
+
) -> torch.Tensor:
|
506
|
+
raise NotImplementedError
|
507
|
+
|
508
|
+
|
509
|
+
class Fp8EPMoEMethod(Fp8MoEMethod):
|
510
|
+
"""MoE method for FP8.
|
511
|
+
Supports loading FP8 checkpoints with static weight scale and
|
512
|
+
dynamic/static activation scale.
|
513
|
+
|
514
|
+
Args:
|
515
|
+
quant_config: The quantization config.
|
516
|
+
"""
|
517
|
+
|
518
|
+
def __init__(self, quant_config: Fp8Config):
|
519
|
+
self.quant_config = quant_config
|
520
|
+
|
521
|
+
def create_weights(
|
522
|
+
self,
|
523
|
+
layer: Module,
|
524
|
+
num_experts_per_partition: int,
|
525
|
+
hidden_size: int,
|
526
|
+
intermediate_size: int,
|
527
|
+
params_dtype: torch.dtype,
|
528
|
+
**extra_weight_attrs,
|
529
|
+
):
|
530
|
+
|
531
|
+
if self.quant_config.is_checkpoint_fp8_serialized:
|
532
|
+
params_dtype = torch.float8_e4m3fn
|
533
|
+
|
534
|
+
# WEIGHTS
|
535
|
+
w13_weight = torch.nn.Parameter(
|
536
|
+
torch.empty(
|
537
|
+
num_experts_per_partition,
|
538
|
+
2 * intermediate_size,
|
539
|
+
hidden_size,
|
540
|
+
dtype=params_dtype,
|
541
|
+
),
|
542
|
+
requires_grad=False,
|
543
|
+
)
|
544
|
+
layer.register_parameter("w13_weight", w13_weight)
|
545
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
546
|
+
|
547
|
+
w2_weight = torch.nn.Parameter(
|
548
|
+
torch.empty(
|
549
|
+
num_experts_per_partition,
|
550
|
+
hidden_size,
|
551
|
+
intermediate_size,
|
552
|
+
dtype=params_dtype,
|
553
|
+
),
|
554
|
+
requires_grad=False,
|
555
|
+
)
|
556
|
+
layer.register_parameter("w2_weight", w2_weight)
|
557
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
558
|
+
|
559
|
+
# WEIGHT_SCALES
|
560
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
561
|
+
w13_weight_scale = torch.nn.Parameter(
|
562
|
+
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
|
563
|
+
requires_grad=False,
|
564
|
+
)
|
565
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
566
|
+
|
567
|
+
w2_weight_scale = torch.nn.Parameter(
|
568
|
+
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
569
|
+
requires_grad=False,
|
570
|
+
)
|
571
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
572
|
+
# Add the quantization method used (per tensor/grouped/channel)
|
573
|
+
# to ensure the weight scales are loaded in properly
|
574
|
+
extra_weight_attrs.update({"quant_method": "tensor"})
|
575
|
+
# If loading fp8 checkpoint, pass the weight loaders.
|
576
|
+
# If loading an fp16 checkpoint, do not (we will quantize in
|
577
|
+
# process_weights_after_loading()
|
578
|
+
if self.quant_config.is_checkpoint_fp8_serialized:
|
579
|
+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
580
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
581
|
+
|
582
|
+
# INPUT_SCALES
|
583
|
+
if self.quant_config.activation_scheme == "static":
|
584
|
+
if not self.quant_config.is_checkpoint_fp8_serialized:
|
585
|
+
raise ValueError(
|
586
|
+
"Found static activation scheme for checkpoint that "
|
587
|
+
"was not serialized fp8."
|
588
|
+
)
|
589
|
+
|
590
|
+
w13_input_scale = torch.nn.Parameter(
|
591
|
+
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
592
|
+
requires_grad=False,
|
593
|
+
)
|
594
|
+
layer.register_parameter("w13_input_scale", w13_input_scale)
|
595
|
+
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
596
|
+
|
597
|
+
w2_input_scale = torch.nn.Parameter(
|
598
|
+
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
599
|
+
requires_grad=False,
|
600
|
+
)
|
601
|
+
layer.register_parameter("w2_input_scale", w2_input_scale)
|
602
|
+
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
603
|
+
|
604
|
+
else:
|
605
|
+
layer.w13_input_scale = None
|
606
|
+
layer.w2_input_scale = None
|
607
|
+
|
608
|
+
def process_weights_after_loading(self, layer: Module) -> None:
|
609
|
+
|
610
|
+
# If checkpoint is fp16, quantize in place.
|
611
|
+
if not self.quant_config.is_checkpoint_fp8_serialized:
|
612
|
+
# If rocm, use float8_e4m3fnuz as dtype
|
613
|
+
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
614
|
+
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
615
|
+
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
616
|
+
|
617
|
+
layer.w13_weight_scale = torch.nn.Parameter(
|
618
|
+
torch.ones(
|
619
|
+
layer.num_experts_per_partition,
|
620
|
+
dtype=torch.float32,
|
621
|
+
device=w13_weight.device,
|
622
|
+
),
|
623
|
+
requires_grad=False,
|
624
|
+
)
|
625
|
+
|
626
|
+
for expert in range(layer.num_experts_per_partition):
|
627
|
+
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
628
|
+
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
629
|
+
)
|
630
|
+
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
631
|
+
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
632
|
+
)
|
633
|
+
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
634
|
+
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
635
|
+
return
|
636
|
+
|
637
|
+
# If checkpoint is fp8, we need to handle that the
|
638
|
+
# MoE kernels require single activation scale and single weight
|
639
|
+
# scale for w13 per expert.
|
640
|
+
else:
|
641
|
+
if self.quant_config.activation_scheme == "static":
|
642
|
+
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
643
|
+
raise ValueError(
|
644
|
+
"QuantConfig has static quantization, but found "
|
645
|
+
"activation scales are None."
|
646
|
+
)
|
647
|
+
layer.w13_weight_scale = torch.nn.Parameter(
|
648
|
+
torch.max(layer.w13_weight_scale, dim=1).values,
|
649
|
+
requires_grad=False,
|
650
|
+
)
|
651
|
+
return
|
652
|
+
|
653
|
+
def apply(
|
654
|
+
self,
|
655
|
+
layer: torch.nn.Module,
|
656
|
+
x: torch.Tensor,
|
657
|
+
router_logits: torch.Tensor,
|
658
|
+
top_k: int,
|
659
|
+
renormalize: bool,
|
660
|
+
use_grouped_topk: bool,
|
661
|
+
topk_group: Optional[int] = None,
|
662
|
+
num_expert_group: Optional[int] = None,
|
663
|
+
custom_routing_function: Optional[Callable] = None,
|
664
|
+
) -> torch.Tensor:
|
665
|
+
raise NotImplementedError
|