sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
7
7
|
import torch
|
8
8
|
from torch.nn.parameter import Parameter
|
9
9
|
|
10
|
+
from sglang.srt.distributed import get_tp_group
|
11
|
+
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
|
12
|
+
from sglang.srt.layers.moe import (
|
13
|
+
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
14
|
+
should_use_flashinfer_trtllm_moe,
|
15
|
+
)
|
10
16
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
11
|
-
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
12
17
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
13
18
|
from sglang.srt.layers.quantization.base_config import (
|
14
19
|
FusedMoEMethodBase,
|
@@ -30,10 +35,11 @@ from sglang.srt.layers.quantization.utils import (
|
|
30
35
|
requantize_with_max_scale,
|
31
36
|
)
|
32
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
33
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
34
38
|
from sglang.srt.utils import is_cuda, next_power_of_2
|
35
39
|
|
36
40
|
if TYPE_CHECKING:
|
41
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
42
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
37
43
|
from sglang.srt.layers.moe.topk import TopKOutput
|
38
44
|
|
39
45
|
if is_cuda():
|
@@ -105,18 +111,52 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
105
111
|
|
106
112
|
@classmethod
|
107
113
|
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
114
|
+
# Handle two different config formats:
|
115
|
+
# 1. hf_quant_config.json format: {"quantization": {"quant_algo": "FP8", ...}}
|
116
|
+
# 2. config.json quantization_config format: {"quant_algo": "FP8", ...}
|
117
|
+
# In future modelopt will deprecate hf_quant_config.json, and only keep config.json.
|
118
|
+
# For legacy reasons, we keep hf_quant_config.json for now.
|
119
|
+
|
120
|
+
# Initialize variables
|
121
|
+
kv_cache_quant_method = None
|
122
|
+
exclude_modules = None
|
123
|
+
|
124
|
+
# Try flat format first (config.json quantization_config - preferred format)
|
125
|
+
quant_method = config.get("quant_algo")
|
126
|
+
if quant_method is not None:
|
127
|
+
# Flat format (config.json quantization_config)
|
128
|
+
# For kv_cache, check if kv_cache_scheme exists and extract algo
|
129
|
+
kv_cache_scheme = config.get("kv_cache_scheme")
|
130
|
+
if (
|
131
|
+
kv_cache_scheme
|
132
|
+
and kv_cache_scheme.get("type") == "float"
|
133
|
+
and kv_cache_scheme.get("num_bits") == 8
|
134
|
+
):
|
135
|
+
kv_cache_quant_method = "FP8"
|
115
136
|
|
137
|
+
# Map 'ignore' field to 'exclude_modules'
|
138
|
+
exclude_modules = config.get("ignore")
|
139
|
+
else:
|
140
|
+
# Fall back to nested format (hf_quant_config.json - legacy format)
|
141
|
+
try:
|
142
|
+
quantization_section = cls.get_from_keys(config, ["quantization"])
|
143
|
+
quant_method = quantization_section.get("quant_algo")
|
144
|
+
kv_cache_quant_method = quantization_section.get("kv_cache_quant_algo")
|
145
|
+
exclude_modules = quantization_section.get("exclude_modules")
|
146
|
+
except ValueError:
|
147
|
+
raise ValueError(
|
148
|
+
"Cannot find 'quant_algo' in the model's quantization config. "
|
149
|
+
"Expected either flat format (config.json) or nested format (hf_quant_config.json)."
|
150
|
+
)
|
151
|
+
if quant_method is None:
|
152
|
+
raise ValueError(
|
153
|
+
"Cannot find 'quant_algo' in the model's quantization config. "
|
154
|
+
)
|
116
155
|
if "FP8" not in quant_method:
|
117
156
|
raise ValueError(
|
118
|
-
"
|
119
|
-
"
|
157
|
+
"ModelOptFp8Config only supports static FP8 quantization in SGLang. "
|
158
|
+
"For FP4 quantization, use ModelOptFp4Config. "
|
159
|
+
"Check the quantization config for your model's configuration."
|
120
160
|
)
|
121
161
|
|
122
162
|
return cls(
|
@@ -422,12 +462,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
422
462
|
layer: torch.nn.Module,
|
423
463
|
x: torch.Tensor,
|
424
464
|
topk_output: TopKOutput,
|
425
|
-
|
426
|
-
activation: str = "silu",
|
427
|
-
apply_router_weight_on_input: bool = False,
|
428
|
-
inplace: bool = True,
|
429
|
-
no_combine: bool = False,
|
430
|
-
routed_scaling_factor: Optional[float] = None,
|
465
|
+
moe_runner_config: MoeRunnerConfig,
|
431
466
|
) -> torch.Tensor:
|
432
467
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
433
468
|
|
@@ -436,15 +471,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
436
471
|
layer.w13_weight,
|
437
472
|
layer.w2_weight,
|
438
473
|
topk_output=topk_output,
|
439
|
-
|
440
|
-
activation=activation,
|
474
|
+
moe_runner_config=moe_runner_config,
|
441
475
|
use_fp8_w8a8=True,
|
442
476
|
per_channel_quant=False, # ModelOpt uses per-tensor quantization
|
443
477
|
w1_scale=layer.w13_weight_scale,
|
444
478
|
w2_scale=layer.w2_weight_scale,
|
445
479
|
a1_scale=layer.w13_input_scale,
|
446
480
|
a2_scale=layer.w2_input_scale,
|
447
|
-
no_combine=no_combine,
|
448
481
|
)
|
449
482
|
|
450
483
|
|
@@ -486,22 +519,63 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
486
519
|
|
487
520
|
@classmethod
|
488
521
|
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
|
489
|
-
|
490
|
-
|
522
|
+
# Handle two different config formats:
|
523
|
+
# 1. hf_quant_config.json format: {"quantization": {"quant_algo": "NVFP4", ...}}
|
524
|
+
# 2. config.json quantization_config format: {"quant_algo": "NVFP4", ...}
|
525
|
+
# In future modelopt will deprecate hf_quant_config.json, and only keep config.json.
|
526
|
+
# For legacy reasons, we keep hf_quant_config.json for now.
|
527
|
+
|
528
|
+
# Initialize variables
|
529
|
+
kv_cache_quant_algo = None
|
530
|
+
group_size = None
|
531
|
+
exclude_modules = []
|
532
|
+
|
533
|
+
# Try flat format first (config.json quantization_config - preferred format)
|
534
|
+
quant_method = config.get("quant_algo")
|
535
|
+
if quant_method is not None:
|
536
|
+
# Flat format (config.json quantization_config)
|
537
|
+
# Note: FP4 models in config.json format may not have all the detailed fields
|
538
|
+
# that are present in hf_quant_config.json, so we need to handle defaults
|
539
|
+
kv_cache_quant_algo = config.get("kv_cache_quant_algo")
|
540
|
+
if not kv_cache_quant_algo:
|
541
|
+
# For config.json format, derive from kv_cache_scheme if available
|
542
|
+
kv_cache_scheme = config.get("kv_cache_scheme")
|
543
|
+
if (
|
544
|
+
kv_cache_scheme
|
545
|
+
and kv_cache_scheme.get("type") == "float"
|
546
|
+
and kv_cache_scheme.get("num_bits") == 8
|
547
|
+
):
|
548
|
+
kv_cache_quant_algo = "FP8"
|
549
|
+
else:
|
550
|
+
kv_cache_quant_algo = "auto"
|
551
|
+
|
552
|
+
group_size = config.get("group_size")
|
553
|
+
exclude_modules = config.get("ignore", [])
|
554
|
+
else:
|
555
|
+
# Fall back to nested format (hf_quant_config.json - legacy format)
|
556
|
+
try:
|
557
|
+
quant_config = cls.get_from_keys(config, ["quantization"])
|
558
|
+
quant_method = quant_config["quant_algo"]
|
559
|
+
kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
|
560
|
+
if not kv_cache_quant_algo:
|
561
|
+
kv_cache_quant_algo = "auto"
|
562
|
+
group_size = quant_config.get("group_size")
|
563
|
+
exclude_modules = quant_config.get("exclude_modules", [])
|
564
|
+
except (ValueError, KeyError):
|
565
|
+
raise ValueError(
|
566
|
+
"Cannot find 'quant_algo' in the model's quantization config. "
|
567
|
+
"Expected either flat format (config.json) or nested format (hf_quant_config.json)."
|
568
|
+
)
|
569
|
+
|
491
570
|
if not quant_method in ["FP8", "NVFP4"]:
|
492
571
|
raise ValueError(
|
493
572
|
f"ModelOpt currently only supports: FP8, NVFP4"
|
494
573
|
" quantizations in sglang. Please check the "
|
495
|
-
"
|
496
|
-
"quant configuration."
|
574
|
+
"quantization config for your model's configuration."
|
497
575
|
)
|
498
576
|
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
499
|
-
|
500
|
-
if not kv_cache_quant_algo:
|
501
|
-
kv_cache_quant_algo = "auto"
|
502
|
-
group_size = quant_config["group_size"]
|
503
|
-
exclude_modules = quant_config["exclude_modules"]
|
504
|
-
if not (group_size and kv_cache_quant_algo and exclude_modules):
|
577
|
+
|
578
|
+
if not (group_size and kv_cache_quant_algo) or exclude_modules is None:
|
505
579
|
logger.warning(
|
506
580
|
f"group_size: {group_size},"
|
507
581
|
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
|
@@ -509,8 +583,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
509
583
|
)
|
510
584
|
raise ValueError(
|
511
585
|
"NVFP4 quantization requires group size and "
|
512
|
-
"kv_cache_quant_algo specified in "
|
513
|
-
"hf_quant_config.json"
|
586
|
+
"kv_cache_quant_algo specified in the quantization config"
|
514
587
|
)
|
515
588
|
return cls(
|
516
589
|
is_checkpoint_nvfp4_serialized,
|
@@ -737,11 +810,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
737
810
|
" above."
|
738
811
|
)
|
739
812
|
self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
813
|
+
self._cache_permute_indices = {}
|
740
814
|
|
741
815
|
@property
|
742
816
|
def enable_flashinfer_cutlass_moe(self) -> bool:
|
817
|
+
from sglang.srt.layers.moe import get_moe_runner_backend
|
818
|
+
|
743
819
|
"""Access the global enable_flashinfer_cutlass_moe setting."""
|
744
|
-
return
|
820
|
+
return get_moe_runner_backend().is_flashinfer_cutlass()
|
745
821
|
|
746
822
|
def create_weights(
|
747
823
|
self,
|
@@ -810,6 +886,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
810
886
|
)
|
811
887
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
812
888
|
|
889
|
+
# Only use `swizzle_blockscale` for shapes, not for real content
|
890
|
+
layer.w13_blockscale_swizzled = Parameter(
|
891
|
+
self.swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
|
892
|
+
)
|
893
|
+
|
813
894
|
w2_weight_scale = ModelWeightParameter(
|
814
895
|
data=torch.empty(
|
815
896
|
layer.num_local_experts,
|
@@ -824,6 +905,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
824
905
|
)
|
825
906
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
826
907
|
|
908
|
+
layer.w2_blockscale_swizzled = Parameter(
|
909
|
+
self.swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
|
910
|
+
)
|
911
|
+
|
827
912
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
828
913
|
|
829
914
|
extra_weight_attrs.update(
|
@@ -900,10 +985,15 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
900
985
|
e2m1_and_ufp8sf_scale_to_float,
|
901
986
|
fp4_quantize,
|
902
987
|
next_positive_power_of_2,
|
988
|
+
nvfp4_block_scale_interleave,
|
903
989
|
reorder_rows_for_gated_act_gemm,
|
904
990
|
shuffle_matrix_a,
|
905
991
|
shuffle_matrix_sf_a,
|
906
992
|
)
|
993
|
+
from flashinfer.fused_moe.core import (
|
994
|
+
_maybe_get_cached_w2_permute_indices,
|
995
|
+
_maybe_get_cached_w3_w1_permute_indices,
|
996
|
+
)
|
907
997
|
|
908
998
|
"""Prepare quantized weights for kernel (done offline with weights)."""
|
909
999
|
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
@@ -927,50 +1017,66 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
927
1017
|
num_experts, hidden_size, intermediate_size // 16
|
928
1018
|
) # fp8 scaling factors
|
929
1019
|
|
930
|
-
# Reorder rows of W1 and scales for fused gated activation
|
931
|
-
gemm1_weights_fp4_interleaved = []
|
932
|
-
gemm1_scales_fp4_interleaved = []
|
933
|
-
for i in range(num_experts):
|
934
|
-
gemm1_weights_fp4_interleaved.append(
|
935
|
-
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
|
936
|
-
)
|
937
|
-
gemm1_scales_fp4_interleaved.append(
|
938
|
-
reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
|
939
|
-
)
|
940
|
-
|
941
|
-
# Stack weights and scales for all experts
|
942
|
-
gemm1_weights_fp4_interleaved = torch.stack(
|
943
|
-
gemm1_weights_fp4_interleaved
|
944
|
-
).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
|
945
|
-
gemm1_scales_fp4_interleaved = torch.stack(
|
946
|
-
gemm1_scales_fp4_interleaved
|
947
|
-
).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
|
948
|
-
|
949
|
-
# Shuffle weights and scaling factors for transposed mma output
|
950
1020
|
gemm1_weights_fp4_shuffled = []
|
951
1021
|
gemm1_scales_fp4_shuffled = []
|
952
1022
|
gemm2_weights_fp4_shuffled = []
|
953
1023
|
gemm2_scales_fp4_shuffled = []
|
954
1024
|
for i in range(num_experts):
|
1025
|
+
# Calculate the permute indices for the following:
|
1026
|
+
# 1. Reorder rows of W1 and scales for fused gated activation
|
1027
|
+
# 2. Shuffle weights and scaling factors for transposed mma output
|
1028
|
+
# for both w3_w1 and w2 weights and scale factors
|
1029
|
+
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
|
1030
|
+
self._cache_permute_indices,
|
1031
|
+
gemm1_weights_fp4[i].view(torch.uint8),
|
1032
|
+
epilogue_tile_m,
|
1033
|
+
)
|
955
1034
|
gemm1_weights_fp4_shuffled.append(
|
956
|
-
|
957
|
-
|
958
|
-
)
|
1035
|
+
gemm1_weights_fp4[i]
|
1036
|
+
.view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)]
|
1037
|
+
.contiguous()
|
1038
|
+
)
|
1039
|
+
|
1040
|
+
permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices(
|
1041
|
+
self._cache_permute_indices,
|
1042
|
+
gemm1_scales_linear_fp4[i].view(torch.uint8),
|
1043
|
+
epilogue_tile_m,
|
1044
|
+
num_elts_per_sf=16,
|
959
1045
|
)
|
960
1046
|
gemm1_scales_fp4_shuffled.append(
|
961
|
-
|
962
|
-
|
1047
|
+
nvfp4_block_scale_interleave(
|
1048
|
+
gemm1_scales_linear_fp4[i]
|
1049
|
+
.view(torch.uint8)[
|
1050
|
+
permute_sf_indices.to(gemm1_scales_linear_fp4.device)
|
1051
|
+
]
|
1052
|
+
.contiguous()
|
963
1053
|
)
|
964
1054
|
)
|
965
1055
|
|
1056
|
+
permute_indices = _maybe_get_cached_w2_permute_indices(
|
1057
|
+
self._cache_permute_indices,
|
1058
|
+
gemm2_weights_fp4[i].view(torch.uint8),
|
1059
|
+
epilogue_tile_m,
|
1060
|
+
)
|
966
1061
|
gemm2_weights_fp4_shuffled.append(
|
967
|
-
|
968
|
-
|
969
|
-
)
|
1062
|
+
gemm2_weights_fp4[i]
|
1063
|
+
.view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)]
|
1064
|
+
.contiguous()
|
1065
|
+
)
|
1066
|
+
|
1067
|
+
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
1068
|
+
self._cache_permute_indices,
|
1069
|
+
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
1070
|
+
epilogue_tile_m,
|
1071
|
+
num_elts_per_sf=16,
|
970
1072
|
)
|
971
1073
|
gemm2_scales_fp4_shuffled.append(
|
972
|
-
|
973
|
-
gemm2_scales_linear_fp4[i]
|
1074
|
+
nvfp4_block_scale_interleave(
|
1075
|
+
gemm2_scales_linear_fp4[i]
|
1076
|
+
.view(torch.uint8)[
|
1077
|
+
permute_sf_indices.to(gemm2_scales_linear_fp4.device)
|
1078
|
+
]
|
1079
|
+
.contiguous()
|
974
1080
|
)
|
975
1081
|
)
|
976
1082
|
|
@@ -1106,16 +1212,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1106
1212
|
|
1107
1213
|
# Process w13 weights
|
1108
1214
|
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
|
1109
|
-
layer.w13_blockscale_swizzled
|
1110
|
-
w13_blockscale_swizzled, requires_grad=False
|
1111
|
-
)
|
1215
|
+
layer.w13_blockscale_swizzled.data.copy_(w13_blockscale_swizzled)
|
1112
1216
|
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
1113
1217
|
|
1114
1218
|
# Process w2 weights
|
1115
1219
|
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
|
1116
|
-
layer.w2_blockscale_swizzled
|
1117
|
-
w2_blockscale_swizzled, requires_grad=False
|
1118
|
-
)
|
1220
|
+
layer.w2_blockscale_swizzled.data.copy_(w2_blockscale_swizzled)
|
1119
1221
|
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
1120
1222
|
|
1121
1223
|
# Both flashinfer cutlass and regular cutlass use same processing for w2
|
@@ -1138,21 +1240,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1138
1240
|
|
1139
1241
|
def apply(
|
1140
1242
|
self,
|
1141
|
-
layer:
|
1243
|
+
layer: FusedMoE,
|
1142
1244
|
x: torch.Tensor,
|
1143
1245
|
topk_output: TopKOutput,
|
1144
|
-
|
1145
|
-
activation: str = "silu",
|
1146
|
-
apply_router_weight_on_input: bool = False,
|
1147
|
-
inplace: bool = True,
|
1148
|
-
no_combine: bool = False,
|
1149
|
-
routed_scaling_factor: Optional[float] = None,
|
1150
|
-
ep_rank: Optional[int] = None,
|
1151
|
-
ep_size: Optional[int] = None,
|
1152
|
-
tp_rank: Optional[int] = None,
|
1153
|
-
tp_size: Optional[int] = None,
|
1246
|
+
moe_runner_config: MoeRunnerConfig,
|
1154
1247
|
) -> torch.Tensor:
|
1155
|
-
assert
|
1248
|
+
assert (
|
1249
|
+
moe_runner_config.activation == "silu"
|
1250
|
+
), "Only SiLU activation is supported."
|
1156
1251
|
|
1157
1252
|
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
|
1158
1253
|
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
|
@@ -1161,20 +1256,41 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1161
1256
|
|
1162
1257
|
if self.enable_flashinfer_cutlass_moe:
|
1163
1258
|
assert (
|
1164
|
-
not apply_router_weight_on_input
|
1259
|
+
not moe_runner_config.apply_router_weight_on_input
|
1165
1260
|
), "apply_router_weight_on_input is not supported for Flashinfer"
|
1166
1261
|
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
|
1167
1262
|
# and fp4 quantized weights loaded from the checkpoint
|
1168
|
-
|
1169
1263
|
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
|
1170
1264
|
|
1265
|
+
output_dtype = x.dtype
|
1266
|
+
x_sf = None
|
1267
|
+
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
1268
|
+
from flashinfer import fp4_quantize, nvfp4_block_scale_interleave
|
1269
|
+
|
1270
|
+
# Quantize before comm, swizzle after.
|
1271
|
+
if x.shape[0] > 0:
|
1272
|
+
x, x_sf = fp4_quantize(
|
1273
|
+
x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False
|
1274
|
+
)
|
1275
|
+
else:
|
1276
|
+
x_col = x.shape[1]
|
1277
|
+
x = torch.zeros(0, x_col // 2, dtype=torch.uint8, device=x.device)
|
1278
|
+
x_sf = torch.zeros(
|
1279
|
+
0, x_col // 16, dtype=torch.uint8, device=x.device
|
1280
|
+
)
|
1281
|
+
topk_weights, topk_ids, x, x_sf = get_tp_group().all_gatherv(
|
1282
|
+
[topk_weights, topk_ids, x, x_sf], sizes=get_dp_global_num_tokens()
|
1283
|
+
)
|
1284
|
+
x_sf = nvfp4_block_scale_interleave(x_sf)
|
1285
|
+
|
1171
1286
|
output = flashinfer_cutlass_fused_moe(
|
1172
|
-
x,
|
1173
|
-
topk_ids.to(torch.int),
|
1174
|
-
topk_weights,
|
1175
|
-
layer.w13_weight.view(torch.long),
|
1176
|
-
layer.w2_weight.view(torch.long),
|
1177
|
-
|
1287
|
+
input=x,
|
1288
|
+
token_selected_experts=topk_ids.to(torch.int),
|
1289
|
+
token_final_scales=topk_weights,
|
1290
|
+
fc1_expert_weights=layer.w13_weight.view(torch.long),
|
1291
|
+
fc2_expert_weights=layer.w2_weight.view(torch.long),
|
1292
|
+
output_dtype=output_dtype,
|
1293
|
+
input_sf=x_sf,
|
1178
1294
|
quant_scales=[
|
1179
1295
|
layer.w13_input_scale_quant,
|
1180
1296
|
layer.w13_blockscale_swizzled.view(torch.int32),
|
@@ -1183,14 +1299,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1183
1299
|
layer.w2_blockscale_swizzled.view(torch.int32),
|
1184
1300
|
layer.g2_alphas,
|
1185
1301
|
],
|
1186
|
-
ep_size=
|
1187
|
-
ep_rank=
|
1188
|
-
tp_size=
|
1189
|
-
tp_rank=
|
1302
|
+
ep_size=layer.moe_ep_size,
|
1303
|
+
ep_rank=layer.moe_ep_rank,
|
1304
|
+
tp_size=layer.moe_tp_size,
|
1305
|
+
tp_rank=layer.moe_tp_rank,
|
1190
1306
|
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
1191
1307
|
)[0]
|
1192
|
-
|
1193
|
-
|
1308
|
+
# Scale by routed_scaling_factor is fused into select_experts.
|
1309
|
+
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
1310
|
+
output, global_output = get_local_dp_buffer(), output
|
1311
|
+
get_tp_group().reduce_scatterv(
|
1312
|
+
global_output, output=output, sizes=get_dp_global_num_tokens()
|
1313
|
+
)
|
1194
1314
|
return output
|
1195
1315
|
|
1196
1316
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
@@ -1209,8 +1329,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1209
1329
|
topk_weights=topk_weights,
|
1210
1330
|
topk_ids=topk_ids,
|
1211
1331
|
params=layer.cutlass_moe_params,
|
1212
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
1332
|
+
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
|
1213
1333
|
).to(x.dtype)
|
1214
|
-
|
1215
|
-
output *= routed_scaling_factor
|
1334
|
+
# Scale by routed_scaling_factor is fused into select_experts.
|
1216
1335
|
return output
|
@@ -22,6 +22,7 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
|
|
22
22
|
logger = logging.getLogger(__name__)
|
23
23
|
|
24
24
|
if TYPE_CHECKING:
|
25
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
25
26
|
from sglang.srt.layers.moe.topk import TopKOutput
|
26
27
|
|
27
28
|
|
@@ -353,17 +354,14 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|
353
354
|
layer: torch.nn.Module,
|
354
355
|
x: torch.Tensor,
|
355
356
|
topk_output: TopKOutput,
|
356
|
-
|
357
|
-
activation: str = "silu",
|
358
|
-
apply_router_weight_on_input: bool = False,
|
359
|
-
inplace: bool = True,
|
360
|
-
no_combine: bool = False,
|
361
|
-
routed_scaling_factor: Optional[float] = None,
|
357
|
+
moe_runner_config: MoeRunnerConfig,
|
362
358
|
) -> torch.Tensor:
|
363
359
|
# avoid circular import
|
364
360
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
365
361
|
|
366
|
-
assert
|
362
|
+
assert (
|
363
|
+
moe_runner_config.activation == "silu"
|
364
|
+
), "Only SiLU activation is supported."
|
367
365
|
|
368
366
|
weight_bits = self.quant_config.weight_bits
|
369
367
|
has_zp = self.quant_config.has_zp
|
@@ -373,8 +371,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|
373
371
|
layer.w13_qweight,
|
374
372
|
layer.w2_qweight,
|
375
373
|
topk_output=topk_output,
|
376
|
-
|
377
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
374
|
+
moe_runner_config=moe_runner_config,
|
378
375
|
use_int4_w4a16=weight_bits == 4,
|
379
376
|
use_int8_w8a16=weight_bits == 8,
|
380
377
|
w1_scale=layer.w13_scales,
|
@@ -382,8 +379,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|
382
379
|
w1_zp=layer.w13_qzeros if has_zp else None,
|
383
380
|
w2_zp=layer.w2_qzeros if has_zp else None,
|
384
381
|
block_shape=[0, layer.group_size],
|
385
|
-
no_combine=no_combine,
|
386
|
-
routed_scaling_factor=routed_scaling_factor,
|
387
382
|
)
|
388
383
|
|
389
384
|
@staticmethod
|
@@ -486,16 +481,16 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|
486
481
|
)
|
487
482
|
|
488
483
|
if "w13_qzeros" in weight_name:
|
489
|
-
tensor = loaded_weight.view(
|
490
|
-
|
491
|
-
]
|
484
|
+
tensor = loaded_weight.view(
|
485
|
+
layer.moe_tp_size, -1, loaded_weight.size(1)
|
486
|
+
)[tp_rank]
|
492
487
|
if shard_id == "w1":
|
493
488
|
param.data[expert_id, : shard_size // 2] = tensor
|
494
489
|
else:
|
495
490
|
param.data[expert_id, shard_size // 2 :] = tensor
|
496
491
|
elif "w2_qzeros" in weight_name:
|
497
492
|
param.data[expert_id] = loaded_weight.view(
|
498
|
-
loaded_weight.size(0), layer.
|
493
|
+
loaded_weight.size(0), layer.moe_tp_size, -1
|
499
494
|
)[:, tp_rank]
|
500
495
|
else:
|
501
496
|
weight_loader(param, loaded_weight, weight_name, shard_id, expert_id)
|