sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +11 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +4 -3
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +71 -0
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/vision.py +13 -5
- sglang/srt/layers/communicator.py +21 -4
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +2 -7
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +77 -73
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +3 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +55 -30
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +28 -7
- sglang/srt/managers/scheduler.py +26 -12
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +24 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +7 -6
- sglang/srt/model_executor/forward_batch_info.py +35 -14
- sglang/srt/model_executor/model_runner.py +19 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +72 -33
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +24 -12
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +142 -7
- sglang/srt/two_batch_overlap.py +157 -5
- sglang/srt/utils.py +38 -2
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,14 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
2
2
|
|
3
|
-
import
|
3
|
+
import datetime
|
4
|
+
import glob
|
4
5
|
import logging
|
6
|
+
import os
|
7
|
+
import sys
|
5
8
|
from enum import Enum
|
6
|
-
from functools import lru_cache
|
7
9
|
from typing import List, Optional, Tuple
|
8
10
|
|
9
11
|
import torch
|
10
|
-
from packaging import version as pkg_version
|
11
12
|
|
12
13
|
from sglang.srt.distributed import (
|
13
14
|
get_moe_expert_parallel_rank,
|
@@ -22,6 +23,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
|
22
23
|
)
|
23
24
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
24
25
|
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
26
|
+
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
25
27
|
from sglang.srt.layers.quantization.base_config import (
|
26
28
|
QuantizationConfig,
|
27
29
|
QuantizeMethodBase,
|
@@ -29,22 +31,59 @@ from sglang.srt.layers.quantization.base_config import (
|
|
29
31
|
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
30
32
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
31
33
|
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
32
|
-
from sglang.srt.utils import
|
34
|
+
from sglang.srt.utils import (
|
35
|
+
cpu_has_amx_support,
|
36
|
+
get_bool_env_var,
|
37
|
+
is_cpu,
|
38
|
+
is_flashinfer_available,
|
39
|
+
is_hip,
|
40
|
+
next_power_of_2,
|
41
|
+
round_up,
|
42
|
+
)
|
43
|
+
|
44
|
+
if is_flashinfer_available():
|
45
|
+
from flashinfer import (
|
46
|
+
RoutingMethodType,
|
47
|
+
fp4_quantize,
|
48
|
+
reorder_rows_for_gated_act_gemm,
|
49
|
+
shuffle_matrix_a,
|
50
|
+
shuffle_matrix_sf_a,
|
51
|
+
)
|
33
52
|
|
34
53
|
_is_hip = is_hip()
|
35
54
|
_is_cpu_amx_available = cpu_has_amx_support()
|
36
55
|
_is_cpu = is_cpu()
|
37
56
|
|
57
|
+
|
58
|
+
# Try to import FP4 TRTLLM function if flashinfer is available
|
59
|
+
trtllm_fp4_block_scale_moe = None
|
60
|
+
if should_use_flashinfer_trtllm_moe():
|
61
|
+
try:
|
62
|
+
from flashinfer.fused_moe import trtllm_fp4_block_scale_moe
|
63
|
+
except ImportError:
|
64
|
+
trtllm_fp4_block_scale_moe = None
|
65
|
+
|
38
66
|
logger = logging.getLogger(__name__)
|
39
67
|
|
40
68
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
69
|
+
def _is_fp4_quantization_enabled():
|
70
|
+
"""Check if ModelOpt FP4 quantization is enabled."""
|
71
|
+
try:
|
72
|
+
# Use the same simple check that works for class selection
|
73
|
+
quantization = global_server_args_dict.get("quantization")
|
74
|
+
return quantization == "modelopt_fp4"
|
75
|
+
except:
|
76
|
+
return False
|
77
|
+
|
78
|
+
|
79
|
+
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
80
|
+
# Guess tokens per expert assuming perfect expert distribution first.
|
81
|
+
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
82
|
+
# And pad the number to the next power of 2.
|
83
|
+
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
84
|
+
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
85
|
+
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
86
|
+
return tile_tokens_dim
|
48
87
|
|
49
88
|
|
50
89
|
class FusedMoeWeightScaleSupported(Enum):
|
@@ -96,6 +135,10 @@ class FusedMoE(torch.nn.Module):
|
|
96
135
|
no_combine: bool = False,
|
97
136
|
routed_scaling_factor: Optional[float] = None,
|
98
137
|
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
138
|
+
activation_alpha: Optional[float] = None,
|
139
|
+
swiglu_limit: Optional[float] = None,
|
140
|
+
use_weight_loader_fused: bool = False,
|
141
|
+
with_bias=False,
|
99
142
|
):
|
100
143
|
super().__init__()
|
101
144
|
|
@@ -104,12 +147,15 @@ class FusedMoE(torch.nn.Module):
|
|
104
147
|
|
105
148
|
self.layer_id = layer_id
|
106
149
|
self.top_k = top_k
|
107
|
-
self.hidden_size = hidden_size
|
108
150
|
self.num_experts = num_experts
|
109
151
|
self.num_fused_shared_experts = num_fused_shared_experts
|
110
152
|
self.expert_map_cpu = None
|
111
153
|
self.expert_map_gpu = None
|
112
154
|
|
155
|
+
# For activation
|
156
|
+
self.activation_alpha = activation_alpha
|
157
|
+
self.swiglu_limit = swiglu_limit
|
158
|
+
|
113
159
|
if enable_flashinfer_cutlass_moe and quant_config is None:
|
114
160
|
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
115
161
|
enable_flashinfer_cutlass_moe = False
|
@@ -124,15 +170,18 @@ class FusedMoE(torch.nn.Module):
|
|
124
170
|
if self.moe_ep_size > 1:
|
125
171
|
# TODO(ch-wan): support shared experts fusion
|
126
172
|
# Create a tensor of size num_experts filled with -1
|
127
|
-
self.expert_map_cpu = torch.full(
|
173
|
+
self.expert_map_cpu = torch.full(
|
174
|
+
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
175
|
+
)
|
176
|
+
self.expert_map_cpu = torch.full(
|
177
|
+
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
178
|
+
)
|
128
179
|
# Create a expert map for the local experts
|
129
180
|
self.expert_map_cpu[
|
130
181
|
self.moe_ep_rank
|
131
182
|
* self.num_local_experts : (self.moe_ep_rank + 1)
|
132
183
|
* self.num_local_experts
|
133
184
|
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
134
|
-
if not self.enable_flashinfer_cutlass_moe:
|
135
|
-
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
136
185
|
|
137
186
|
self.routed_scaling_factor = routed_scaling_factor
|
138
187
|
assert intermediate_size % self.moe_tp_size == 0
|
@@ -154,13 +203,19 @@ class FusedMoE(torch.nn.Module):
|
|
154
203
|
)
|
155
204
|
else:
|
156
205
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
157
|
-
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
|
158
|
-
self.quant_method.enable_flashinfer_cutlass_moe = (
|
159
|
-
self.enable_flashinfer_cutlass_moe
|
160
|
-
)
|
161
206
|
assert self.quant_method is not None
|
162
207
|
|
163
208
|
self.quant_config = quant_config
|
209
|
+
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
|
210
|
+
"enable_flashinfer_mxfp4_moe", False
|
211
|
+
)
|
212
|
+
if (
|
213
|
+
self.quant_config is not None
|
214
|
+
and self.quant_config.get_name() == "mxfp4"
|
215
|
+
and self.use_enable_flashinfer_mxfp4_moe
|
216
|
+
):
|
217
|
+
hidden_size = round_up(hidden_size, 256)
|
218
|
+
self.hidden_size = hidden_size
|
164
219
|
self.quant_method.create_weights(
|
165
220
|
layer=self,
|
166
221
|
num_experts=self.num_local_experts,
|
@@ -169,7 +224,12 @@ class FusedMoE(torch.nn.Module):
|
|
169
224
|
intermediate_size=self.intermediate_size_per_partition,
|
170
225
|
intermediate_size_per_partition=self.intermediate_size_per_partition,
|
171
226
|
params_dtype=params_dtype,
|
172
|
-
weight_loader=
|
227
|
+
weight_loader=(
|
228
|
+
self.weight_loader
|
229
|
+
if not use_weight_loader_fused
|
230
|
+
else self.weight_loader_fused
|
231
|
+
),
|
232
|
+
with_bias=with_bias,
|
173
233
|
)
|
174
234
|
|
175
235
|
def _load_per_tensor_weight_scale(
|
@@ -197,6 +257,7 @@ class FusedMoE(torch.nn.Module):
|
|
197
257
|
shard_id: str,
|
198
258
|
loaded_weight: torch.Tensor,
|
199
259
|
tp_rank: int,
|
260
|
+
is_bias: bool = False,
|
200
261
|
):
|
201
262
|
# Load grouped weight scales for group quantization
|
202
263
|
# or model weights
|
@@ -207,14 +268,16 @@ class FusedMoE(torch.nn.Module):
|
|
207
268
|
loaded_weight=loaded_weight,
|
208
269
|
expert_data=expert_data,
|
209
270
|
tp_rank=tp_rank,
|
271
|
+
is_bias=is_bias,
|
210
272
|
)
|
211
|
-
elif shard_id in ("w1", "w3"):
|
273
|
+
elif shard_id in ("w1", "w3", "w13"):
|
212
274
|
self._load_w13(
|
213
275
|
shard_id=shard_id,
|
214
276
|
shard_dim=shard_dim,
|
215
277
|
loaded_weight=loaded_weight,
|
216
278
|
expert_data=expert_data,
|
217
279
|
tp_rank=tp_rank,
|
280
|
+
is_bias=is_bias,
|
218
281
|
)
|
219
282
|
|
220
283
|
def _load_per_channel_weight_scale(
|
@@ -244,17 +307,30 @@ class FusedMoE(torch.nn.Module):
|
|
244
307
|
shard_id: str,
|
245
308
|
loaded_weight: torch.Tensor,
|
246
309
|
tp_rank: int,
|
310
|
+
is_bias: bool = False,
|
247
311
|
):
|
248
312
|
|
249
313
|
# Index the loaded weight for tp sharding.
|
250
314
|
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
251
|
-
|
315
|
+
assert shard_id in {"w1", "w3", "w13"}
|
316
|
+
|
317
|
+
if is_bias:
|
318
|
+
# if this weight is a bias, the last dimension must be the sharded dimension
|
319
|
+
shard_dim = -1
|
320
|
+
|
321
|
+
if shard_id in {"w1", "w3"}:
|
322
|
+
# non-fused version
|
323
|
+
shard_size = expert_data.shape[shard_dim] // 2
|
324
|
+
elif shard_id in {"w13"}:
|
325
|
+
# fused version
|
326
|
+
shard_size = expert_data.shape[shard_dim]
|
327
|
+
else:
|
328
|
+
raise NotImplementedError
|
252
329
|
|
253
330
|
# Narrow parameter and load.
|
254
331
|
# w1, gate_proj: Load into first logical weight of w13.
|
255
332
|
# w3, up_proj: Load into second logical weight of w13.
|
256
333
|
# trtllm cutlass kernel assumes differently
|
257
|
-
assert shard_id in ("w1", "w3")
|
258
334
|
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
|
259
335
|
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
|
260
336
|
start = shard_size
|
@@ -273,7 +349,8 @@ class FusedMoE(torch.nn.Module):
|
|
273
349
|
)
|
274
350
|
else:
|
275
351
|
if not self.use_presharded_weights:
|
276
|
-
if self.use_triton_kernels:
|
352
|
+
if not is_bias and self.use_triton_kernels:
|
353
|
+
# do not transpose for bias
|
277
354
|
loaded_weight = loaded_weight.transpose(-2, -1)
|
278
355
|
loaded_weight = loaded_weight.narrow(
|
279
356
|
shard_dim, shard_size * tp_rank, shard_size
|
@@ -289,6 +366,7 @@ class FusedMoE(torch.nn.Module):
|
|
289
366
|
shard_id: str,
|
290
367
|
loaded_weight: torch.Tensor,
|
291
368
|
tp_rank: int,
|
369
|
+
is_bias: bool = False,
|
292
370
|
):
|
293
371
|
"""Load w2 weights for down projection.
|
294
372
|
|
@@ -319,7 +397,14 @@ class FusedMoE(torch.nn.Module):
|
|
319
397
|
# Index the loaded weight for tp sharding.
|
320
398
|
# down_proj: "RowParallel" so tp sharding on input_dim
|
321
399
|
# Narrow parameter and load.
|
322
|
-
|
400
|
+
if is_bias:
|
401
|
+
# this expert_data is a bias, not weight,
|
402
|
+
# for w2_weight_bias in TP, it does not need to be sharded
|
403
|
+
shard_size = expert_data.shape[-1]
|
404
|
+
else:
|
405
|
+
# this parameter is a weight matrix
|
406
|
+
# for w2 in TP, it shards the input_features, i.e., shard_dim=2
|
407
|
+
shard_size = expert_data.shape[shard_dim]
|
323
408
|
|
324
409
|
if _is_cpu:
|
325
410
|
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
@@ -332,13 +417,9 @@ class FusedMoE(torch.nn.Module):
|
|
332
417
|
not self.use_presharded_weights,
|
333
418
|
)
|
334
419
|
else:
|
335
|
-
if not self.use_presharded_weights:
|
420
|
+
if not is_bias and not self.use_presharded_weights:
|
336
421
|
if self.use_triton_kernels:
|
337
422
|
loaded_weight = loaded_weight.transpose(-2, -1)
|
338
|
-
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
|
339
|
-
raise ValueError(
|
340
|
-
f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
|
341
|
-
)
|
342
423
|
loaded_weight = loaded_weight.narrow(
|
343
424
|
shard_dim, shard_size * tp_rank, shard_size
|
344
425
|
)
|
@@ -386,9 +467,25 @@ class FusedMoE(torch.nn.Module):
|
|
386
467
|
loaded_weight: torch.Tensor,
|
387
468
|
weight_name: str,
|
388
469
|
shard_id: str,
|
389
|
-
expert_id: int,
|
470
|
+
expert_id: Optional[int],
|
390
471
|
) -> None:
|
391
472
|
|
473
|
+
# if expert_id is None, then
|
474
|
+
# all the experts are loaded at the same time
|
475
|
+
if (
|
476
|
+
not expert_id
|
477
|
+
and self.quant_config is not None
|
478
|
+
and self.quant_config.get_name() == "mxfp4"
|
479
|
+
):
|
480
|
+
if "bias" in weight_name:
|
481
|
+
dim1 = loaded_weight.shape[1]
|
482
|
+
param.data[:, :dim1].copy_(loaded_weight)
|
483
|
+
else:
|
484
|
+
dim1 = loaded_weight.shape[1]
|
485
|
+
dim2 = loaded_weight.shape[2]
|
486
|
+
param.data[:, :dim1, :dim2].copy_(loaded_weight)
|
487
|
+
return
|
488
|
+
|
392
489
|
global_expert_location_metadata = get_global_expert_location_metadata()
|
393
490
|
if global_expert_location_metadata is None:
|
394
491
|
self._weight_loader_impl(
|
@@ -427,6 +524,7 @@ class FusedMoE(torch.nn.Module):
|
|
427
524
|
shard_id: str,
|
428
525
|
expert_id: int,
|
429
526
|
) -> None:
|
527
|
+
|
430
528
|
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
431
529
|
if expert_id == -1:
|
432
530
|
return
|
@@ -621,16 +719,111 @@ class FusedMoE(torch.nn.Module):
|
|
621
719
|
)
|
622
720
|
return
|
623
721
|
|
722
|
+
def weight_loader_fused(
|
723
|
+
self,
|
724
|
+
param: torch.nn.Parameter,
|
725
|
+
loaded_weight: torch.Tensor,
|
726
|
+
weight_name: str,
|
727
|
+
shard_id: str,
|
728
|
+
) -> None:
|
729
|
+
tp_rank = self.moe_tp_rank
|
730
|
+
|
731
|
+
if self.quant_config is not None and self.quant_config.get_name() == "mxfp4":
|
732
|
+
if "bias" in weight_name:
|
733
|
+
dim1 = loaded_weight.shape[1]
|
734
|
+
param.data[:, :dim1].copy_(loaded_weight)
|
735
|
+
elif "scale" in weight_name:
|
736
|
+
param.data.copy_(loaded_weight)
|
737
|
+
else:
|
738
|
+
dim1 = loaded_weight.shape[1]
|
739
|
+
dim2 = loaded_weight.shape[2]
|
740
|
+
param.data[:, :dim1, :dim2].copy_(loaded_weight)
|
741
|
+
return
|
742
|
+
|
743
|
+
# compressed-tensors checkpoints with packed weights are stored flipped
|
744
|
+
# TODO: check self.quant_method.quant_config.quant_format
|
745
|
+
# against known CompressionFormat enum values that have this quality
|
746
|
+
loaded_weight = (
|
747
|
+
loaded_weight.t().contiguous()
|
748
|
+
if (
|
749
|
+
self.quant_method.__class__.__name__
|
750
|
+
== "CompressedTensorsWNA16MoEMethod"
|
751
|
+
)
|
752
|
+
else loaded_weight
|
753
|
+
)
|
754
|
+
|
755
|
+
if shard_id not in ("w13", "w2"):
|
756
|
+
raise ValueError(f"shard_id must be ['w13','w2'] but " f"got {shard_id}.")
|
757
|
+
|
758
|
+
# Fetch the dim to shard the parameter/loaded weight
|
759
|
+
# based on the shard id. This will be whatever
|
760
|
+
# dimension intermediate_size is used.
|
761
|
+
SHARD_ID_TO_SHARDED_DIM = {"w13": 1, "w2": 2}
|
762
|
+
SHARD_ID_TO_SHARDED_DIM_TRANSPOSE = {"w13": 2, "w2": 1}
|
763
|
+
|
764
|
+
expert_data = param.data
|
765
|
+
is_bias = expert_data.dim() == 2
|
766
|
+
|
767
|
+
# is_transposed: if the dim to shard the weight
|
768
|
+
# should be flipped. Required by GPTQ, compressed-tensors
|
769
|
+
# should be whatever dimension intermediate_size is
|
770
|
+
is_transposed = getattr(param, "is_transposed", False)
|
771
|
+
|
772
|
+
if self.use_triton_kernels:
|
773
|
+
is_transposed = True
|
774
|
+
shard_dim = (
|
775
|
+
SHARD_ID_TO_SHARDED_DIM[shard_id]
|
776
|
+
if not is_transposed
|
777
|
+
else SHARD_ID_TO_SHARDED_DIM_TRANSPOSE[shard_id]
|
778
|
+
)
|
779
|
+
|
780
|
+
# Case model weights
|
781
|
+
if "weight" in weight_name:
|
782
|
+
self._load_model_weight_or_group_weight_scale(
|
783
|
+
shard_id=shard_id,
|
784
|
+
shard_dim=shard_dim,
|
785
|
+
loaded_weight=loaded_weight,
|
786
|
+
expert_data=expert_data,
|
787
|
+
tp_rank=tp_rank,
|
788
|
+
is_bias=is_bias,
|
789
|
+
)
|
790
|
+
return
|
791
|
+
else:
|
792
|
+
logging.warning(
|
793
|
+
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
|
794
|
+
)
|
795
|
+
|
624
796
|
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
797
|
+
origin_hidden_states_dim = hidden_states.shape[-1]
|
798
|
+
if self.hidden_size != origin_hidden_states_dim:
|
799
|
+
hidden_states = torch.nn.functional.pad(
|
800
|
+
hidden_states,
|
801
|
+
(0, self.hidden_size - origin_hidden_states_dim),
|
802
|
+
mode="constant",
|
803
|
+
value=0.0,
|
804
|
+
)
|
625
805
|
assert self.quant_method is not None
|
626
806
|
|
627
|
-
if self.
|
807
|
+
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
|
808
|
+
if self.expert_map_cpu is not None and self.expert_map_gpu is None:
|
809
|
+
# If we are in EP mode, we need to move the expert map to GPU.
|
810
|
+
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
811
|
+
|
812
|
+
if self.expert_map_gpu is not None and isinstance(
|
813
|
+
topk_output, StandardTopKOutput
|
814
|
+
):
|
628
815
|
topk_output = topk_output._replace(
|
629
816
|
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
630
817
|
)
|
631
818
|
|
632
819
|
# Matrix multiply.
|
633
820
|
with use_symmetric_memory(get_tp_group()) as sm:
|
821
|
+
kwargs = {}
|
822
|
+
if self.activation_alpha is not None:
|
823
|
+
kwargs["activation_alpha"] = self.activation_alpha
|
824
|
+
if self.swiglu_limit is not None:
|
825
|
+
kwargs["swiglu_limit"] = self.swiglu_limit
|
826
|
+
|
634
827
|
final_hidden_states = self.quant_method.apply(
|
635
828
|
layer=self,
|
636
829
|
x=hidden_states,
|
@@ -649,13 +842,14 @@ class FusedMoE(torch.nn.Module):
|
|
649
842
|
== "ModelOptNvFp4FusedMoEMethod"
|
650
843
|
else {}
|
651
844
|
),
|
845
|
+
**kwargs,
|
652
846
|
)
|
653
847
|
sm.tag(final_hidden_states)
|
654
848
|
|
655
849
|
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
656
850
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
657
851
|
|
658
|
-
return final_hidden_states
|
852
|
+
return final_hidden_states[..., :origin_hidden_states_dim].contiguous()
|
659
853
|
|
660
854
|
@classmethod
|
661
855
|
def make_expert_params_mapping(
|
@@ -686,6 +880,52 @@ class FusedMoE(torch.nn.Module):
|
|
686
880
|
]
|
687
881
|
]
|
688
882
|
|
883
|
+
@classmethod
|
884
|
+
def make_expert_params_mapping_fused(
|
885
|
+
cls,
|
886
|
+
ckpt_gate_up_proj_name: str,
|
887
|
+
ckpt_down_proj_name: str,
|
888
|
+
ckpt_gate_up_proj_bias_name: str,
|
889
|
+
ckpt_down_proj_bias_name: str,
|
890
|
+
):
|
891
|
+
return [
|
892
|
+
("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
|
893
|
+
(
|
894
|
+
"experts.w13_weight_bias",
|
895
|
+
f"experts.{ckpt_gate_up_proj_bias_name}",
|
896
|
+
"w13",
|
897
|
+
),
|
898
|
+
("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
|
899
|
+
("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
|
900
|
+
]
|
901
|
+
|
902
|
+
@classmethod
|
903
|
+
def make_expert_params_mapping_fused_mxfp4(
|
904
|
+
cls,
|
905
|
+
ckpt_gate_up_proj_name: str,
|
906
|
+
ckpt_down_proj_name: str,
|
907
|
+
ckpt_gate_up_proj_bias_name: str,
|
908
|
+
ckpt_down_proj_bias_name: str,
|
909
|
+
ckpt_gate_up_proj_scale_name: str,
|
910
|
+
ckpt_down_proj_scale_name: str,
|
911
|
+
):
|
912
|
+
return [
|
913
|
+
("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
|
914
|
+
(
|
915
|
+
"experts.w13_weight_bias",
|
916
|
+
f"experts.{ckpt_gate_up_proj_bias_name}",
|
917
|
+
"w13",
|
918
|
+
),
|
919
|
+
("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
|
920
|
+
("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
|
921
|
+
(
|
922
|
+
"experts.w13_weight_scale",
|
923
|
+
f"experts.{ckpt_gate_up_proj_scale_name}",
|
924
|
+
"w13",
|
925
|
+
),
|
926
|
+
("experts.w2_weight_scale", f"experts.{ckpt_down_proj_scale_name}", "w2"),
|
927
|
+
]
|
928
|
+
|
689
929
|
@classmethod
|
690
930
|
def make_expert_input_scale_params_mapping(
|
691
931
|
cls,
|
@@ -721,8 +961,13 @@ class FlashInferFusedMoE(FusedMoE):
|
|
721
961
|
self.num_expert_group = num_expert_group
|
722
962
|
self.topk_group = topk_group
|
723
963
|
self.correction_bias = correction_bias
|
964
|
+
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
724
965
|
|
725
|
-
def forward(self, hidden_states: torch.Tensor,
|
966
|
+
def forward(self, hidden_states: torch.Tensor, topk_output: tuple):
|
967
|
+
assert self.use_flashinfer_trtllm_moe
|
968
|
+
assert (
|
969
|
+
self.activation == "silu"
|
970
|
+
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
726
971
|
assert self.quant_method is not None
|
727
972
|
assert (
|
728
973
|
self.renormalize
|
@@ -730,6 +975,14 @@ class FlashInferFusedMoE(FusedMoE):
|
|
730
975
|
assert (
|
731
976
|
self.num_fused_shared_experts == 0
|
732
977
|
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
978
|
+
|
979
|
+
# TRTLLM mode expects (TopK_config, router_logits) tuple
|
980
|
+
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
|
981
|
+
raise ValueError(
|
982
|
+
f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
|
983
|
+
)
|
984
|
+
_, router_logits = topk_output
|
985
|
+
|
733
986
|
# Matrix multiply.
|
734
987
|
final_hidden_states = self.quant_method.apply_with_router_logits(
|
735
988
|
layer=self,
|
@@ -739,7 +992,135 @@ class FlashInferFusedMoE(FusedMoE):
|
|
739
992
|
routed_scaling_factor=self.routed_scaling_factor,
|
740
993
|
)
|
741
994
|
|
742
|
-
if self.reduce_results and (self.
|
995
|
+
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
743
996
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
744
997
|
|
745
998
|
return final_hidden_states
|
999
|
+
|
1000
|
+
|
1001
|
+
class FlashInferFP4MoE(FusedMoE):
|
1002
|
+
"""FP4 TRTLLM MoE implementation using FlashInfer."""
|
1003
|
+
|
1004
|
+
def __init__(self, *args, **kwargs):
|
1005
|
+
# Extract DeepSeek-specific parameters
|
1006
|
+
renormalize = kwargs.pop("renormalize", True)
|
1007
|
+
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
1008
|
+
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
1009
|
+
num_expert_group = kwargs.pop("num_expert_group", None)
|
1010
|
+
topk_group = kwargs.pop("topk_group", None)
|
1011
|
+
correction_bias = kwargs.pop("correction_bias", None)
|
1012
|
+
|
1013
|
+
# Extract additional TopK parameters that were previously extracted in forward
|
1014
|
+
routed_scaling_factor = kwargs.pop("routed_scaling_factor", None)
|
1015
|
+
|
1016
|
+
super().__init__(*args, **kwargs)
|
1017
|
+
|
1018
|
+
# Store DeepSeek parameters
|
1019
|
+
self.renormalize = renormalize
|
1020
|
+
self.num_fused_shared_experts = num_fused_shared_experts
|
1021
|
+
self.use_grouped_topk = use_grouped_topk
|
1022
|
+
self.num_expert_group = num_expert_group
|
1023
|
+
self.topk_group = topk_group
|
1024
|
+
self.correction_bias = correction_bias
|
1025
|
+
self.routed_scaling_factor = routed_scaling_factor
|
1026
|
+
|
1027
|
+
# ---------------------------------------------------------------------
|
1028
|
+
# Helper: quantize hidden states to FP4 each forward pass
|
1029
|
+
# ---------------------------------------------------------------------
|
1030
|
+
def _quantize_hidden_states_fp4(self, hidden_states: torch.Tensor):
|
1031
|
+
"""
|
1032
|
+
Quantize hidden states using global scale factor from quantization method.
|
1033
|
+
|
1034
|
+
Global scale factor is set by ModelOptNvFp4FusedMoEMethod during weight loading.
|
1035
|
+
Only block scales are computed at runtime for efficiency.
|
1036
|
+
|
1037
|
+
Returns (packed_fp4_uint8, scale_float8_e4m3fn_runtime, global_scale_float32)
|
1038
|
+
"""
|
1039
|
+
|
1040
|
+
# flashinfer.fp4_quantize returns (packed_uint8, scale_fp8)
|
1041
|
+
# Only the block scales are computed at runtime
|
1042
|
+
hs_fp4_bytes, hs_sf_bytes = fp4_quantize(
|
1043
|
+
hidden_states,
|
1044
|
+
self.w13_input_scale_quant,
|
1045
|
+
16, # sf_vec_size
|
1046
|
+
False, # use_ue8m0
|
1047
|
+
False, # is_sf_swizzled_layout
|
1048
|
+
)
|
1049
|
+
|
1050
|
+
hs_fp4 = hs_fp4_bytes.reshape(
|
1051
|
+
hidden_states.shape[0], hidden_states.shape[1] // 2
|
1052
|
+
)
|
1053
|
+
hs_sf = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1)
|
1054
|
+
|
1055
|
+
return hs_fp4, hs_sf
|
1056
|
+
|
1057
|
+
def forward(self, hidden_states: torch.Tensor, topk_output):
|
1058
|
+
"""Forward pass using FP4 TRTLLM kernel.
|
1059
|
+
|
1060
|
+
Args:
|
1061
|
+
hidden_states: Input tensor
|
1062
|
+
topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode
|
1063
|
+
"""
|
1064
|
+
|
1065
|
+
# TRTLLM mode expects (TopK_config, router_logits) tuple
|
1066
|
+
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
|
1067
|
+
raise ValueError(
|
1068
|
+
f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
|
1069
|
+
)
|
1070
|
+
|
1071
|
+
_, router_logits = topk_output
|
1072
|
+
|
1073
|
+
hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
|
1074
|
+
|
1075
|
+
router_logits = router_logits.to(torch.float32)
|
1076
|
+
|
1077
|
+
result = trtllm_fp4_block_scale_moe(
|
1078
|
+
routing_logits=router_logits,
|
1079
|
+
routing_bias=self.correction_bias.to(hidden_states.dtype),
|
1080
|
+
hidden_states=hs_fp4,
|
1081
|
+
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
|
1082
|
+
gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
|
1083
|
+
gemm1_weights_scale=self.gemm1_scales_fp4_shuffled.data.view(
|
1084
|
+
torch.float8_e4m3fn
|
1085
|
+
),
|
1086
|
+
gemm1_bias=None,
|
1087
|
+
gemm1_alpha=None,
|
1088
|
+
gemm1_beta=None,
|
1089
|
+
gemm1_clamp_limit=None,
|
1090
|
+
gemm2_weights=self.gemm2_weights_fp4_shuffled.data,
|
1091
|
+
gemm2_weights_scale=self.gemm2_scales_fp4_shuffled.data.view(
|
1092
|
+
torch.float8_e4m3fn
|
1093
|
+
),
|
1094
|
+
gemm2_bias=None,
|
1095
|
+
output1_scale_scalar=self.g1_scale_c.data,
|
1096
|
+
output1_scale_gate_scalar=self.g1_alphas.data,
|
1097
|
+
output2_scale_scalar=self.g2_alphas.data,
|
1098
|
+
num_experts=self.num_experts,
|
1099
|
+
top_k=self.top_k,
|
1100
|
+
n_group=self.num_expert_group,
|
1101
|
+
topk_group=self.topk_group,
|
1102
|
+
intermediate_size=self.intermediate_size_per_partition,
|
1103
|
+
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
|
1104
|
+
local_num_experts=self.num_local_experts,
|
1105
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
1106
|
+
tile_tokens_dim=_get_tile_tokens_dim(
|
1107
|
+
hidden_states.shape[0], self.top_k, self.num_local_experts
|
1108
|
+
),
|
1109
|
+
routing_method_type=RoutingMethodType.DeepSeekV3,
|
1110
|
+
do_finalize=True,
|
1111
|
+
)[0]
|
1112
|
+
|
1113
|
+
return result
|
1114
|
+
|
1115
|
+
|
1116
|
+
def get_fused_moe_impl_class():
|
1117
|
+
"""Factory function to get the appropriate FusedMoE implementation class."""
|
1118
|
+
if should_use_flashinfer_trtllm_moe() and _is_fp4_quantization_enabled():
|
1119
|
+
# Use FP4 variant when FP4 quantization is enabled
|
1120
|
+
return FlashInferFP4MoE
|
1121
|
+
elif should_use_flashinfer_trtllm_moe():
|
1122
|
+
# Use regular FlashInfer variant for non-FP4 FlashInfer cases
|
1123
|
+
return FlashInferFusedMoE
|
1124
|
+
else:
|
1125
|
+
# Default case
|
1126
|
+
return FusedMoE
|