sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List, Optional
|
|
22
22
|
import torch
|
23
23
|
from torch.nn.parameter import Parameter
|
24
24
|
|
25
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
26
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
25
27
|
from sglang.srt.layers.moe.utils import get_moe_runner_backend
|
26
28
|
from sglang.srt.layers.quantization.base_config import (
|
27
29
|
FusedMoEMethodBase,
|
@@ -59,8 +61,10 @@ if is_flashinfer_available():
|
|
59
61
|
logger = logging.getLogger(__name__)
|
60
62
|
|
61
63
|
if TYPE_CHECKING:
|
62
|
-
from sglang.srt.layers.moe.
|
63
|
-
|
64
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
65
|
+
CombineInput,
|
66
|
+
StandardDispatchOutput,
|
67
|
+
)
|
64
68
|
|
65
69
|
_is_hip = is_hip()
|
66
70
|
|
@@ -283,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
283
287
|
layer: torch.nn.Module,
|
284
288
|
num_experts: int,
|
285
289
|
hidden_size: int,
|
286
|
-
|
290
|
+
intermediate_size_per_partition: int,
|
287
291
|
params_dtype: torch.dtype,
|
288
292
|
with_bias: bool = False,
|
289
293
|
**extra_weight_attrs,
|
@@ -296,26 +300,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
296
300
|
|
297
301
|
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
298
302
|
# for to hold non-uniform sharded tensor as well as swizzling
|
299
|
-
intermediate_size_per_partition_after_pad =
|
303
|
+
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
|
300
304
|
if _is_sm100_supported:
|
301
305
|
if self.use_flashinfer:
|
302
306
|
intermediate_size_per_partition_after_pad = round_up(
|
303
|
-
|
307
|
+
intermediate_size_per_partition, 256
|
304
308
|
)
|
305
309
|
hidden_size = round_up(hidden_size, 256)
|
306
310
|
else:
|
307
311
|
intermediate_size_per_partition_after_pad = round_up(
|
308
|
-
|
312
|
+
intermediate_size_per_partition, 64
|
309
313
|
)
|
310
314
|
elif has_triton_kernels:
|
311
315
|
# TODO: this is a hack to make
|
312
316
|
# intermediate_size_per_partition_after_pad the same as the
|
313
317
|
# per_rank_intermediate_size during weight loading
|
314
318
|
intermediate_size_per_partition_after_pad = round_up(
|
315
|
-
|
319
|
+
intermediate_size_per_partition, mxfp4_block
|
316
320
|
)
|
317
321
|
|
318
|
-
self.
|
322
|
+
self.intermediate_size_per_partition = intermediate_size_per_partition_after_pad
|
319
323
|
|
320
324
|
self.hidden_size = hidden_size
|
321
325
|
# Fused gate_up_proj (column parallel)
|
@@ -410,31 +414,35 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
410
414
|
assert (
|
411
415
|
layer.w13_weight.dim() == 3
|
412
416
|
and layer.w13_weight.shape[0] == self.num_experts
|
413
|
-
and layer.w13_weight.shape[1]
|
417
|
+
and layer.w13_weight.shape[1]
|
418
|
+
== self.intermediate_size_per_partition * 2
|
414
419
|
and layer.w13_weight.shape[2] == self.hidden_size // 2
|
415
420
|
)
|
416
421
|
assert (
|
417
422
|
layer.w13_weight_scale.dim() == 3
|
418
423
|
and layer.w13_weight_scale.shape[0] == self.num_experts
|
419
|
-
and layer.w13_weight_scale.shape[1]
|
424
|
+
and layer.w13_weight_scale.shape[1]
|
425
|
+
== self.intermediate_size_per_partition * 2
|
420
426
|
and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
|
421
427
|
)
|
422
428
|
assert (
|
423
429
|
layer.w2_weight.dim() == 3
|
424
430
|
and layer.w2_weight.shape[0] == self.num_experts
|
425
431
|
and layer.w2_weight.shape[1] == self.hidden_size
|
426
|
-
and layer.w2_weight.shape[2]
|
432
|
+
and layer.w2_weight.shape[2]
|
433
|
+
== self.intermediate_size_per_partition // 2
|
427
434
|
)
|
428
435
|
assert (
|
429
436
|
layer.w2_weight_scale.dim() == 3
|
430
437
|
and layer.w2_weight_scale.shape[1] == self.hidden_size
|
431
438
|
and layer.w2_weight_scale.shape[2]
|
432
|
-
== self.
|
439
|
+
== self.intermediate_size_per_partition // sf_block_size
|
433
440
|
)
|
434
441
|
assert (
|
435
442
|
layer.w13_weight_bias.dim() == 2
|
436
443
|
and layer.w13_weight_bias.shape[0] == self.num_experts
|
437
|
-
and layer.w13_weight_bias.shape[1]
|
444
|
+
and layer.w13_weight_bias.shape[1]
|
445
|
+
== self.intermediate_size_per_partition * 2
|
438
446
|
)
|
439
447
|
assert (
|
440
448
|
layer.w2_weight_bias.dim() == 2
|
@@ -511,7 +519,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
511
519
|
torch.stack(gemm1_scales_mxfp4_shuffled)
|
512
520
|
.reshape(
|
513
521
|
self.num_experts,
|
514
|
-
2 * self.
|
522
|
+
2 * self.intermediate_size_per_partition,
|
515
523
|
self.hidden_size // sf_block_size,
|
516
524
|
)
|
517
525
|
.view(torch.float8_e4m3fn)
|
@@ -523,7 +531,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
523
531
|
.reshape(
|
524
532
|
self.num_experts,
|
525
533
|
self.hidden_size,
|
526
|
-
self.
|
534
|
+
self.intermediate_size_per_partition // sf_block_size,
|
527
535
|
)
|
528
536
|
.view(torch.float8_e4m3fn)
|
529
537
|
)
|
@@ -613,16 +621,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
613
621
|
|
614
622
|
return tile_tokens_dim
|
615
623
|
|
624
|
+
def create_moe_runner(
|
625
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
626
|
+
):
|
627
|
+
self.moe_runner_config = moe_runner_config
|
628
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
629
|
+
|
616
630
|
def apply(
|
617
631
|
self,
|
618
632
|
layer: torch.nn.Module,
|
619
|
-
|
620
|
-
|
621
|
-
moe_runner_config: MoeRunnerConfig,
|
622
|
-
) -> torch.Tensor:
|
633
|
+
dispatch_output: StandardDispatchOutput,
|
634
|
+
) -> CombineInput:
|
623
635
|
|
636
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
624
637
|
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
625
638
|
|
639
|
+
x = dispatch_output.hidden_states
|
640
|
+
topk_output = dispatch_output.topk_output
|
641
|
+
|
642
|
+
moe_runner_config = self.moe_runner_config
|
643
|
+
|
626
644
|
if self.use_flashinfer:
|
627
645
|
# When bf16 mode is enabled, we don't need to quantize the input,
|
628
646
|
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
|
@@ -674,7 +692,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
674
692
|
top_k,
|
675
693
|
None, # n_group # TODO: support n_group
|
676
694
|
None, # topk_group # TODO: support topk_group
|
677
|
-
self.
|
695
|
+
self.intermediate_size_per_partition, # padded to multiple of 256
|
678
696
|
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
|
679
697
|
layer.num_local_experts, # local num experts
|
680
698
|
None,
|
@@ -682,14 +700,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
682
700
|
1, # routing_method_type, renormalize
|
683
701
|
True, # do finalize
|
684
702
|
)[0]
|
685
|
-
return trtllm_gen_output
|
703
|
+
return StandardCombineInput(hidden_states=trtllm_gen_output)
|
686
704
|
|
687
705
|
if self.use_triton_kernels:
|
688
706
|
assert (
|
689
707
|
layer.moe_ep_size == 1
|
690
708
|
), "Expert parallel is not supported when using triton kernels"
|
691
709
|
if self.with_bias:
|
692
|
-
|
710
|
+
output = self.triton_kernel_moe_with_bias_forward(
|
693
711
|
hidden_states=x,
|
694
712
|
w1=self.w13_weight_triton_tensor,
|
695
713
|
w1_pcg=self.w13_precision_config,
|
@@ -701,25 +719,22 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
701
719
|
moe_runner_config=moe_runner_config,
|
702
720
|
)
|
703
721
|
else:
|
704
|
-
|
722
|
+
output = self.triton_kernel_moe_forward(
|
705
723
|
hidden_states=x,
|
706
724
|
w1=layer.w13_weight,
|
707
725
|
w2=layer.w2_weight,
|
708
726
|
topk_output=topk_output,
|
709
727
|
moe_runner_config=moe_runner_config,
|
710
728
|
)
|
729
|
+
return StandardCombineInput(hidden_states=output)
|
711
730
|
else:
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
w2=layer.w2_weight,
|
718
|
-
topk_output=topk_output,
|
719
|
-
moe_runner_config=moe_runner_config,
|
720
|
-
b1=layer.w13_weight_bias,
|
721
|
-
b2=layer.w2_weight_bias,
|
731
|
+
quant_info = TritonMoeQuantInfo(
|
732
|
+
w13_weight=layer.w13_weight,
|
733
|
+
w2_weight=layer.w2_weight,
|
734
|
+
w13_weight_bias=layer.w13_weight_bias,
|
735
|
+
w2_weight_bias=layer.w2_weight_bias,
|
722
736
|
)
|
737
|
+
return self.runner.run(dispatch_output, quant_info)
|
723
738
|
|
724
739
|
|
725
740
|
class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
@@ -798,7 +813,7 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
|
798
813
|
|
799
814
|
return w, mx_scales
|
800
815
|
|
801
|
-
def process_weights_after_loading(self, layer: Module) -> None:
|
816
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
802
817
|
w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
|
803
818
|
w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
|
804
819
|
|
@@ -808,19 +823,27 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
|
808
823
|
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
|
809
824
|
layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
|
810
825
|
|
826
|
+
def create_moe_runner(
|
827
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
828
|
+
):
|
829
|
+
self.moe_runner_config = moe_runner_config
|
830
|
+
|
811
831
|
def apply(
|
812
832
|
self,
|
813
833
|
layer: torch.nn.Module,
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
834
|
+
dispatch_output: StandardDispatchOutput,
|
835
|
+
) -> CombineInput:
|
836
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
837
|
+
|
838
|
+
x = dispatch_output.hidden_states
|
839
|
+
topk_output = dispatch_output.topk_output
|
840
|
+
|
818
841
|
topk_weights, topk_ids, _ = topk_output
|
819
842
|
if _is_hip:
|
820
843
|
topk_weights = topk_weights.to(
|
821
844
|
torch.float32
|
822
845
|
) # aiter's moe_sorting requires topk_weights to be FP32
|
823
|
-
|
846
|
+
output = fused_moe(
|
824
847
|
x,
|
825
848
|
layer.w13_weight,
|
826
849
|
layer.w2_weight,
|
@@ -831,8 +854,9 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
|
831
854
|
w2_scale=layer.w2_weight_scale,
|
832
855
|
activation=(
|
833
856
|
ActivationType.Silu
|
834
|
-
if moe_runner_config.activation == "silu"
|
857
|
+
if self.moe_runner_config.activation == "silu"
|
835
858
|
else ActivationType.Gelu
|
836
859
|
),
|
837
860
|
doweight_stage1=False,
|
838
861
|
)
|
862
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -10,8 +10,17 @@ from aiter import ActivationType, QuantType, biased_grouped_topk
|
|
10
10
|
from aiter.fused_moe import fused_moe
|
11
11
|
from aiter.utility.fp4_utils import e8m0_shuffle
|
12
12
|
|
13
|
+
from sglang.srt.layers.moe import MoeRunnerConfig
|
14
|
+
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
13
15
|
from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs
|
14
16
|
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
19
|
+
CombineInput,
|
20
|
+
StandardDispatchOutput,
|
21
|
+
)
|
22
|
+
from sglang.srt.layers.quantization.quark.quark import QuarkConfig
|
23
|
+
|
15
24
|
logger = logging.getLogger(__name__)
|
16
25
|
|
17
26
|
__all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
|
@@ -19,31 +28,17 @@ __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
|
|
19
28
|
OCP_MX_BLOCK_SIZE = 32
|
20
29
|
|
21
30
|
if TYPE_CHECKING:
|
22
|
-
from sglang.srt.layers.
|
23
|
-
|
24
|
-
|
25
|
-
class QuarkMoEMethod:
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
if not hasattr(cls, "_initialized"):
|
30
|
-
original_init = cls.__init__
|
31
|
-
new_cls = type(
|
32
|
-
cls.__name__,
|
33
|
-
(FusedMoEMethodBase,),
|
34
|
-
{
|
35
|
-
"__init__": original_init,
|
36
|
-
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
37
|
-
},
|
38
|
-
)
|
39
|
-
obj = super(new_cls, new_cls).__new__(new_cls)
|
40
|
-
obj.__init__(*args, **kwargs)
|
41
|
-
return obj
|
42
|
-
return super().__new__(cls)
|
31
|
+
from sglang.srt.layers.quantization import QuarkConfig
|
32
|
+
|
33
|
+
|
34
|
+
class QuarkMoEMethod(FusedMoEMethodBase):
|
35
|
+
|
36
|
+
def __init__(self, quant_config: QuarkConfig):
|
37
|
+
self.quant_config = quant_config
|
43
38
|
|
44
39
|
@staticmethod
|
45
40
|
def get_moe_method(
|
46
|
-
quant_config:
|
41
|
+
quant_config: QuarkConfig, # type: ignore # noqa E501 # noqa F821
|
47
42
|
module: torch.nn.Module,
|
48
43
|
layer_name: str,
|
49
44
|
) -> "QuarkMoEMethod":
|
@@ -170,16 +165,25 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
|
170
165
|
# layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)
|
171
166
|
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
|
172
167
|
|
168
|
+
def create_moe_runner(
|
169
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
170
|
+
):
|
171
|
+
self.moe_runner_config = moe_runner_config
|
172
|
+
|
173
173
|
def apply(
|
174
174
|
self,
|
175
175
|
layer: torch.nn.Module,
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
176
|
+
dispatch_output: StandardDispatchOutput,
|
177
|
+
) -> CombineInput:
|
178
|
+
|
179
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
180
|
+
|
181
|
+
x = dispatch_output.hidden_states
|
182
|
+
topk_output = dispatch_output.topk_output
|
183
|
+
moe_runner_config = self.moe_runner_config
|
180
184
|
topk_weights, topk_ids, _ = topk_output
|
181
185
|
|
182
|
-
|
186
|
+
output = fused_moe(
|
183
187
|
x,
|
184
188
|
layer.w13_weight,
|
185
189
|
layer.w2_weight,
|
@@ -195,3 +199,4 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
|
195
199
|
),
|
196
200
|
doweight_stage1=False,
|
197
201
|
)
|
202
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -9,6 +9,8 @@ from torch.nn.parameter import Parameter
|
|
9
9
|
|
10
10
|
from sglang.srt.custom_op import CustomOp
|
11
11
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
12
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
13
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
12
14
|
from sglang.srt.layers.quantization.base_config import (
|
13
15
|
FusedMoEMethodBase,
|
14
16
|
LinearMethodBase,
|
@@ -24,8 +26,10 @@ from sglang.srt.utils import (
|
|
24
26
|
)
|
25
27
|
|
26
28
|
if TYPE_CHECKING:
|
27
|
-
from sglang.srt.layers.moe.
|
28
|
-
|
29
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
30
|
+
CombineInput,
|
31
|
+
StandardDispatchOutput,
|
32
|
+
)
|
29
33
|
|
30
34
|
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
31
35
|
|
@@ -155,7 +159,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
155
159
|
layer: torch.nn.Module,
|
156
160
|
num_experts: int,
|
157
161
|
hidden_size: int,
|
158
|
-
|
162
|
+
intermediate_size_per_partition: int,
|
159
163
|
params_dtype: torch.dtype,
|
160
164
|
with_bias: bool = False,
|
161
165
|
**extra_weight_attrs,
|
@@ -163,7 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
163
167
|
self.with_bias = with_bias
|
164
168
|
|
165
169
|
# Fused gate_up_proj (column parallel)
|
166
|
-
w13_weight_n, w13_weight_k = 2 *
|
170
|
+
w13_weight_n, w13_weight_k = 2 * intermediate_size_per_partition, hidden_size
|
167
171
|
if self.use_triton_kernels:
|
168
172
|
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
|
169
173
|
w13_weight = torch.nn.Parameter(
|
@@ -175,7 +179,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
175
179
|
|
176
180
|
if self.with_bias:
|
177
181
|
w13_weight_bias = torch.nn.Parameter(
|
178
|
-
torch.empty(
|
182
|
+
torch.empty(
|
183
|
+
num_experts,
|
184
|
+
2 * intermediate_size_per_partition,
|
185
|
+
dtype=torch.float32,
|
186
|
+
),
|
179
187
|
requires_grad=False,
|
180
188
|
)
|
181
189
|
layer.register_parameter("w13_weight_bias", w13_weight_bias)
|
@@ -184,7 +192,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
184
192
|
# down_proj (row parallel)
|
185
193
|
w2_weight_n, w2_weight_k = (
|
186
194
|
hidden_size,
|
187
|
-
|
195
|
+
intermediate_size_per_partition,
|
188
196
|
)
|
189
197
|
if self.use_triton_kernels:
|
190
198
|
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
|
@@ -222,33 +230,40 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
222
230
|
|
223
231
|
return
|
224
232
|
|
233
|
+
def create_moe_runner(
|
234
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
235
|
+
):
|
236
|
+
self.moe_runner_config = moe_runner_config
|
237
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
238
|
+
|
225
239
|
def apply(
|
226
240
|
self,
|
227
241
|
layer: torch.nn.Module,
|
228
|
-
|
229
|
-
|
230
|
-
moe_runner_config: MoeRunnerConfig,
|
231
|
-
) -> torch.Tensor:
|
242
|
+
dispatch_output: StandardDispatchOutput,
|
243
|
+
) -> CombineInput:
|
232
244
|
|
233
245
|
return self.forward(
|
234
|
-
x=x,
|
235
246
|
layer=layer,
|
236
|
-
|
237
|
-
moe_runner_config=moe_runner_config,
|
247
|
+
dispatch_output=dispatch_output,
|
238
248
|
)
|
239
249
|
|
240
250
|
def forward_cuda(
|
241
251
|
self,
|
242
252
|
layer: torch.nn.Module,
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
253
|
+
dispatch_output: StandardDispatchOutput,
|
254
|
+
) -> CombineInput:
|
255
|
+
|
256
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
257
|
+
|
258
|
+
x = dispatch_output.hidden_states
|
259
|
+
topk_output = dispatch_output.topk_output
|
260
|
+
|
261
|
+
moe_runner_config = self.moe_runner_config
|
247
262
|
|
248
263
|
if self.use_triton_kernels:
|
249
264
|
if self.with_bias:
|
250
265
|
assert self.triton_kernel_moe_with_bias_forward is not None
|
251
|
-
|
266
|
+
output = self.triton_kernel_moe_with_bias_forward(
|
252
267
|
hidden_states=x,
|
253
268
|
w1=layer.w13_weight,
|
254
269
|
w2=layer.w2_weight,
|
@@ -261,13 +276,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
261
276
|
)
|
262
277
|
else:
|
263
278
|
assert self.triton_kernel_moe_forward is not None
|
264
|
-
|
279
|
+
output = self.triton_kernel_moe_forward(
|
265
280
|
hidden_states=x,
|
266
281
|
w1=layer.w13_weight,
|
267
282
|
w2=layer.w2_weight,
|
268
283
|
topk_output=topk_output,
|
269
284
|
moe_runner_config=moe_runner_config,
|
270
285
|
)
|
286
|
+
return StandardCombineInput(hidden_states=output)
|
271
287
|
else:
|
272
288
|
if _use_aiter:
|
273
289
|
assert not moe_runner_config.no_combine, "unsupported"
|
@@ -284,7 +300,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
284
300
|
topk_weights = torch.ones_like(
|
285
301
|
topk_weights, dtype=torch.float32
|
286
302
|
) # topk_weights must be FP32 (float32)
|
287
|
-
|
303
|
+
output = fused_moe(
|
288
304
|
x,
|
289
305
|
layer.w13_weight,
|
290
306
|
layer.w2_weight,
|
@@ -296,28 +312,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
296
312
|
else ActivationType.Gelu
|
297
313
|
),
|
298
314
|
)
|
315
|
+
return StandardCombineInput(hidden_states=output)
|
299
316
|
else:
|
300
|
-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
301
|
-
fused_experts,
|
302
|
-
)
|
303
317
|
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
b1=getattr(layer, "w13_weight_bias", None),
|
318
|
+
quant_info = TritonMoeQuantInfo(
|
319
|
+
w13_weight=layer.w13_weight,
|
320
|
+
w2_weight=layer.w2_weight,
|
321
|
+
b13=getattr(layer, "w13_weight_bias", None),
|
309
322
|
b2=getattr(layer, "w2_weight_bias", None),
|
310
|
-
topk_output=topk_output,
|
311
|
-
moe_runner_config=moe_runner_config,
|
312
323
|
)
|
324
|
+
return self.runner.run(dispatch_output, quant_info)
|
313
325
|
|
314
326
|
def forward_cpu(
|
315
327
|
self,
|
316
328
|
layer: torch.nn.Module,
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
329
|
+
dispatch_output: StandardDispatchOutput,
|
330
|
+
) -> CombineInput:
|
331
|
+
|
332
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
333
|
+
|
334
|
+
x = dispatch_output.hidden_states
|
335
|
+
topk_output = dispatch_output.topk_output
|
336
|
+
|
337
|
+
moe_runner_config = self.moe_runner_config
|
338
|
+
|
321
339
|
assert (
|
322
340
|
moe_runner_config.activation == "silu"
|
323
341
|
), f"activation = {moe_runner_config.activation} is not supported."
|
@@ -332,7 +350,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
332
350
|
x, topk_weights = apply_topk_weights_cpu(
|
333
351
|
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
334
352
|
)
|
335
|
-
|
353
|
+
output = torch.ops.sgl_kernel.fused_experts_cpu(
|
336
354
|
x,
|
337
355
|
layer.w13_weight,
|
338
356
|
layer.w2_weight,
|
@@ -348,33 +366,103 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
348
366
|
None, # a2_scale
|
349
367
|
True, # is_vnni
|
350
368
|
)
|
369
|
+
return StandardCombineInput(hidden_states=output)
|
351
370
|
else:
|
352
371
|
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
353
372
|
|
354
|
-
|
373
|
+
output = moe_forward_native(
|
355
374
|
layer,
|
356
375
|
x,
|
357
376
|
topk_output,
|
358
377
|
moe_runner_config,
|
359
378
|
)
|
379
|
+
return StandardCombineInput(hidden_states=output)
|
360
380
|
|
361
381
|
def forward_npu(
|
362
382
|
self,
|
363
383
|
layer: torch.nn.Module,
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
384
|
+
dispatch_output: StandardDispatchOutput,
|
385
|
+
) -> CombineInput:
|
386
|
+
|
387
|
+
import torch_npu
|
388
|
+
|
389
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
390
|
+
|
391
|
+
x = dispatch_output.hidden_states
|
392
|
+
topk_weights, topk_ids, _ = dispatch_output.topk_output
|
393
|
+
|
394
|
+
original_dtype = x.dtype
|
395
|
+
num_tokens = x.shape[0]
|
396
|
+
topk_weights = topk_weights.to(x.dtype)
|
397
|
+
topk_ids = topk_ids.to(torch.int32)
|
398
|
+
num_experts = layer.num_experts
|
399
|
+
top_k = layer.top_k
|
400
|
+
row_idx_len = num_tokens * top_k
|
401
|
+
row_idx = (
|
402
|
+
torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device)
|
403
|
+
.view(top_k, -1)
|
404
|
+
.permute(1, 0)
|
405
|
+
.contiguous()
|
406
|
+
)
|
369
407
|
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
408
|
+
hidden_states, expanded_row_idx, expanded_expert_idx = (
|
409
|
+
torch_npu.npu_moe_init_routing(
|
410
|
+
x, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens
|
411
|
+
)
|
412
|
+
)
|
413
|
+
|
414
|
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
415
|
+
expanded_expert_idx, num_experts
|
375
416
|
)
|
376
417
|
|
377
|
-
|
418
|
+
expert_tokens = expert_tokens.to(torch.int64)
|
419
|
+
if layer.w13_weight.shape[-1] == layer.hidden_size:
|
420
|
+
w13 = layer.w13_weight.transpose(1, 2)
|
421
|
+
w2 = layer.w2_weight.transpose(1, 2)
|
422
|
+
|
423
|
+
# gmm1: gate_up_proj
|
424
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
425
|
+
x=[hidden_states],
|
426
|
+
weight=[w13],
|
427
|
+
split_item=2,
|
428
|
+
group_list_type=0,
|
429
|
+
group_type=0,
|
430
|
+
group_list=expert_tokens,
|
431
|
+
output_dtype=original_dtype,
|
432
|
+
)[0]
|
433
|
+
|
434
|
+
# act_fn:
|
435
|
+
if self.moe_runner_config.activation == "silu":
|
436
|
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
437
|
+
else:
|
438
|
+
from sglang.srt.layers.activation import GeluAndMul
|
439
|
+
|
440
|
+
hidden_states = GeluAndMul()(hidden_states)
|
441
|
+
|
442
|
+
# gmm2: down_proj
|
443
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
444
|
+
x=[hidden_states],
|
445
|
+
weight=[w2],
|
446
|
+
split_item=2,
|
447
|
+
group_list_type=0,
|
448
|
+
group_type=0,
|
449
|
+
group_list=expert_tokens,
|
450
|
+
output_dtype=original_dtype,
|
451
|
+
)[0]
|
452
|
+
|
453
|
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
454
|
+
hidden_states,
|
455
|
+
skip1=None,
|
456
|
+
skip2=None,
|
457
|
+
bias=None,
|
458
|
+
scales=topk_weights,
|
459
|
+
expanded_src_to_dst_row=expanded_row_idx,
|
460
|
+
export_for_source_row=topk_ids,
|
461
|
+
)
|
462
|
+
|
463
|
+
return StandardCombineInput(hidden_states=final_hidden_states)
|
464
|
+
|
465
|
+
def forward_tpu(self, *args, **kwargs) -> CombineInput:
|
378
466
|
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
379
467
|
|
380
468
|
forward_native = forward_cpu
|