sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -17,12 +17,19 @@ from sglang.srt.layers.quantization.base_config import (
|
|
17
17
|
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
18
18
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
19
19
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
20
|
-
from sglang.srt.utils import set_weight_attrs
|
20
|
+
from sglang.srt.utils import is_npu, set_weight_attrs
|
21
|
+
|
22
|
+
_is_npu = is_npu()
|
23
|
+
if not _is_npu:
|
24
|
+
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
21
25
|
|
22
26
|
if TYPE_CHECKING:
|
23
27
|
from sglang.srt.layers.moe import MoeRunnerConfig
|
24
28
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
25
|
-
from sglang.srt.layers.moe.
|
29
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
30
|
+
CombineInput,
|
31
|
+
StandardDispatchOutput,
|
32
|
+
)
|
26
33
|
|
27
34
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
28
35
|
|
@@ -133,7 +140,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
133
140
|
layer: EPMoE,
|
134
141
|
num_experts: int,
|
135
142
|
hidden_size: int,
|
136
|
-
|
143
|
+
intermediate_size_per_partition: int,
|
137
144
|
params_dtype: torch.dtype,
|
138
145
|
**extra_weight_attrs,
|
139
146
|
):
|
@@ -145,7 +152,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
145
152
|
w13_weight = torch.nn.Parameter(
|
146
153
|
torch.empty(
|
147
154
|
num_experts,
|
148
|
-
|
155
|
+
intermediate_size_per_partition * 2,
|
149
156
|
hidden_size // 2,
|
150
157
|
dtype=torch.int8,
|
151
158
|
),
|
@@ -159,7 +166,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
159
166
|
torch.empty(
|
160
167
|
num_experts,
|
161
168
|
hidden_size,
|
162
|
-
|
169
|
+
intermediate_size_per_partition // 2,
|
163
170
|
dtype=torch.int8,
|
164
171
|
),
|
165
172
|
requires_grad=False,
|
@@ -173,7 +180,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
173
180
|
w13_weight_scale = torch.nn.Parameter(
|
174
181
|
torch.zeros(
|
175
182
|
num_experts,
|
176
|
-
2 *
|
183
|
+
2 * intermediate_size_per_partition,
|
177
184
|
hidden_size // self.quant_config.group_size,
|
178
185
|
dtype=torch.float32,
|
179
186
|
),
|
@@ -186,7 +193,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
186
193
|
torch.zeros(
|
187
194
|
num_experts,
|
188
195
|
hidden_size,
|
189
|
-
|
196
|
+
intermediate_size_per_partition // self.quant_config.group_size,
|
190
197
|
dtype=torch.float32,
|
191
198
|
),
|
192
199
|
requires_grad=False,
|
@@ -220,13 +227,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
220
227
|
)
|
221
228
|
self.c_strides1 = torch.full(
|
222
229
|
(num_experts, 3),
|
223
|
-
2 *
|
230
|
+
2 * intermediate_size_per_partition,
|
224
231
|
device=device,
|
225
232
|
dtype=torch.int64,
|
226
233
|
)
|
227
234
|
self.a_strides2 = torch.full(
|
228
235
|
(num_experts, 3),
|
229
|
-
|
236
|
+
intermediate_size_per_partition,
|
230
237
|
device=device,
|
231
238
|
dtype=torch.int64,
|
232
239
|
)
|
@@ -282,16 +289,22 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
282
289
|
)
|
283
290
|
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
|
284
291
|
|
292
|
+
def create_moe_runner(
|
293
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
294
|
+
):
|
295
|
+
self.moe_runner_config = moe_runner_config
|
296
|
+
|
285
297
|
def apply(
|
286
298
|
self,
|
287
299
|
layer: EPMoE,
|
288
|
-
|
289
|
-
|
290
|
-
moe_runner_config: MoeRunnerConfig,
|
291
|
-
) -> torch.Tensor:
|
300
|
+
dispatch_output: StandardDispatchOutput,
|
301
|
+
) -> CombineInput:
|
292
302
|
|
293
|
-
# TODO(ch-wan): move it out of this class
|
294
303
|
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
304
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
305
|
+
|
306
|
+
x = dispatch_output.hidden_states
|
307
|
+
topk_output = dispatch_output.topk_output
|
295
308
|
|
296
309
|
topk_weights, topk_ids, _ = topk_output
|
297
310
|
local_topk_ids = topk_ids
|
@@ -328,6 +341,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
328
341
|
layer.w13_input_scale,
|
329
342
|
layer.w2_input_scale,
|
330
343
|
)
|
331
|
-
if moe_runner_config.routed_scaling_factor is not None:
|
332
|
-
output *= moe_runner_config.routed_scaling_factor
|
333
|
-
return output
|
344
|
+
if self.moe_runner_config.routed_scaling_factor is not None:
|
345
|
+
output *= self.moe_runner_config.routed_scaling_factor
|
346
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
5
5
|
import torch
|
6
6
|
from torch.nn.parameter import Parameter
|
7
7
|
|
8
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
9
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
8
10
|
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
9
11
|
from sglang.srt.layers.quantization.base_config import (
|
10
12
|
FusedMoEMethodBase,
|
@@ -26,8 +28,10 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
26
28
|
from sglang.srt.utils import set_weight_attrs
|
27
29
|
|
28
30
|
if TYPE_CHECKING:
|
29
|
-
from sglang.srt.layers.moe.
|
30
|
-
|
31
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
32
|
+
CombineInput,
|
33
|
+
StandardDispatchOutput,
|
34
|
+
)
|
31
35
|
|
32
36
|
_is_fp8_fnuz = is_fp8_fnuz()
|
33
37
|
|
@@ -209,7 +213,7 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
|
209
213
|
layer: torch.nn.Module,
|
210
214
|
num_experts: int,
|
211
215
|
hidden_size: int,
|
212
|
-
|
216
|
+
intermediate_size_per_partition: int,
|
213
217
|
params_dtype: torch.dtype,
|
214
218
|
**extra_weight_attrs,
|
215
219
|
):
|
@@ -218,7 +222,10 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
|
218
222
|
# WEIGHTS
|
219
223
|
w13_weight = torch.nn.Parameter(
|
220
224
|
torch.empty(
|
221
|
-
num_experts,
|
225
|
+
num_experts,
|
226
|
+
2 * intermediate_size_per_partition,
|
227
|
+
hidden_size,
|
228
|
+
dtype=fp8_dtype,
|
222
229
|
),
|
223
230
|
requires_grad=False,
|
224
231
|
)
|
@@ -226,14 +233,21 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
|
226
233
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
227
234
|
|
228
235
|
w2_weight = torch.nn.Parameter(
|
229
|
-
torch.empty(
|
236
|
+
torch.empty(
|
237
|
+
num_experts,
|
238
|
+
hidden_size,
|
239
|
+
intermediate_size_per_partition,
|
240
|
+
dtype=fp8_dtype,
|
241
|
+
),
|
230
242
|
requires_grad=False,
|
231
243
|
)
|
232
244
|
layer.register_parameter("w2_weight", w2_weight)
|
233
245
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
234
246
|
|
235
247
|
w13_weight_scale = torch.nn.Parameter(
|
236
|
-
torch.ones(
|
248
|
+
torch.ones(
|
249
|
+
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
250
|
+
),
|
237
251
|
requires_grad=False,
|
238
252
|
)
|
239
253
|
w2_weight_scale = torch.nn.Parameter(
|
@@ -266,25 +280,26 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
|
266
280
|
layer.w2_weight_scale.data, requires_grad=False
|
267
281
|
)
|
268
282
|
|
283
|
+
def create_moe_runner(
|
284
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
285
|
+
):
|
286
|
+
self.moe_runner_config = moe_runner_config
|
287
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
288
|
+
|
269
289
|
def apply(
|
270
290
|
self,
|
271
291
|
layer: torch.nn.Module,
|
272
|
-
|
273
|
-
|
274
|
-
moe_runner_config: MoeRunnerConfig,
|
275
|
-
) -> torch.Tensor:
|
276
|
-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
292
|
+
dispatch_output: StandardDispatchOutput,
|
293
|
+
) -> CombineInput:
|
277
294
|
|
278
|
-
|
279
|
-
|
280
|
-
layer.
|
281
|
-
layer.w2_weight,
|
282
|
-
topk_output=topk_output,
|
283
|
-
moe_runner_config=moe_runner_config,
|
295
|
+
quant_info = TritonMoeQuantInfo(
|
296
|
+
w13_weight=layer.w13_weight,
|
297
|
+
w2_weight=layer.w2_weight,
|
284
298
|
use_fp8_w8a8=True,
|
285
299
|
per_channel_quant=True,
|
286
|
-
|
287
|
-
w2_scale=
|
288
|
-
|
300
|
+
w13_scale=layer.w13_weight_scale,
|
301
|
+
w2_scale=layer.w2_weight_scale,
|
302
|
+
a13_scale=layer.w13_input_scale,
|
289
303
|
a2_scale=layer.w2_input_scale,
|
290
304
|
)
|
305
|
+
return self.runner.run(dispatch_output, quant_info)
|
@@ -24,6 +24,8 @@ from sglang.srt.distributed import (
|
|
24
24
|
get_tensor_model_parallel_world_size,
|
25
25
|
)
|
26
26
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
27
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
28
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
27
29
|
from sglang.srt.layers.parameter import (
|
28
30
|
ChannelQuantScaleParameter,
|
29
31
|
ModelWeightParameter,
|
@@ -49,8 +51,10 @@ from sglang.srt.utils import (
|
|
49
51
|
)
|
50
52
|
|
51
53
|
if TYPE_CHECKING:
|
52
|
-
from sglang.srt.layers.moe.
|
53
|
-
|
54
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
55
|
+
CombineInput,
|
56
|
+
StandardDispatchOutput,
|
57
|
+
)
|
54
58
|
|
55
59
|
_is_cuda = is_cuda()
|
56
60
|
_is_cpu_amx_available = cpu_has_amx_support()
|
@@ -339,9 +343,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|
339
343
|
_is_cpu_amx_available
|
340
344
|
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
|
341
345
|
_amx_process_weight_after_loading(layer, ["weight"])
|
342
|
-
|
343
|
-
|
344
|
-
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
346
|
+
else:
|
347
|
+
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
345
348
|
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
346
349
|
|
347
350
|
def create_weights(
|
@@ -417,7 +420,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
417
420
|
layer: torch.nn.Module,
|
418
421
|
num_experts: int,
|
419
422
|
hidden_size: int,
|
420
|
-
|
423
|
+
intermediate_size_per_partition: int,
|
421
424
|
params_dtype: torch.dtype,
|
422
425
|
**extra_weight_attrs,
|
423
426
|
):
|
@@ -428,7 +431,10 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
428
431
|
# WEIGHTS
|
429
432
|
w13_weight = torch.nn.Parameter(
|
430
433
|
torch.empty(
|
431
|
-
num_experts,
|
434
|
+
num_experts,
|
435
|
+
2 * intermediate_size_per_partition,
|
436
|
+
hidden_size,
|
437
|
+
dtype=torch.int8,
|
432
438
|
),
|
433
439
|
requires_grad=False,
|
434
440
|
)
|
@@ -436,14 +442,21 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
436
442
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
437
443
|
|
438
444
|
w2_weight = torch.nn.Parameter(
|
439
|
-
torch.empty(
|
445
|
+
torch.empty(
|
446
|
+
num_experts,
|
447
|
+
hidden_size,
|
448
|
+
intermediate_size_per_partition,
|
449
|
+
dtype=torch.int8,
|
450
|
+
),
|
440
451
|
requires_grad=False,
|
441
452
|
)
|
442
453
|
layer.register_parameter("w2_weight", w2_weight)
|
443
454
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
444
455
|
|
445
456
|
w13_weight_scale = torch.nn.Parameter(
|
446
|
-
torch.ones(
|
457
|
+
torch.ones(
|
458
|
+
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
459
|
+
),
|
447
460
|
requires_grad=False,
|
448
461
|
)
|
449
462
|
w2_weight_scale = torch.nn.Parameter(
|
@@ -472,10 +485,9 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
472
485
|
_is_cpu_amx_available
|
473
486
|
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
|
474
487
|
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
488
|
+
else:
|
489
|
+
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
490
|
+
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
479
491
|
layer.w13_weight_scale = Parameter(
|
480
492
|
layer.w13_weight_scale.data, requires_grad=False
|
481
493
|
)
|
@@ -483,23 +495,30 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
483
495
|
layer.w2_weight_scale.data, requires_grad=False
|
484
496
|
)
|
485
497
|
|
498
|
+
def create_moe_runner(
|
499
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
500
|
+
):
|
501
|
+
self.moe_runner_config = moe_runner_config
|
502
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
503
|
+
|
486
504
|
def apply(
|
487
505
|
self,
|
488
506
|
layer: torch.nn.Module,
|
489
|
-
|
490
|
-
topk_output: TopKOutput,
|
491
|
-
moe_runner_config: MoeRunnerConfig,
|
507
|
+
dispatch_output: StandardDispatchOutput,
|
492
508
|
) -> torch.Tensor:
|
493
|
-
from sglang.srt.layers.moe.
|
509
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
510
|
+
|
511
|
+
x = dispatch_output.hidden_states
|
512
|
+
topk_output = dispatch_output.topk_output
|
494
513
|
|
495
514
|
if use_intel_amx_backend(layer):
|
496
515
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
497
516
|
|
498
517
|
topk_weights, topk_ids, _ = topk_output
|
499
518
|
x, topk_weights = apply_topk_weights_cpu(
|
500
|
-
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
519
|
+
self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
501
520
|
)
|
502
|
-
|
521
|
+
output = torch.ops.sgl_kernel.fused_experts_cpu(
|
503
522
|
x,
|
504
523
|
layer.w13_weight,
|
505
524
|
layer.w2_weight,
|
@@ -515,20 +534,19 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
|
515
534
|
layer.w2_input_scale, # a2_scale
|
516
535
|
True, # is_vnni
|
517
536
|
)
|
537
|
+
return StandardCombineInput(hidden_states=output)
|
518
538
|
|
519
|
-
|
520
|
-
|
521
|
-
layer.
|
522
|
-
layer.w2_weight,
|
523
|
-
topk_output=topk_output,
|
524
|
-
moe_runner_config=moe_runner_config,
|
539
|
+
quant_info = TritonMoeQuantInfo(
|
540
|
+
w13_weight=layer.w13_weight,
|
541
|
+
w2_weight=layer.w2_weight,
|
525
542
|
use_int8_w8a8=True,
|
526
543
|
per_channel_quant=True,
|
527
|
-
|
528
|
-
w2_scale=
|
529
|
-
|
544
|
+
w13_scale=layer.w13_weight_scale,
|
545
|
+
w2_scale=layer.w2_weight_scale,
|
546
|
+
a13_scale=layer.w13_input_scale,
|
530
547
|
a2_scale=layer.w2_input_scale,
|
531
548
|
)
|
549
|
+
return self.runner.run(dispatch_output, quant_info)
|
532
550
|
|
533
551
|
|
534
552
|
class NPU_W8A8LinearMethodImpl:
|
@@ -900,7 +918,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
900
918
|
layer: torch.nn.Module,
|
901
919
|
num_experts: int,
|
902
920
|
hidden_size: int,
|
903
|
-
|
921
|
+
intermediate_size_per_partition: int,
|
904
922
|
params_dtype: torch.dtype,
|
905
923
|
**extra_weight_attrs,
|
906
924
|
) -> None:
|
@@ -914,21 +932,31 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
914
932
|
# weight
|
915
933
|
w13_weight = torch.nn.Parameter(
|
916
934
|
torch.empty(
|
917
|
-
num_experts,
|
935
|
+
num_experts,
|
936
|
+
2 * intermediate_size_per_partition,
|
937
|
+
hidden_size,
|
938
|
+
dtype=torch.int8,
|
918
939
|
),
|
919
940
|
requires_grad=False,
|
920
941
|
)
|
921
942
|
layer.register_parameter("w13_weight", w13_weight)
|
922
943
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
923
944
|
w2_weight = torch.nn.Parameter(
|
924
|
-
torch.empty(
|
945
|
+
torch.empty(
|
946
|
+
num_experts,
|
947
|
+
hidden_size,
|
948
|
+
intermediate_size_per_partition,
|
949
|
+
dtype=torch.int8,
|
950
|
+
),
|
925
951
|
requires_grad=False,
|
926
952
|
)
|
927
953
|
layer.register_parameter("w2_weight", w2_weight)
|
928
954
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
929
955
|
# scale
|
930
956
|
w13_weight_scale = torch.nn.Parameter(
|
931
|
-
torch.empty(
|
957
|
+
torch.empty(
|
958
|
+
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
959
|
+
),
|
932
960
|
requires_grad=False,
|
933
961
|
)
|
934
962
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
@@ -941,7 +969,9 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
941
969
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
942
970
|
# offset
|
943
971
|
w13_weight_offset = torch.nn.Parameter(
|
944
|
-
torch.empty(
|
972
|
+
torch.empty(
|
973
|
+
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
|
974
|
+
),
|
945
975
|
requires_grad=False,
|
946
976
|
)
|
947
977
|
layer.register_parameter("w13_weight_offset", w13_weight_offset)
|
@@ -973,18 +1003,25 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
973
1003
|
layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
|
974
1004
|
)
|
975
1005
|
|
1006
|
+
def create_moe_runner(
|
1007
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
1008
|
+
):
|
1009
|
+
self.moe_runner_config = moe_runner_config
|
1010
|
+
|
976
1011
|
def apply(
|
977
1012
|
self,
|
978
1013
|
layer,
|
979
|
-
|
980
|
-
|
981
|
-
|
982
|
-
|
1014
|
+
dispatch_output: StandardDispatchOutput,
|
1015
|
+
) -> CombineInput:
|
1016
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
1017
|
+
|
1018
|
+
x = dispatch_output.hidden_states
|
1019
|
+
topk_output = dispatch_output.topk_output
|
983
1020
|
|
984
1021
|
topk_weights, topk_ids, _ = topk_output
|
985
1022
|
topk_ids = topk_ids.to(torch.int32)
|
986
1023
|
topk_weights = topk_weights.to(x.dtype)
|
987
|
-
|
1024
|
+
output = npu_fused_experts(
|
988
1025
|
hidden_states=x,
|
989
1026
|
w13=layer.w13_weight,
|
990
1027
|
w13_scale=layer.w13_weight_scale,
|
@@ -994,3 +1031,4 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
|
994
1031
|
topk_ids=topk_ids,
|
995
1032
|
top_k=topk_ids.shape[1],
|
996
1033
|
)
|
1034
|
+
return StandardCombineInput(hidden_states=output)
|