sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,14 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
2
2
|
|
3
|
-
import
|
3
|
+
import datetime
|
4
|
+
import glob
|
4
5
|
import logging
|
6
|
+
import os
|
7
|
+
import sys
|
5
8
|
from enum import Enum
|
6
|
-
from functools import lru_cache
|
7
9
|
from typing import List, Optional, Tuple
|
8
10
|
|
9
11
|
import torch
|
10
|
-
from packaging import version as pkg_version
|
11
12
|
|
12
13
|
from sglang.srt.distributed import (
|
13
14
|
get_moe_expert_parallel_rank,
|
@@ -22,6 +23,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
|
22
23
|
)
|
23
24
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
24
25
|
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
26
|
+
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
25
27
|
from sglang.srt.layers.quantization.base_config import (
|
26
28
|
QuantizationConfig,
|
27
29
|
QuantizeMethodBase,
|
@@ -29,22 +31,59 @@ from sglang.srt.layers.quantization.base_config import (
|
|
29
31
|
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
30
32
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
31
33
|
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
32
|
-
from sglang.srt.utils import
|
34
|
+
from sglang.srt.utils import (
|
35
|
+
cpu_has_amx_support,
|
36
|
+
get_bool_env_var,
|
37
|
+
is_cpu,
|
38
|
+
is_flashinfer_available,
|
39
|
+
is_hip,
|
40
|
+
next_power_of_2,
|
41
|
+
round_up,
|
42
|
+
)
|
43
|
+
|
44
|
+
if is_flashinfer_available():
|
45
|
+
from flashinfer import (
|
46
|
+
RoutingMethodType,
|
47
|
+
fp4_quantize,
|
48
|
+
reorder_rows_for_gated_act_gemm,
|
49
|
+
shuffle_matrix_a,
|
50
|
+
shuffle_matrix_sf_a,
|
51
|
+
)
|
33
52
|
|
34
53
|
_is_hip = is_hip()
|
35
54
|
_is_cpu_amx_available = cpu_has_amx_support()
|
36
55
|
_is_cpu = is_cpu()
|
37
56
|
|
57
|
+
|
58
|
+
# Try to import FP4 TRTLLM function if flashinfer is available
|
59
|
+
trtllm_fp4_block_scale_moe = None
|
60
|
+
if should_use_flashinfer_trtllm_moe():
|
61
|
+
try:
|
62
|
+
from flashinfer.fused_moe import trtllm_fp4_block_scale_moe
|
63
|
+
except ImportError:
|
64
|
+
trtllm_fp4_block_scale_moe = None
|
65
|
+
|
38
66
|
logger = logging.getLogger(__name__)
|
39
67
|
|
40
68
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
69
|
+
def _is_fp4_quantization_enabled():
|
70
|
+
"""Check if ModelOpt FP4 quantization is enabled."""
|
71
|
+
try:
|
72
|
+
# Use the same simple check that works for class selection
|
73
|
+
quantization = global_server_args_dict.get("quantization")
|
74
|
+
return quantization == "modelopt_fp4"
|
75
|
+
except:
|
76
|
+
return False
|
77
|
+
|
78
|
+
|
79
|
+
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
80
|
+
# Guess tokens per expert assuming perfect expert distribution first.
|
81
|
+
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
82
|
+
# And pad the number to the next power of 2.
|
83
|
+
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
84
|
+
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
85
|
+
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
86
|
+
return tile_tokens_dim
|
48
87
|
|
49
88
|
|
50
89
|
class FusedMoeWeightScaleSupported(Enum):
|
@@ -96,6 +135,10 @@ class FusedMoE(torch.nn.Module):
|
|
96
135
|
no_combine: bool = False,
|
97
136
|
routed_scaling_factor: Optional[float] = None,
|
98
137
|
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
138
|
+
activation_alpha: Optional[float] = None,
|
139
|
+
swiglu_limit: Optional[float] = None,
|
140
|
+
use_weight_loader_fused: bool = False,
|
141
|
+
with_bias=False,
|
99
142
|
):
|
100
143
|
super().__init__()
|
101
144
|
|
@@ -110,6 +153,10 @@ class FusedMoE(torch.nn.Module):
|
|
110
153
|
self.expert_map_cpu = None
|
111
154
|
self.expert_map_gpu = None
|
112
155
|
|
156
|
+
# For activation
|
157
|
+
self.activation_alpha = activation_alpha
|
158
|
+
self.swiglu_limit = swiglu_limit
|
159
|
+
|
113
160
|
if enable_flashinfer_cutlass_moe and quant_config is None:
|
114
161
|
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
115
162
|
enable_flashinfer_cutlass_moe = False
|
@@ -124,15 +171,18 @@ class FusedMoE(torch.nn.Module):
|
|
124
171
|
if self.moe_ep_size > 1:
|
125
172
|
# TODO(ch-wan): support shared experts fusion
|
126
173
|
# Create a tensor of size num_experts filled with -1
|
127
|
-
self.expert_map_cpu = torch.full(
|
174
|
+
self.expert_map_cpu = torch.full(
|
175
|
+
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
176
|
+
)
|
177
|
+
self.expert_map_cpu = torch.full(
|
178
|
+
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
179
|
+
)
|
128
180
|
# Create a expert map for the local experts
|
129
181
|
self.expert_map_cpu[
|
130
182
|
self.moe_ep_rank
|
131
183
|
* self.num_local_experts : (self.moe_ep_rank + 1)
|
132
184
|
* self.num_local_experts
|
133
185
|
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
134
|
-
if not self.enable_flashinfer_cutlass_moe:
|
135
|
-
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
136
186
|
|
137
187
|
self.routed_scaling_factor = routed_scaling_factor
|
138
188
|
assert intermediate_size % self.moe_tp_size == 0
|
@@ -154,13 +204,19 @@ class FusedMoE(torch.nn.Module):
|
|
154
204
|
)
|
155
205
|
else:
|
156
206
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
157
|
-
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
|
158
|
-
self.quant_method.enable_flashinfer_cutlass_moe = (
|
159
|
-
self.enable_flashinfer_cutlass_moe
|
160
|
-
)
|
161
207
|
assert self.quant_method is not None
|
162
208
|
|
163
209
|
self.quant_config = quant_config
|
210
|
+
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
|
211
|
+
"enable_flashinfer_mxfp4_moe", False
|
212
|
+
)
|
213
|
+
# TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
|
214
|
+
if (
|
215
|
+
self.quant_config is not None
|
216
|
+
and self.quant_config.get_name() == "mxfp4"
|
217
|
+
and self.use_enable_flashinfer_mxfp4_moe
|
218
|
+
):
|
219
|
+
hidden_size = round_up(hidden_size, 256)
|
164
220
|
self.quant_method.create_weights(
|
165
221
|
layer=self,
|
166
222
|
num_experts=self.num_local_experts,
|
@@ -169,7 +225,12 @@ class FusedMoE(torch.nn.Module):
|
|
169
225
|
intermediate_size=self.intermediate_size_per_partition,
|
170
226
|
intermediate_size_per_partition=self.intermediate_size_per_partition,
|
171
227
|
params_dtype=params_dtype,
|
172
|
-
weight_loader=
|
228
|
+
weight_loader=(
|
229
|
+
self.weight_loader
|
230
|
+
if not use_weight_loader_fused
|
231
|
+
else self.weight_loader_fused
|
232
|
+
),
|
233
|
+
with_bias=with_bias,
|
173
234
|
)
|
174
235
|
|
175
236
|
def _load_per_tensor_weight_scale(
|
@@ -197,6 +258,7 @@ class FusedMoE(torch.nn.Module):
|
|
197
258
|
shard_id: str,
|
198
259
|
loaded_weight: torch.Tensor,
|
199
260
|
tp_rank: int,
|
261
|
+
is_bias: bool = False,
|
200
262
|
):
|
201
263
|
# Load grouped weight scales for group quantization
|
202
264
|
# or model weights
|
@@ -207,14 +269,16 @@ class FusedMoE(torch.nn.Module):
|
|
207
269
|
loaded_weight=loaded_weight,
|
208
270
|
expert_data=expert_data,
|
209
271
|
tp_rank=tp_rank,
|
272
|
+
is_bias=is_bias,
|
210
273
|
)
|
211
|
-
elif shard_id in ("w1", "w3"):
|
274
|
+
elif shard_id in ("w1", "w3", "w13"):
|
212
275
|
self._load_w13(
|
213
276
|
shard_id=shard_id,
|
214
277
|
shard_dim=shard_dim,
|
215
278
|
loaded_weight=loaded_weight,
|
216
279
|
expert_data=expert_data,
|
217
280
|
tp_rank=tp_rank,
|
281
|
+
is_bias=is_bias,
|
218
282
|
)
|
219
283
|
|
220
284
|
def _load_per_channel_weight_scale(
|
@@ -244,17 +308,30 @@ class FusedMoE(torch.nn.Module):
|
|
244
308
|
shard_id: str,
|
245
309
|
loaded_weight: torch.Tensor,
|
246
310
|
tp_rank: int,
|
311
|
+
is_bias: bool = False,
|
247
312
|
):
|
248
313
|
|
249
314
|
# Index the loaded weight for tp sharding.
|
250
315
|
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
251
|
-
|
316
|
+
assert shard_id in {"w1", "w3", "w13"}
|
317
|
+
|
318
|
+
if is_bias:
|
319
|
+
# if this weight is a bias, the last dimension must be the sharded dimension
|
320
|
+
shard_dim = -1
|
321
|
+
|
322
|
+
if shard_id in {"w1", "w3"}:
|
323
|
+
# non-fused version
|
324
|
+
shard_size = expert_data.shape[shard_dim] // 2
|
325
|
+
elif shard_id in {"w13"}:
|
326
|
+
# fused version
|
327
|
+
shard_size = expert_data.shape[shard_dim]
|
328
|
+
else:
|
329
|
+
raise NotImplementedError
|
252
330
|
|
253
331
|
# Narrow parameter and load.
|
254
332
|
# w1, gate_proj: Load into first logical weight of w13.
|
255
333
|
# w3, up_proj: Load into second logical weight of w13.
|
256
334
|
# trtllm cutlass kernel assumes differently
|
257
|
-
assert shard_id in ("w1", "w3")
|
258
335
|
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
|
259
336
|
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
|
260
337
|
start = shard_size
|
@@ -273,7 +350,8 @@ class FusedMoE(torch.nn.Module):
|
|
273
350
|
)
|
274
351
|
else:
|
275
352
|
if not self.use_presharded_weights:
|
276
|
-
if self.use_triton_kernels:
|
353
|
+
if not is_bias and self.use_triton_kernels:
|
354
|
+
# do not transpose for bias
|
277
355
|
loaded_weight = loaded_weight.transpose(-2, -1)
|
278
356
|
loaded_weight = loaded_weight.narrow(
|
279
357
|
shard_dim, shard_size * tp_rank, shard_size
|
@@ -289,6 +367,7 @@ class FusedMoE(torch.nn.Module):
|
|
289
367
|
shard_id: str,
|
290
368
|
loaded_weight: torch.Tensor,
|
291
369
|
tp_rank: int,
|
370
|
+
is_bias: bool = False,
|
292
371
|
):
|
293
372
|
"""Load w2 weights for down projection.
|
294
373
|
|
@@ -319,7 +398,14 @@ class FusedMoE(torch.nn.Module):
|
|
319
398
|
# Index the loaded weight for tp sharding.
|
320
399
|
# down_proj: "RowParallel" so tp sharding on input_dim
|
321
400
|
# Narrow parameter and load.
|
322
|
-
|
401
|
+
if is_bias:
|
402
|
+
# this expert_data is a bias, not weight,
|
403
|
+
# for w2_weight_bias in TP, it does not need to be sharded
|
404
|
+
shard_size = expert_data.shape[-1]
|
405
|
+
else:
|
406
|
+
# this parameter is a weight matrix
|
407
|
+
# for w2 in TP, it shards the input_features, i.e., shard_dim=2
|
408
|
+
shard_size = expert_data.shape[shard_dim]
|
323
409
|
|
324
410
|
if _is_cpu:
|
325
411
|
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
@@ -332,13 +418,9 @@ class FusedMoE(torch.nn.Module):
|
|
332
418
|
not self.use_presharded_weights,
|
333
419
|
)
|
334
420
|
else:
|
335
|
-
if not self.use_presharded_weights:
|
421
|
+
if not is_bias and not self.use_presharded_weights:
|
336
422
|
if self.use_triton_kernels:
|
337
423
|
loaded_weight = loaded_weight.transpose(-2, -1)
|
338
|
-
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
|
339
|
-
raise ValueError(
|
340
|
-
f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
|
341
|
-
)
|
342
424
|
loaded_weight = loaded_weight.narrow(
|
343
425
|
shard_dim, shard_size * tp_rank, shard_size
|
344
426
|
)
|
@@ -386,9 +468,25 @@ class FusedMoE(torch.nn.Module):
|
|
386
468
|
loaded_weight: torch.Tensor,
|
387
469
|
weight_name: str,
|
388
470
|
shard_id: str,
|
389
|
-
expert_id: int,
|
471
|
+
expert_id: Optional[int],
|
390
472
|
) -> None:
|
391
473
|
|
474
|
+
# if expert_id is None, then
|
475
|
+
# all the experts are loaded at the same time
|
476
|
+
if (
|
477
|
+
not expert_id
|
478
|
+
and self.quant_config is not None
|
479
|
+
and self.quant_config.get_name() == "mxfp4"
|
480
|
+
):
|
481
|
+
if "bias" in weight_name:
|
482
|
+
dim1 = loaded_weight.shape[1]
|
483
|
+
param.data[:, :dim1].copy_(loaded_weight)
|
484
|
+
else:
|
485
|
+
dim1 = loaded_weight.shape[1]
|
486
|
+
dim2 = loaded_weight.shape[2]
|
487
|
+
param.data[:, :dim1, :dim2].copy_(loaded_weight)
|
488
|
+
return
|
489
|
+
|
392
490
|
global_expert_location_metadata = get_global_expert_location_metadata()
|
393
491
|
if global_expert_location_metadata is None:
|
394
492
|
self._weight_loader_impl(
|
@@ -427,6 +525,7 @@ class FusedMoE(torch.nn.Module):
|
|
427
525
|
shard_id: str,
|
428
526
|
expert_id: int,
|
429
527
|
) -> None:
|
528
|
+
|
430
529
|
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
431
530
|
if expert_id == -1:
|
432
531
|
return
|
@@ -621,16 +720,104 @@ class FusedMoE(torch.nn.Module):
|
|
621
720
|
)
|
622
721
|
return
|
623
722
|
|
723
|
+
def weight_loader_fused(
|
724
|
+
self,
|
725
|
+
param: torch.nn.Parameter,
|
726
|
+
loaded_weight: torch.Tensor,
|
727
|
+
weight_name: str,
|
728
|
+
shard_id: str,
|
729
|
+
) -> None:
|
730
|
+
tp_rank = self.moe_tp_rank
|
731
|
+
|
732
|
+
if self.quant_config is not None and self.quant_config.get_name() == "mxfp4":
|
733
|
+
if "bias" in weight_name:
|
734
|
+
dim1 = loaded_weight.shape[1]
|
735
|
+
param.data[:, :dim1].copy_(loaded_weight)
|
736
|
+
elif "scale" in weight_name:
|
737
|
+
param.data.copy_(loaded_weight)
|
738
|
+
else:
|
739
|
+
dim1 = loaded_weight.shape[1]
|
740
|
+
dim2 = loaded_weight.shape[2]
|
741
|
+
param.data[:, :dim1, :dim2].copy_(loaded_weight)
|
742
|
+
return
|
743
|
+
|
744
|
+
# compressed-tensors checkpoints with packed weights are stored flipped
|
745
|
+
# TODO: check self.quant_method.quant_config.quant_format
|
746
|
+
# against known CompressionFormat enum values that have this quality
|
747
|
+
loaded_weight = (
|
748
|
+
loaded_weight.t().contiguous()
|
749
|
+
if (
|
750
|
+
self.quant_method.__class__.__name__
|
751
|
+
== "CompressedTensorsWNA16MoEMethod"
|
752
|
+
)
|
753
|
+
else loaded_weight
|
754
|
+
)
|
755
|
+
|
756
|
+
if shard_id not in ("w13", "w2"):
|
757
|
+
raise ValueError(f"shard_id must be ['w13','w2'] but " f"got {shard_id}.")
|
758
|
+
|
759
|
+
# Fetch the dim to shard the parameter/loaded weight
|
760
|
+
# based on the shard id. This will be whatever
|
761
|
+
# dimension intermediate_size is used.
|
762
|
+
SHARD_ID_TO_SHARDED_DIM = {"w13": 1, "w2": 2}
|
763
|
+
SHARD_ID_TO_SHARDED_DIM_TRANSPOSE = {"w13": 2, "w2": 1}
|
764
|
+
|
765
|
+
expert_data = param.data
|
766
|
+
is_bias = expert_data.dim() == 2
|
767
|
+
|
768
|
+
# is_transposed: if the dim to shard the weight
|
769
|
+
# should be flipped. Required by GPTQ, compressed-tensors
|
770
|
+
# should be whatever dimension intermediate_size is
|
771
|
+
is_transposed = getattr(param, "is_transposed", False)
|
772
|
+
|
773
|
+
if self.use_triton_kernels:
|
774
|
+
is_transposed = True
|
775
|
+
shard_dim = (
|
776
|
+
SHARD_ID_TO_SHARDED_DIM[shard_id]
|
777
|
+
if not is_transposed
|
778
|
+
else SHARD_ID_TO_SHARDED_DIM_TRANSPOSE[shard_id]
|
779
|
+
)
|
780
|
+
|
781
|
+
# Case model weights
|
782
|
+
if "weight" in weight_name:
|
783
|
+
self._load_model_weight_or_group_weight_scale(
|
784
|
+
shard_id=shard_id,
|
785
|
+
shard_dim=shard_dim,
|
786
|
+
loaded_weight=loaded_weight,
|
787
|
+
expert_data=expert_data,
|
788
|
+
tp_rank=tp_rank,
|
789
|
+
is_bias=is_bias,
|
790
|
+
)
|
791
|
+
return
|
792
|
+
else:
|
793
|
+
logging.warning(
|
794
|
+
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
|
795
|
+
)
|
796
|
+
|
624
797
|
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
798
|
+
origin_hidden_states_dim = hidden_states.shape[-1]
|
625
799
|
assert self.quant_method is not None
|
626
800
|
|
627
|
-
if self.
|
801
|
+
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
|
802
|
+
if self.expert_map_cpu is not None and self.expert_map_gpu is None:
|
803
|
+
# If we are in EP mode, we need to move the expert map to GPU.
|
804
|
+
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
805
|
+
|
806
|
+
if self.expert_map_gpu is not None and isinstance(
|
807
|
+
topk_output, StandardTopKOutput
|
808
|
+
):
|
628
809
|
topk_output = topk_output._replace(
|
629
810
|
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
630
811
|
)
|
631
812
|
|
632
813
|
# Matrix multiply.
|
633
814
|
with use_symmetric_memory(get_tp_group()) as sm:
|
815
|
+
kwargs = {}
|
816
|
+
if self.activation_alpha is not None:
|
817
|
+
kwargs["activation_alpha"] = self.activation_alpha
|
818
|
+
if self.swiglu_limit is not None:
|
819
|
+
kwargs["swiglu_limit"] = self.swiglu_limit
|
820
|
+
|
634
821
|
final_hidden_states = self.quant_method.apply(
|
635
822
|
layer=self,
|
636
823
|
x=hidden_states,
|
@@ -649,9 +836,14 @@ class FusedMoE(torch.nn.Module):
|
|
649
836
|
== "ModelOptNvFp4FusedMoEMethod"
|
650
837
|
else {}
|
651
838
|
),
|
839
|
+
**kwargs,
|
652
840
|
)
|
653
841
|
sm.tag(final_hidden_states)
|
654
842
|
|
843
|
+
final_hidden_states = final_hidden_states[
|
844
|
+
..., :origin_hidden_states_dim
|
845
|
+
].contiguous()
|
846
|
+
|
655
847
|
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
656
848
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
657
849
|
|
@@ -686,6 +878,52 @@ class FusedMoE(torch.nn.Module):
|
|
686
878
|
]
|
687
879
|
]
|
688
880
|
|
881
|
+
@classmethod
|
882
|
+
def make_expert_params_mapping_fused(
|
883
|
+
cls,
|
884
|
+
ckpt_gate_up_proj_name: str,
|
885
|
+
ckpt_down_proj_name: str,
|
886
|
+
ckpt_gate_up_proj_bias_name: str,
|
887
|
+
ckpt_down_proj_bias_name: str,
|
888
|
+
):
|
889
|
+
return [
|
890
|
+
("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
|
891
|
+
(
|
892
|
+
"experts.w13_weight_bias",
|
893
|
+
f"experts.{ckpt_gate_up_proj_bias_name}",
|
894
|
+
"w13",
|
895
|
+
),
|
896
|
+
("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
|
897
|
+
("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
|
898
|
+
]
|
899
|
+
|
900
|
+
@classmethod
|
901
|
+
def make_expert_params_mapping_fused_mxfp4(
|
902
|
+
cls,
|
903
|
+
ckpt_gate_up_proj_name: str,
|
904
|
+
ckpt_down_proj_name: str,
|
905
|
+
ckpt_gate_up_proj_bias_name: str,
|
906
|
+
ckpt_down_proj_bias_name: str,
|
907
|
+
ckpt_gate_up_proj_scale_name: str,
|
908
|
+
ckpt_down_proj_scale_name: str,
|
909
|
+
):
|
910
|
+
return [
|
911
|
+
("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
|
912
|
+
(
|
913
|
+
"experts.w13_weight_bias",
|
914
|
+
f"experts.{ckpt_gate_up_proj_bias_name}",
|
915
|
+
"w13",
|
916
|
+
),
|
917
|
+
("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
|
918
|
+
("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
|
919
|
+
(
|
920
|
+
"experts.w13_weight_scale",
|
921
|
+
f"experts.{ckpt_gate_up_proj_scale_name}",
|
922
|
+
"w13",
|
923
|
+
),
|
924
|
+
("experts.w2_weight_scale", f"experts.{ckpt_down_proj_scale_name}", "w2"),
|
925
|
+
]
|
926
|
+
|
689
927
|
@classmethod
|
690
928
|
def make_expert_input_scale_params_mapping(
|
691
929
|
cls,
|
@@ -721,8 +959,13 @@ class FlashInferFusedMoE(FusedMoE):
|
|
721
959
|
self.num_expert_group = num_expert_group
|
722
960
|
self.topk_group = topk_group
|
723
961
|
self.correction_bias = correction_bias
|
962
|
+
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
724
963
|
|
725
|
-
def forward(self, hidden_states: torch.Tensor,
|
964
|
+
def forward(self, hidden_states: torch.Tensor, topk_output: tuple):
|
965
|
+
assert self.use_flashinfer_trtllm_moe
|
966
|
+
assert (
|
967
|
+
self.activation == "silu"
|
968
|
+
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
726
969
|
assert self.quant_method is not None
|
727
970
|
assert (
|
728
971
|
self.renormalize
|
@@ -730,6 +973,14 @@ class FlashInferFusedMoE(FusedMoE):
|
|
730
973
|
assert (
|
731
974
|
self.num_fused_shared_experts == 0
|
732
975
|
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
976
|
+
|
977
|
+
# TRTLLM mode expects (TopK_config, router_logits) tuple
|
978
|
+
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
|
979
|
+
raise ValueError(
|
980
|
+
f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
|
981
|
+
)
|
982
|
+
_, router_logits = topk_output
|
983
|
+
|
733
984
|
# Matrix multiply.
|
734
985
|
final_hidden_states = self.quant_method.apply_with_router_logits(
|
735
986
|
layer=self,
|
@@ -739,7 +990,135 @@ class FlashInferFusedMoE(FusedMoE):
|
|
739
990
|
routed_scaling_factor=self.routed_scaling_factor,
|
740
991
|
)
|
741
992
|
|
742
|
-
if self.reduce_results and (self.
|
993
|
+
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
743
994
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
744
995
|
|
745
996
|
return final_hidden_states
|
997
|
+
|
998
|
+
|
999
|
+
class FlashInferFP4MoE(FusedMoE):
|
1000
|
+
"""FP4 TRTLLM MoE implementation using FlashInfer."""
|
1001
|
+
|
1002
|
+
def __init__(self, *args, **kwargs):
|
1003
|
+
# Extract DeepSeek-specific parameters
|
1004
|
+
renormalize = kwargs.pop("renormalize", True)
|
1005
|
+
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
1006
|
+
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
1007
|
+
num_expert_group = kwargs.pop("num_expert_group", None)
|
1008
|
+
topk_group = kwargs.pop("topk_group", None)
|
1009
|
+
correction_bias = kwargs.pop("correction_bias", None)
|
1010
|
+
|
1011
|
+
# Extract additional TopK parameters that were previously extracted in forward
|
1012
|
+
routed_scaling_factor = kwargs.pop("routed_scaling_factor", None)
|
1013
|
+
|
1014
|
+
super().__init__(*args, **kwargs)
|
1015
|
+
|
1016
|
+
# Store DeepSeek parameters
|
1017
|
+
self.renormalize = renormalize
|
1018
|
+
self.num_fused_shared_experts = num_fused_shared_experts
|
1019
|
+
self.use_grouped_topk = use_grouped_topk
|
1020
|
+
self.num_expert_group = num_expert_group
|
1021
|
+
self.topk_group = topk_group
|
1022
|
+
self.correction_bias = correction_bias
|
1023
|
+
self.routed_scaling_factor = routed_scaling_factor
|
1024
|
+
|
1025
|
+
# ---------------------------------------------------------------------
|
1026
|
+
# Helper: quantize hidden states to FP4 each forward pass
|
1027
|
+
# ---------------------------------------------------------------------
|
1028
|
+
def _quantize_hidden_states_fp4(self, hidden_states: torch.Tensor):
|
1029
|
+
"""
|
1030
|
+
Quantize hidden states using global scale factor from quantization method.
|
1031
|
+
|
1032
|
+
Global scale factor is set by ModelOptNvFp4FusedMoEMethod during weight loading.
|
1033
|
+
Only block scales are computed at runtime for efficiency.
|
1034
|
+
|
1035
|
+
Returns (packed_fp4_uint8, scale_float8_e4m3fn_runtime, global_scale_float32)
|
1036
|
+
"""
|
1037
|
+
|
1038
|
+
# flashinfer.fp4_quantize returns (packed_uint8, scale_fp8)
|
1039
|
+
# Only the block scales are computed at runtime
|
1040
|
+
hs_fp4_bytes, hs_sf_bytes = fp4_quantize(
|
1041
|
+
hidden_states,
|
1042
|
+
self.w13_input_scale_quant,
|
1043
|
+
16, # sf_vec_size
|
1044
|
+
False, # use_ue8m0
|
1045
|
+
False, # is_sf_swizzled_layout
|
1046
|
+
)
|
1047
|
+
|
1048
|
+
hs_fp4 = hs_fp4_bytes.reshape(
|
1049
|
+
hidden_states.shape[0], hidden_states.shape[1] // 2
|
1050
|
+
)
|
1051
|
+
hs_sf = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1)
|
1052
|
+
|
1053
|
+
return hs_fp4, hs_sf
|
1054
|
+
|
1055
|
+
def forward(self, hidden_states: torch.Tensor, topk_output):
|
1056
|
+
"""Forward pass using FP4 TRTLLM kernel.
|
1057
|
+
|
1058
|
+
Args:
|
1059
|
+
hidden_states: Input tensor
|
1060
|
+
topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode
|
1061
|
+
"""
|
1062
|
+
|
1063
|
+
# TRTLLM mode expects (TopK_config, router_logits) tuple
|
1064
|
+
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
|
1065
|
+
raise ValueError(
|
1066
|
+
f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
|
1067
|
+
)
|
1068
|
+
|
1069
|
+
_, router_logits = topk_output
|
1070
|
+
|
1071
|
+
hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
|
1072
|
+
|
1073
|
+
router_logits = router_logits.to(torch.float32)
|
1074
|
+
|
1075
|
+
result = trtllm_fp4_block_scale_moe(
|
1076
|
+
routing_logits=router_logits,
|
1077
|
+
routing_bias=self.correction_bias.to(hidden_states.dtype),
|
1078
|
+
hidden_states=hs_fp4,
|
1079
|
+
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
|
1080
|
+
gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
|
1081
|
+
gemm1_weights_scale=self.gemm1_scales_fp4_shuffled.data.view(
|
1082
|
+
torch.float8_e4m3fn
|
1083
|
+
),
|
1084
|
+
gemm1_bias=None,
|
1085
|
+
gemm1_alpha=None,
|
1086
|
+
gemm1_beta=None,
|
1087
|
+
gemm1_clamp_limit=None,
|
1088
|
+
gemm2_weights=self.gemm2_weights_fp4_shuffled.data,
|
1089
|
+
gemm2_weights_scale=self.gemm2_scales_fp4_shuffled.data.view(
|
1090
|
+
torch.float8_e4m3fn
|
1091
|
+
),
|
1092
|
+
gemm2_bias=None,
|
1093
|
+
output1_scale_scalar=self.g1_scale_c.data,
|
1094
|
+
output1_scale_gate_scalar=self.g1_alphas.data,
|
1095
|
+
output2_scale_scalar=self.g2_alphas.data,
|
1096
|
+
num_experts=self.num_experts,
|
1097
|
+
top_k=self.top_k,
|
1098
|
+
n_group=self.num_expert_group,
|
1099
|
+
topk_group=self.topk_group,
|
1100
|
+
intermediate_size=self.intermediate_size_per_partition,
|
1101
|
+
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
|
1102
|
+
local_num_experts=self.num_local_experts,
|
1103
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
1104
|
+
tile_tokens_dim=_get_tile_tokens_dim(
|
1105
|
+
hidden_states.shape[0], self.top_k, self.num_local_experts
|
1106
|
+
),
|
1107
|
+
routing_method_type=RoutingMethodType.DeepSeekV3,
|
1108
|
+
do_finalize=True,
|
1109
|
+
)[0]
|
1110
|
+
|
1111
|
+
return result
|
1112
|
+
|
1113
|
+
|
1114
|
+
def get_fused_moe_impl_class():
|
1115
|
+
"""Factory function to get the appropriate FusedMoE implementation class."""
|
1116
|
+
if should_use_flashinfer_trtllm_moe() and _is_fp4_quantization_enabled():
|
1117
|
+
# Use FP4 variant when FP4 quantization is enabled
|
1118
|
+
return FlashInferFP4MoE
|
1119
|
+
elif should_use_flashinfer_trtllm_moe():
|
1120
|
+
# Use regular FlashInfer variant for non-FP4 FlashInfer cases
|
1121
|
+
return FlashInferFusedMoE
|
1122
|
+
else:
|
1123
|
+
# Default case
|
1124
|
+
return FusedMoE
|