sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -35,6 +35,9 @@ from sglang.srt.layers.dp_attention import (
|
|
35
35
|
get_attention_dp_rank,
|
36
36
|
get_attention_dp_size,
|
37
37
|
get_attention_tp_size,
|
38
|
+
get_dp_device,
|
39
|
+
get_dp_dtype,
|
40
|
+
get_dp_hidden_size,
|
38
41
|
get_global_dp_buffer,
|
39
42
|
get_local_attention_dp_size,
|
40
43
|
set_dp_buffer_len,
|
@@ -46,10 +49,12 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
46
49
|
ForwardBatch,
|
47
50
|
ForwardMode,
|
48
51
|
)
|
49
|
-
from sglang.srt.utils import dump_to_file, use_intel_amx_backend
|
52
|
+
from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
|
50
53
|
|
51
54
|
logger = logging.getLogger(__name__)
|
52
55
|
|
56
|
+
_is_npu = is_npu()
|
57
|
+
|
53
58
|
|
54
59
|
@dataclasses.dataclass
|
55
60
|
class LogitsProcessorOutput:
|
@@ -67,7 +72,10 @@ class LogitsProcessorOutput:
|
|
67
72
|
next_token_top_logprobs_val: Optional[List] = None
|
68
73
|
next_token_top_logprobs_idx: Optional[List] = None
|
69
74
|
# The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
|
70
|
-
|
75
|
+
# Can contain either lists or GPU tensors (for delayed copy optimization in prefill-only requests)
|
76
|
+
next_token_token_ids_logprobs_val: Optional[
|
77
|
+
List[Union[List[float], torch.Tensor]]
|
78
|
+
] = None
|
71
79
|
next_token_token_ids_logprobs_idx: Optional[List] = None
|
72
80
|
|
73
81
|
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
@@ -180,10 +188,13 @@ class LogitsMetadata:
|
|
180
188
|
)
|
181
189
|
else:
|
182
190
|
dp_local_start_pos = cumtokens[dp_rank - 1]
|
183
|
-
dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
|
184
191
|
|
185
192
|
self.dp_local_start_pos = dp_local_start_pos
|
186
|
-
self.dp_local_num_tokens =
|
193
|
+
self.dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
|
194
|
+
|
195
|
+
hidden_size = get_dp_hidden_size()
|
196
|
+
dtype = get_dp_dtype()
|
197
|
+
device = get_dp_device()
|
187
198
|
|
188
199
|
if self.global_num_tokens_for_logprob_cpu is not None:
|
189
200
|
# create a smaller buffer to reduce peak memory usage
|
@@ -191,10 +202,13 @@ class LogitsMetadata:
|
|
191
202
|
else:
|
192
203
|
self.global_dp_buffer_len = self.global_dp_buffer_len
|
193
204
|
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
205
|
+
self.gathered_buffer = torch.empty(
|
206
|
+
(
|
207
|
+
self.global_dp_buffer_len,
|
208
|
+
hidden_size,
|
209
|
+
),
|
210
|
+
dtype=dtype,
|
211
|
+
device=device,
|
198
212
|
)
|
199
213
|
|
200
214
|
|
@@ -441,7 +455,7 @@ class LogitsProcessor(nn.Module):
|
|
441
455
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
442
456
|
logits_metadata.compute_dp_attention_metadata()
|
443
457
|
hidden_states, local_hidden_states = (
|
444
|
-
|
458
|
+
logits_metadata.gathered_buffer,
|
445
459
|
hidden_states,
|
446
460
|
)
|
447
461
|
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
@@ -517,7 +531,12 @@ class LogitsProcessor(nn.Module):
|
|
517
531
|
logits = logits[:, : self.config.vocab_size].float()
|
518
532
|
|
519
533
|
if self.final_logit_softcapping:
|
520
|
-
|
534
|
+
if not _is_npu:
|
535
|
+
fused_softcap(logits, self.final_logit_softcapping)
|
536
|
+
else:
|
537
|
+
logits = self.final_logit_softcapping * torch.tanh(
|
538
|
+
logits / self.final_logit_softcapping
|
539
|
+
)
|
521
540
|
|
522
541
|
return logits
|
523
542
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
1
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunner, MoeRunnerConfig
|
2
2
|
from sglang.srt.layers.moe.utils import (
|
3
3
|
DeepEPMode,
|
4
4
|
MoeA2ABackend,
|
@@ -17,6 +17,7 @@ from sglang.srt.layers.moe.utils import (
|
|
17
17
|
__all__ = [
|
18
18
|
"DeepEPMode",
|
19
19
|
"MoeA2ABackend",
|
20
|
+
"MoeRunner",
|
20
21
|
"MoeRunnerConfig",
|
21
22
|
"MoeRunnerBackend",
|
22
23
|
"initialize_moe_config",
|
@@ -147,8 +147,8 @@ def cutlass_w4a8_moe(
|
|
147
147
|
k,
|
148
148
|
)
|
149
149
|
|
150
|
-
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.
|
151
|
-
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.
|
150
|
+
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
|
151
|
+
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
|
152
152
|
|
153
153
|
cutlass_w4a8_moe_mm(
|
154
154
|
c1,
|
@@ -166,7 +166,7 @@ def cutlass_w4a8_moe(
|
|
166
166
|
topk,
|
167
167
|
)
|
168
168
|
|
169
|
-
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.
|
169
|
+
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
|
170
170
|
silu_and_mul(c1, intermediate)
|
171
171
|
|
172
172
|
intermediate_q = torch.empty(
|
@@ -1416,7 +1416,7 @@ def zero_experts_compute_triton(
|
|
1416
1416
|
zero_expert_scales[zero_expert_mask] = 0.0
|
1417
1417
|
|
1418
1418
|
normal_expert_mask = expert_indices >= num_experts
|
1419
|
-
expert_indices[normal_expert_mask] =
|
1419
|
+
expert_indices[normal_expert_mask] = -1
|
1420
1420
|
expert_scales[normal_expert_mask] = 0.0
|
1421
1421
|
|
1422
1422
|
output = torch.zeros_like(hidden_states).to(hidden_states.device)
|
@@ -1,9 +1,11 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from typing import TYPE_CHECKING, Optional, Union
|
4
|
+
from typing import TYPE_CHECKING, List, Optional, Union
|
5
5
|
|
6
6
|
import torch
|
7
|
+
import triton
|
8
|
+
import triton.language as tl
|
7
9
|
|
8
10
|
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
9
11
|
from sglang.srt.layers.moe import (
|
@@ -31,11 +33,18 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
31
33
|
)
|
32
34
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
33
35
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
34
|
-
from sglang.srt.
|
36
|
+
from sglang.srt.offloader import get_offloader
|
37
|
+
from sglang.srt.utils import (
|
38
|
+
ceil_div,
|
39
|
+
dispose_tensor,
|
40
|
+
get_bool_env_var,
|
41
|
+
is_cuda,
|
42
|
+
is_hip,
|
43
|
+
is_npu,
|
44
|
+
)
|
35
45
|
|
36
46
|
if TYPE_CHECKING:
|
37
47
|
from sglang.srt.layers.moe.token_dispatcher import (
|
38
|
-
AscendDeepEPLLOutput,
|
39
48
|
DeepEPLLOutput,
|
40
49
|
DeepEPNormalOutput,
|
41
50
|
DispatchOutput,
|
@@ -454,12 +463,14 @@ class DeepEPMoE(EPMoE):
|
|
454
463
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
455
464
|
return self.forward_aiter(dispatch_output)
|
456
465
|
if _is_npu:
|
457
|
-
assert DispatchOutputChecker.
|
466
|
+
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
458
467
|
return self.forward_npu(dispatch_output)
|
459
468
|
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
460
469
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
461
470
|
return self.forward_deepgemm_contiguous(dispatch_output)
|
462
471
|
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
472
|
+
if get_moe_runner_backend().is_flashinfer_cutedsl():
|
473
|
+
return self.forward_flashinfer_cutedsl(dispatch_output)
|
463
474
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
464
475
|
return self.forward_deepgemm_masked(dispatch_output)
|
465
476
|
else:
|
@@ -534,6 +545,24 @@ class DeepEPMoE(EPMoE):
|
|
534
545
|
N = self.w13_weight.size(1)
|
535
546
|
scale_block_size = 128
|
536
547
|
|
548
|
+
# TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
|
549
|
+
w13_weight_fp8 = (
|
550
|
+
self.w13_weight,
|
551
|
+
(
|
552
|
+
self.w13_weight_scale_inv
|
553
|
+
if self.use_block_quant
|
554
|
+
else self.w13_weight_scale
|
555
|
+
),
|
556
|
+
)
|
557
|
+
w2_weight_fp8 = (
|
558
|
+
self.w2_weight,
|
559
|
+
(
|
560
|
+
self.w2_weight_scale_inv
|
561
|
+
if self.use_block_quant
|
562
|
+
else self.w2_weight_scale
|
563
|
+
),
|
564
|
+
)
|
565
|
+
|
537
566
|
hidden_states_fp8_shape = hidden_states_fp8.shape
|
538
567
|
hidden_states_fp8_device = hidden_states_fp8.device
|
539
568
|
hidden_states_fp8_dtype = hidden_states_fp8.dtype
|
@@ -564,12 +593,17 @@ class DeepEPMoE(EPMoE):
|
|
564
593
|
)
|
565
594
|
output_index = torch.empty_like(topk_idx)
|
566
595
|
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
596
|
+
if get_offloader().forbid_copy_engine_usage:
|
597
|
+
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
|
598
|
+
num_recv_tokens_per_expert
|
599
|
+
)
|
600
|
+
else:
|
601
|
+
num_recv_tokens_per_expert_gpu = torch.tensor(
|
602
|
+
num_recv_tokens_per_expert,
|
603
|
+
dtype=torch.int32,
|
604
|
+
pin_memory=True,
|
605
|
+
device="cpu",
|
606
|
+
).cuda(non_blocking=True)
|
573
607
|
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
|
574
608
|
|
575
609
|
ep_scatter(
|
@@ -594,7 +628,7 @@ class DeepEPMoE(EPMoE):
|
|
594
628
|
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
595
629
|
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
596
630
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
597
|
-
input_tensor,
|
631
|
+
input_tensor, w13_weight_fp8, gateup_output, m_indices
|
598
632
|
)
|
599
633
|
del input_tensor
|
600
634
|
down_input = torch.empty(
|
@@ -624,7 +658,7 @@ class DeepEPMoE(EPMoE):
|
|
624
658
|
down_input_scale = tma_align_input_scale(down_input_scale)
|
625
659
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
626
660
|
(down_input_fp8, down_input_scale),
|
627
|
-
|
661
|
+
w2_weight_fp8,
|
628
662
|
down_output,
|
629
663
|
m_indices,
|
630
664
|
)
|
@@ -639,6 +673,22 @@ class DeepEPMoE(EPMoE):
|
|
639
673
|
|
640
674
|
return gather_out
|
641
675
|
|
676
|
+
def forward_flashinfer_cutedsl(
|
677
|
+
self,
|
678
|
+
dispatch_output: DeepEPLLOutput,
|
679
|
+
):
|
680
|
+
hidden_states, _, _, masked_m, _ = dispatch_output
|
681
|
+
assert self.quant_method is not None
|
682
|
+
assert self.moe_runner_config.activation == "silu"
|
683
|
+
|
684
|
+
output = self.quant_method.apply_without_routing_weights(
|
685
|
+
layer=self,
|
686
|
+
x=hidden_states,
|
687
|
+
masked_m=masked_m,
|
688
|
+
moe_runner_config=self.moe_runner_config,
|
689
|
+
)
|
690
|
+
return output
|
691
|
+
|
642
692
|
def forward_deepgemm_masked(
|
643
693
|
self,
|
644
694
|
dispatch_output: DeepEPLLOutput,
|
@@ -718,66 +768,127 @@ class DeepEPMoE(EPMoE):
|
|
718
768
|
|
719
769
|
def forward_npu(
|
720
770
|
self,
|
721
|
-
dispatch_output: DeepEPLLOutput,
|
771
|
+
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
|
722
772
|
):
|
723
|
-
if TYPE_CHECKING:
|
724
|
-
assert isinstance(dispatch_output, AscendDeepEPLLOutput)
|
725
|
-
hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
|
726
773
|
assert self.quant_method is not None
|
727
774
|
assert self.moe_runner_config.activation == "silu"
|
728
775
|
|
776
|
+
import torch_npu
|
777
|
+
|
778
|
+
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
|
779
|
+
|
729
780
|
# NOTE: Ascend's Dispatch & Combine does not support FP16
|
730
781
|
output_dtype = torch.bfloat16
|
782
|
+
group_list_type = 1
|
731
783
|
|
732
|
-
|
733
|
-
|
784
|
+
def _forward_normal(dispatch_output: DeepEPNormalOutput):
|
785
|
+
if TYPE_CHECKING:
|
786
|
+
assert isinstance(dispatch_output, DeepEPNormalOutput)
|
787
|
+
hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
|
788
|
+
|
789
|
+
if isinstance(hidden_states, tuple):
|
790
|
+
per_token_scale = hidden_states[1]
|
791
|
+
hidden_states = hidden_states[0]
|
792
|
+
else:
|
793
|
+
# dynamic quant
|
794
|
+
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
|
795
|
+
hidden_states
|
796
|
+
)
|
734
797
|
|
735
|
-
|
736
|
-
|
798
|
+
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
|
799
|
+
hidden_states.device
|
800
|
+
)
|
737
801
|
|
738
|
-
|
802
|
+
# gmm1: gate_up_proj
|
803
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
804
|
+
x=[hidden_states],
|
805
|
+
weight=[self.w13_weight],
|
806
|
+
scale=[self.w13_weight_scale.to(output_dtype)],
|
807
|
+
per_token_scale=[per_token_scale],
|
808
|
+
split_item=2,
|
809
|
+
group_list_type=group_list_type,
|
810
|
+
group_type=0,
|
811
|
+
group_list=group_list,
|
812
|
+
output_dtype=output_dtype,
|
813
|
+
)[0]
|
814
|
+
|
815
|
+
# act_fn: swiglu
|
816
|
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
817
|
+
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
818
|
+
|
819
|
+
# gmm2: down_proj
|
820
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
821
|
+
x=[hidden_states],
|
822
|
+
weight=[self.w2_weight],
|
823
|
+
scale=[self.w2_weight_scale.to(output_dtype)],
|
824
|
+
per_token_scale=[swiglu_out_scale],
|
825
|
+
split_item=2,
|
826
|
+
group_list_type=group_list_type,
|
827
|
+
group_type=0,
|
828
|
+
group_list=group_list,
|
829
|
+
output_dtype=output_dtype,
|
830
|
+
)[0]
|
739
831
|
|
740
|
-
|
741
|
-
hidden_states = torch_npu.npu_grouped_matmul(
|
742
|
-
x=[hidden_states],
|
743
|
-
weight=[self.w13_weight],
|
744
|
-
split_item=2,
|
745
|
-
group_list_type=group_list_type,
|
746
|
-
group_type=0,
|
747
|
-
group_list=seg_indptr,
|
748
|
-
output_dtype=torch.int32,
|
749
|
-
)[0]
|
750
|
-
|
751
|
-
# act_fn: swiglu
|
752
|
-
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
753
|
-
x=hidden_states,
|
754
|
-
weight_scale=self.w13_weight_scale.to(torch.float32),
|
755
|
-
activation_scale=pertoken_scale,
|
756
|
-
bias=None,
|
757
|
-
quant_scale=None,
|
758
|
-
quant_offset=None,
|
759
|
-
group_index=seg_indptr,
|
760
|
-
activate_left=True,
|
761
|
-
quant_mode=1,
|
762
|
-
)
|
763
|
-
|
764
|
-
# gmm2: down_proj
|
765
|
-
hidden_states = torch_npu.npu_grouped_matmul(
|
766
|
-
x=[hidden_states],
|
767
|
-
weight=[self.w2_weight],
|
768
|
-
scale=[self.w2_weight_scale.to(output_dtype)],
|
769
|
-
per_token_scale=[swiglu_out_scale],
|
770
|
-
split_item=2,
|
771
|
-
group_list_type=group_list_type,
|
772
|
-
group_type=0,
|
773
|
-
group_list=seg_indptr,
|
774
|
-
output_dtype=output_dtype,
|
775
|
-
)[0]
|
832
|
+
return hidden_states
|
776
833
|
|
777
|
-
|
834
|
+
def _forward_ll(dispatch_output: DeepEPLLOutput):
|
835
|
+
if TYPE_CHECKING:
|
836
|
+
assert isinstance(dispatch_output, DeepEPLLOutput)
|
837
|
+
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
|
838
|
+
|
839
|
+
per_token_scale = hidden_states[1]
|
840
|
+
hidden_states = hidden_states[0]
|
841
|
+
|
842
|
+
group_list = group_list.to(torch.int64)
|
843
|
+
|
844
|
+
# gmm1: gate_up_proj
|
845
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
846
|
+
x=[hidden_states],
|
847
|
+
weight=[self.w13_weight],
|
848
|
+
split_item=2,
|
849
|
+
group_list_type=group_list_type,
|
850
|
+
group_type=0,
|
851
|
+
group_list=group_list,
|
852
|
+
output_dtype=torch.int32,
|
853
|
+
)[0]
|
854
|
+
|
855
|
+
# act_fn: swiglu
|
856
|
+
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
857
|
+
x=hidden_states,
|
858
|
+
weight_scale=self.w13_weight_scale.to(torch.float32),
|
859
|
+
activation_scale=per_token_scale,
|
860
|
+
bias=None,
|
861
|
+
quant_scale=None,
|
862
|
+
quant_offset=None,
|
863
|
+
group_index=group_list,
|
864
|
+
activate_left=True,
|
865
|
+
quant_mode=1,
|
866
|
+
)
|
778
867
|
|
868
|
+
# gmm2: down_proj
|
869
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
870
|
+
x=[hidden_states],
|
871
|
+
weight=[self.w2_weight],
|
872
|
+
scale=[self.w2_weight_scale.to(output_dtype)],
|
873
|
+
per_token_scale=[swiglu_out_scale],
|
874
|
+
split_item=2,
|
875
|
+
group_list_type=group_list_type,
|
876
|
+
group_type=0,
|
877
|
+
group_list=group_list,
|
878
|
+
output_dtype=output_dtype,
|
879
|
+
)[0]
|
779
880
|
|
780
|
-
|
881
|
+
return hidden_states
|
882
|
+
|
883
|
+
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
884
|
+
return _forward_normal(dispatch_output)
|
885
|
+
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
886
|
+
return _forward_ll(dispatch_output)
|
887
|
+
else:
|
888
|
+
raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
|
889
|
+
|
890
|
+
|
891
|
+
def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
|
781
892
|
if get_moe_a2a_backend().is_deepep():
|
782
893
|
return DeepEPMoE
|
783
894
|
|
@@ -790,8 +901,7 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
|
790
901
|
return FusedMoE
|
791
902
|
try:
|
792
903
|
# Check the quantization argument directly
|
793
|
-
|
794
|
-
if quantization == "modelopt_fp4":
|
904
|
+
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
|
795
905
|
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
796
906
|
FlashInferFP4MoE,
|
797
907
|
)
|
@@ -800,10 +910,20 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
|
800
910
|
except:
|
801
911
|
pass
|
802
912
|
|
803
|
-
if should_use_flashinfer_trtllm_moe():
|
913
|
+
if should_use_flashinfer_trtllm_moe() and quant_config is not None:
|
914
|
+
# FIXME: FlashInferFusedMoE only supports fp8 quant now
|
804
915
|
return FlashInferFusedMoE
|
805
916
|
if get_moe_runner_backend().is_flashinfer_cutlass():
|
806
917
|
return FusedMoE
|
807
918
|
if get_moe_expert_parallel_world_size() > 1:
|
808
919
|
return EPMoE
|
809
920
|
return FusedMoE
|
921
|
+
|
922
|
+
|
923
|
+
def copy_list_to_gpu_no_ce(arr: List[int]):
|
924
|
+
from sgl_kernel.elementwise import copy_to_gpu_no_ce
|
925
|
+
|
926
|
+
tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
|
927
|
+
tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
|
928
|
+
copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
|
929
|
+
return tensor_gpu
|
@@ -0,0 +1,156 @@
|
|
1
|
+
from typing import Any, Dict, Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
|
5
|
+
from sgl_kernel.gemm import (
|
6
|
+
scaled_fp4_grouped_quant,
|
7
|
+
silu_and_mul_scaled_fp4_grouped_quant,
|
8
|
+
)
|
9
|
+
|
10
|
+
|
11
|
+
def get_cute_dtype(input: torch.Tensor) -> str:
|
12
|
+
if input.dtype == torch.bfloat16:
|
13
|
+
return "bfloat16"
|
14
|
+
elif input.dtype == torch.float16:
|
15
|
+
return "float16"
|
16
|
+
elif input.dtype == torch.float32:
|
17
|
+
return "float32"
|
18
|
+
else:
|
19
|
+
raise ValueError(f"Unsupported cute dtype {input.dtype}")
|
20
|
+
|
21
|
+
|
22
|
+
def flashinfer_cutedsl_moe_masked(
|
23
|
+
hidden_states: torch.Tensor,
|
24
|
+
input_global_scale: torch.Tensor,
|
25
|
+
w1: torch.Tensor,
|
26
|
+
w1_blockscale: torch.Tensor,
|
27
|
+
w1_alpha,
|
28
|
+
w2: torch.Tensor,
|
29
|
+
a2_global_scale: torch.Tensor,
|
30
|
+
w2_blockscale: torch.Tensor,
|
31
|
+
w2_alpha,
|
32
|
+
masked_m: torch.Tensor,
|
33
|
+
):
|
34
|
+
"""
|
35
|
+
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
|
36
|
+
kernels.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
hidden_states (torch.Tensor): [num_experts, m, k], bf16
|
40
|
+
input_global_scale (torch.Tensor): (l,)
|
41
|
+
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
|
42
|
+
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
|
43
|
+
w1_alpha (torch.Tensor): (l,)
|
44
|
+
w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8
|
45
|
+
a2_global_scale (torch.Tensor): (l,)
|
46
|
+
w2_blockscale (torch.Tensor): blockscale factors, e4m3,
|
47
|
+
w2_alpha (torch.Tensor): (l,)
|
48
|
+
masked_m (torch.Tensor): Masked dimension indices
|
49
|
+
|
50
|
+
Notes:
|
51
|
+
- Assumes max(masked_m) <= m.
|
52
|
+
"""
|
53
|
+
|
54
|
+
# === Assertions on dtypes ===
|
55
|
+
assert (
|
56
|
+
input_global_scale.dtype == torch.float32
|
57
|
+
), f"input_global_scale must be float32, got {input_global_scale.dtype}"
|
58
|
+
assert w1.dtype == torch.uint8, f"w1 must be uint8 (fp4 packed), got {w1.dtype}"
|
59
|
+
assert (
|
60
|
+
w1_blockscale.dtype == torch.float8_e4m3fn
|
61
|
+
), f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
|
62
|
+
assert (
|
63
|
+
w1_alpha.dtype == torch.float32
|
64
|
+
), f"w1_alpha must be float32, got {w1_alpha.dtype}"
|
65
|
+
assert w2.dtype == torch.uint8, f"w2 must be uint8 (fp4 packed), got {w2.dtype}"
|
66
|
+
assert (
|
67
|
+
a2_global_scale.dtype == torch.float32
|
68
|
+
), f"a2_global_scale must be float32, got {a2_global_scale.dtype}"
|
69
|
+
assert (
|
70
|
+
w2_blockscale.dtype == torch.float8_e4m3fn
|
71
|
+
), f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}"
|
72
|
+
assert (
|
73
|
+
w2_alpha.dtype == torch.float32
|
74
|
+
), f"w2_alpha must be float32, got {w2_alpha.dtype}"
|
75
|
+
|
76
|
+
# === Assertions on shapes ===
|
77
|
+
n = w2.shape[-1] * 2 # intermediate dimension
|
78
|
+
num_experts, m, k = hidden_states.shape
|
79
|
+
|
80
|
+
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
|
81
|
+
assert (
|
82
|
+
w1.shape[-1] * 2 == k
|
83
|
+
), f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}"
|
84
|
+
assert w2.shape[-2:] == (
|
85
|
+
k,
|
86
|
+
n // 2,
|
87
|
+
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n//2)}"
|
88
|
+
|
89
|
+
assert input_global_scale.shape == (
|
90
|
+
num_experts,
|
91
|
+
), f"input_global_scale must be (l,), got {input_global_scale.shape}"
|
92
|
+
assert w1_alpha.shape == (
|
93
|
+
num_experts,
|
94
|
+
), f"w1_alpha must be (l,), got {w1_alpha.shape}"
|
95
|
+
assert a2_global_scale.shape == (
|
96
|
+
num_experts,
|
97
|
+
), f"a2_global_scale must be (l,), got {a2_global_scale.shape}"
|
98
|
+
assert w2_alpha.shape == (
|
99
|
+
num_experts,
|
100
|
+
), f"w2_alpha must be (l,), got {w2_alpha.shape}"
|
101
|
+
|
102
|
+
aq, aq_sf = scaled_fp4_grouped_quant(
|
103
|
+
hidden_states,
|
104
|
+
input_global_scale,
|
105
|
+
masked_m,
|
106
|
+
)
|
107
|
+
gateup_output = torch.empty(
|
108
|
+
(num_experts, m, n * 2), dtype=hidden_states.dtype, device=aq.device
|
109
|
+
)
|
110
|
+
gateup_output = gateup_output.permute(1, 2, 0) # requirement of kernel
|
111
|
+
sf_vec_size = 16
|
112
|
+
assert aq_sf.dtype == torch.float8_e4m3fn
|
113
|
+
assert aq.dtype == torch.uint8
|
114
|
+
ab_dtype = "float4_e2m1fn"
|
115
|
+
sf_dtype = "float8_e4m3fn"
|
116
|
+
|
117
|
+
c_dtype = get_cute_dtype(hidden_states)
|
118
|
+
|
119
|
+
# Gemm1
|
120
|
+
|
121
|
+
grouped_gemm_nt_masked(
|
122
|
+
(aq, aq_sf),
|
123
|
+
(w1.permute(1, 2, 0), w1_blockscale),
|
124
|
+
gateup_output,
|
125
|
+
masked_m,
|
126
|
+
ab_dtype=ab_dtype,
|
127
|
+
sf_dtype=sf_dtype,
|
128
|
+
c_dtype=c_dtype,
|
129
|
+
sf_vec_size=sf_vec_size,
|
130
|
+
alpha=w1_alpha.view(1, 1, num_experts),
|
131
|
+
alpha_dtype=get_cute_dtype(w1_alpha),
|
132
|
+
) # in logical [m, n, l]
|
133
|
+
|
134
|
+
# SILU and quantization
|
135
|
+
diq, diq_sf = silu_and_mul_scaled_fp4_grouped_quant(
|
136
|
+
gateup_output.permute(2, 0, 1),
|
137
|
+
a2_global_scale,
|
138
|
+
masked_m,
|
139
|
+
)
|
140
|
+
|
141
|
+
# Gemm2
|
142
|
+
out = torch.empty_like(hidden_states)
|
143
|
+
out = out.permute(1, 2, 0) # requirement of kernel
|
144
|
+
grouped_gemm_nt_masked(
|
145
|
+
(diq, diq_sf),
|
146
|
+
(w2.permute(1, 2, 0), w2_blockscale),
|
147
|
+
out,
|
148
|
+
masked_m,
|
149
|
+
ab_dtype=ab_dtype,
|
150
|
+
sf_dtype=sf_dtype,
|
151
|
+
c_dtype=c_dtype,
|
152
|
+
sf_vec_size=sf_vec_size,
|
153
|
+
alpha=w2_alpha.view(1, 1, num_experts),
|
154
|
+
alpha_dtype=get_cute_dtype(w2_alpha),
|
155
|
+
) # in logical [m, k, l]
|
156
|
+
return out.permute(2, 0, 1)
|
@@ -8,16 +8,18 @@ from torch.nn import functional as F
|
|
8
8
|
|
9
9
|
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
10
10
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
11
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardDispatchOutput
|
11
12
|
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
12
13
|
|
13
14
|
|
14
15
|
def fused_moe_forward_native(
|
15
16
|
layer: torch.nn.Module,
|
16
|
-
|
17
|
-
topk_output: StandardTopKOutput,
|
18
|
-
moe_runner_config: MoeRunnerConfig,
|
17
|
+
dispatch_output: StandardDispatchOutput,
|
19
18
|
) -> torch.Tensor:
|
20
19
|
|
20
|
+
x, topk_output = dispatch_output
|
21
|
+
moe_runner_config = layer.moe_runner_config
|
22
|
+
|
21
23
|
if moe_runner_config.apply_router_weight_on_input:
|
22
24
|
raise NotImplementedError()
|
23
25
|
|