sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -91,18 +91,10 @@ def cutlass_w4a8_moe(
|
|
91
91
|
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
92
92
|
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
93
93
|
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
94
|
-
assert (
|
95
|
-
w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
|
96
|
-
and w1_scale.shape[2] == w1_q.shape[1] * 4
|
97
|
-
), "W1 scale shape mismatch"
|
98
|
-
assert (
|
99
|
-
w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
|
100
|
-
and w2_scale.shape[2] == w2_q.shape[1] * 4
|
101
|
-
), "W2 scale shape mismatch"
|
102
94
|
|
103
95
|
assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
|
104
96
|
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
|
105
|
-
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number
|
97
|
+
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
|
106
98
|
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
|
107
99
|
num_experts = w1_q.size(0)
|
108
100
|
m = a.size(0)
|
@@ -155,8 +147,8 @@ def cutlass_w4a8_moe(
|
|
155
147
|
k,
|
156
148
|
)
|
157
149
|
|
158
|
-
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.
|
159
|
-
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.
|
150
|
+
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
|
151
|
+
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
|
160
152
|
|
161
153
|
cutlass_w4a8_moe_mm(
|
162
154
|
c1,
|
@@ -174,7 +166,7 @@ def cutlass_w4a8_moe(
|
|
174
166
|
topk,
|
175
167
|
)
|
176
168
|
|
177
|
-
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.
|
169
|
+
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
|
178
170
|
silu_and_mul(c1, intermediate)
|
179
171
|
|
180
172
|
intermediate_q = torch.empty(
|
@@ -1362,3 +1362,77 @@ def moe_ep_deepgemm_preprocess(
|
|
1362
1362
|
gateup_input,
|
1363
1363
|
gateup_input_scale,
|
1364
1364
|
)
|
1365
|
+
|
1366
|
+
|
1367
|
+
@triton.jit
|
1368
|
+
def compute_identity_kernel(
|
1369
|
+
top_k,
|
1370
|
+
hidden_states_ptr,
|
1371
|
+
expert_scales_ptr,
|
1372
|
+
num_tokens,
|
1373
|
+
output_ptr,
|
1374
|
+
hidden_dim,
|
1375
|
+
scales_stride,
|
1376
|
+
BLOCK_SIZE: tl.constexpr,
|
1377
|
+
):
|
1378
|
+
pid = tl.program_id(0)
|
1379
|
+
|
1380
|
+
batch_id = pid // (hidden_dim // BLOCK_SIZE)
|
1381
|
+
dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE
|
1382
|
+
|
1383
|
+
if batch_id >= num_tokens or dim_offset >= hidden_dim:
|
1384
|
+
return
|
1385
|
+
|
1386
|
+
h = tl.load(
|
1387
|
+
hidden_states_ptr
|
1388
|
+
+ batch_id * hidden_dim
|
1389
|
+
+ dim_offset
|
1390
|
+
+ tl.arange(0, BLOCK_SIZE),
|
1391
|
+
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
|
1392
|
+
)
|
1393
|
+
|
1394
|
+
result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
1395
|
+
for i in range(top_k):
|
1396
|
+
scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)
|
1397
|
+
result += h * scale
|
1398
|
+
|
1399
|
+
tl.store(
|
1400
|
+
output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE),
|
1401
|
+
result,
|
1402
|
+
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
|
1403
|
+
)
|
1404
|
+
|
1405
|
+
|
1406
|
+
def zero_experts_compute_triton(
|
1407
|
+
expert_indices, expert_scales, num_experts, zero_expert_type, hidden_states
|
1408
|
+
):
|
1409
|
+
N = expert_indices.numel()
|
1410
|
+
top_k = expert_indices.size(-1)
|
1411
|
+
grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
|
1412
|
+
|
1413
|
+
if zero_expert_type == "identity":
|
1414
|
+
zero_expert_mask = expert_indices < num_experts
|
1415
|
+
zero_expert_scales = expert_scales.clone()
|
1416
|
+
zero_expert_scales[zero_expert_mask] = 0.0
|
1417
|
+
|
1418
|
+
normal_expert_mask = expert_indices >= num_experts
|
1419
|
+
expert_indices[normal_expert_mask] = -1
|
1420
|
+
expert_scales[normal_expert_mask] = 0.0
|
1421
|
+
|
1422
|
+
output = torch.zeros_like(hidden_states).to(hidden_states.device)
|
1423
|
+
hidden_dim = hidden_states.size(-1)
|
1424
|
+
num_tokens = hidden_states.size(0)
|
1425
|
+
|
1426
|
+
grid = lambda meta: (num_tokens * (hidden_dim // meta["BLOCK_SIZE"]),)
|
1427
|
+
compute_identity_kernel[grid](
|
1428
|
+
top_k,
|
1429
|
+
hidden_states,
|
1430
|
+
zero_expert_scales,
|
1431
|
+
num_tokens,
|
1432
|
+
output,
|
1433
|
+
hidden_dim,
|
1434
|
+
zero_expert_scales.stride(0),
|
1435
|
+
BLOCK_SIZE=256,
|
1436
|
+
)
|
1437
|
+
|
1438
|
+
return output
|
@@ -35,7 +35,6 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
|
|
35
35
|
|
36
36
|
if TYPE_CHECKING:
|
37
37
|
from sglang.srt.layers.moe.token_dispatcher import (
|
38
|
-
AscendDeepEPLLOutput,
|
39
38
|
DeepEPLLOutput,
|
40
39
|
DeepEPNormalOutput,
|
41
40
|
DispatchOutput,
|
@@ -114,9 +113,6 @@ class EPMoE(FusedMoE):
|
|
114
113
|
with_bias=with_bias,
|
115
114
|
)
|
116
115
|
|
117
|
-
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
118
|
-
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
119
|
-
|
120
116
|
self.intermediate_size = intermediate_size
|
121
117
|
|
122
118
|
if isinstance(quant_config, Fp8Config):
|
@@ -232,7 +228,7 @@ class EPMoE(FusedMoE):
|
|
232
228
|
(
|
233
229
|
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
|
234
230
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
235
|
-
else deep_gemm_wrapper.
|
231
|
+
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
236
232
|
gateup_input_scale
|
237
233
|
)
|
238
234
|
),
|
@@ -248,7 +244,6 @@ class EPMoE(FusedMoE):
|
|
248
244
|
gateup_output,
|
249
245
|
masked_m,
|
250
246
|
expected_m,
|
251
|
-
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
252
247
|
)
|
253
248
|
del gateup_input
|
254
249
|
del gateup_input_fp8
|
@@ -290,9 +285,7 @@ class EPMoE(FusedMoE):
|
|
290
285
|
(
|
291
286
|
down_input_scale
|
292
287
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
293
|
-
else deep_gemm_wrapper.
|
294
|
-
down_input_scale
|
295
|
-
)
|
288
|
+
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
|
296
289
|
),
|
297
290
|
)
|
298
291
|
down_output = torch.empty(
|
@@ -304,7 +297,6 @@ class EPMoE(FusedMoE):
|
|
304
297
|
down_output,
|
305
298
|
masked_m,
|
306
299
|
expected_m,
|
307
|
-
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
308
300
|
)
|
309
301
|
del down_input
|
310
302
|
del down_input_fp8
|
@@ -461,7 +453,7 @@ class DeepEPMoE(EPMoE):
|
|
461
453
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
462
454
|
return self.forward_aiter(dispatch_output)
|
463
455
|
if _is_npu:
|
464
|
-
assert DispatchOutputChecker.
|
456
|
+
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
465
457
|
return self.forward_npu(dispatch_output)
|
466
458
|
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
467
459
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
@@ -667,7 +659,6 @@ class DeepEPMoE(EPMoE):
|
|
667
659
|
gateup_output,
|
668
660
|
masked_m,
|
669
661
|
expected_m,
|
670
|
-
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
671
662
|
)
|
672
663
|
dispose_tensor(hidden_states_fp8[0])
|
673
664
|
|
@@ -708,9 +699,7 @@ class DeepEPMoE(EPMoE):
|
|
708
699
|
(
|
709
700
|
down_input_scale
|
710
701
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
711
|
-
else deep_gemm_wrapper.
|
712
|
-
down_input_scale
|
713
|
-
)
|
702
|
+
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
|
714
703
|
),
|
715
704
|
)
|
716
705
|
down_output = torch.empty(
|
@@ -722,64 +711,130 @@ class DeepEPMoE(EPMoE):
|
|
722
711
|
down_output,
|
723
712
|
masked_m,
|
724
713
|
expected_m,
|
725
|
-
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
726
714
|
)
|
727
715
|
|
728
716
|
return down_output
|
729
717
|
|
730
718
|
def forward_npu(
|
731
719
|
self,
|
732
|
-
dispatch_output: DeepEPLLOutput,
|
720
|
+
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
|
733
721
|
):
|
734
|
-
if TYPE_CHECKING:
|
735
|
-
assert isinstance(dispatch_output, AscendDeepEPLLOutput)
|
736
|
-
hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
|
737
722
|
assert self.quant_method is not None
|
738
723
|
assert self.moe_runner_config.activation == "silu"
|
739
724
|
|
725
|
+
import torch_npu
|
726
|
+
|
727
|
+
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
|
728
|
+
|
740
729
|
# NOTE: Ascend's Dispatch & Combine does not support FP16
|
741
730
|
output_dtype = torch.bfloat16
|
731
|
+
group_list_type = 1
|
742
732
|
|
743
|
-
|
744
|
-
|
733
|
+
def _forward_normal(dispatch_output: DeepEPNormalOutput):
|
734
|
+
if TYPE_CHECKING:
|
735
|
+
assert isinstance(dispatch_output, DeepEPNormalOutput)
|
736
|
+
hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
|
737
|
+
|
738
|
+
if isinstance(hidden_states, tuple):
|
739
|
+
per_token_scale = hidden_states[1]
|
740
|
+
hidden_states = hidden_states[0]
|
741
|
+
else:
|
742
|
+
# dynamic quant
|
743
|
+
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
|
744
|
+
hidden_states
|
745
|
+
)
|
745
746
|
|
746
|
-
|
747
|
-
|
747
|
+
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
|
748
|
+
hidden_states.device
|
749
|
+
)
|
748
750
|
|
749
|
-
|
751
|
+
# gmm1: gate_up_proj
|
752
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
753
|
+
x=[hidden_states],
|
754
|
+
weight=[self.w13_weight],
|
755
|
+
scale=[self.w13_weight_scale.to(output_dtype)],
|
756
|
+
per_token_scale=[per_token_scale],
|
757
|
+
split_item=2,
|
758
|
+
group_list_type=group_list_type,
|
759
|
+
group_type=0,
|
760
|
+
group_list=group_list,
|
761
|
+
output_dtype=output_dtype,
|
762
|
+
)[0]
|
763
|
+
|
764
|
+
# act_fn: swiglu
|
765
|
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
766
|
+
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
767
|
+
|
768
|
+
# gmm2: down_proj
|
769
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
770
|
+
x=[hidden_states],
|
771
|
+
weight=[self.w2_weight],
|
772
|
+
scale=[self.w2_weight_scale.to(output_dtype)],
|
773
|
+
per_token_scale=[swiglu_out_scale],
|
774
|
+
split_item=2,
|
775
|
+
group_list_type=group_list_type,
|
776
|
+
group_type=0,
|
777
|
+
group_list=group_list,
|
778
|
+
output_dtype=output_dtype,
|
779
|
+
)[0]
|
780
|
+
|
781
|
+
return hidden_states
|
750
782
|
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
per_token_scale=[
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
783
|
+
def _forward_ll(dispatch_output: DeepEPLLOutput):
|
784
|
+
if TYPE_CHECKING:
|
785
|
+
assert isinstance(dispatch_output, DeepEPLLOutput)
|
786
|
+
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
|
787
|
+
|
788
|
+
per_token_scale = hidden_states[1]
|
789
|
+
hidden_states = hidden_states[0]
|
790
|
+
|
791
|
+
group_list = group_list.to(torch.int64)
|
792
|
+
|
793
|
+
# gmm1: gate_up_proj
|
794
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
795
|
+
x=[hidden_states],
|
796
|
+
weight=[self.w13_weight],
|
797
|
+
split_item=2,
|
798
|
+
group_list_type=group_list_type,
|
799
|
+
group_type=0,
|
800
|
+
group_list=group_list,
|
801
|
+
output_dtype=torch.int32,
|
802
|
+
)[0]
|
803
|
+
|
804
|
+
# act_fn: swiglu
|
805
|
+
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
806
|
+
x=hidden_states,
|
807
|
+
weight_scale=self.w13_weight_scale.to(torch.float32),
|
808
|
+
activation_scale=per_token_scale,
|
809
|
+
bias=None,
|
810
|
+
quant_scale=None,
|
811
|
+
quant_offset=None,
|
812
|
+
group_index=group_list,
|
813
|
+
activate_left=True,
|
814
|
+
quant_mode=1,
|
815
|
+
)
|
781
816
|
|
782
|
-
|
817
|
+
# gmm2: down_proj
|
818
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
819
|
+
x=[hidden_states],
|
820
|
+
weight=[self.w2_weight],
|
821
|
+
scale=[self.w2_weight_scale.to(output_dtype)],
|
822
|
+
per_token_scale=[swiglu_out_scale],
|
823
|
+
split_item=2,
|
824
|
+
group_list_type=group_list_type,
|
825
|
+
group_type=0,
|
826
|
+
group_list=group_list,
|
827
|
+
output_dtype=output_dtype,
|
828
|
+
)[0]
|
829
|
+
|
830
|
+
return hidden_states
|
831
|
+
|
832
|
+
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
833
|
+
return _forward_normal(dispatch_output)
|
834
|
+
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
835
|
+
return _forward_ll(dispatch_output)
|
836
|
+
else:
|
837
|
+
raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
|
783
838
|
|
784
839
|
|
785
840
|
def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
@@ -8,16 +8,18 @@ from torch.nn import functional as F
|
|
8
8
|
|
9
9
|
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
10
10
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
11
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardDispatchOutput
|
11
12
|
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
12
13
|
|
13
14
|
|
14
15
|
def fused_moe_forward_native(
|
15
16
|
layer: torch.nn.Module,
|
16
|
-
|
17
|
-
topk_output: StandardTopKOutput,
|
18
|
-
moe_runner_config: MoeRunnerConfig,
|
17
|
+
dispatch_output: StandardDispatchOutput,
|
19
18
|
) -> torch.Tensor:
|
20
19
|
|
20
|
+
x, topk_output = dispatch_output
|
21
|
+
moe_runner_config = layer.moe_runner_config
|
22
|
+
|
21
23
|
if moe_runner_config.apply_router_weight_on_input:
|
22
24
|
raise NotImplementedError()
|
23
25
|
|
@@ -1,16 +1,18 @@
|
|
1
1
|
from contextlib import contextmanager
|
2
2
|
from typing import Any, Dict, Optional
|
3
3
|
|
4
|
-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import
|
5
|
-
|
4
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
5
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (
|
6
6
|
get_config_file_name,
|
7
|
-
moe_align_block_size,
|
8
7
|
try_get_optimal_moe_config,
|
9
8
|
)
|
10
9
|
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
11
10
|
FusedMoE,
|
12
11
|
FusedMoeWeightScaleSupported,
|
13
12
|
)
|
13
|
+
from sglang.srt.layers.moe.fused_moe_triton.moe_align_block_size import (
|
14
|
+
moe_align_block_size,
|
15
|
+
)
|
14
16
|
|
15
17
|
_config: Optional[Dict[str, Any]] = None
|
16
18
|
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 5
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 3
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 32,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 3
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 64,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 32,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 32,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 16,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 1,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|