sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -11,53 +11,41 @@ from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
|
11
11
|
ENABLE_JIT_DEEPGEMM,
|
12
12
|
)
|
13
13
|
from sglang.srt.server_args import ServerArgs
|
14
|
+
from sglang.srt.utils import get_bool_env_var
|
14
15
|
|
15
16
|
logger = logging.getLogger(__name__)
|
16
17
|
|
17
18
|
if ENABLE_JIT_DEEPGEMM:
|
18
19
|
import deep_gemm
|
20
|
+
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
19
21
|
|
20
|
-
|
21
|
-
from deep_gemm import fp8_gemm_nt as _gemm_nt_f8f8bf16_raw
|
22
|
-
from deep_gemm import (
|
23
|
-
fp8_m_grouped_gemm_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
24
|
-
)
|
25
|
-
from deep_gemm import (
|
26
|
-
m_grouped_fp8_gemm_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
27
|
-
)
|
28
|
-
else:
|
29
|
-
from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
|
30
|
-
from deep_gemm import get_col_major_tma_aligned_tensor
|
31
|
-
from deep_gemm import (
|
32
|
-
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
33
|
-
)
|
34
|
-
from deep_gemm import (
|
35
|
-
m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
36
|
-
)
|
22
|
+
_SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
|
37
23
|
|
38
24
|
|
25
|
+
# TODO maybe rename these functions
|
39
26
|
def grouped_gemm_nt_f8f8bf16_masked(
|
40
27
|
lhs: Tuple[torch.Tensor, torch.Tensor],
|
41
28
|
rhs: Tuple[torch.Tensor, torch.Tensor],
|
42
29
|
out: torch.Tensor,
|
43
30
|
masked_m: torch.Tensor,
|
44
31
|
expected_m: int,
|
45
|
-
recipe=None,
|
46
32
|
):
|
47
33
|
num_groups, _, k = lhs[0].shape
|
48
34
|
_, n, _ = rhs[0].shape
|
49
35
|
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
|
50
36
|
|
37
|
+
_sanity_check_input(lhs)
|
38
|
+
_sanity_check_input(rhs)
|
39
|
+
|
51
40
|
with compile_utils.deep_gemm_execution_hook(
|
52
41
|
expected_m, n, k, num_groups, kernel_type
|
53
42
|
):
|
54
|
-
|
43
|
+
deep_gemm.fp8_m_grouped_gemm_nt_masked(
|
55
44
|
lhs,
|
56
45
|
rhs,
|
57
46
|
out,
|
58
47
|
masked_m,
|
59
48
|
expected_m,
|
60
|
-
**({"recipe": recipe} if DEEPGEMM_BLACKWELL else {})
|
61
49
|
)
|
62
50
|
|
63
51
|
|
@@ -71,8 +59,11 @@ def grouped_gemm_nt_f8f8bf16_contig(
|
|
71
59
|
num_groups, n, _ = rhs[0].shape
|
72
60
|
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
73
61
|
|
62
|
+
_sanity_check_input(lhs)
|
63
|
+
_sanity_check_input(rhs)
|
64
|
+
|
74
65
|
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
75
|
-
|
66
|
+
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices)
|
76
67
|
|
77
68
|
|
78
69
|
def gemm_nt_f8f8bf16(
|
@@ -85,8 +76,11 @@ def gemm_nt_f8f8bf16(
|
|
85
76
|
num_groups = 1
|
86
77
|
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
|
87
78
|
|
79
|
+
_sanity_check_input(lhs)
|
80
|
+
_sanity_check_input(rhs)
|
81
|
+
|
88
82
|
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
89
|
-
|
83
|
+
deep_gemm.fp8_gemm_nt(
|
90
84
|
lhs,
|
91
85
|
rhs,
|
92
86
|
out,
|
@@ -108,3 +102,18 @@ def configure_deep_gemm_num_sms(num_sms):
|
|
108
102
|
yield
|
109
103
|
finally:
|
110
104
|
deep_gemm.set_num_sms(original_num_sms)
|
105
|
+
|
106
|
+
|
107
|
+
def _sanity_check_input(x_fp8: Tuple[torch.Tensor, torch.Tensor]):
|
108
|
+
if not _SANITY_CHECK:
|
109
|
+
return
|
110
|
+
|
111
|
+
x, x_scale = x_fp8
|
112
|
+
|
113
|
+
if x_scale.dtype == torch.int:
|
114
|
+
return
|
115
|
+
|
116
|
+
from sglang.srt.layers.quantization.fp8_utils import ceil_to_ue8m0
|
117
|
+
|
118
|
+
x_scale_ceil = ceil_to_ue8m0(x_scale)
|
119
|
+
assert torch.all(x_scale == x_scale_ceil), f"{x_scale=} {x_scale_ceil=}"
|
@@ -30,6 +30,9 @@ except ImportError:
|
|
30
30
|
|
31
31
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
32
32
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
33
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
34
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
35
|
+
from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
|
33
36
|
from sglang.srt.layers.parameter import (
|
34
37
|
BlockQuantScaleParameter,
|
35
38
|
ModelWeightParameter,
|
@@ -64,7 +67,6 @@ from sglang.srt.layers.quantization.utils import (
|
|
64
67
|
per_tensor_dequantize,
|
65
68
|
requantize_with_max_scale,
|
66
69
|
)
|
67
|
-
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
|
68
70
|
from sglang.srt.utils import (
|
69
71
|
cpu_has_amx_support,
|
70
72
|
get_bool_env_var,
|
@@ -72,6 +74,8 @@ from sglang.srt.utils import (
|
|
72
74
|
is_cuda,
|
73
75
|
is_hip,
|
74
76
|
is_npu,
|
77
|
+
is_sm90_supported,
|
78
|
+
is_sm100_supported,
|
75
79
|
log_info_on_rank0,
|
76
80
|
next_power_of_2,
|
77
81
|
print_warning_once,
|
@@ -80,7 +84,11 @@ from sglang.srt.utils import (
|
|
80
84
|
)
|
81
85
|
|
82
86
|
if TYPE_CHECKING:
|
83
|
-
from sglang.srt.layers.moe.
|
87
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
88
|
+
CombineInput,
|
89
|
+
DispatchOutput,
|
90
|
+
StandardDispatchOutput,
|
91
|
+
)
|
84
92
|
from sglang.srt.layers.moe.topk import TopKOutput
|
85
93
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
86
94
|
|
@@ -344,6 +352,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
344
352
|
_is_cpu_amx_available
|
345
353
|
), "Fp8LinearMethod on CPU requires that CPU has AMX support"
|
346
354
|
_amx_process_weight_after_loading(layer, ["weight"])
|
355
|
+
layer.weight_scale_inv = torch.nn.Parameter(
|
356
|
+
layer.weight_scale_inv.data, requires_grad=False
|
357
|
+
)
|
347
358
|
return
|
348
359
|
else:
|
349
360
|
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
@@ -526,7 +537,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
526
537
|
layer: Module,
|
527
538
|
num_experts: int,
|
528
539
|
hidden_size: int,
|
529
|
-
|
540
|
+
intermediate_size_per_partition: int,
|
530
541
|
params_dtype: torch.dtype,
|
531
542
|
**extra_weight_attrs,
|
532
543
|
):
|
@@ -542,18 +553,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
542
553
|
)
|
543
554
|
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
544
555
|
# Required by column parallel or enabling merged weights
|
545
|
-
if
|
556
|
+
if intermediate_size_per_partition % block_n != 0:
|
546
557
|
raise ValueError(
|
547
558
|
f"The output_size of gate's and up's weight = "
|
548
|
-
f"{
|
559
|
+
f"{intermediate_size_per_partition} is not divisible by "
|
549
560
|
f"weight quantization block_n = {block_n}."
|
550
561
|
)
|
551
562
|
if tp_size > 1:
|
552
563
|
# Required by row parallel
|
553
|
-
if
|
564
|
+
if intermediate_size_per_partition % block_k != 0:
|
554
565
|
raise ValueError(
|
555
566
|
f"The input_size of down's weight = "
|
556
|
-
f"{
|
567
|
+
f"{intermediate_size_per_partition} is not divisible by "
|
557
568
|
f"weight quantization block_k = {block_k}."
|
558
569
|
)
|
559
570
|
|
@@ -563,7 +574,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
563
574
|
w13_weight = torch.nn.Parameter(
|
564
575
|
torch.empty(
|
565
576
|
num_experts,
|
566
|
-
2 *
|
577
|
+
2 * intermediate_size_per_partition,
|
567
578
|
hidden_size // 8,
|
568
579
|
dtype=params_dtype,
|
569
580
|
),
|
@@ -571,20 +582,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
571
582
|
)
|
572
583
|
w2_weight = torch.nn.Parameter(
|
573
584
|
torch.empty(
|
574
|
-
num_experts,
|
585
|
+
num_experts,
|
586
|
+
hidden_size,
|
587
|
+
intermediate_size_per_partition // 8,
|
588
|
+
dtype=params_dtype,
|
575
589
|
),
|
576
590
|
requires_grad=False,
|
577
591
|
)
|
578
592
|
else:
|
579
593
|
w13_weight = torch.nn.Parameter(
|
580
594
|
torch.empty(
|
581
|
-
num_experts,
|
595
|
+
num_experts,
|
596
|
+
2 * intermediate_size_per_partition,
|
597
|
+
hidden_size,
|
598
|
+
dtype=params_dtype,
|
582
599
|
),
|
583
600
|
requires_grad=False,
|
584
601
|
)
|
585
602
|
w2_weight = torch.nn.Parameter(
|
586
603
|
torch.empty(
|
587
|
-
num_experts,
|
604
|
+
num_experts,
|
605
|
+
hidden_size,
|
606
|
+
intermediate_size_per_partition,
|
607
|
+
dtype=params_dtype,
|
588
608
|
),
|
589
609
|
requires_grad=False,
|
590
610
|
)
|
@@ -600,7 +620,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
600
620
|
w13_weight_scale = torch.nn.Parameter(
|
601
621
|
torch.ones(
|
602
622
|
num_experts,
|
603
|
-
2 * ((
|
623
|
+
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
|
604
624
|
(hidden_size + block_k - 1) // block_k,
|
605
625
|
dtype=torch.float32,
|
606
626
|
),
|
@@ -610,7 +630,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
610
630
|
torch.ones(
|
611
631
|
num_experts,
|
612
632
|
(hidden_size + block_n - 1) // block_n,
|
613
|
-
(
|
633
|
+
(intermediate_size_per_partition + block_k - 1) // block_k,
|
614
634
|
dtype=torch.float32,
|
615
635
|
),
|
616
636
|
requires_grad=False,
|
@@ -618,11 +638,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
618
638
|
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
619
639
|
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
620
640
|
assert self.quant_config.activation_scheme == "dynamic"
|
621
|
-
if
|
622
|
-
get_bool_env_var("SGLANG_CUTLASS_MOE")
|
623
|
-
and self.cutlass_fp8_supported
|
624
|
-
and (is_sm100_supported() or is_sm90_supported())
|
625
|
-
):
|
641
|
+
if self.use_cutlass_fused_experts_fp8:
|
626
642
|
self.ab_strides1 = torch.full(
|
627
643
|
(num_experts,),
|
628
644
|
hidden_size,
|
@@ -631,13 +647,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
631
647
|
)
|
632
648
|
self.c_strides1 = torch.full(
|
633
649
|
(num_experts,),
|
634
|
-
2 *
|
650
|
+
2 * intermediate_size_per_partition,
|
635
651
|
device=w13_weight.device,
|
636
652
|
dtype=torch.int64,
|
637
653
|
)
|
638
654
|
self.ab_strides2 = torch.full(
|
639
655
|
(num_experts,),
|
640
|
-
|
656
|
+
intermediate_size_per_partition,
|
641
657
|
device=w2_weight.device,
|
642
658
|
dtype=torch.int64,
|
643
659
|
)
|
@@ -690,7 +706,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
690
706
|
if _is_hip: # _use_aiter: TODO: add check back after triton kernel
|
691
707
|
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
692
708
|
w13_weight_scale1 = torch.nn.Parameter(
|
693
|
-
torch.ones(
|
709
|
+
torch.ones(
|
710
|
+
num_experts,
|
711
|
+
2 * intermediate_size_per_partition,
|
712
|
+
dtype=torch.float32,
|
713
|
+
),
|
694
714
|
requires_grad=False,
|
695
715
|
)
|
696
716
|
w2_weight_scale1 = torch.nn.Parameter(
|
@@ -983,14 +1003,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
983
1003
|
)
|
984
1004
|
torch.cuda.empty_cache()
|
985
1005
|
|
1006
|
+
def create_moe_runner(
|
1007
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
1008
|
+
):
|
1009
|
+
self.moe_runner_config = moe_runner_config
|
1010
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
1011
|
+
|
986
1012
|
def apply(
|
987
1013
|
self,
|
988
1014
|
layer: torch.nn.Module,
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
993
|
-
|
1015
|
+
dispatch_output: DispatchOutput,
|
1016
|
+
) -> CombineInput:
|
1017
|
+
|
1018
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
1019
|
+
|
1020
|
+
x = dispatch_output.hidden_states
|
1021
|
+
topk_output = dispatch_output.topk_output
|
1022
|
+
moe_runner_config = self.moe_runner_config
|
994
1023
|
|
995
1024
|
if use_intel_amx_backend(layer):
|
996
1025
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
@@ -1000,7 +1029,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1000
1029
|
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
1001
1030
|
)
|
1002
1031
|
|
1003
|
-
|
1032
|
+
output = torch.ops.sgl_kernel.fused_experts_cpu(
|
1004
1033
|
x,
|
1005
1034
|
layer.w13_weight,
|
1006
1035
|
layer.w2_weight,
|
@@ -1016,6 +1045,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1016
1045
|
None, # a2_scale
|
1017
1046
|
True, # is_vnni
|
1018
1047
|
)
|
1048
|
+
return StandardCombineInput(hidden_states=output)
|
1019
1049
|
|
1020
1050
|
if _is_hip:
|
1021
1051
|
ret = self.maybe_apply_hip_fused_experts(
|
@@ -1026,7 +1056,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1026
1056
|
moe_runner_config.no_combine,
|
1027
1057
|
)
|
1028
1058
|
if ret is not None:
|
1029
|
-
return ret
|
1059
|
+
return StandardCombineInput(hidden_states=ret)
|
1030
1060
|
|
1031
1061
|
if self.use_cutlass_fused_experts_fp8:
|
1032
1062
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
@@ -1055,17 +1085,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1055
1085
|
self.problem_sizes2,
|
1056
1086
|
use_fp8_blockscale=True,
|
1057
1087
|
)
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
layer.w13_weight,
|
1064
|
-
layer.w2_weight,
|
1065
|
-
topk_output=topk_output,
|
1066
|
-
moe_runner_config=moe_runner_config,
|
1088
|
+
return StandardCombineInput(hidden_states=output)
|
1089
|
+
|
1090
|
+
quant_info = TritonMoeQuantInfo(
|
1091
|
+
w13_weight=layer.w13_weight,
|
1092
|
+
w2_weight=layer.w2_weight,
|
1067
1093
|
use_fp8_w8a8=True,
|
1068
|
-
|
1094
|
+
w13_scale=(
|
1069
1095
|
layer.w13_weight_scale_inv
|
1070
1096
|
if self.block_quant
|
1071
1097
|
else layer.w13_weight_scale
|
@@ -1073,20 +1099,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1073
1099
|
w2_scale=(
|
1074
1100
|
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
1075
1101
|
),
|
1076
|
-
|
1102
|
+
a13_scale=layer.w13_input_scale,
|
1077
1103
|
a2_scale=layer.w2_input_scale,
|
1078
1104
|
block_shape=self.quant_config.weight_block_size,
|
1079
1105
|
)
|
1106
|
+
return self.runner.run(dispatch_output, quant_info)
|
1080
1107
|
|
1081
1108
|
def apply_with_router_logits(
|
1082
1109
|
self,
|
1083
1110
|
layer: torch.nn.Module,
|
1084
|
-
|
1085
|
-
topk_output: TopKOutput,
|
1086
|
-
moe_runner_config: MoeRunnerConfig,
|
1111
|
+
dispatch_output: StandardDispatchOutput,
|
1087
1112
|
) -> torch.Tensor:
|
1088
|
-
|
1089
|
-
|
1113
|
+
x = dispatch_output.hidden_states
|
1114
|
+
topk_output = dispatch_output.topk_output
|
1115
|
+
|
1116
|
+
activation = self.moe_runner_config.activation
|
1117
|
+
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
|
1090
1118
|
|
1091
1119
|
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
1092
1120
|
|
@@ -1107,10 +1135,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1107
1135
|
and topk_config.topk_group is not None
|
1108
1136
|
), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None"
|
1109
1137
|
|
1110
|
-
|
1111
|
-
|
1112
|
-
|
1113
|
-
correction_bias
|
1138
|
+
correction_bias = (
|
1139
|
+
None
|
1140
|
+
if topk_config.correction_bias is None
|
1141
|
+
else topk_config.correction_bias.to(x.dtype)
|
1142
|
+
)
|
1143
|
+
|
1114
1144
|
return trtllm_fp8_block_scale_moe(
|
1115
1145
|
routing_logits=router_logits.to(torch.float32),
|
1116
1146
|
routing_bias=correction_bias,
|
@@ -298,7 +298,7 @@ def _per_token_group_quant_8bit_raw(
|
|
298
298
|
)
|
299
299
|
|
300
300
|
if scale_ue8m0:
|
301
|
-
from deep_gemm
|
301
|
+
from deep_gemm import transform_sf_into_required_layout
|
302
302
|
|
303
303
|
assert group_size == 128
|
304
304
|
x_s = transform_sf_into_required_layout(
|
@@ -338,7 +338,7 @@ def _per_token_group_quant_8bit_fuse_silu_and_mul(
|
|
338
338
|
# scale_ue8m0=scale_ue8m0,
|
339
339
|
# )
|
340
340
|
|
341
|
-
from deep_gemm
|
341
|
+
from deep_gemm import transform_sf_into_required_layout
|
342
342
|
|
343
343
|
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
|
344
344
|
|
@@ -5,7 +5,7 @@ import torch
|
|
5
5
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
6
6
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
7
7
|
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
|
8
|
-
from sglang.srt.
|
8
|
+
from sglang.srt.utils import is_sm100_supported
|
9
9
|
|
10
10
|
try:
|
11
11
|
from vllm import _custom_ops as ops
|
@@ -45,7 +45,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
45
45
|
|
46
46
|
if _use_aiter:
|
47
47
|
import aiter
|
48
|
-
from aiter import gemm_a8w8_blockscale, get_hip_quant
|
48
|
+
from aiter import gemm_a8w8_blockscale, gemm_a8w8_bpreshuffle, get_hip_quant
|
49
49
|
|
50
50
|
aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)
|
51
51
|
|
@@ -248,11 +248,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
|
248
248
|
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
249
249
|
)
|
250
250
|
|
251
|
-
# NOTE(alcanderian): Useless when scale is packed to int32
|
252
|
-
# if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
|
253
|
-
# _check_ue8m0("x_scale", x_scale)
|
254
|
-
# _check_ue8m0("weight_scale", ws)
|
255
|
-
|
256
251
|
output = w8a8_block_fp8_matmul_deepgemm(
|
257
252
|
q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
|
258
253
|
)
|
@@ -261,11 +256,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
|
261
256
|
return output.to(dtype=output_dtype).view(*output_shape)
|
262
257
|
|
263
258
|
|
264
|
-
def _check_ue8m0(name, x):
|
265
|
-
x_ceil = ceil_to_ue8m0(x)
|
266
|
-
assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}"
|
267
|
-
|
268
|
-
|
269
259
|
def aiter_w8a8_block_fp8_linear(
|
270
260
|
input: torch.Tensor,
|
271
261
|
weight: torch.Tensor,
|
@@ -459,7 +449,7 @@ def _requant_weight_ue8m0(
|
|
459
449
|
import deep_gemm.utils.layout
|
460
450
|
|
461
451
|
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
|
462
|
-
sf = deep_gemm.utils.layout.
|
452
|
+
sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
|
463
453
|
return sf
|
464
454
|
|
465
455
|
out_s = _transform_scale(out_s, mn=out_w.shape[-2])
|
@@ -652,25 +642,49 @@ def apply_fp8_linear(
|
|
652
642
|
use_per_token_if_dynamic
|
653
643
|
and not per_tensor_weights
|
654
644
|
and not per_tensor_activations
|
655
|
-
and USE_ROWWISE_TORCH_SCALED_MM
|
645
|
+
and (USE_ROWWISE_TORCH_SCALED_MM or _use_aiter)
|
656
646
|
):
|
657
|
-
#
|
658
|
-
#
|
659
|
-
#
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
647
|
+
# into this sector means use dynamic per-token-per-channel quant
|
648
|
+
# per-token scale quant for input matrix, every row(one token) have one scale factor
|
649
|
+
# per-channel scale quant for weight matrix, every col(one channel) have one scale factor
|
650
|
+
if _use_aiter:
|
651
|
+
# gemm_a8w8_bpreshuffle(XQ, WQ, x_scale, w_scale, dtype)
|
652
|
+
# XQ -> input tensor, shape = (m, k)
|
653
|
+
# WQ -> weight tensor, shape = (n, k), with preshuffe get better perf
|
654
|
+
# x_scale -> input scale tensor, shape = (m, 1)
|
655
|
+
# w_scale -> weight scale tensor, shape = (n ,1)
|
656
|
+
# dtype -> output dtype
|
657
|
+
output = gemm_a8w8_bpreshuffle(
|
658
|
+
XQ=qinput,
|
659
|
+
WQ=weight,
|
660
|
+
x_scale=x_scale,
|
661
|
+
w_scale=weight_scale,
|
662
|
+
dtype=input.dtype,
|
663
|
+
)
|
664
|
+
if bias is not None:
|
665
|
+
output += bias
|
666
|
+
return _process_scaled_mm_output(
|
667
|
+
output, input_2d.shape, [*input.shape[:-1], weight.shape[0]]
|
668
|
+
)
|
669
|
+
else:
|
670
|
+
# For now validated on ROCm platform
|
671
|
+
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
672
|
+
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
|
673
|
+
# and ROCm 6.3, which only exists in torch 2.7 and above.
|
674
|
+
# For CUDA platform please validate if the
|
675
|
+
# torch._scaled_mm support rowwise scaled GEMM
|
676
|
+
# Fused GEMM_DQ Rowwise GEMM
|
677
|
+
output = torch._scaled_mm(
|
678
|
+
qinput,
|
679
|
+
weight,
|
680
|
+
out_dtype=input.dtype,
|
681
|
+
scale_a=x_scale,
|
682
|
+
scale_b=weight_scale.t(),
|
683
|
+
bias=bias,
|
684
|
+
)
|
685
|
+
return _process_scaled_mm_output(
|
686
|
+
output, input_2d.shape, output_shape
|
687
|
+
)
|
674
688
|
else:
|
675
689
|
# Fallback for channelwise case, where we use unfused DQ
|
676
690
|
# due to limitations with scaled_mm
|
@@ -45,7 +45,10 @@ from sglang.srt.layers.quantization.utils import (
|
|
45
45
|
|
46
46
|
if TYPE_CHECKING:
|
47
47
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
48
|
-
from sglang.srt.layers.moe.
|
48
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
49
|
+
StandardDispatchOutput,
|
50
|
+
CombineInput,
|
51
|
+
)
|
49
52
|
|
50
53
|
from sglang.srt.utils import is_cuda
|
51
54
|
|
@@ -838,19 +841,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
838
841
|
from sglang.srt.layers.linear import set_weight_attrs
|
839
842
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
840
843
|
|
841
|
-
|
842
|
-
|
843
|
-
self.is_k_full = (not self.quant_config.desc_act) or (
|
844
|
-
intermediate_size_per_partition == intermediate_size
|
845
|
-
)
|
844
|
+
self.is_k_full = (not self.quant_config.desc_act) or layer.moe_tp_size == 1
|
846
845
|
|
847
846
|
if self.quant_config.group_size != -1:
|
848
847
|
scales_size13 = hidden_size // self.quant_config.group_size
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
)
|
848
|
+
if self.quant_config.desc_act:
|
849
|
+
w2_scales_size = intermediate_size_per_partition
|
850
|
+
else:
|
851
|
+
w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
|
854
852
|
scales_size2 = w2_scales_size // self.quant_config.group_size
|
855
853
|
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
856
854
|
else:
|
@@ -1052,17 +1050,26 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
1052
1050
|
)
|
1053
1051
|
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
1054
1052
|
|
1053
|
+
def create_moe_runner(
|
1054
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
1055
|
+
):
|
1056
|
+
self.moe_runner_config = moe_runner_config
|
1057
|
+
|
1055
1058
|
def apply(
|
1056
1059
|
self,
|
1057
1060
|
layer: torch.nn.Module,
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1061
|
+
dispatch_output: StandardDispatchOutput,
|
1062
|
+
) -> CombineInput:
|
1063
|
+
|
1064
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
1065
|
+
|
1066
|
+
x = dispatch_output.hidden_states
|
1067
|
+
topk_output = dispatch_output.topk_output
|
1068
|
+
|
1062
1069
|
# Delay the import to avoid circular dependency
|
1063
1070
|
|
1064
1071
|
assert (
|
1065
|
-
moe_runner_config.activation == "silu"
|
1072
|
+
self.moe_runner_config.activation == "silu"
|
1066
1073
|
), "Only SiLU activation is supported."
|
1067
1074
|
|
1068
1075
|
# The input must currently be float16
|
@@ -1071,7 +1078,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
1071
1078
|
|
1072
1079
|
topk_weights, topk_ids, router_logits = topk_output
|
1073
1080
|
|
1074
|
-
|
1081
|
+
output = fused_marlin_moe(
|
1075
1082
|
x,
|
1076
1083
|
layer.w13_qweight,
|
1077
1084
|
layer.w2_qweight,
|
@@ -1087,3 +1094,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
1087
1094
|
num_bits=self.quant_config.weight_bits,
|
1088
1095
|
is_k_full=self.is_k_full,
|
1089
1096
|
).to(orig_dtype)
|
1097
|
+
return StandardCombineInput(hidden_states=output)
|