sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
sglang/srt/layers/elementwise.py
CHANGED
@@ -486,3 +486,97 @@ def gelu_and_mul_triton(
|
|
486
486
|
return out_hidden_states, out_scales
|
487
487
|
else:
|
488
488
|
return out_hidden_states, None
|
489
|
+
|
490
|
+
|
491
|
+
# silu on first half of vector
|
492
|
+
@triton.jit
|
493
|
+
def silu_and_mul_kernel(
|
494
|
+
out_hidden_states_ptr, # (bs, hidden_dim)
|
495
|
+
out_scales_ptr, # (bs,)
|
496
|
+
hidden_states_ptr, # (bs, hidden_dim * 2)
|
497
|
+
quant_max: tl.constexpr,
|
498
|
+
static_scale: tl.constexpr,
|
499
|
+
hidden_dim: tl.constexpr, # the output hidden_dim
|
500
|
+
BLOCK_SIZE: tl.constexpr,
|
501
|
+
):
|
502
|
+
pid = tl.program_id(axis=0)
|
503
|
+
|
504
|
+
input_start = pid * hidden_dim * 2
|
505
|
+
output_start = pid * hidden_dim
|
506
|
+
|
507
|
+
input1_offs = tl.arange(0, BLOCK_SIZE)
|
508
|
+
mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output
|
509
|
+
input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)
|
510
|
+
output_offs = tl.arange(0, BLOCK_SIZE)
|
511
|
+
|
512
|
+
x1 = tl.load(
|
513
|
+
hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0
|
514
|
+
).to(tl.float32)
|
515
|
+
x3 = tl.load(
|
516
|
+
hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0
|
517
|
+
).to(tl.float32)
|
518
|
+
|
519
|
+
# silu
|
520
|
+
# cast down before mul to better match training?
|
521
|
+
silu_x1 = x1 * tl.sigmoid(x1)
|
522
|
+
out = x3 * silu_x1.to(hidden_states_ptr.dtype.element_ty)
|
523
|
+
|
524
|
+
if quant_max is not None:
|
525
|
+
raise NotImplementedError()
|
526
|
+
|
527
|
+
tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)
|
528
|
+
|
529
|
+
|
530
|
+
def silu_and_mul_triton(
|
531
|
+
hidden_states,
|
532
|
+
scales=None,
|
533
|
+
quantize=None, # dtype to quantize to
|
534
|
+
out=None,
|
535
|
+
):
|
536
|
+
bs, in_hidden_dim = hidden_states.shape
|
537
|
+
hidden_dim = in_hidden_dim // 2
|
538
|
+
|
539
|
+
if out is None:
|
540
|
+
out_hidden_states = torch.empty(
|
541
|
+
(bs, hidden_dim),
|
542
|
+
dtype=quantize or hidden_states.dtype,
|
543
|
+
device=hidden_states.device,
|
544
|
+
)
|
545
|
+
else:
|
546
|
+
assert out.shape == (bs, hidden_dim)
|
547
|
+
assert out.dtype == (quantize or hidden_states.dtype)
|
548
|
+
out_hidden_states = out
|
549
|
+
out_scales = None
|
550
|
+
static_scale = False
|
551
|
+
if quantize is not None:
|
552
|
+
if scales is None:
|
553
|
+
out_scales = torch.empty(
|
554
|
+
(bs,), dtype=torch.float32, device=hidden_states.device
|
555
|
+
)
|
556
|
+
else:
|
557
|
+
out_scales = scales
|
558
|
+
static_scale = True
|
559
|
+
|
560
|
+
max_warps = 16 if _is_hip else 32
|
561
|
+
config = {
|
562
|
+
# 8 ele per thread (not tuned)
|
563
|
+
"num_warps": max(
|
564
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
|
565
|
+
),
|
566
|
+
}
|
567
|
+
|
568
|
+
silu_and_mul_kernel[(bs,)](
|
569
|
+
out_hidden_states,
|
570
|
+
out_scales,
|
571
|
+
hidden_states,
|
572
|
+
quant_max=torch.finfo(quantize).max if quantize is not None else None,
|
573
|
+
static_scale=static_scale,
|
574
|
+
hidden_dim=hidden_dim,
|
575
|
+
BLOCK_SIZE=triton.next_power_of_2(hidden_dim),
|
576
|
+
**config,
|
577
|
+
)
|
578
|
+
|
579
|
+
if quantize is not None:
|
580
|
+
return out_hidden_states, out_scales
|
581
|
+
else:
|
582
|
+
return out_hidden_states, None
|
@@ -5,7 +5,11 @@ import torch
|
|
5
5
|
import torch.distributed as dist
|
6
6
|
|
7
7
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
8
|
-
from sglang.srt.utils import
|
8
|
+
from sglang.srt.utils import (
|
9
|
+
direct_register_custom_op,
|
10
|
+
is_flashinfer_available,
|
11
|
+
supports_custom_op,
|
12
|
+
)
|
9
13
|
|
10
14
|
logger = logging.getLogger(__name__)
|
11
15
|
|
@@ -196,6 +200,30 @@ def flashinfer_allreduce_residual_rmsnorm(
|
|
196
200
|
return norm_out, residual_out
|
197
201
|
|
198
202
|
|
203
|
+
def fake_flashinfer_allreduce_residual_rmsnorm(
|
204
|
+
input_tensor: torch.Tensor,
|
205
|
+
residual: torch.Tensor,
|
206
|
+
weight: torch.Tensor,
|
207
|
+
eps: float = 1e-6,
|
208
|
+
max_token_num: int = 2048,
|
209
|
+
use_oneshot: Optional[bool] = None,
|
210
|
+
trigger_completion_at_end: bool = False,
|
211
|
+
fp32_acc: bool = False,
|
212
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
213
|
+
residual_out = torch.empty_like(residual)
|
214
|
+
norm_out = torch.empty_like(input_tensor)
|
215
|
+
return norm_out, residual_out
|
216
|
+
|
217
|
+
|
218
|
+
if supports_custom_op():
|
219
|
+
direct_register_custom_op(
|
220
|
+
"flashinfer_allreduce_residual_rmsnorm",
|
221
|
+
flashinfer_allreduce_residual_rmsnorm,
|
222
|
+
mutates_args=["input_tensor", "residual", "weight"],
|
223
|
+
fake_impl=fake_flashinfer_allreduce_residual_rmsnorm,
|
224
|
+
)
|
225
|
+
|
226
|
+
|
199
227
|
def cleanup_flashinfer_workspace():
|
200
228
|
global _workspace_manager
|
201
229
|
if _workspace_manager is not None:
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -27,6 +27,7 @@ from sglang.srt.utils import (
|
|
27
27
|
is_cuda,
|
28
28
|
is_hip,
|
29
29
|
is_npu,
|
30
|
+
supports_custom_op,
|
30
31
|
)
|
31
32
|
|
32
33
|
_is_cuda = is_cuda()
|
@@ -202,8 +203,14 @@ class RMSNorm(CustomOp):
|
|
202
203
|
flashinfer_allreduce_residual_rmsnorm,
|
203
204
|
)
|
204
205
|
|
206
|
+
fused_op = (
|
207
|
+
torch.ops.sglang.flashinfer_allreduce_residual_rmsnorm
|
208
|
+
if supports_custom_op()
|
209
|
+
else flashinfer_allreduce_residual_rmsnorm
|
210
|
+
)
|
211
|
+
|
205
212
|
if get_tensor_model_parallel_world_size() > 1:
|
206
|
-
fused_result =
|
213
|
+
fused_result = fused_op(
|
207
214
|
input_tensor=x,
|
208
215
|
residual=residual,
|
209
216
|
weight=self.weight,
|
sglang/srt/layers/linear.py
CHANGED
@@ -110,6 +110,20 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
|
110
110
|
return param[shard_id], loaded_weight
|
111
111
|
|
112
112
|
|
113
|
+
def adjust_shard_offsets(shard_offsets, loaded_weight, dim):
|
114
|
+
actual_weight_size = loaded_weight.size(dim)
|
115
|
+
target_weight_size = shard_offsets[-1][-1] + shard_offsets[-1][-2]
|
116
|
+
if actual_weight_size != target_weight_size:
|
117
|
+
new_shard_offsets = []
|
118
|
+
new_offset = 0
|
119
|
+
for shard_id, shard_offset, shard_size in shard_offsets:
|
120
|
+
actual_shard_size = actual_weight_size * shard_size // target_weight_size
|
121
|
+
new_shard_offsets.append((shard_id, new_offset, actual_shard_size))
|
122
|
+
new_offset += actual_shard_size
|
123
|
+
return new_shard_offsets
|
124
|
+
return shard_offsets
|
125
|
+
|
126
|
+
|
113
127
|
class LinearBase(torch.nn.Module):
|
114
128
|
"""Base linear layer.
|
115
129
|
|
@@ -535,6 +549,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
535
549
|
packed_dim = getattr(param, "packed_dim", None)
|
536
550
|
|
537
551
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
552
|
+
if _is_cpu:
|
553
|
+
shard_offsets = adjust_shard_offsets(
|
554
|
+
shard_offsets, loaded_weight, output_dim
|
555
|
+
)
|
556
|
+
|
538
557
|
for shard_id, shard_offset, shard_size in shard_offsets:
|
539
558
|
# Special case for Quantization.
|
540
559
|
# If quantized, we need to adjust the offset and size to account
|
@@ -977,6 +996,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
977
996
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
978
997
|
|
979
998
|
packed_dim = getattr(param, "packed_dim", None)
|
999
|
+
if _is_cpu:
|
1000
|
+
shard_offsets = adjust_shard_offsets(
|
1001
|
+
shard_offsets, loaded_weight, output_dim
|
1002
|
+
)
|
1003
|
+
|
980
1004
|
for shard_id, shard_offset, shard_size in shard_offsets:
|
981
1005
|
# Special case for Quantized Weights.
|
982
1006
|
# If quantized, we need to adjust the offset and size to account
|
@@ -27,7 +27,7 @@ from sglang.srt.distributed import (
|
|
27
27
|
tensor_model_parallel_all_gather,
|
28
28
|
)
|
29
29
|
from sglang.srt.layers.dp_attention import (
|
30
|
-
|
30
|
+
DpPaddingMode,
|
31
31
|
attn_tp_all_gather,
|
32
32
|
attn_tp_all_gather_into_tensor,
|
33
33
|
dp_gather_replicate,
|
@@ -35,7 +35,9 @@ from sglang.srt.layers.dp_attention import (
|
|
35
35
|
get_attention_dp_rank,
|
36
36
|
get_attention_dp_size,
|
37
37
|
get_attention_tp_size,
|
38
|
+
get_global_dp_buffer,
|
38
39
|
get_local_attention_dp_size,
|
40
|
+
set_dp_buffer_len,
|
39
41
|
)
|
40
42
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
41
43
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -108,14 +110,12 @@ class LogitsMetadata:
|
|
108
110
|
# The start position of local hidden states.
|
109
111
|
dp_local_start_pos: Optional[torch.Tensor] = None
|
110
112
|
dp_local_num_tokens: Optional[torch.Tensor] = None
|
111
|
-
|
112
|
-
# Buffer to gather logits from all ranks.
|
113
|
-
forward_batch_gathered_buffer: Optional[torch.Tensor] = None
|
113
|
+
global_dp_buffer_len: Optional[int] = None
|
114
114
|
# Number of tokens to sample per DP rank
|
115
115
|
global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
|
116
116
|
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
117
117
|
# The gather mode for DP attention
|
118
|
-
dp_padding_mode: Optional[
|
118
|
+
dp_padding_mode: Optional[DpPaddingMode] = None
|
119
119
|
# for padding
|
120
120
|
padded_static_len: int = -1
|
121
121
|
|
@@ -164,11 +164,10 @@ class LogitsMetadata:
|
|
164
164
|
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
|
165
165
|
dp_local_start_pos=forward_batch.dp_local_start_pos,
|
166
166
|
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
|
167
|
-
|
168
|
-
forward_batch_gathered_buffer=forward_batch.gathered_buffer,
|
167
|
+
global_dp_buffer_len=forward_batch.global_dp_buffer_len,
|
169
168
|
global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
|
170
169
|
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
|
171
|
-
dp_padding_mode=
|
170
|
+
dp_padding_mode=DpPaddingMode.SUM_LEN,
|
172
171
|
)
|
173
172
|
|
174
173
|
def compute_dp_attention_metadata(self):
|
@@ -188,16 +187,15 @@ class LogitsMetadata:
|
|
188
187
|
|
189
188
|
if self.global_num_tokens_for_logprob_cpu is not None:
|
190
189
|
# create a smaller buffer to reduce peak memory usage
|
191
|
-
self.
|
192
|
-
(
|
193
|
-
sum(self.global_num_tokens_for_logprob_cpu),
|
194
|
-
self.gathered_buffer.shape[1],
|
195
|
-
),
|
196
|
-
dtype=self.gathered_buffer.dtype,
|
197
|
-
device=self.gathered_buffer.device,
|
198
|
-
)
|
190
|
+
self.global_dp_buffer_len = sum(self.global_num_tokens_for_logprob_cpu)
|
199
191
|
else:
|
200
|
-
self.
|
192
|
+
self.global_dp_buffer_len = self.global_dp_buffer_len
|
193
|
+
|
194
|
+
set_dp_buffer_len(
|
195
|
+
self.global_dp_buffer_len,
|
196
|
+
self.dp_local_num_tokens,
|
197
|
+
self.global_num_tokens_for_logprob_cpu,
|
198
|
+
)
|
201
199
|
|
202
200
|
|
203
201
|
class LogitsProcessor(nn.Module):
|
@@ -443,7 +441,7 @@ class LogitsProcessor(nn.Module):
|
|
443
441
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
444
442
|
logits_metadata.compute_dp_attention_metadata()
|
445
443
|
hidden_states, local_hidden_states = (
|
446
|
-
|
444
|
+
get_global_dp_buffer(),
|
447
445
|
hidden_states,
|
448
446
|
)
|
449
447
|
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
@@ -0,0 +1,31 @@
|
|
1
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
2
|
+
from sglang.srt.layers.moe.utils import (
|
3
|
+
DeepEPMode,
|
4
|
+
MoeA2ABackend,
|
5
|
+
MoeRunnerBackend,
|
6
|
+
get_deepep_config,
|
7
|
+
get_deepep_mode,
|
8
|
+
get_moe_a2a_backend,
|
9
|
+
get_moe_runner_backend,
|
10
|
+
get_tbo_token_distribution_threshold,
|
11
|
+
initialize_moe_config,
|
12
|
+
is_tbo_enabled,
|
13
|
+
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
14
|
+
should_use_flashinfer_trtllm_moe,
|
15
|
+
)
|
16
|
+
|
17
|
+
__all__ = [
|
18
|
+
"DeepEPMode",
|
19
|
+
"MoeA2ABackend",
|
20
|
+
"MoeRunnerConfig",
|
21
|
+
"MoeRunnerBackend",
|
22
|
+
"initialize_moe_config",
|
23
|
+
"get_moe_a2a_backend",
|
24
|
+
"get_moe_runner_backend",
|
25
|
+
"get_deepep_mode",
|
26
|
+
"should_use_flashinfer_trtllm_moe",
|
27
|
+
"should_use_flashinfer_cutlass_moe_fp4_allgather",
|
28
|
+
"is_tbo_enabled",
|
29
|
+
"get_tbo_token_distribution_threshold",
|
30
|
+
"get_deepep_config",
|
31
|
+
]
|
@@ -1,11 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from typing import TYPE_CHECKING, Optional
|
4
|
+
from typing import TYPE_CHECKING, Optional, Union
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
8
8
|
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
9
|
+
from sglang.srt.layers.moe import (
|
10
|
+
get_deepep_mode,
|
11
|
+
get_moe_a2a_backend,
|
12
|
+
get_moe_runner_backend,
|
13
|
+
should_use_flashinfer_trtllm_moe,
|
14
|
+
)
|
9
15
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
10
16
|
ep_gather,
|
11
17
|
ep_scatter,
|
@@ -16,14 +22,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
16
22
|
)
|
17
23
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
|
18
24
|
from sglang.srt.layers.moe.topk import TopKOutput
|
19
|
-
from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
|
20
25
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
21
26
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
22
|
-
from sglang.srt.layers.quantization.fp8 import
|
23
|
-
Fp8Config,
|
24
|
-
Fp8MoEMethod,
|
25
|
-
get_tile_tokens_dim,
|
26
|
-
)
|
27
|
+
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
27
28
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
28
29
|
is_fp8_fnuz,
|
29
30
|
sglang_per_token_group_quant_fp8,
|
@@ -51,7 +52,6 @@ if not (_is_npu or _is_hip):
|
|
51
52
|
if _use_aiter:
|
52
53
|
from aiter import ActivationType, QuantType
|
53
54
|
from aiter.fused_moe import fused_moe
|
54
|
-
from aiter.ops.shuffle import shuffle_weight
|
55
55
|
|
56
56
|
logger = logging.getLogger(__name__)
|
57
57
|
|
@@ -89,12 +89,11 @@ class EPMoE(FusedMoE):
|
|
89
89
|
num_fused_shared_experts: int = 0,
|
90
90
|
params_dtype: Optional[torch.dtype] = None,
|
91
91
|
quant_config: Optional[QuantizationConfig] = None,
|
92
|
-
tp_size: Optional[int] = None,
|
93
92
|
prefix: str = "",
|
94
93
|
activation: str = "silu",
|
95
94
|
routed_scaling_factor: Optional[float] = None,
|
96
|
-
|
97
|
-
|
95
|
+
gemm1_alpha: Optional[float] = None,
|
96
|
+
gemm1_clamp_limit: Optional[float] = None,
|
98
97
|
with_bias: bool = False,
|
99
98
|
):
|
100
99
|
super().__init__(
|
@@ -106,13 +105,12 @@ class EPMoE(FusedMoE):
|
|
106
105
|
top_k=top_k,
|
107
106
|
params_dtype=params_dtype,
|
108
107
|
quant_config=quant_config,
|
109
|
-
tp_size=tp_size,
|
110
108
|
prefix=prefix,
|
111
109
|
activation=activation,
|
112
110
|
# apply_router_weight_on_input=apply_router_weight_on_input,
|
113
111
|
routed_scaling_factor=routed_scaling_factor,
|
114
|
-
|
115
|
-
|
112
|
+
gemm1_alpha=gemm1_alpha,
|
113
|
+
gemm1_clamp_limit=gemm1_clamp_limit,
|
116
114
|
with_bias=with_bias,
|
117
115
|
)
|
118
116
|
|
@@ -163,7 +161,8 @@ class EPMoE(FusedMoE):
|
|
163
161
|
)
|
164
162
|
|
165
163
|
assert self.quant_method is not None
|
166
|
-
assert self.activation == "silu"
|
164
|
+
assert self.moe_runner_config.activation == "silu"
|
165
|
+
|
167
166
|
hidden_states_shape = hidden_states.shape
|
168
167
|
hidden_states_dtype = hidden_states.dtype
|
169
168
|
hidden_states_device = hidden_states.device
|
@@ -327,8 +326,8 @@ class EPMoE(FusedMoE):
|
|
327
326
|
m_max * self.start_expert_id,
|
328
327
|
BLOCK_SIZE=512,
|
329
328
|
)
|
330
|
-
if self.routed_scaling_factor is not None:
|
331
|
-
output *= self.routed_scaling_factor
|
329
|
+
if self.moe_runner_config.routed_scaling_factor is not None:
|
330
|
+
output *= self.moe_runner_config.routed_scaling_factor
|
332
331
|
return output
|
333
332
|
|
334
333
|
|
@@ -349,11 +348,9 @@ class DeepEPMoE(EPMoE):
|
|
349
348
|
num_fused_shared_experts: int = 0,
|
350
349
|
params_dtype: Optional[torch.dtype] = None,
|
351
350
|
quant_config: Optional[QuantizationConfig] = None,
|
352
|
-
tp_size: Optional[int] = None,
|
353
351
|
prefix: str = "",
|
354
352
|
activation: str = "silu",
|
355
353
|
routed_scaling_factor: Optional[float] = None,
|
356
|
-
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
|
357
354
|
):
|
358
355
|
super().__init__(
|
359
356
|
num_experts=num_experts,
|
@@ -364,12 +361,11 @@ class DeepEPMoE(EPMoE):
|
|
364
361
|
num_fused_shared_experts=num_fused_shared_experts,
|
365
362
|
params_dtype=params_dtype,
|
366
363
|
quant_config=quant_config,
|
367
|
-
tp_size=tp_size,
|
368
364
|
prefix=prefix,
|
369
365
|
activation=activation,
|
370
366
|
routed_scaling_factor=routed_scaling_factor,
|
371
367
|
)
|
372
|
-
self.deepep_mode =
|
368
|
+
self.deepep_mode = get_deepep_mode()
|
373
369
|
|
374
370
|
# TODO: move to the beginning of the file
|
375
371
|
from sglang.srt.distributed.parallel_state import get_tp_group
|
@@ -383,7 +379,7 @@ class DeepEPMoE(EPMoE):
|
|
383
379
|
num_local_experts=self.num_local_experts,
|
384
380
|
hidden_size=hidden_size,
|
385
381
|
params_dtype=params_dtype,
|
386
|
-
deepep_mode=deepep_mode,
|
382
|
+
deepep_mode=self.deepep_mode,
|
387
383
|
async_finish=True, # TODO
|
388
384
|
return_recv_hook=True,
|
389
385
|
)
|
@@ -458,15 +454,19 @@ class DeepEPMoE(EPMoE):
|
|
458
454
|
)
|
459
455
|
|
460
456
|
def moe_impl(self, dispatch_output: DispatchOutput):
|
457
|
+
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
|
458
|
+
|
461
459
|
if _use_aiter:
|
460
|
+
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
462
461
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
463
462
|
return self.forward_aiter(dispatch_output)
|
464
463
|
if _is_npu:
|
464
|
+
assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
|
465
465
|
return self.forward_npu(dispatch_output)
|
466
|
-
if
|
466
|
+
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
467
467
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
468
468
|
return self.forward_deepgemm_contiguous(dispatch_output)
|
469
|
-
elif
|
469
|
+
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
470
470
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
471
471
|
return self.forward_deepgemm_masked(dispatch_output)
|
472
472
|
else:
|
@@ -490,7 +490,7 @@ class DeepEPMoE(EPMoE):
|
|
490
490
|
|
491
491
|
def forward_aiter(
|
492
492
|
self,
|
493
|
-
dispatch_output: DeepEPNormalOutput,
|
493
|
+
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
|
494
494
|
):
|
495
495
|
hidden_states, topk_idx, topk_weights = (
|
496
496
|
dispatch_output.hidden_states,
|
@@ -516,7 +516,7 @@ class DeepEPMoE(EPMoE):
|
|
516
516
|
quant_type=QuantType.per_128x128,
|
517
517
|
activation=(
|
518
518
|
ActivationType.Silu
|
519
|
-
if self.activation == "silu"
|
519
|
+
if self.moe_runner_config.activation == "silu"
|
520
520
|
else ActivationType.Gelu
|
521
521
|
),
|
522
522
|
expert_mask=self.expert_mask,
|
@@ -531,7 +531,7 @@ class DeepEPMoE(EPMoE):
|
|
531
531
|
)
|
532
532
|
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
533
533
|
assert self.quant_method is not None
|
534
|
-
assert self.activation == "silu"
|
534
|
+
assert self.moe_runner_config.activation == "silu"
|
535
535
|
if num_recv_tokens_per_expert is None:
|
536
536
|
return hidden_states_fp8.bfloat16()
|
537
537
|
all_tokens = sum(num_recv_tokens_per_expert)
|
@@ -652,7 +652,7 @@ class DeepEPMoE(EPMoE):
|
|
652
652
|
):
|
653
653
|
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
|
654
654
|
assert self.quant_method is not None
|
655
|
-
assert self.activation == "silu"
|
655
|
+
assert self.moe_runner_config.activation == "silu"
|
656
656
|
|
657
657
|
# GroupGemm-0
|
658
658
|
num_groups, m, k = hidden_states_fp8[0].size()
|
@@ -735,7 +735,7 @@ class DeepEPMoE(EPMoE):
|
|
735
735
|
assert isinstance(dispatch_output, AscendDeepEPLLOutput)
|
736
736
|
hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
|
737
737
|
assert self.quant_method is not None
|
738
|
-
assert self.activation == "silu"
|
738
|
+
assert self.moe_runner_config.activation == "silu"
|
739
739
|
|
740
740
|
# NOTE: Ascend's Dispatch & Combine does not support FP16
|
741
741
|
output_dtype = torch.bfloat16
|
@@ -782,13 +782,17 @@ class DeepEPMoE(EPMoE):
|
|
782
782
|
return hidden_states
|
783
783
|
|
784
784
|
|
785
|
-
def get_moe_impl_class():
|
786
|
-
if
|
785
|
+
def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
786
|
+
if get_moe_a2a_backend().is_deepep():
|
787
787
|
return DeepEPMoE
|
788
788
|
|
789
789
|
# NEW: Direct FP4 detection (bypasses EP requirements)
|
790
790
|
# Check for FP4 quantization with TRTLLM flag, regardless of EP
|
791
|
-
if
|
791
|
+
if get_moe_runner_backend().is_flashinfer_trtllm():
|
792
|
+
# FlashInferFP4MoE must be paired with ModelOptNvFp4FusedMoEMethod.
|
793
|
+
# If UnquantizedFusedMoEMethod is detected, fall back to FusedMoE instead.
|
794
|
+
if quant_config is None:
|
795
|
+
return FusedMoE
|
792
796
|
try:
|
793
797
|
# Check the quantization argument directly
|
794
798
|
quantization = global_server_args_dict.get("quantization")
|
@@ -803,7 +807,7 @@ def get_moe_impl_class():
|
|
803
807
|
|
804
808
|
if should_use_flashinfer_trtllm_moe():
|
805
809
|
return FlashInferFusedMoE
|
806
|
-
if
|
810
|
+
if get_moe_runner_backend().is_flashinfer_cutlass():
|
807
811
|
return FusedMoE
|
808
812
|
if get_moe_expert_parallel_world_size() > 1:
|
809
813
|
return EPMoE
|
@@ -3,28 +3,22 @@ Torch-native implementation for FusedMoE. This is used for torch.compile.
|
|
3
3
|
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
|
4
4
|
"""
|
5
5
|
|
6
|
-
from typing import Callable, Optional
|
7
|
-
|
8
6
|
import torch
|
9
7
|
from torch.nn import functional as F
|
10
8
|
|
11
9
|
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
12
|
-
from sglang.srt.layers.moe.
|
10
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
11
|
+
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
13
12
|
|
14
13
|
|
15
14
|
def fused_moe_forward_native(
|
16
15
|
layer: torch.nn.Module,
|
17
16
|
x: torch.Tensor,
|
18
|
-
topk_output:
|
19
|
-
|
20
|
-
activation: str = "silu",
|
21
|
-
apply_router_weight_on_input: bool = False,
|
22
|
-
inplace: bool = True,
|
23
|
-
no_combine: bool = False,
|
24
|
-
routed_scaling_factor: Optional[float] = None,
|
17
|
+
topk_output: StandardTopKOutput,
|
18
|
+
moe_runner_config: MoeRunnerConfig,
|
25
19
|
) -> torch.Tensor:
|
26
20
|
|
27
|
-
if apply_router_weight_on_input:
|
21
|
+
if moe_runner_config.apply_router_weight_on_input:
|
28
22
|
raise NotImplementedError()
|
29
23
|
|
30
24
|
topk_weights, topk_ids, _ = topk_output
|
@@ -33,12 +27,12 @@ def fused_moe_forward_native(
|
|
33
27
|
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
34
28
|
w2_weights = layer.w2_weight[topk_ids]
|
35
29
|
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
36
|
-
if activation == "silu":
|
30
|
+
if moe_runner_config.activation == "silu":
|
37
31
|
x1 = F.silu(x1)
|
38
|
-
elif activation == "gelu":
|
32
|
+
elif moe_runner_config.activation == "gelu":
|
39
33
|
x1 = F.gelu(x1)
|
40
34
|
else:
|
41
|
-
raise ValueError(f"Unsupported activation: {activation=}")
|
35
|
+
raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
|
42
36
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
43
37
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
44
38
|
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
@@ -47,16 +41,11 @@ def fused_moe_forward_native(
|
|
47
41
|
def moe_forward_native(
|
48
42
|
layer: torch.nn.Module,
|
49
43
|
x: torch.Tensor,
|
50
|
-
topk_output:
|
51
|
-
|
52
|
-
activation: str = "silu",
|
53
|
-
apply_router_weight_on_input: bool = False,
|
54
|
-
inplace: bool = True,
|
55
|
-
no_combine: bool = False,
|
56
|
-
routed_scaling_factor: Optional[float] = None,
|
44
|
+
topk_output: StandardTopKOutput,
|
45
|
+
moe_runner_config: MoeRunnerConfig,
|
57
46
|
) -> torch.Tensor:
|
58
47
|
|
59
|
-
if apply_router_weight_on_input:
|
48
|
+
if moe_runner_config.apply_router_weight_on_input:
|
60
49
|
raise NotImplementedError()
|
61
50
|
|
62
51
|
topk_weights, topk_ids, _ = topk_output
|
@@ -72,12 +61,12 @@ def moe_forward_native(
|
|
72
61
|
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
73
62
|
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
74
63
|
|
75
|
-
if activation == "silu":
|
64
|
+
if moe_runner_config.activation == "silu":
|
76
65
|
act = SiluAndMul()
|
77
|
-
elif activation == "gelu":
|
66
|
+
elif moe_runner_config.activation == "gelu":
|
78
67
|
act = GeluAndMul()
|
79
68
|
else:
|
80
|
-
raise ValueError(f"Unsupported activation: {activation=}")
|
69
|
+
raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
|
81
70
|
|
82
71
|
outputs = []
|
83
72
|
start_idx = 0
|