sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -117
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +3 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +22 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +8 -5
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +106 -15
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +55 -13
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
- sglang/srt/layers/attention/vision.py +40 -15
- sglang/srt/layers/communicator.py +35 -8
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +9 -8
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +87 -107
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +59 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +8 -7
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +15 -4
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +10 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +61 -32
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +21 -4
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +30 -8
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +170 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +59 -22
- sglang/srt/managers/tokenizer_manager.py +137 -67
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -21
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +48 -17
- sglang/srt/model_executor/model_runner.py +24 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +95 -50
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +102 -27
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/qwen3_moe.py +39 -14
- sglang/srt/models/step3_vl.py +10 -1
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +218 -23
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +163 -9
- sglang/srt/utils.py +41 -26
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/runners.py +4 -4
- sglang/test/test_utils.py +4 -4
- sglang/version.py +1 -1
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -1,25 +1,29 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
2
2
|
|
3
|
-
import
|
3
|
+
import datetime
|
4
|
+
import glob
|
4
5
|
import logging
|
6
|
+
import os
|
7
|
+
import sys
|
5
8
|
from enum import Enum
|
6
|
-
from functools import lru_cache
|
7
9
|
from typing import List, Optional, Tuple
|
8
10
|
|
9
11
|
import torch
|
10
|
-
from packaging import version as pkg_version
|
11
12
|
|
12
13
|
from sglang.srt.distributed import (
|
13
14
|
get_moe_expert_parallel_rank,
|
14
15
|
get_moe_expert_parallel_world_size,
|
15
16
|
get_moe_tensor_parallel_rank,
|
16
17
|
get_moe_tensor_parallel_world_size,
|
17
|
-
|
18
|
-
get_tensor_model_parallel_world_size,
|
18
|
+
get_tp_group,
|
19
19
|
tensor_model_parallel_all_reduce,
|
20
20
|
)
|
21
|
+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
22
|
+
use_symmetric_memory,
|
23
|
+
)
|
21
24
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
22
25
|
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
26
|
+
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
23
27
|
from sglang.srt.layers.quantization.base_config import (
|
24
28
|
QuantizationConfig,
|
25
29
|
QuantizeMethodBase,
|
@@ -27,22 +31,59 @@ from sglang.srt.layers.quantization.base_config import (
|
|
27
31
|
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
28
32
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
29
33
|
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
30
|
-
from sglang.srt.utils import
|
34
|
+
from sglang.srt.utils import (
|
35
|
+
cpu_has_amx_support,
|
36
|
+
get_bool_env_var,
|
37
|
+
is_cpu,
|
38
|
+
is_flashinfer_available,
|
39
|
+
is_hip,
|
40
|
+
next_power_of_2,
|
41
|
+
round_up,
|
42
|
+
)
|
43
|
+
|
44
|
+
if is_flashinfer_available():
|
45
|
+
from flashinfer import (
|
46
|
+
RoutingMethodType,
|
47
|
+
fp4_quantize,
|
48
|
+
reorder_rows_for_gated_act_gemm,
|
49
|
+
shuffle_matrix_a,
|
50
|
+
shuffle_matrix_sf_a,
|
51
|
+
)
|
31
52
|
|
32
53
|
_is_hip = is_hip()
|
33
54
|
_is_cpu_amx_available = cpu_has_amx_support()
|
34
55
|
_is_cpu = is_cpu()
|
35
56
|
|
57
|
+
|
58
|
+
# Try to import FP4 TRTLLM function if flashinfer is available
|
59
|
+
trtllm_fp4_block_scale_moe = None
|
60
|
+
if should_use_flashinfer_trtllm_moe():
|
61
|
+
try:
|
62
|
+
from flashinfer.fused_moe import trtllm_fp4_block_scale_moe
|
63
|
+
except ImportError:
|
64
|
+
trtllm_fp4_block_scale_moe = None
|
65
|
+
|
36
66
|
logger = logging.getLogger(__name__)
|
37
67
|
|
38
68
|
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
69
|
+
def _is_fp4_quantization_enabled():
|
70
|
+
"""Check if ModelOpt FP4 quantization is enabled."""
|
71
|
+
try:
|
72
|
+
# Use the same simple check that works for class selection
|
73
|
+
quantization = global_server_args_dict.get("quantization")
|
74
|
+
return quantization == "modelopt_fp4"
|
75
|
+
except:
|
76
|
+
return False
|
77
|
+
|
78
|
+
|
79
|
+
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
80
|
+
# Guess tokens per expert assuming perfect expert distribution first.
|
81
|
+
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
82
|
+
# And pad the number to the next power of 2.
|
83
|
+
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
84
|
+
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
85
|
+
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
86
|
+
return tile_tokens_dim
|
46
87
|
|
47
88
|
|
48
89
|
class FusedMoeWeightScaleSupported(Enum):
|
@@ -94,7 +135,10 @@ class FusedMoE(torch.nn.Module):
|
|
94
135
|
no_combine: bool = False,
|
95
136
|
routed_scaling_factor: Optional[float] = None,
|
96
137
|
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
97
|
-
|
138
|
+
activation_alpha: Optional[float] = None,
|
139
|
+
swiglu_limit: Optional[float] = None,
|
140
|
+
use_weight_loader_fused: bool = False,
|
141
|
+
with_bias=False,
|
98
142
|
):
|
99
143
|
super().__init__()
|
100
144
|
|
@@ -103,16 +147,18 @@ class FusedMoE(torch.nn.Module):
|
|
103
147
|
|
104
148
|
self.layer_id = layer_id
|
105
149
|
self.top_k = top_k
|
106
|
-
self.hidden_size = hidden_size
|
107
150
|
self.num_experts = num_experts
|
108
151
|
self.num_fused_shared_experts = num_fused_shared_experts
|
109
152
|
self.expert_map_cpu = None
|
110
153
|
self.expert_map_gpu = None
|
111
154
|
|
155
|
+
# For activation
|
156
|
+
self.activation_alpha = activation_alpha
|
157
|
+
self.swiglu_limit = swiglu_limit
|
158
|
+
|
112
159
|
if enable_flashinfer_cutlass_moe and quant_config is None:
|
113
160
|
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
114
161
|
enable_flashinfer_cutlass_moe = False
|
115
|
-
enable_ep_moe = False
|
116
162
|
|
117
163
|
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
|
118
164
|
self.moe_ep_size = get_moe_expert_parallel_world_size()
|
@@ -121,18 +167,21 @@ class FusedMoE(torch.nn.Module):
|
|
121
167
|
self.moe_tp_rank = get_moe_tensor_parallel_rank()
|
122
168
|
assert num_experts % self.moe_ep_size == 0
|
123
169
|
self.num_local_experts = num_experts // self.moe_ep_size
|
124
|
-
if
|
170
|
+
if self.moe_ep_size > 1:
|
125
171
|
# TODO(ch-wan): support shared experts fusion
|
126
172
|
# Create a tensor of size num_experts filled with -1
|
127
|
-
self.expert_map_cpu = torch.full(
|
173
|
+
self.expert_map_cpu = torch.full(
|
174
|
+
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
175
|
+
)
|
176
|
+
self.expert_map_cpu = torch.full(
|
177
|
+
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
178
|
+
)
|
128
179
|
# Create a expert map for the local experts
|
129
180
|
self.expert_map_cpu[
|
130
181
|
self.moe_ep_rank
|
131
182
|
* self.num_local_experts : (self.moe_ep_rank + 1)
|
132
183
|
* self.num_local_experts
|
133
184
|
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
134
|
-
if not self.enable_flashinfer_cutlass_moe:
|
135
|
-
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
136
185
|
|
137
186
|
self.routed_scaling_factor = routed_scaling_factor
|
138
187
|
assert intermediate_size % self.moe_tp_size == 0
|
@@ -154,13 +203,19 @@ class FusedMoE(torch.nn.Module):
|
|
154
203
|
)
|
155
204
|
else:
|
156
205
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
157
|
-
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
|
158
|
-
self.quant_method.enable_flashinfer_cutlass_moe = (
|
159
|
-
self.enable_flashinfer_cutlass_moe
|
160
|
-
)
|
161
206
|
assert self.quant_method is not None
|
162
207
|
|
163
208
|
self.quant_config = quant_config
|
209
|
+
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
|
210
|
+
"enable_flashinfer_mxfp4_moe", False
|
211
|
+
)
|
212
|
+
if (
|
213
|
+
self.quant_config is not None
|
214
|
+
and self.quant_config.get_name() == "mxfp4"
|
215
|
+
and self.use_enable_flashinfer_mxfp4_moe
|
216
|
+
):
|
217
|
+
hidden_size = round_up(hidden_size, 256)
|
218
|
+
self.hidden_size = hidden_size
|
164
219
|
self.quant_method.create_weights(
|
165
220
|
layer=self,
|
166
221
|
num_experts=self.num_local_experts,
|
@@ -169,7 +224,12 @@ class FusedMoE(torch.nn.Module):
|
|
169
224
|
intermediate_size=self.intermediate_size_per_partition,
|
170
225
|
intermediate_size_per_partition=self.intermediate_size_per_partition,
|
171
226
|
params_dtype=params_dtype,
|
172
|
-
weight_loader=
|
227
|
+
weight_loader=(
|
228
|
+
self.weight_loader
|
229
|
+
if not use_weight_loader_fused
|
230
|
+
else self.weight_loader_fused
|
231
|
+
),
|
232
|
+
with_bias=with_bias,
|
173
233
|
)
|
174
234
|
|
175
235
|
def _load_per_tensor_weight_scale(
|
@@ -197,6 +257,7 @@ class FusedMoE(torch.nn.Module):
|
|
197
257
|
shard_id: str,
|
198
258
|
loaded_weight: torch.Tensor,
|
199
259
|
tp_rank: int,
|
260
|
+
is_bias: bool = False,
|
200
261
|
):
|
201
262
|
# Load grouped weight scales for group quantization
|
202
263
|
# or model weights
|
@@ -207,14 +268,16 @@ class FusedMoE(torch.nn.Module):
|
|
207
268
|
loaded_weight=loaded_weight,
|
208
269
|
expert_data=expert_data,
|
209
270
|
tp_rank=tp_rank,
|
271
|
+
is_bias=is_bias,
|
210
272
|
)
|
211
|
-
elif shard_id in ("w1", "w3"):
|
273
|
+
elif shard_id in ("w1", "w3", "w13"):
|
212
274
|
self._load_w13(
|
213
275
|
shard_id=shard_id,
|
214
276
|
shard_dim=shard_dim,
|
215
277
|
loaded_weight=loaded_weight,
|
216
278
|
expert_data=expert_data,
|
217
279
|
tp_rank=tp_rank,
|
280
|
+
is_bias=is_bias,
|
218
281
|
)
|
219
282
|
|
220
283
|
def _load_per_channel_weight_scale(
|
@@ -244,17 +307,30 @@ class FusedMoE(torch.nn.Module):
|
|
244
307
|
shard_id: str,
|
245
308
|
loaded_weight: torch.Tensor,
|
246
309
|
tp_rank: int,
|
310
|
+
is_bias: bool = False,
|
247
311
|
):
|
248
312
|
|
249
313
|
# Index the loaded weight for tp sharding.
|
250
314
|
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
251
|
-
|
315
|
+
assert shard_id in {"w1", "w3", "w13"}
|
316
|
+
|
317
|
+
if is_bias:
|
318
|
+
# if this weight is a bias, the last dimension must be the sharded dimension
|
319
|
+
shard_dim = -1
|
320
|
+
|
321
|
+
if shard_id in {"w1", "w3"}:
|
322
|
+
# non-fused version
|
323
|
+
shard_size = expert_data.shape[shard_dim] // 2
|
324
|
+
elif shard_id in {"w13"}:
|
325
|
+
# fused version
|
326
|
+
shard_size = expert_data.shape[shard_dim]
|
327
|
+
else:
|
328
|
+
raise NotImplementedError
|
252
329
|
|
253
330
|
# Narrow parameter and load.
|
254
331
|
# w1, gate_proj: Load into first logical weight of w13.
|
255
332
|
# w3, up_proj: Load into second logical weight of w13.
|
256
333
|
# trtllm cutlass kernel assumes differently
|
257
|
-
assert shard_id in ("w1", "w3")
|
258
334
|
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
|
259
335
|
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
|
260
336
|
start = shard_size
|
@@ -273,7 +349,8 @@ class FusedMoE(torch.nn.Module):
|
|
273
349
|
)
|
274
350
|
else:
|
275
351
|
if not self.use_presharded_weights:
|
276
|
-
if self.use_triton_kernels:
|
352
|
+
if not is_bias and self.use_triton_kernels:
|
353
|
+
# do not transpose for bias
|
277
354
|
loaded_weight = loaded_weight.transpose(-2, -1)
|
278
355
|
loaded_weight = loaded_weight.narrow(
|
279
356
|
shard_dim, shard_size * tp_rank, shard_size
|
@@ -289,6 +366,7 @@ class FusedMoE(torch.nn.Module):
|
|
289
366
|
shard_id: str,
|
290
367
|
loaded_weight: torch.Tensor,
|
291
368
|
tp_rank: int,
|
369
|
+
is_bias: bool = False,
|
292
370
|
):
|
293
371
|
"""Load w2 weights for down projection.
|
294
372
|
|
@@ -319,7 +397,14 @@ class FusedMoE(torch.nn.Module):
|
|
319
397
|
# Index the loaded weight for tp sharding.
|
320
398
|
# down_proj: "RowParallel" so tp sharding on input_dim
|
321
399
|
# Narrow parameter and load.
|
322
|
-
|
400
|
+
if is_bias:
|
401
|
+
# this expert_data is a bias, not weight,
|
402
|
+
# for w2_weight_bias in TP, it does not need to be sharded
|
403
|
+
shard_size = expert_data.shape[-1]
|
404
|
+
else:
|
405
|
+
# this parameter is a weight matrix
|
406
|
+
# for w2 in TP, it shards the input_features, i.e., shard_dim=2
|
407
|
+
shard_size = expert_data.shape[shard_dim]
|
323
408
|
|
324
409
|
if _is_cpu:
|
325
410
|
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
@@ -332,13 +417,9 @@ class FusedMoE(torch.nn.Module):
|
|
332
417
|
not self.use_presharded_weights,
|
333
418
|
)
|
334
419
|
else:
|
335
|
-
if not self.use_presharded_weights:
|
420
|
+
if not is_bias and not self.use_presharded_weights:
|
336
421
|
if self.use_triton_kernels:
|
337
422
|
loaded_weight = loaded_weight.transpose(-2, -1)
|
338
|
-
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
|
339
|
-
raise ValueError(
|
340
|
-
f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
|
341
|
-
)
|
342
423
|
loaded_weight = loaded_weight.narrow(
|
343
424
|
shard_dim, shard_size * tp_rank, shard_size
|
344
425
|
)
|
@@ -386,9 +467,25 @@ class FusedMoE(torch.nn.Module):
|
|
386
467
|
loaded_weight: torch.Tensor,
|
387
468
|
weight_name: str,
|
388
469
|
shard_id: str,
|
389
|
-
expert_id: int,
|
470
|
+
expert_id: Optional[int],
|
390
471
|
) -> None:
|
391
472
|
|
473
|
+
# if expert_id is None, then
|
474
|
+
# all the experts are loaded at the same time
|
475
|
+
if (
|
476
|
+
not expert_id
|
477
|
+
and self.quant_config is not None
|
478
|
+
and self.quant_config.get_name() == "mxfp4"
|
479
|
+
):
|
480
|
+
if "bias" in weight_name:
|
481
|
+
dim1 = loaded_weight.shape[1]
|
482
|
+
param.data[:, :dim1].copy_(loaded_weight)
|
483
|
+
else:
|
484
|
+
dim1 = loaded_weight.shape[1]
|
485
|
+
dim2 = loaded_weight.shape[2]
|
486
|
+
param.data[:, :dim1, :dim2].copy_(loaded_weight)
|
487
|
+
return
|
488
|
+
|
392
489
|
global_expert_location_metadata = get_global_expert_location_metadata()
|
393
490
|
if global_expert_location_metadata is None:
|
394
491
|
self._weight_loader_impl(
|
@@ -427,6 +524,7 @@ class FusedMoE(torch.nn.Module):
|
|
427
524
|
shard_id: str,
|
428
525
|
expert_id: int,
|
429
526
|
) -> None:
|
527
|
+
|
430
528
|
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
431
529
|
if expert_id == -1:
|
432
530
|
return
|
@@ -621,38 +719,137 @@ class FusedMoE(torch.nn.Module):
|
|
621
719
|
)
|
622
720
|
return
|
623
721
|
|
722
|
+
def weight_loader_fused(
|
723
|
+
self,
|
724
|
+
param: torch.nn.Parameter,
|
725
|
+
loaded_weight: torch.Tensor,
|
726
|
+
weight_name: str,
|
727
|
+
shard_id: str,
|
728
|
+
) -> None:
|
729
|
+
tp_rank = self.moe_tp_rank
|
730
|
+
|
731
|
+
if self.quant_config is not None and self.quant_config.get_name() == "mxfp4":
|
732
|
+
if "bias" in weight_name:
|
733
|
+
dim1 = loaded_weight.shape[1]
|
734
|
+
param.data[:, :dim1].copy_(loaded_weight)
|
735
|
+
elif "scale" in weight_name:
|
736
|
+
param.data.copy_(loaded_weight)
|
737
|
+
else:
|
738
|
+
dim1 = loaded_weight.shape[1]
|
739
|
+
dim2 = loaded_weight.shape[2]
|
740
|
+
param.data[:, :dim1, :dim2].copy_(loaded_weight)
|
741
|
+
return
|
742
|
+
|
743
|
+
# compressed-tensors checkpoints with packed weights are stored flipped
|
744
|
+
# TODO: check self.quant_method.quant_config.quant_format
|
745
|
+
# against known CompressionFormat enum values that have this quality
|
746
|
+
loaded_weight = (
|
747
|
+
loaded_weight.t().contiguous()
|
748
|
+
if (
|
749
|
+
self.quant_method.__class__.__name__
|
750
|
+
== "CompressedTensorsWNA16MoEMethod"
|
751
|
+
)
|
752
|
+
else loaded_weight
|
753
|
+
)
|
754
|
+
|
755
|
+
if shard_id not in ("w13", "w2"):
|
756
|
+
raise ValueError(f"shard_id must be ['w13','w2'] but " f"got {shard_id}.")
|
757
|
+
|
758
|
+
# Fetch the dim to shard the parameter/loaded weight
|
759
|
+
# based on the shard id. This will be whatever
|
760
|
+
# dimension intermediate_size is used.
|
761
|
+
SHARD_ID_TO_SHARDED_DIM = {"w13": 1, "w2": 2}
|
762
|
+
SHARD_ID_TO_SHARDED_DIM_TRANSPOSE = {"w13": 2, "w2": 1}
|
763
|
+
|
764
|
+
expert_data = param.data
|
765
|
+
is_bias = expert_data.dim() == 2
|
766
|
+
|
767
|
+
# is_transposed: if the dim to shard the weight
|
768
|
+
# should be flipped. Required by GPTQ, compressed-tensors
|
769
|
+
# should be whatever dimension intermediate_size is
|
770
|
+
is_transposed = getattr(param, "is_transposed", False)
|
771
|
+
|
772
|
+
if self.use_triton_kernels:
|
773
|
+
is_transposed = True
|
774
|
+
shard_dim = (
|
775
|
+
SHARD_ID_TO_SHARDED_DIM[shard_id]
|
776
|
+
if not is_transposed
|
777
|
+
else SHARD_ID_TO_SHARDED_DIM_TRANSPOSE[shard_id]
|
778
|
+
)
|
779
|
+
|
780
|
+
# Case model weights
|
781
|
+
if "weight" in weight_name:
|
782
|
+
self._load_model_weight_or_group_weight_scale(
|
783
|
+
shard_id=shard_id,
|
784
|
+
shard_dim=shard_dim,
|
785
|
+
loaded_weight=loaded_weight,
|
786
|
+
expert_data=expert_data,
|
787
|
+
tp_rank=tp_rank,
|
788
|
+
is_bias=is_bias,
|
789
|
+
)
|
790
|
+
return
|
791
|
+
else:
|
792
|
+
logging.warning(
|
793
|
+
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
|
794
|
+
)
|
795
|
+
|
624
796
|
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
797
|
+
origin_hidden_states_dim = hidden_states.shape[-1]
|
798
|
+
if self.hidden_size != origin_hidden_states_dim:
|
799
|
+
hidden_states = torch.nn.functional.pad(
|
800
|
+
hidden_states,
|
801
|
+
(0, self.hidden_size - origin_hidden_states_dim),
|
802
|
+
mode="constant",
|
803
|
+
value=0.0,
|
804
|
+
)
|
625
805
|
assert self.quant_method is not None
|
626
806
|
|
627
|
-
if self.
|
807
|
+
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
|
808
|
+
if self.expert_map_cpu is not None and self.expert_map_gpu is None:
|
809
|
+
# If we are in EP mode, we need to move the expert map to GPU.
|
810
|
+
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
811
|
+
|
812
|
+
if self.expert_map_gpu is not None and isinstance(
|
813
|
+
topk_output, StandardTopKOutput
|
814
|
+
):
|
628
815
|
topk_output = topk_output._replace(
|
629
816
|
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
630
817
|
)
|
631
818
|
|
632
819
|
# Matrix multiply.
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
820
|
+
with use_symmetric_memory(get_tp_group()) as sm:
|
821
|
+
kwargs = {}
|
822
|
+
if self.activation_alpha is not None:
|
823
|
+
kwargs["activation_alpha"] = self.activation_alpha
|
824
|
+
if self.swiglu_limit is not None:
|
825
|
+
kwargs["swiglu_limit"] = self.swiglu_limit
|
826
|
+
|
827
|
+
final_hidden_states = self.quant_method.apply(
|
828
|
+
layer=self,
|
829
|
+
x=hidden_states,
|
830
|
+
topk_output=topk_output,
|
831
|
+
activation=self.activation,
|
832
|
+
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
833
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
834
|
+
**(
|
835
|
+
dict(
|
836
|
+
tp_rank=self.moe_tp_rank,
|
837
|
+
tp_size=self.moe_tp_size,
|
838
|
+
ep_rank=self.moe_ep_rank,
|
839
|
+
ep_size=self.moe_ep_size,
|
840
|
+
)
|
841
|
+
if self.quant_method.__class__.__name__
|
842
|
+
== "ModelOptNvFp4FusedMoEMethod"
|
843
|
+
else {}
|
844
|
+
),
|
845
|
+
**kwargs,
|
846
|
+
)
|
847
|
+
sm.tag(final_hidden_states)
|
651
848
|
|
652
849
|
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
653
850
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
654
851
|
|
655
|
-
return final_hidden_states
|
852
|
+
return final_hidden_states[..., :origin_hidden_states_dim].contiguous()
|
656
853
|
|
657
854
|
@classmethod
|
658
855
|
def make_expert_params_mapping(
|
@@ -683,6 +880,52 @@ class FusedMoE(torch.nn.Module):
|
|
683
880
|
]
|
684
881
|
]
|
685
882
|
|
883
|
+
@classmethod
|
884
|
+
def make_expert_params_mapping_fused(
|
885
|
+
cls,
|
886
|
+
ckpt_gate_up_proj_name: str,
|
887
|
+
ckpt_down_proj_name: str,
|
888
|
+
ckpt_gate_up_proj_bias_name: str,
|
889
|
+
ckpt_down_proj_bias_name: str,
|
890
|
+
):
|
891
|
+
return [
|
892
|
+
("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
|
893
|
+
(
|
894
|
+
"experts.w13_weight_bias",
|
895
|
+
f"experts.{ckpt_gate_up_proj_bias_name}",
|
896
|
+
"w13",
|
897
|
+
),
|
898
|
+
("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
|
899
|
+
("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
|
900
|
+
]
|
901
|
+
|
902
|
+
@classmethod
|
903
|
+
def make_expert_params_mapping_fused_mxfp4(
|
904
|
+
cls,
|
905
|
+
ckpt_gate_up_proj_name: str,
|
906
|
+
ckpt_down_proj_name: str,
|
907
|
+
ckpt_gate_up_proj_bias_name: str,
|
908
|
+
ckpt_down_proj_bias_name: str,
|
909
|
+
ckpt_gate_up_proj_scale_name: str,
|
910
|
+
ckpt_down_proj_scale_name: str,
|
911
|
+
):
|
912
|
+
return [
|
913
|
+
("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
|
914
|
+
(
|
915
|
+
"experts.w13_weight_bias",
|
916
|
+
f"experts.{ckpt_gate_up_proj_bias_name}",
|
917
|
+
"w13",
|
918
|
+
),
|
919
|
+
("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
|
920
|
+
("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
|
921
|
+
(
|
922
|
+
"experts.w13_weight_scale",
|
923
|
+
f"experts.{ckpt_gate_up_proj_scale_name}",
|
924
|
+
"w13",
|
925
|
+
),
|
926
|
+
("experts.w2_weight_scale", f"experts.{ckpt_down_proj_scale_name}", "w2"),
|
927
|
+
]
|
928
|
+
|
686
929
|
@classmethod
|
687
930
|
def make_expert_input_scale_params_mapping(
|
688
931
|
cls,
|
@@ -718,8 +961,13 @@ class FlashInferFusedMoE(FusedMoE):
|
|
718
961
|
self.num_expert_group = num_expert_group
|
719
962
|
self.topk_group = topk_group
|
720
963
|
self.correction_bias = correction_bias
|
964
|
+
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
721
965
|
|
722
|
-
def forward(self, hidden_states: torch.Tensor,
|
966
|
+
def forward(self, hidden_states: torch.Tensor, topk_output: tuple):
|
967
|
+
assert self.use_flashinfer_trtllm_moe
|
968
|
+
assert (
|
969
|
+
self.activation == "silu"
|
970
|
+
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
723
971
|
assert self.quant_method is not None
|
724
972
|
assert (
|
725
973
|
self.renormalize
|
@@ -727,6 +975,14 @@ class FlashInferFusedMoE(FusedMoE):
|
|
727
975
|
assert (
|
728
976
|
self.num_fused_shared_experts == 0
|
729
977
|
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
978
|
+
|
979
|
+
# TRTLLM mode expects (TopK_config, router_logits) tuple
|
980
|
+
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
|
981
|
+
raise ValueError(
|
982
|
+
f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
|
983
|
+
)
|
984
|
+
_, router_logits = topk_output
|
985
|
+
|
730
986
|
# Matrix multiply.
|
731
987
|
final_hidden_states = self.quant_method.apply_with_router_logits(
|
732
988
|
layer=self,
|
@@ -736,7 +992,135 @@ class FlashInferFusedMoE(FusedMoE):
|
|
736
992
|
routed_scaling_factor=self.routed_scaling_factor,
|
737
993
|
)
|
738
994
|
|
739
|
-
if self.reduce_results and (self.
|
995
|
+
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
740
996
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
741
997
|
|
742
998
|
return final_hidden_states
|
999
|
+
|
1000
|
+
|
1001
|
+
class FlashInferFP4MoE(FusedMoE):
|
1002
|
+
"""FP4 TRTLLM MoE implementation using FlashInfer."""
|
1003
|
+
|
1004
|
+
def __init__(self, *args, **kwargs):
|
1005
|
+
# Extract DeepSeek-specific parameters
|
1006
|
+
renormalize = kwargs.pop("renormalize", True)
|
1007
|
+
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
1008
|
+
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
1009
|
+
num_expert_group = kwargs.pop("num_expert_group", None)
|
1010
|
+
topk_group = kwargs.pop("topk_group", None)
|
1011
|
+
correction_bias = kwargs.pop("correction_bias", None)
|
1012
|
+
|
1013
|
+
# Extract additional TopK parameters that were previously extracted in forward
|
1014
|
+
routed_scaling_factor = kwargs.pop("routed_scaling_factor", None)
|
1015
|
+
|
1016
|
+
super().__init__(*args, **kwargs)
|
1017
|
+
|
1018
|
+
# Store DeepSeek parameters
|
1019
|
+
self.renormalize = renormalize
|
1020
|
+
self.num_fused_shared_experts = num_fused_shared_experts
|
1021
|
+
self.use_grouped_topk = use_grouped_topk
|
1022
|
+
self.num_expert_group = num_expert_group
|
1023
|
+
self.topk_group = topk_group
|
1024
|
+
self.correction_bias = correction_bias
|
1025
|
+
self.routed_scaling_factor = routed_scaling_factor
|
1026
|
+
|
1027
|
+
# ---------------------------------------------------------------------
|
1028
|
+
# Helper: quantize hidden states to FP4 each forward pass
|
1029
|
+
# ---------------------------------------------------------------------
|
1030
|
+
def _quantize_hidden_states_fp4(self, hidden_states: torch.Tensor):
|
1031
|
+
"""
|
1032
|
+
Quantize hidden states using global scale factor from quantization method.
|
1033
|
+
|
1034
|
+
Global scale factor is set by ModelOptNvFp4FusedMoEMethod during weight loading.
|
1035
|
+
Only block scales are computed at runtime for efficiency.
|
1036
|
+
|
1037
|
+
Returns (packed_fp4_uint8, scale_float8_e4m3fn_runtime, global_scale_float32)
|
1038
|
+
"""
|
1039
|
+
|
1040
|
+
# flashinfer.fp4_quantize returns (packed_uint8, scale_fp8)
|
1041
|
+
# Only the block scales are computed at runtime
|
1042
|
+
hs_fp4_bytes, hs_sf_bytes = fp4_quantize(
|
1043
|
+
hidden_states,
|
1044
|
+
self.w13_input_scale_quant,
|
1045
|
+
16, # sf_vec_size
|
1046
|
+
False, # use_ue8m0
|
1047
|
+
False, # is_sf_swizzled_layout
|
1048
|
+
)
|
1049
|
+
|
1050
|
+
hs_fp4 = hs_fp4_bytes.reshape(
|
1051
|
+
hidden_states.shape[0], hidden_states.shape[1] // 2
|
1052
|
+
)
|
1053
|
+
hs_sf = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1)
|
1054
|
+
|
1055
|
+
return hs_fp4, hs_sf
|
1056
|
+
|
1057
|
+
def forward(self, hidden_states: torch.Tensor, topk_output):
|
1058
|
+
"""Forward pass using FP4 TRTLLM kernel.
|
1059
|
+
|
1060
|
+
Args:
|
1061
|
+
hidden_states: Input tensor
|
1062
|
+
topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode
|
1063
|
+
"""
|
1064
|
+
|
1065
|
+
# TRTLLM mode expects (TopK_config, router_logits) tuple
|
1066
|
+
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
|
1067
|
+
raise ValueError(
|
1068
|
+
f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
|
1069
|
+
)
|
1070
|
+
|
1071
|
+
_, router_logits = topk_output
|
1072
|
+
|
1073
|
+
hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
|
1074
|
+
|
1075
|
+
router_logits = router_logits.to(torch.float32)
|
1076
|
+
|
1077
|
+
result = trtllm_fp4_block_scale_moe(
|
1078
|
+
routing_logits=router_logits,
|
1079
|
+
routing_bias=self.correction_bias.to(hidden_states.dtype),
|
1080
|
+
hidden_states=hs_fp4,
|
1081
|
+
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
|
1082
|
+
gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
|
1083
|
+
gemm1_weights_scale=self.gemm1_scales_fp4_shuffled.data.view(
|
1084
|
+
torch.float8_e4m3fn
|
1085
|
+
),
|
1086
|
+
gemm1_bias=None,
|
1087
|
+
gemm1_alpha=None,
|
1088
|
+
gemm1_beta=None,
|
1089
|
+
gemm1_clamp_limit=None,
|
1090
|
+
gemm2_weights=self.gemm2_weights_fp4_shuffled.data,
|
1091
|
+
gemm2_weights_scale=self.gemm2_scales_fp4_shuffled.data.view(
|
1092
|
+
torch.float8_e4m3fn
|
1093
|
+
),
|
1094
|
+
gemm2_bias=None,
|
1095
|
+
output1_scale_scalar=self.g1_scale_c.data,
|
1096
|
+
output1_scale_gate_scalar=self.g1_alphas.data,
|
1097
|
+
output2_scale_scalar=self.g2_alphas.data,
|
1098
|
+
num_experts=self.num_experts,
|
1099
|
+
top_k=self.top_k,
|
1100
|
+
n_group=self.num_expert_group,
|
1101
|
+
topk_group=self.topk_group,
|
1102
|
+
intermediate_size=self.intermediate_size_per_partition,
|
1103
|
+
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
|
1104
|
+
local_num_experts=self.num_local_experts,
|
1105
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
1106
|
+
tile_tokens_dim=_get_tile_tokens_dim(
|
1107
|
+
hidden_states.shape[0], self.top_k, self.num_local_experts
|
1108
|
+
),
|
1109
|
+
routing_method_type=RoutingMethodType.DeepSeekV3,
|
1110
|
+
do_finalize=True,
|
1111
|
+
)[0]
|
1112
|
+
|
1113
|
+
return result
|
1114
|
+
|
1115
|
+
|
1116
|
+
def get_fused_moe_impl_class():
|
1117
|
+
"""Factory function to get the appropriate FusedMoE implementation class."""
|
1118
|
+
if should_use_flashinfer_trtllm_moe() and _is_fp4_quantization_enabled():
|
1119
|
+
# Use FP4 variant when FP4 quantization is enabled
|
1120
|
+
return FlashInferFP4MoE
|
1121
|
+
elif should_use_flashinfer_trtllm_moe():
|
1122
|
+
# Use regular FlashInfer variant for non-FP4 FlashInfer cases
|
1123
|
+
return FlashInferFusedMoE
|
1124
|
+
else:
|
1125
|
+
# Default case
|
1126
|
+
return FusedMoE
|