sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,448 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import functools
|
4
|
+
import os
|
5
|
+
from dataclasses import dataclass
|
6
|
+
from typing import TYPE_CHECKING, List, Optional
|
7
|
+
|
8
|
+
import torch
|
9
|
+
import triton.language as tl
|
10
|
+
|
11
|
+
from sglang.srt.layers.moe.moe_runner.base import (
|
12
|
+
MoeQuantInfo,
|
13
|
+
MoeRunnerConfig,
|
14
|
+
MoeRunnerCore,
|
15
|
+
RunnerInput,
|
16
|
+
RunnerOutput,
|
17
|
+
register_fused_func,
|
18
|
+
register_post_permute,
|
19
|
+
register_pre_permute,
|
20
|
+
)
|
21
|
+
from sglang.srt.layers.moe.utils import MoeRunnerBackend
|
22
|
+
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
|
23
|
+
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
26
|
+
StandardCombineInput,
|
27
|
+
StandardDispatchOutput,
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
_is_hip = is_hip()
|
32
|
+
_is_cuda = is_cuda()
|
33
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
34
|
+
_is_cpu = is_cpu()
|
35
|
+
_use_aiter = bool(int(os.getenv("SGLANG_MOE_USE_AITER", "0")))
|
36
|
+
_MOE_PADDING_SIZE = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
|
37
|
+
|
38
|
+
|
39
|
+
if _is_cuda:
|
40
|
+
from sgl_kernel import gelu_and_mul, silu_and_mul
|
41
|
+
elif _is_cpu and _is_cpu_amx_available:
|
42
|
+
pass
|
43
|
+
elif _is_hip:
|
44
|
+
from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul
|
45
|
+
|
46
|
+
if _use_aiter:
|
47
|
+
try:
|
48
|
+
from aiter import moe_sum
|
49
|
+
except ImportError:
|
50
|
+
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
51
|
+
|
52
|
+
|
53
|
+
if _is_cuda or _is_hip:
|
54
|
+
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
55
|
+
|
56
|
+
|
57
|
+
@dataclass
|
58
|
+
class TritonRunnerInput(RunnerInput):
|
59
|
+
|
60
|
+
hidden_states: torch.Tensor
|
61
|
+
topk_weights: torch.Tensor
|
62
|
+
topk_ids: torch.Tensor
|
63
|
+
sorted_token_ids: torch.Tensor
|
64
|
+
expert_ids: torch.Tensor
|
65
|
+
num_tokens_post_padded: torch.Tensor
|
66
|
+
|
67
|
+
@property
|
68
|
+
def runner_backend(self) -> MoeRunnerBackend:
|
69
|
+
return MoeRunnerBackend.TRITON
|
70
|
+
|
71
|
+
|
72
|
+
@dataclass
|
73
|
+
class TritonRunnerOutput(RunnerOutput):
|
74
|
+
|
75
|
+
hidden_states: torch.Tensor
|
76
|
+
|
77
|
+
@property
|
78
|
+
def runner_backend(self) -> MoeRunnerBackend:
|
79
|
+
return MoeRunnerBackend.TRITON
|
80
|
+
|
81
|
+
|
82
|
+
@dataclass
|
83
|
+
class TritonMoeQuantInfo(MoeQuantInfo):
|
84
|
+
w13_weight: torch.Tensor
|
85
|
+
w2_weight: torch.Tensor
|
86
|
+
b13: Optional[torch.Tensor] = None
|
87
|
+
b2: Optional[torch.Tensor] = None
|
88
|
+
use_fp8_w8a8: bool = False
|
89
|
+
use_int8_w8a8: bool = False
|
90
|
+
use_int8_w8a16: bool = False
|
91
|
+
use_int4_w4a16: bool = False
|
92
|
+
per_channel_quant: bool = False
|
93
|
+
w13_scale: Optional[torch.Tensor] = None
|
94
|
+
w2_scale: Optional[torch.Tensor] = None
|
95
|
+
w13_zp: Optional[torch.Tensor] = None
|
96
|
+
w2_zp: Optional[torch.Tensor] = None
|
97
|
+
a13_scale: Optional[torch.Tensor] = None
|
98
|
+
a2_scale: Optional[torch.Tensor] = None
|
99
|
+
block_shape: Optional[List[int]] = None
|
100
|
+
|
101
|
+
|
102
|
+
class TritonRunnerCore(MoeRunnerCore):
|
103
|
+
|
104
|
+
def __init__(self, config: MoeRunnerConfig):
|
105
|
+
super().__init__(config)
|
106
|
+
|
107
|
+
def run(
|
108
|
+
self,
|
109
|
+
runner_input: TritonRunnerInput,
|
110
|
+
quant_info: TritonMoeQuantInfo,
|
111
|
+
running_state: dict,
|
112
|
+
) -> TritonRunnerOutput:
|
113
|
+
|
114
|
+
# TODO: move these functions to the triton runner
|
115
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
116
|
+
invoke_fused_moe_kernel,
|
117
|
+
moe_sum_reduce_torch_compile,
|
118
|
+
moe_sum_reduce_triton,
|
119
|
+
swiglu_with_alpha_and_limit,
|
120
|
+
)
|
121
|
+
|
122
|
+
hidden_states = runner_input.hidden_states
|
123
|
+
topk_weights = runner_input.topk_weights
|
124
|
+
topk_ids = runner_input.topk_ids
|
125
|
+
sorted_token_ids = runner_input.sorted_token_ids
|
126
|
+
expert_ids = runner_input.expert_ids
|
127
|
+
num_tokens_post_padded = runner_input.num_tokens_post_padded
|
128
|
+
|
129
|
+
w13 = quant_info.w13_weight
|
130
|
+
w2 = quant_info.w2_weight
|
131
|
+
b13 = quant_info.b13
|
132
|
+
b2 = quant_info.b2
|
133
|
+
a13_scale = quant_info.a13_scale
|
134
|
+
a2_scale = quant_info.a2_scale
|
135
|
+
w13_scale = quant_info.w13_scale
|
136
|
+
w2_scale = quant_info.w2_scale
|
137
|
+
w13_zp = quant_info.w13_zp
|
138
|
+
w2_zp = quant_info.w2_zp
|
139
|
+
block_shape = quant_info.block_shape
|
140
|
+
per_channel_quant = quant_info.per_channel_quant
|
141
|
+
use_fp8_w8a8 = quant_info.use_fp8_w8a8
|
142
|
+
use_int8_w8a8 = quant_info.use_int8_w8a8
|
143
|
+
use_int8_w8a16 = quant_info.use_int8_w8a16
|
144
|
+
use_int4_w4a16 = quant_info.use_int4_w4a16
|
145
|
+
|
146
|
+
activation = self.config.activation
|
147
|
+
no_combine = self.config.no_combine
|
148
|
+
inplace = self.config.inplace
|
149
|
+
gemm1_alpha = self.config.gemm1_alpha
|
150
|
+
gemm1_limit = self.config.gemm1_clamp_limit
|
151
|
+
routed_scaling_factor = self.config.routed_scaling_factor
|
152
|
+
apply_router_weight_on_input = self.config.apply_router_weight_on_input
|
153
|
+
|
154
|
+
M = hidden_states.shape[0]
|
155
|
+
E, N, _ = w13.shape
|
156
|
+
compute_type = (
|
157
|
+
tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
158
|
+
)
|
159
|
+
|
160
|
+
intermediate_cache1 = torch.empty(
|
161
|
+
(M, topk_ids.shape[1], N),
|
162
|
+
device=hidden_states.device,
|
163
|
+
dtype=hidden_states.dtype,
|
164
|
+
)
|
165
|
+
|
166
|
+
invoke_fused_moe_kernel(
|
167
|
+
hidden_states,
|
168
|
+
w13,
|
169
|
+
b13,
|
170
|
+
intermediate_cache1,
|
171
|
+
a13_scale,
|
172
|
+
w13_scale,
|
173
|
+
w13_zp,
|
174
|
+
topk_weights,
|
175
|
+
topk_ids,
|
176
|
+
sorted_token_ids,
|
177
|
+
expert_ids,
|
178
|
+
num_tokens_post_padded,
|
179
|
+
apply_router_weight_on_input,
|
180
|
+
topk_ids.shape[1],
|
181
|
+
running_state["config"],
|
182
|
+
compute_type=compute_type,
|
183
|
+
use_fp8_w8a8=use_fp8_w8a8,
|
184
|
+
use_int8_w8a8=use_int8_w8a8,
|
185
|
+
use_int8_w8a16=use_int8_w8a16,
|
186
|
+
use_int4_w4a16=use_int4_w4a16,
|
187
|
+
per_channel_quant=per_channel_quant,
|
188
|
+
block_shape=block_shape,
|
189
|
+
)
|
190
|
+
|
191
|
+
intermediate_cache2 = torch.empty(
|
192
|
+
(M * topk_ids.shape[1], N // 2),
|
193
|
+
device=hidden_states.device,
|
194
|
+
dtype=hidden_states.dtype,
|
195
|
+
)
|
196
|
+
|
197
|
+
if activation == "silu":
|
198
|
+
if gemm1_alpha is not None:
|
199
|
+
assert gemm1_limit is not None
|
200
|
+
intermediate_cache2 = swiglu_with_alpha_and_limit(
|
201
|
+
intermediate_cache1.view(-1, N),
|
202
|
+
gemm1_alpha,
|
203
|
+
gemm1_limit,
|
204
|
+
)
|
205
|
+
elif _is_cuda:
|
206
|
+
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
207
|
+
else:
|
208
|
+
vllm_ops.silu_and_mul(
|
209
|
+
intermediate_cache2, intermediate_cache1.view(-1, N)
|
210
|
+
)
|
211
|
+
elif activation == "gelu":
|
212
|
+
assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu"
|
213
|
+
assert gemm1_limit is None, "gemm1_limit is not supported for gelu"
|
214
|
+
if _is_cuda:
|
215
|
+
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
216
|
+
else:
|
217
|
+
vllm_ops.gelu_and_mul(
|
218
|
+
intermediate_cache2, intermediate_cache1.view(-1, N)
|
219
|
+
)
|
220
|
+
else:
|
221
|
+
raise ValueError(f"Unsupported activation: {activation=}")
|
222
|
+
|
223
|
+
intermediate_cache3 = torch.empty(
|
224
|
+
(M, topk_ids.shape[1], w2.shape[1]),
|
225
|
+
device=hidden_states.device,
|
226
|
+
dtype=hidden_states.dtype,
|
227
|
+
)
|
228
|
+
|
229
|
+
if no_combine:
|
230
|
+
assert not inplace
|
231
|
+
out_hidden_states = torch.empty(
|
232
|
+
(M, topk_ids.shape[1], w2.shape[1]),
|
233
|
+
device=hidden_states.device,
|
234
|
+
dtype=hidden_states.dtype,
|
235
|
+
)
|
236
|
+
elif inplace:
|
237
|
+
out_hidden_states = hidden_states
|
238
|
+
else:
|
239
|
+
out_hidden_states = torch.empty_like(hidden_states)
|
240
|
+
|
241
|
+
invoke_fused_moe_kernel(
|
242
|
+
intermediate_cache2,
|
243
|
+
w2,
|
244
|
+
b2,
|
245
|
+
(
|
246
|
+
intermediate_cache3
|
247
|
+
if not no_combine and topk_ids.shape[1] != 1
|
248
|
+
else out_hidden_states.unsqueeze(0)
|
249
|
+
),
|
250
|
+
a2_scale,
|
251
|
+
w2_scale,
|
252
|
+
w2_zp,
|
253
|
+
topk_weights,
|
254
|
+
topk_ids,
|
255
|
+
sorted_token_ids,
|
256
|
+
expert_ids,
|
257
|
+
num_tokens_post_padded,
|
258
|
+
not apply_router_weight_on_input,
|
259
|
+
1,
|
260
|
+
running_state["config"],
|
261
|
+
compute_type=compute_type,
|
262
|
+
use_fp8_w8a8=use_fp8_w8a8,
|
263
|
+
use_int8_w8a8=use_int8_w8a8,
|
264
|
+
use_int8_w8a16=use_int8_w8a16,
|
265
|
+
use_int4_w4a16=use_int4_w4a16,
|
266
|
+
per_channel_quant=per_channel_quant,
|
267
|
+
block_shape=block_shape,
|
268
|
+
)
|
269
|
+
|
270
|
+
if routed_scaling_factor is None:
|
271
|
+
routed_scaling_factor = 1.0
|
272
|
+
|
273
|
+
if no_combine:
|
274
|
+
pass
|
275
|
+
elif _is_cuda:
|
276
|
+
if topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0:
|
277
|
+
pass # we write directly into out_hidden_states
|
278
|
+
elif topk_ids.shape[1] == 2 and routed_scaling_factor == 1.0:
|
279
|
+
torch.add(
|
280
|
+
intermediate_cache3[:, 0],
|
281
|
+
intermediate_cache3[:, 1],
|
282
|
+
out=out_hidden_states,
|
283
|
+
).squeeze(dim=1)
|
284
|
+
else:
|
285
|
+
# According to micro benchmark results, torch.compile can get better performance for small token.
|
286
|
+
if M <= 32:
|
287
|
+
moe_sum_reduce_torch_compile(
|
288
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
289
|
+
out_hidden_states,
|
290
|
+
routed_scaling_factor,
|
291
|
+
)
|
292
|
+
else:
|
293
|
+
moe_sum_reduce_triton(
|
294
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
295
|
+
out_hidden_states,
|
296
|
+
routed_scaling_factor,
|
297
|
+
)
|
298
|
+
elif _is_hip:
|
299
|
+
if _use_aiter:
|
300
|
+
moe_sum(
|
301
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
302
|
+
out_hidden_states,
|
303
|
+
)
|
304
|
+
else:
|
305
|
+
vllm_ops.moe_sum(
|
306
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
307
|
+
out_hidden_states,
|
308
|
+
)
|
309
|
+
else:
|
310
|
+
vllm_ops.moe_sum(
|
311
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
312
|
+
out_hidden_states,
|
313
|
+
)
|
314
|
+
|
315
|
+
return TritonRunnerOutput(
|
316
|
+
hidden_states=out_hidden_states,
|
317
|
+
)
|
318
|
+
|
319
|
+
@property
|
320
|
+
def runner_backend(self) -> MoeRunnerBackend:
|
321
|
+
return MoeRunnerBackend.TRITON
|
322
|
+
|
323
|
+
|
324
|
+
@register_fused_func("none", "triton")
|
325
|
+
def fused_experts_none_to_triton(
|
326
|
+
dispatch_output: StandardDispatchOutput,
|
327
|
+
quant_info: TritonMoeQuantInfo,
|
328
|
+
runner_config: MoeRunnerConfig,
|
329
|
+
) -> StandardCombineInput:
|
330
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
331
|
+
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
|
332
|
+
|
333
|
+
output = fused_experts(
|
334
|
+
hidden_states=dispatch_output.hidden_states,
|
335
|
+
w1=quant_info.w13_weight,
|
336
|
+
w2=quant_info.w2_weight,
|
337
|
+
topk_output=dispatch_output.topk_output,
|
338
|
+
moe_runner_config=runner_config,
|
339
|
+
b1=quant_info.b13,
|
340
|
+
b2=quant_info.b2,
|
341
|
+
use_fp8_w8a8=quant_info.use_fp8_w8a8,
|
342
|
+
use_int8_w8a8=quant_info.use_int8_w8a8,
|
343
|
+
use_int8_w8a16=quant_info.use_int8_w8a16,
|
344
|
+
use_int4_w4a16=quant_info.use_int4_w4a16,
|
345
|
+
per_channel_quant=quant_info.per_channel_quant,
|
346
|
+
w1_scale=quant_info.w13_scale,
|
347
|
+
w2_scale=quant_info.w2_scale,
|
348
|
+
w1_zp=quant_info.w13_zp,
|
349
|
+
w2_zp=quant_info.w2_zp,
|
350
|
+
a1_scale=quant_info.a13_scale,
|
351
|
+
a2_scale=quant_info.a2_scale,
|
352
|
+
block_shape=quant_info.block_shape,
|
353
|
+
)
|
354
|
+
|
355
|
+
return StandardCombineInput(
|
356
|
+
hidden_states=output,
|
357
|
+
)
|
358
|
+
|
359
|
+
|
360
|
+
@register_pre_permute("standard", "triton")
|
361
|
+
def pre_permute_standard_to_triton(
|
362
|
+
dispatch_output: StandardDispatchOutput,
|
363
|
+
quant_info: TritonMoeQuantInfo,
|
364
|
+
runner_config: MoeRunnerConfig,
|
365
|
+
running_state: dict,
|
366
|
+
) -> TritonRunnerInput:
|
367
|
+
|
368
|
+
# NOTE: this is dead code as a fused func for standard format is registered.
|
369
|
+
# This is left here for testing and examples.
|
370
|
+
|
371
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
372
|
+
get_config_dtype_str,
|
373
|
+
moe_align_block_size,
|
374
|
+
try_get_optimal_moe_config,
|
375
|
+
)
|
376
|
+
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
377
|
+
|
378
|
+
hidden_states, topk_output = dispatch_output
|
379
|
+
|
380
|
+
assert TopKOutputChecker.format_is_standard(topk_output)
|
381
|
+
|
382
|
+
num_tokens = hidden_states.shape[0]
|
383
|
+
num_local_experts = runner_config.num_local_experts
|
384
|
+
|
385
|
+
if (
|
386
|
+
not (quant_info.use_fp8_w8a8 or quant_info.use_int8_w8a8)
|
387
|
+
or quant_info.block_shape is not None
|
388
|
+
or _use_aiter
|
389
|
+
):
|
390
|
+
padding_size = 0
|
391
|
+
else:
|
392
|
+
padding_size = _MOE_PADDING_SIZE
|
393
|
+
|
394
|
+
config_dtype = get_config_dtype_str(
|
395
|
+
use_fp8_w8a8=quant_info.use_fp8_w8a8,
|
396
|
+
use_int8_w8a8=quant_info.use_int8_w8a8,
|
397
|
+
use_int8_w8a16=quant_info.use_int8_w8a16,
|
398
|
+
use_int4_w4a16=quant_info.use_int4_w4a16,
|
399
|
+
dtype=hidden_states.dtype,
|
400
|
+
)
|
401
|
+
|
402
|
+
get_config_func = functools.partial(
|
403
|
+
try_get_optimal_moe_config,
|
404
|
+
quant_info.w13_weight.shape,
|
405
|
+
(
|
406
|
+
num_local_experts,
|
407
|
+
quant_info.w2_weight.shape[1],
|
408
|
+
quant_info.w2_weight.shape[2] - padding_size,
|
409
|
+
),
|
410
|
+
topk_output.topk_ids.shape[1],
|
411
|
+
config_dtype,
|
412
|
+
block_shape=quant_info.block_shape,
|
413
|
+
)
|
414
|
+
|
415
|
+
config = get_config_func(num_tokens)
|
416
|
+
|
417
|
+
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
418
|
+
topk_output.topk_ids, config["BLOCK_SIZE_M"], num_local_experts
|
419
|
+
)
|
420
|
+
|
421
|
+
running_state["config"] = config
|
422
|
+
|
423
|
+
return TritonRunnerInput(
|
424
|
+
hidden_states=hidden_states,
|
425
|
+
topk_weights=topk_output.topk_weights,
|
426
|
+
topk_ids=topk_output.topk_ids,
|
427
|
+
sorted_token_ids=sorted_token_ids,
|
428
|
+
expert_ids=expert_ids,
|
429
|
+
num_tokens_post_padded=num_tokens_post_padded,
|
430
|
+
)
|
431
|
+
|
432
|
+
|
433
|
+
@register_post_permute("triton", "standard")
|
434
|
+
def post_permute_triton_to_standard(
|
435
|
+
runner_output: TritonRunnerOutput,
|
436
|
+
quant_info: TritonMoeQuantInfo,
|
437
|
+
runner_config: MoeRunnerConfig,
|
438
|
+
running_state: dict,
|
439
|
+
) -> StandardCombineInput:
|
440
|
+
|
441
|
+
# NOTE: this is dead code as a fused func for standard format is registered.
|
442
|
+
# This is left here for testing and examples.
|
443
|
+
|
444
|
+
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
|
445
|
+
|
446
|
+
return StandardCombineInput(
|
447
|
+
hidden_states=runner_output.hidden_states,
|
448
|
+
)
|
@@ -1,29 +1,41 @@
|
|
1
|
-
from sglang.srt.layers.moe.token_dispatcher.
|
1
|
+
from sglang.srt.layers.moe.token_dispatcher.base import (
|
2
2
|
BaseDispatcher,
|
3
3
|
BaseDispatcherConfig,
|
4
|
+
CombineInput,
|
5
|
+
CombineInputChecker,
|
6
|
+
CombineInputFormat,
|
4
7
|
DispatchOutput,
|
5
8
|
DispatchOutputChecker,
|
6
9
|
DispatchOutputFormat,
|
7
10
|
)
|
8
11
|
from sglang.srt.layers.moe.token_dispatcher.deepep import (
|
9
|
-
AscendDeepEPLLOutput,
|
10
12
|
DeepEPConfig,
|
11
13
|
DeepEPDispatcher,
|
14
|
+
DeepEPLLCombineInput,
|
12
15
|
DeepEPLLOutput,
|
16
|
+
DeepEPNormalCombineInput,
|
13
17
|
DeepEPNormalOutput,
|
14
18
|
)
|
15
|
-
from sglang.srt.layers.moe.token_dispatcher.standard import
|
19
|
+
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
20
|
+
StandardCombineInput,
|
21
|
+
StandardDispatchOutput,
|
22
|
+
)
|
16
23
|
|
17
24
|
__all__ = [
|
18
|
-
"AscendDeepEPLLOutput",
|
19
25
|
"BaseDispatcher",
|
20
26
|
"BaseDispatcherConfig",
|
27
|
+
"CombineInput",
|
28
|
+
"CombineInputChecker",
|
29
|
+
"CombineInputFormat",
|
21
30
|
"DispatchOutput",
|
22
31
|
"DispatchOutputFormat",
|
23
32
|
"DispatchOutputChecker",
|
24
33
|
"StandardDispatchOutput",
|
34
|
+
"StandardCombineInput",
|
25
35
|
"DeepEPConfig",
|
26
36
|
"DeepEPDispatcher",
|
27
37
|
"DeepEPNormalOutput",
|
28
38
|
"DeepEPLLOutput",
|
39
|
+
"DeepEPLLCombineInput",
|
40
|
+
"DeepEPNormalCombineInput",
|
29
41
|
]
|
@@ -1,18 +1,23 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
|
-
from enum import Enum
|
4
|
+
from enum import Enum
|
5
5
|
from typing import TYPE_CHECKING, Protocol, TypeGuard, Union, runtime_checkable
|
6
6
|
|
7
7
|
import torch
|
8
8
|
|
9
9
|
if TYPE_CHECKING:
|
10
10
|
from sglang.srt.layers.moe.token_dispatcher import (
|
11
|
-
|
11
|
+
DeepEPLLCombineInput,
|
12
12
|
DeepEPLLOutput,
|
13
|
+
DeepEPNormalCombineInput,
|
13
14
|
DeepEPNormalOutput,
|
15
|
+
StandardCombineInput,
|
14
16
|
StandardDispatchOutput,
|
15
17
|
)
|
18
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
19
|
+
|
20
|
+
# ------------------------------ Dispatch Output -------------------------------------
|
16
21
|
|
17
22
|
|
18
23
|
class DispatchOutputChecker:
|
@@ -41,19 +46,12 @@ class DispatchOutputChecker:
|
|
41
46
|
) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]:
|
42
47
|
return dispatch_output.format.is_deepep()
|
43
48
|
|
44
|
-
@staticmethod
|
45
|
-
def format_is_ascent_ll(
|
46
|
-
dispatch_output: DispatchOutput,
|
47
|
-
) -> TypeGuard[AscendDeepEPLLOutput]:
|
48
|
-
return dispatch_output.format.is_ascent_ll()
|
49
|
-
|
50
49
|
|
51
50
|
class DispatchOutputFormat(Enum):
|
52
51
|
|
53
|
-
STANDARD =
|
54
|
-
DEEPEP_NORMAL =
|
55
|
-
DEEPEP_LL =
|
56
|
-
ASCENT_LL = auto()
|
52
|
+
STANDARD = "standard"
|
53
|
+
DEEPEP_NORMAL = "deepep_normal"
|
54
|
+
DEEPEP_LL = "deepep_ll"
|
57
55
|
|
58
56
|
def is_standard(self) -> bool:
|
59
57
|
return self == DispatchOutputFormat.STANDARD
|
@@ -70,18 +68,68 @@ class DispatchOutputFormat(Enum):
|
|
70
68
|
DispatchOutputFormat.DEEPEP_LL,
|
71
69
|
]
|
72
70
|
|
73
|
-
def is_ascent_ll(self) -> bool:
|
74
|
-
return self == DispatchOutputFormat.ASCENT_LL
|
75
|
-
|
76
71
|
|
77
72
|
@runtime_checkable
|
78
73
|
class DispatchOutput(Protocol):
|
79
74
|
"""Protocol for dispatch outputs in different formats."""
|
80
75
|
|
76
|
+
# TODO: add hidden_states to the protocol
|
77
|
+
|
81
78
|
@property
|
82
79
|
def format(self) -> DispatchOutputFormat: ...
|
83
80
|
|
84
81
|
|
82
|
+
# ------------------------------ Combine Input -------------------------------------
|
83
|
+
|
84
|
+
|
85
|
+
class CombineInputChecker:
|
86
|
+
@staticmethod
|
87
|
+
def format_is_standard(
|
88
|
+
combine_input: CombineInput,
|
89
|
+
) -> TypeGuard[StandardCombineInput]:
|
90
|
+
return combine_input.format == CombineInputFormat.STANDARD
|
91
|
+
|
92
|
+
@staticmethod
|
93
|
+
def format_is_deepep_normal(
|
94
|
+
combine_input: CombineInput,
|
95
|
+
) -> TypeGuard[DeepEPNormalCombineInput]:
|
96
|
+
return combine_input.format == CombineInputFormat.DEEPEP_NORMAL
|
97
|
+
|
98
|
+
@staticmethod
|
99
|
+
def format_is_deepep_ll(
|
100
|
+
combine_input: CombineInput,
|
101
|
+
) -> TypeGuard[DeepEPLLCombineInput]:
|
102
|
+
return combine_input.format == CombineInputFormat.DEEPEP_LL
|
103
|
+
|
104
|
+
@staticmethod
|
105
|
+
def format_is_deepep(
|
106
|
+
combine_input: CombineInput,
|
107
|
+
) -> TypeGuard[Union[DeepEPNormalCombineInput, DeepEPLLCombineInput]]:
|
108
|
+
return combine_input.format in [
|
109
|
+
CombineInputFormat.DEEPEP_NORMAL,
|
110
|
+
CombineInputFormat.DEEPEP_LL,
|
111
|
+
]
|
112
|
+
|
113
|
+
|
114
|
+
class CombineInputFormat(Enum):
|
115
|
+
STANDARD = "standard"
|
116
|
+
DEEPEP_NORMAL = "deepep_normal"
|
117
|
+
DEEPEP_LL = "deepep_ll"
|
118
|
+
|
119
|
+
|
120
|
+
@runtime_checkable
|
121
|
+
class CombineInput(Protocol):
|
122
|
+
"""Protocol for combine inputs in different formats."""
|
123
|
+
|
124
|
+
# TODO: add hidden_states to the protocol
|
125
|
+
|
126
|
+
@property
|
127
|
+
def format(self) -> CombineInputFormat: ...
|
128
|
+
|
129
|
+
|
130
|
+
# ------------------------------ Base Dispatcher -------------------------------------
|
131
|
+
|
132
|
+
|
85
133
|
class BaseDispatcherConfig(ABC):
|
86
134
|
"""Base class for dispatcher configs."""
|
87
135
|
|
@@ -92,9 +140,11 @@ class BaseDispatcher(ABC):
|
|
92
140
|
"""Base class for dispatchers."""
|
93
141
|
|
94
142
|
@abstractmethod
|
95
|
-
def dispatch(
|
143
|
+
def dispatch(
|
144
|
+
self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs
|
145
|
+
) -> DispatchOutput:
|
96
146
|
pass
|
97
147
|
|
98
148
|
@abstractmethod
|
99
|
-
def combine(self,
|
149
|
+
def combine(self, combine_input: CombineInput, **kwargs) -> torch.Tensor:
|
100
150
|
pass
|