sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter
|
|
10
10
|
from sglang.srt.distributed import get_tp_group
|
11
11
|
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
|
12
12
|
from sglang.srt.layers.moe import (
|
13
|
+
MoeRunner,
|
14
|
+
MoeRunnerBackend,
|
15
|
+
MoeRunnerConfig,
|
13
16
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
14
17
|
should_use_flashinfer_trtllm_moe,
|
15
18
|
)
|
16
19
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
20
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
17
21
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
18
22
|
from sglang.srt.layers.quantization.base_config import (
|
19
23
|
FusedMoEMethodBase,
|
@@ -39,8 +43,10 @@ from sglang.srt.utils import is_cuda, next_power_of_2
|
|
39
43
|
|
40
44
|
if TYPE_CHECKING:
|
41
45
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
42
|
-
from sglang.srt.layers.moe.
|
43
|
-
|
46
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
47
|
+
CombineInput,
|
48
|
+
StandardDispatchOutput,
|
49
|
+
)
|
44
50
|
|
45
51
|
if is_cuda():
|
46
52
|
from sgl_kernel import scaled_fp4_quant
|
@@ -322,7 +328,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
322
328
|
layer: torch.nn.Module,
|
323
329
|
num_experts: int,
|
324
330
|
hidden_size: int,
|
325
|
-
|
331
|
+
intermediate_size_per_partition: int,
|
326
332
|
params_dtype: torch.dtype,
|
327
333
|
**extra_weight_attrs,
|
328
334
|
):
|
@@ -338,7 +344,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
338
344
|
|
339
345
|
w13_weight = ModelWeightParameter(
|
340
346
|
data=torch.empty(
|
341
|
-
num_experts,
|
347
|
+
num_experts,
|
348
|
+
2 * intermediate_size_per_partition,
|
349
|
+
hidden_size,
|
350
|
+
dtype=weight_dtype,
|
342
351
|
),
|
343
352
|
input_dim=2,
|
344
353
|
output_dim=1,
|
@@ -348,7 +357,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
348
357
|
|
349
358
|
w2_weight = ModelWeightParameter(
|
350
359
|
data=torch.empty(
|
351
|
-
num_experts,
|
360
|
+
num_experts,
|
361
|
+
hidden_size,
|
362
|
+
intermediate_size_per_partition,
|
363
|
+
dtype=weight_dtype,
|
352
364
|
),
|
353
365
|
input_dim=2,
|
354
366
|
output_dim=1,
|
@@ -414,28 +426,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
414
426
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
415
427
|
|
416
428
|
# Requantize each expert's weights using the combined scale
|
417
|
-
# w13_weight has shape (num_experts, 2 *
|
418
|
-
# where the first
|
419
|
-
|
429
|
+
# w13_weight has shape (num_experts, 2 * intermediate_size_per_partition, hidden_size)
|
430
|
+
# where the first intermediate_size_per_partition rows are w1, the next are w3
|
431
|
+
intermediate_size_per_partition = layer.w13_weight.shape[1] // 2
|
420
432
|
for expert_id in range(layer.w13_weight.shape[0]):
|
421
433
|
start = 0
|
422
434
|
for shard_id in range(2): # w1 and w3
|
423
435
|
# Dequantize using the original scale for this shard
|
424
436
|
dq_weight = per_tensor_dequantize(
|
425
437
|
layer.w13_weight[expert_id][
|
426
|
-
start : start +
|
438
|
+
start : start + intermediate_size_per_partition, :
|
427
439
|
],
|
428
440
|
layer.w13_weight_scale[expert_id][shard_id],
|
429
441
|
)
|
430
442
|
# Requantize using the combined max scale
|
431
443
|
(
|
432
444
|
layer.w13_weight[expert_id][
|
433
|
-
start : start +
|
445
|
+
start : start + intermediate_size_per_partition, :
|
434
446
|
],
|
435
447
|
_,
|
436
448
|
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
437
449
|
|
438
|
-
start +=
|
450
|
+
start += intermediate_size_per_partition
|
439
451
|
|
440
452
|
# Update the scale parameter to be per-expert instead of per-shard
|
441
453
|
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
|
@@ -457,29 +469,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
457
469
|
layer.w2_input_scale.max(), requires_grad=False
|
458
470
|
)
|
459
471
|
|
472
|
+
def create_moe_runner(
|
473
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
474
|
+
):
|
475
|
+
self.moe_runner_config = moe_runner_config
|
476
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
477
|
+
|
460
478
|
def apply(
|
461
479
|
self,
|
462
480
|
layer: torch.nn.Module,
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
return fused_experts(
|
470
|
-
x,
|
471
|
-
layer.w13_weight,
|
472
|
-
layer.w2_weight,
|
473
|
-
topk_output=topk_output,
|
474
|
-
moe_runner_config=moe_runner_config,
|
481
|
+
dispatch_output: StandardDispatchOutput,
|
482
|
+
) -> CombineInput:
|
483
|
+
|
484
|
+
quant_info = TritonMoeQuantInfo(
|
485
|
+
w13_weight=layer.w13_weight,
|
486
|
+
w2_weight=layer.w2_weight,
|
475
487
|
use_fp8_w8a8=True,
|
476
|
-
per_channel_quant=False,
|
477
|
-
|
488
|
+
per_channel_quant=False,
|
489
|
+
w13_scale=layer.w13_weight_scale,
|
478
490
|
w2_scale=layer.w2_weight_scale,
|
479
|
-
|
491
|
+
a13_scale=layer.w13_input_scale,
|
480
492
|
a2_scale=layer.w2_input_scale,
|
481
493
|
)
|
482
494
|
|
495
|
+
return self.runner.run(dispatch_output, quant_info)
|
496
|
+
|
483
497
|
|
484
498
|
class ModelOptFp4Config(QuantizationConfig):
|
485
499
|
"""Config class for FP4."""
|
@@ -517,6 +531,39 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
517
531
|
def get_config_filenames(cls) -> List[str]:
|
518
532
|
return ["hf_quant_config.json"]
|
519
533
|
|
534
|
+
@staticmethod
|
535
|
+
def common_group_size(cfg: dict) -> int:
|
536
|
+
"""Return the unique group_size across the config; raise if missing/mismatched."""
|
537
|
+
sizes = set()
|
538
|
+
|
539
|
+
# Top-level and 'quantization' block
|
540
|
+
v = cfg.get("group_size")
|
541
|
+
if isinstance(v, int):
|
542
|
+
sizes.add(v)
|
543
|
+
q = cfg.get("quantization")
|
544
|
+
if isinstance(q, dict):
|
545
|
+
v = q.get("group_size")
|
546
|
+
if isinstance(v, int):
|
547
|
+
sizes.add(v)
|
548
|
+
|
549
|
+
# config_groups: accept group-level or nested dicts (e.g., weights/input_activations)
|
550
|
+
for g in (cfg.get("config_groups") or {}).values():
|
551
|
+
if isinstance(g, dict):
|
552
|
+
v = g.get("group_size")
|
553
|
+
if isinstance(v, int):
|
554
|
+
sizes.add(v)
|
555
|
+
for sub in g.values():
|
556
|
+
if isinstance(sub, dict):
|
557
|
+
v = sub.get("group_size")
|
558
|
+
if isinstance(v, int):
|
559
|
+
sizes.add(v)
|
560
|
+
|
561
|
+
if not sizes:
|
562
|
+
raise ValueError("No group_size found in config.")
|
563
|
+
if len(sizes) > 1:
|
564
|
+
raise ValueError(f"Inconsistent group_size values: {sorted(sizes)}")
|
565
|
+
return next(iter(sizes))
|
566
|
+
|
520
567
|
@classmethod
|
521
568
|
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
|
522
569
|
# Handle two different config formats:
|
@@ -549,7 +596,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
549
596
|
else:
|
550
597
|
kv_cache_quant_algo = "auto"
|
551
598
|
|
552
|
-
group_size =
|
599
|
+
group_size = ModelOptFp4Config.common_group_size(config)
|
553
600
|
exclude_modules = config.get("ignore", [])
|
554
601
|
else:
|
555
602
|
# Fall back to nested format (hf_quant_config.json - legacy format)
|
@@ -559,7 +606,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
559
606
|
kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
|
560
607
|
if not kv_cache_quant_algo:
|
561
608
|
kv_cache_quant_algo = "auto"
|
562
|
-
group_size =
|
609
|
+
group_size = ModelOptFp4Config.common_group_size(config)
|
563
610
|
exclude_modules = quant_config.get("exclude_modules", [])
|
564
611
|
except (ValueError, KeyError):
|
565
612
|
raise ValueError(
|
@@ -595,10 +642,22 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
595
642
|
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
596
643
|
import regex as re
|
597
644
|
|
645
|
+
fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
|
646
|
+
prefix_split = prefix.split(".")
|
598
647
|
for pattern in exclude_modules:
|
599
648
|
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
649
|
+
pattern_split = pattern.split(".")
|
600
650
|
if re.fullmatch(regex_str, prefix):
|
601
651
|
return True
|
652
|
+
elif (
|
653
|
+
pattern_split[-1] in fused_patterns
|
654
|
+
and pattern_split[-1] in prefix_split[-1]
|
655
|
+
):
|
656
|
+
# Check if the last part of the excluded pattern is contained in the last part of the prefix
|
657
|
+
# This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
|
658
|
+
# e.g., model.layers.{i}.self_attn.{fused_weight_name}
|
659
|
+
assert len(prefix_split) == 5 and len(pattern_split) == 5
|
660
|
+
return True
|
602
661
|
return False
|
603
662
|
|
604
663
|
def get_quant_method(
|
@@ -1203,8 +1262,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1203
1262
|
layer.w13_weight_scale,
|
1204
1263
|
)
|
1205
1264
|
|
1206
|
-
logger.info_once("Applied flashinfer weight processing for both w13 and w2")
|
1207
|
-
|
1208
1265
|
else:
|
1209
1266
|
# CUTLASS processing - handle w13 and w2 separately
|
1210
1267
|
|
@@ -1221,7 +1278,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1221
1278
|
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
1222
1279
|
|
1223
1280
|
# Both flashinfer cutlass and regular cutlass use same processing for w2
|
1224
|
-
logger.info_once("Applied weight processing for both w13 and w2")
|
1225
1281
|
|
1226
1282
|
# Set up CUTLASS MoE parameters
|
1227
1283
|
device = layer.w13_weight.device
|
@@ -1238,21 +1294,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1238
1294
|
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
1239
1295
|
return self.enable_flashinfer_cutlass_moe
|
1240
1296
|
|
1297
|
+
def create_moe_runner(
|
1298
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
1299
|
+
):
|
1300
|
+
self.moe_runner_config = moe_runner_config
|
1301
|
+
|
1241
1302
|
def apply(
|
1242
1303
|
self,
|
1243
1304
|
layer: FusedMoE,
|
1244
|
-
|
1245
|
-
|
1246
|
-
|
1247
|
-
|
1305
|
+
dispatch_output: StandardDispatchOutput,
|
1306
|
+
) -> CombineInput:
|
1307
|
+
|
1308
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
1309
|
+
|
1310
|
+
x = dispatch_output.hidden_states
|
1311
|
+
topk_output = dispatch_output.topk_output
|
1312
|
+
|
1248
1313
|
assert (
|
1249
|
-
moe_runner_config.activation == "silu"
|
1314
|
+
self.moe_runner_config.activation == "silu"
|
1250
1315
|
), "Only SiLU activation is supported."
|
1251
1316
|
|
1317
|
+
moe_runner_config = self.moe_runner_config
|
1318
|
+
|
1252
1319
|
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
|
1253
1320
|
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
|
1254
1321
|
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
|
1255
|
-
return layer.forward(x, topk_output)
|
1322
|
+
return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
|
1256
1323
|
|
1257
1324
|
if self.enable_flashinfer_cutlass_moe:
|
1258
1325
|
assert (
|
@@ -1305,13 +1372,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1305
1372
|
tp_rank=layer.moe_tp_rank,
|
1306
1373
|
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
1307
1374
|
)[0]
|
1308
|
-
# Scale by routed_scaling_factor is fused into select_experts.
|
1309
1375
|
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
1310
1376
|
output, global_output = get_local_dp_buffer(), output
|
1311
1377
|
get_tp_group().reduce_scatterv(
|
1312
1378
|
global_output, output=output, sizes=get_dp_global_num_tokens()
|
1313
1379
|
)
|
1314
|
-
return output
|
1380
|
+
return StandardCombineInput(hidden_states=output)
|
1315
1381
|
|
1316
1382
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
1317
1383
|
|
@@ -1332,4 +1398,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1332
1398
|
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
|
1333
1399
|
).to(x.dtype)
|
1334
1400
|
# Scale by routed_scaling_factor is fused into select_experts.
|
1335
|
-
|
1401
|
+
|
1402
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -9,6 +9,8 @@ import torch
|
|
9
9
|
|
10
10
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
11
11
|
from sglang.srt.distributed.parallel_state import get_tp_group
|
12
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
13
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
12
14
|
from sglang.srt.layers.quantization.awq import AWQConfig
|
13
15
|
from sglang.srt.layers.quantization.base_config import (
|
14
16
|
FusedMoEMethodBase,
|
@@ -22,8 +24,10 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
|
|
22
24
|
logger = logging.getLogger(__name__)
|
23
25
|
|
24
26
|
if TYPE_CHECKING:
|
25
|
-
from sglang.srt.layers.moe.
|
26
|
-
|
27
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
28
|
+
CombineInput,
|
29
|
+
StandardDispatchOutput,
|
30
|
+
)
|
27
31
|
|
28
32
|
|
29
33
|
def get_weight_perm(num_bits: int):
|
@@ -349,37 +353,36 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|
349
353
|
layer.register_parameter(key, param)
|
350
354
|
set_weight_attrs(param, extra_weight_attrs)
|
351
355
|
|
356
|
+
def create_moe_runner(
|
357
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
358
|
+
):
|
359
|
+
self.moe_runner_config = moe_runner_config
|
360
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
361
|
+
|
352
362
|
def apply(
|
353
363
|
self,
|
354
364
|
layer: torch.nn.Module,
|
355
|
-
|
356
|
-
|
357
|
-
moe_runner_config: MoeRunnerConfig,
|
358
|
-
) -> torch.Tensor:
|
359
|
-
# avoid circular import
|
360
|
-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
361
|
-
|
365
|
+
dispatch_output: StandardDispatchOutput,
|
366
|
+
) -> CombineInput:
|
362
367
|
assert (
|
363
|
-
moe_runner_config.activation == "silu"
|
368
|
+
self.moe_runner_config.activation == "silu"
|
364
369
|
), "Only SiLU activation is supported."
|
365
370
|
|
366
371
|
weight_bits = self.quant_config.weight_bits
|
367
372
|
has_zp = self.quant_config.has_zp
|
368
373
|
|
369
|
-
|
370
|
-
|
371
|
-
layer.
|
372
|
-
layer.w2_qweight,
|
373
|
-
topk_output=topk_output,
|
374
|
-
moe_runner_config=moe_runner_config,
|
374
|
+
quant_info = TritonMoeQuantInfo(
|
375
|
+
w13_weight=layer.w13_qweight,
|
376
|
+
w2_weight=layer.w2_qweight,
|
375
377
|
use_int4_w4a16=weight_bits == 4,
|
376
378
|
use_int8_w8a16=weight_bits == 8,
|
377
|
-
|
379
|
+
w13_scale=layer.w13_scales,
|
378
380
|
w2_scale=layer.w2_scales,
|
379
|
-
|
381
|
+
w13_zp=layer.w13_qzeros if has_zp else None,
|
380
382
|
w2_zp=layer.w2_qzeros if has_zp else None,
|
381
383
|
block_shape=[0, layer.group_size],
|
382
384
|
)
|
385
|
+
return self.runner.run(dispatch_output, quant_info)
|
383
386
|
|
384
387
|
@staticmethod
|
385
388
|
def get_weight_loader(layer, weight_loader):
|