sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -3,10 +3,9 @@ from typing import List, Optional
|
|
3
3
|
|
4
4
|
import torch
|
5
5
|
import triton
|
6
|
-
import triton.language as tl
|
7
6
|
|
8
7
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
9
|
-
from sglang.srt.utils import is_cuda
|
8
|
+
from sglang.srt.utils import dispose_tensor, is_cuda
|
10
9
|
|
11
10
|
logger = logging.getLogger(__name__)
|
12
11
|
|
@@ -653,12 +652,15 @@ def grouped_gemm_triton(
|
|
653
652
|
scale_a: torch.Tensor = None,
|
654
653
|
scale_b: torch.Tensor = None,
|
655
654
|
block_shape: Optional[List[int]] = None,
|
655
|
+
c_dtype=None,
|
656
656
|
):
|
657
657
|
assert weight_column_major == True # TODO: more
|
658
658
|
if use_fp8_w8a8 and block_shape is None:
|
659
659
|
assert scale_a is not None and scale_b is not None
|
660
660
|
|
661
661
|
if block_shape is not None:
|
662
|
+
a_original = a
|
663
|
+
|
662
664
|
assert len(block_shape) == 2
|
663
665
|
block_n, block_k = block_shape[0], block_shape[1]
|
664
666
|
a, scale_a = per_token_group_quant_fp8(a, block_k)
|
@@ -667,6 +669,8 @@ def grouped_gemm_triton(
|
|
667
669
|
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
|
668
670
|
assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
|
669
671
|
|
672
|
+
dispose_tensor(a_original)
|
673
|
+
|
670
674
|
# TODO: adjust config or tune kernel
|
671
675
|
# Reduce block size to prevent L40 shared memory overflow.
|
672
676
|
config = {
|
@@ -680,6 +684,10 @@ def grouped_gemm_triton(
|
|
680
684
|
m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
|
681
685
|
)
|
682
686
|
|
687
|
+
if c is None:
|
688
|
+
assert c_dtype is not None
|
689
|
+
c = torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=c_dtype)
|
690
|
+
|
683
691
|
grid = lambda META: (
|
684
692
|
triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
|
685
693
|
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
|
@@ -783,19 +791,23 @@ def _fwd_kernel_ep_scatter_2(
|
|
783
791
|
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
|
784
792
|
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
|
785
793
|
|
786
|
-
for
|
794
|
+
for token_id_int32 in range(start_token_id, total_token_num, grid_num):
|
795
|
+
token_id = token_id_int32.to(tl.int64)
|
787
796
|
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
|
788
797
|
to_copy_s = tl.load(
|
789
798
|
recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
|
790
799
|
)
|
791
800
|
|
792
|
-
for
|
801
|
+
for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
|
802
|
+
topk_index = topk_idx_int32.to(tl.int64)
|
793
803
|
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
|
794
804
|
if expert_id >= 0:
|
795
|
-
|
805
|
+
dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1)
|
806
|
+
dest_token_index = dest_token_index_int32.to(tl.int64)
|
807
|
+
|
796
808
|
tl.store(
|
797
809
|
output_index + token_id * output_index_stride0 + topk_index,
|
798
|
-
|
810
|
+
dest_token_index_int32,
|
799
811
|
)
|
800
812
|
output_tensor_ptr = (
|
801
813
|
output_tensor + dest_token_index * output_tensor_stride0
|
@@ -894,21 +906,31 @@ def _fwd_kernel_ep_gather(
|
|
894
906
|
topk_num: tl.constexpr,
|
895
907
|
BLOCK_D: tl.constexpr,
|
896
908
|
):
|
897
|
-
|
898
|
-
|
909
|
+
cur_block_int32 = tl.program_id(0)
|
910
|
+
cur_block = cur_block_int32.to(tl.int64)
|
911
|
+
|
912
|
+
start_cur_token_int32 = tl.program_id(1)
|
913
|
+
|
899
914
|
grid_num = tl.num_programs(1)
|
900
915
|
|
901
|
-
for
|
916
|
+
for cur_token_int32 in range(start_cur_token_int32, total_token_num, grid_num):
|
917
|
+
cur_token = cur_token_int32.to(tl.int64)
|
918
|
+
|
902
919
|
off_d = tl.arange(0, BLOCK_D)
|
903
920
|
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
|
904
|
-
|
921
|
+
|
922
|
+
for topk_index_int32 in range(0, topk_num):
|
923
|
+
topk_index = topk_index_int32.to(tl.int64)
|
924
|
+
|
905
925
|
expert_id = tl.load(
|
906
926
|
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
|
907
927
|
)
|
908
928
|
if expert_id >= 0:
|
909
|
-
|
929
|
+
source_token_index_int32 = tl.load(
|
910
930
|
input_index + cur_token * input_index_stride0 + topk_index
|
911
931
|
)
|
932
|
+
source_token_index = source_token_index_int32.to(tl.int64)
|
933
|
+
|
912
934
|
acc_weight = tl.load(
|
913
935
|
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
|
914
936
|
)
|
@@ -5,6 +5,9 @@ import torch
|
|
5
5
|
from torch.nn import Module
|
6
6
|
|
7
7
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
8
|
+
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
9
|
+
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
10
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
8
11
|
|
9
12
|
try:
|
10
13
|
from deep_gemm import (
|
@@ -40,7 +43,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
40
43
|
tma_align_input_scale,
|
41
44
|
)
|
42
45
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
43
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
|
46
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
|
44
47
|
from sglang.srt.layers.moe.topk import select_experts
|
45
48
|
from sglang.srt.layers.quantization.base_config import (
|
46
49
|
QuantizationConfig,
|
@@ -49,7 +52,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
49
52
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
50
53
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
51
54
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
52
|
-
from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
|
55
|
+
from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
|
53
56
|
|
54
57
|
_is_hip = is_hip()
|
55
58
|
|
@@ -92,6 +95,7 @@ class GroupedGemmRunner(torch.nn.Module):
|
|
92
95
|
scale_a: torch.Tensor = None,
|
93
96
|
scale_b: torch.Tensor = None,
|
94
97
|
block_shape: Optional[List[int]] = None,
|
98
|
+
c_dtype=None,
|
95
99
|
):
|
96
100
|
if self.use_flashinfer:
|
97
101
|
# TODO: flashinfer
|
@@ -119,6 +123,7 @@ class GroupedGemmRunner(torch.nn.Module):
|
|
119
123
|
scale_a,
|
120
124
|
scale_b,
|
121
125
|
block_shape=block_shape,
|
126
|
+
c_dtype=c_dtype,
|
122
127
|
)
|
123
128
|
return c
|
124
129
|
|
@@ -136,6 +141,7 @@ class EPMoE(torch.nn.Module):
|
|
136
141
|
top_k: int,
|
137
142
|
hidden_size: int,
|
138
143
|
intermediate_size: int,
|
144
|
+
layer_id: int,
|
139
145
|
params_dtype: Optional[torch.dtype] = None,
|
140
146
|
renormalize: bool = True,
|
141
147
|
use_grouped_topk: bool = False,
|
@@ -159,6 +165,7 @@ class EPMoE(torch.nn.Module):
|
|
159
165
|
)
|
160
166
|
self.tp_rank = get_tensor_model_parallel_rank()
|
161
167
|
|
168
|
+
self.layer_id = layer_id
|
162
169
|
self.num_experts = num_experts
|
163
170
|
assert self.num_experts % self.tp_size == 0
|
164
171
|
self.num_experts_per_partition = self.num_experts // self.tp_size
|
@@ -210,6 +217,10 @@ class EPMoE(torch.nn.Module):
|
|
210
217
|
self.grouped_gemm_runner = None
|
211
218
|
|
212
219
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
220
|
+
hidden_states_shape = hidden_states.shape
|
221
|
+
hidden_states_dtype = hidden_states.dtype
|
222
|
+
hidden_states_device = hidden_states.device
|
223
|
+
|
213
224
|
assert self.quant_method is not None
|
214
225
|
|
215
226
|
if self.grouped_gemm_runner is None:
|
@@ -229,6 +240,9 @@ class EPMoE(torch.nn.Module):
|
|
229
240
|
correction_bias=self.correction_bias,
|
230
241
|
custom_routing_function=self.custom_routing_function,
|
231
242
|
routed_scaling_factor=self.routed_scaling_factor,
|
243
|
+
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
244
|
+
layer_id=self.layer_id,
|
245
|
+
),
|
232
246
|
)
|
233
247
|
|
234
248
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
@@ -265,25 +279,21 @@ class EPMoE(torch.nn.Module):
|
|
265
279
|
hidden_states.shape[1],
|
266
280
|
BLOCK_SIZE=512,
|
267
281
|
)
|
282
|
+
dispose_tensor(hidden_states)
|
268
283
|
|
269
284
|
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
270
285
|
weight_indices_cur_rank = torch.arange(
|
271
286
|
0,
|
272
287
|
self.num_experts_per_partition,
|
273
|
-
device=
|
288
|
+
device=hidden_states_device,
|
274
289
|
dtype=torch.int64,
|
275
290
|
)
|
276
291
|
# GroupGemm-0
|
277
|
-
gateup_output = torch.empty(
|
278
|
-
gateup_input.shape[0],
|
279
|
-
self.w13_weight.shape[1],
|
280
|
-
device=hidden_states.device,
|
281
|
-
dtype=hidden_states.dtype,
|
282
|
-
)
|
283
292
|
gateup_output = self.grouped_gemm_runner(
|
284
293
|
a=gateup_input,
|
285
294
|
b=self.w13_weight,
|
286
|
-
c=
|
295
|
+
c=None,
|
296
|
+
c_dtype=hidden_states_dtype,
|
287
297
|
batch_size=self.num_experts_per_partition,
|
288
298
|
weight_column_major=True,
|
289
299
|
seg_indptr=seg_indptr_cur_rank,
|
@@ -297,6 +307,7 @@ class EPMoE(torch.nn.Module):
|
|
297
307
|
),
|
298
308
|
block_shape=self.block_shape,
|
299
309
|
)
|
310
|
+
del gateup_input
|
300
311
|
|
301
312
|
# Act
|
302
313
|
down_input = torch.empty(
|
@@ -306,14 +317,14 @@ class EPMoE(torch.nn.Module):
|
|
306
317
|
dtype=(
|
307
318
|
self.fp8_dtype
|
308
319
|
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
309
|
-
else
|
320
|
+
else hidden_states_dtype
|
310
321
|
),
|
311
322
|
)
|
312
323
|
if self.w2_input_scale is None and not self.use_block_quant:
|
313
324
|
self.w2_input_scale = torch.ones(
|
314
325
|
self.num_experts_per_partition,
|
315
326
|
dtype=torch.float32,
|
316
|
-
device=
|
327
|
+
device=hidden_states_device,
|
317
328
|
)
|
318
329
|
|
319
330
|
if self.activation == "silu":
|
@@ -340,13 +351,14 @@ class EPMoE(torch.nn.Module):
|
|
340
351
|
)
|
341
352
|
else:
|
342
353
|
raise ValueError(f"Unsupported activation: {self.activation=}")
|
354
|
+
del gateup_output
|
343
355
|
|
344
356
|
# GroupGemm-1
|
345
357
|
down_output = torch.empty(
|
346
358
|
down_input.shape[0],
|
347
359
|
self.w2_weight.shape[1],
|
348
|
-
device=
|
349
|
-
dtype=
|
360
|
+
device=hidden_states_device,
|
361
|
+
dtype=hidden_states_dtype,
|
350
362
|
)
|
351
363
|
down_output = self.grouped_gemm_runner(
|
352
364
|
a=down_input,
|
@@ -365,10 +377,13 @@ class EPMoE(torch.nn.Module):
|
|
365
377
|
),
|
366
378
|
block_shape=self.block_shape,
|
367
379
|
)
|
380
|
+
del down_input
|
368
381
|
|
369
382
|
# PostReorder
|
370
|
-
output = torch.
|
371
|
-
|
383
|
+
output = torch.empty(
|
384
|
+
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
385
|
+
)
|
386
|
+
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
372
387
|
down_output,
|
373
388
|
output,
|
374
389
|
src2dst,
|
@@ -377,7 +392,7 @@ class EPMoE(torch.nn.Module):
|
|
377
392
|
self.start_expert_id,
|
378
393
|
self.end_expert_id,
|
379
394
|
self.top_k,
|
380
|
-
|
395
|
+
hidden_states_shape[1],
|
381
396
|
BLOCK_SIZE=512,
|
382
397
|
)
|
383
398
|
return output
|
@@ -417,6 +432,28 @@ class EPMoE(torch.nn.Module):
|
|
417
432
|
weight_name: str,
|
418
433
|
shard_id: str,
|
419
434
|
expert_id: int,
|
435
|
+
) -> None:
|
436
|
+
physical_expert_ids = (
|
437
|
+
get_global_expert_location_metadata().logical_to_all_physical(
|
438
|
+
self.layer_id, expert_id
|
439
|
+
)
|
440
|
+
)
|
441
|
+
for physical_expert_id in physical_expert_ids:
|
442
|
+
self._weight_loader_physical(
|
443
|
+
param=param,
|
444
|
+
loaded_weight=loaded_weight,
|
445
|
+
weight_name=weight_name,
|
446
|
+
shard_id=shard_id,
|
447
|
+
expert_id=physical_expert_id,
|
448
|
+
)
|
449
|
+
|
450
|
+
def _weight_loader_physical(
|
451
|
+
self,
|
452
|
+
param: torch.nn.Parameter,
|
453
|
+
loaded_weight: torch.Tensor,
|
454
|
+
weight_name: str,
|
455
|
+
shard_id: str,
|
456
|
+
expert_id: int,
|
420
457
|
) -> None:
|
421
458
|
if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
|
422
459
|
return
|
@@ -460,7 +497,8 @@ class EPMoE(torch.nn.Module):
|
|
460
497
|
# Input scales can be loaded directly and should be equal.
|
461
498
|
if "input_scale" in weight_name:
|
462
499
|
if (
|
463
|
-
|
500
|
+
(shard_id == "w1" or shard_id == "w3")
|
501
|
+
and param_data[expert_id] != 1
|
464
502
|
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
465
503
|
):
|
466
504
|
raise ValueError(
|
@@ -534,13 +572,10 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
534
572
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
535
573
|
|
536
574
|
# scale
|
575
|
+
layer.register_parameter("w13_input_scale", None)
|
576
|
+
layer.register_parameter("w13_weight_scale", None)
|
577
|
+
|
537
578
|
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
538
|
-
w13_input_scale = torch.nn.Parameter(
|
539
|
-
ones_tensor,
|
540
|
-
requires_grad=False,
|
541
|
-
)
|
542
|
-
layer.register_parameter("w13_input_scale", w13_input_scale)
|
543
|
-
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
544
579
|
|
545
580
|
w2_input_scale = torch.nn.Parameter(
|
546
581
|
ones_tensor,
|
@@ -549,13 +584,6 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
549
584
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
550
585
|
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
551
586
|
|
552
|
-
w13_weight_scale = torch.nn.Parameter(
|
553
|
-
ones_tensor,
|
554
|
-
requires_grad=False,
|
555
|
-
)
|
556
|
-
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
557
|
-
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
558
|
-
|
559
587
|
w2_weight_scale = torch.nn.Parameter(
|
560
588
|
ones_tensor,
|
561
589
|
requires_grad=False,
|
@@ -802,6 +830,7 @@ class DeepEPMoE(EPMoE):
|
|
802
830
|
top_k: int,
|
803
831
|
hidden_size: int,
|
804
832
|
intermediate_size: int,
|
833
|
+
layer_id: int,
|
805
834
|
params_dtype: Optional[torch.dtype] = None,
|
806
835
|
renormalize: bool = True,
|
807
836
|
use_grouped_topk: bool = False,
|
@@ -821,6 +850,7 @@ class DeepEPMoE(EPMoE):
|
|
821
850
|
top_k,
|
822
851
|
hidden_size,
|
823
852
|
intermediate_size,
|
853
|
+
layer_id,
|
824
854
|
params_dtype,
|
825
855
|
renormalize,
|
826
856
|
use_grouped_topk,
|
@@ -881,6 +911,9 @@ class DeepEPMoE(EPMoE):
|
|
881
911
|
reorder_topk_ids: torch.Tensor,
|
882
912
|
seg_indptr: torch.Tensor,
|
883
913
|
):
|
914
|
+
hidden_states_dtype = hidden_states.dtype
|
915
|
+
hidden_states_device = hidden_states.device
|
916
|
+
|
884
917
|
assert self.quant_method is not None
|
885
918
|
assert self.activation == "silu"
|
886
919
|
if self.grouped_gemm_runner is None:
|
@@ -903,18 +936,12 @@ class DeepEPMoE(EPMoE):
|
|
903
936
|
)
|
904
937
|
|
905
938
|
# GroupGemm-0
|
906
|
-
gateup_output = torch.empty(
|
907
|
-
hidden_states.shape[0],
|
908
|
-
self.w13_weight.shape[1],
|
909
|
-
device=hidden_states.device,
|
910
|
-
dtype=hidden_states.dtype,
|
911
|
-
)
|
912
|
-
|
913
939
|
if hidden_states.shape[0] > 0:
|
914
940
|
gateup_output = self.grouped_gemm_runner(
|
915
941
|
a=hidden_states,
|
916
942
|
b=self.w13_weight,
|
917
|
-
c=
|
943
|
+
c=None,
|
944
|
+
c_dtype=hidden_states.dtype,
|
918
945
|
batch_size=self.num_experts_per_partition,
|
919
946
|
weight_column_major=True,
|
920
947
|
seg_indptr=seg_indptr,
|
@@ -928,6 +955,13 @@ class DeepEPMoE(EPMoE):
|
|
928
955
|
),
|
929
956
|
block_shape=self.block_shape,
|
930
957
|
)
|
958
|
+
else:
|
959
|
+
gateup_output = torch.empty(
|
960
|
+
hidden_states.shape[0],
|
961
|
+
self.w13_weight.shape[1],
|
962
|
+
device=hidden_states.device,
|
963
|
+
dtype=hidden_states.dtype,
|
964
|
+
)
|
931
965
|
|
932
966
|
# Act
|
933
967
|
down_input = torch.empty(
|
@@ -937,14 +971,14 @@ class DeepEPMoE(EPMoE):
|
|
937
971
|
dtype=(
|
938
972
|
self.fp8_dtype
|
939
973
|
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
940
|
-
else
|
974
|
+
else hidden_states_dtype
|
941
975
|
),
|
942
976
|
)
|
943
977
|
if self.w2_input_scale is None and not self.use_block_quant:
|
944
978
|
self.w2_input_scale = torch.ones(
|
945
979
|
self.num_experts_per_partition,
|
946
980
|
dtype=torch.float32,
|
947
|
-
device=
|
981
|
+
device=hidden_states_device,
|
948
982
|
)
|
949
983
|
|
950
984
|
if self.activation == "silu":
|
@@ -961,12 +995,14 @@ class DeepEPMoE(EPMoE):
|
|
961
995
|
else:
|
962
996
|
raise ValueError(f"Unsupported activation: {self.activation=}")
|
963
997
|
|
998
|
+
del gateup_output
|
999
|
+
|
964
1000
|
# GroupGemm-1
|
965
1001
|
down_output = torch.empty(
|
966
1002
|
down_input.shape[0],
|
967
1003
|
self.w2_weight.shape[1],
|
968
|
-
device=
|
969
|
-
dtype=
|
1004
|
+
device=hidden_states_device,
|
1005
|
+
dtype=hidden_states_dtype,
|
970
1006
|
)
|
971
1007
|
if down_input.shape[0] > 0:
|
972
1008
|
down_output = self.grouped_gemm_runner(
|
@@ -1007,11 +1043,9 @@ class DeepEPMoE(EPMoE):
|
|
1007
1043
|
N = self.w13_weight.size(1)
|
1008
1044
|
scale_block_size = 128
|
1009
1045
|
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
dtype=torch.bfloat16,
|
1014
|
-
)
|
1046
|
+
hidden_states_fp8_shape = hidden_states_fp8.shape
|
1047
|
+
hidden_states_fp8_device = hidden_states_fp8.device
|
1048
|
+
hidden_states_fp8_dtype = hidden_states_fp8.dtype
|
1015
1049
|
|
1016
1050
|
input_tensor = [
|
1017
1051
|
torch.empty(
|
@@ -1049,16 +1083,18 @@ class DeepEPMoE(EPMoE):
|
|
1049
1083
|
m_indices,
|
1050
1084
|
output_index,
|
1051
1085
|
)
|
1086
|
+
dispose_tensor(hidden_states_fp8)
|
1052
1087
|
|
1053
1088
|
gateup_output = torch.empty(
|
1054
1089
|
(all_tokens, N),
|
1055
|
-
device=
|
1090
|
+
device=hidden_states_fp8_device,
|
1056
1091
|
dtype=torch.bfloat16,
|
1057
1092
|
)
|
1058
1093
|
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
1059
1094
|
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
1060
1095
|
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
1061
1096
|
)
|
1097
|
+
del input_tensor
|
1062
1098
|
down_input = torch.empty(
|
1063
1099
|
(
|
1064
1100
|
all_tokens,
|
@@ -1068,14 +1104,16 @@ class DeepEPMoE(EPMoE):
|
|
1068
1104
|
dtype=torch.bfloat16,
|
1069
1105
|
)
|
1070
1106
|
silu_and_mul(gateup_output.view(-1, N), down_input)
|
1107
|
+
del gateup_output
|
1071
1108
|
down_output = torch.empty(
|
1072
1109
|
(all_tokens, K),
|
1073
|
-
device=
|
1110
|
+
device=hidden_states_fp8_device,
|
1074
1111
|
dtype=torch.bfloat16,
|
1075
1112
|
)
|
1076
1113
|
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
1077
1114
|
down_input, scale_block_size
|
1078
1115
|
)
|
1116
|
+
del down_input
|
1079
1117
|
down_input_scale = tma_align_input_scale(down_input_scale)
|
1080
1118
|
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
1081
1119
|
(down_input_fp8, down_input_scale),
|
@@ -1083,7 +1121,13 @@ class DeepEPMoE(EPMoE):
|
|
1083
1121
|
down_output,
|
1084
1122
|
m_indices,
|
1085
1123
|
)
|
1124
|
+
del down_input_fp8, down_input_scale
|
1086
1125
|
|
1126
|
+
gather_out = torch.empty(
|
1127
|
+
hidden_states_fp8_shape,
|
1128
|
+
device=hidden_states_fp8_device,
|
1129
|
+
dtype=torch.bfloat16,
|
1130
|
+
)
|
1087
1131
|
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
|
1088
1132
|
|
1089
1133
|
return gather_out
|
@@ -1107,6 +1151,7 @@ class DeepEPMoE(EPMoE):
|
|
1107
1151
|
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
1108
1152
|
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
|
1109
1153
|
)
|
1154
|
+
dispose_tensor(hidden_states_fp8[0])
|
1110
1155
|
|
1111
1156
|
# Act
|
1112
1157
|
down_input = torch.empty(
|
@@ -1135,6 +1180,7 @@ class DeepEPMoE(EPMoE):
|
|
1135
1180
|
scale_block_size,
|
1136
1181
|
masked_m,
|
1137
1182
|
)
|
1183
|
+
del gateup_output
|
1138
1184
|
|
1139
1185
|
# GroupGemm-1
|
1140
1186
|
n = self.w2_weight.size(1)
|
@@ -1150,3 +1196,11 @@ class DeepEPMoE(EPMoE):
|
|
1150
1196
|
)
|
1151
1197
|
|
1152
1198
|
return down_output
|
1199
|
+
|
1200
|
+
|
1201
|
+
def get_moe_impl_class():
|
1202
|
+
if global_server_args_dict["enable_deepep_moe"]:
|
1203
|
+
return DeepEPMoE
|
1204
|
+
if global_server_args_dict["enable_ep_moe"]:
|
1205
|
+
return EPMoE
|
1206
|
+
return FusedMoE
|