sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
7
7
|
import torch
|
8
8
|
from torch.nn.parameter import Parameter
|
9
9
|
|
10
|
+
from sglang.srt.distributed import get_tp_group
|
11
|
+
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
|
12
|
+
from sglang.srt.layers.moe import (
|
13
|
+
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
14
|
+
should_use_flashinfer_trtllm_moe,
|
15
|
+
)
|
10
16
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
11
|
-
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
12
17
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
13
18
|
from sglang.srt.layers.quantization.base_config import (
|
14
19
|
FusedMoEMethodBase,
|
@@ -30,10 +35,11 @@ from sglang.srt.layers.quantization.utils import (
|
|
30
35
|
requantize_with_max_scale,
|
31
36
|
)
|
32
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
33
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
34
38
|
from sglang.srt.utils import is_cuda, next_power_of_2
|
35
39
|
|
36
40
|
if TYPE_CHECKING:
|
41
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
42
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
37
43
|
from sglang.srt.layers.moe.topk import TopKOutput
|
38
44
|
|
39
45
|
if is_cuda():
|
@@ -105,18 +111,52 @@ class ModelOptFp8Config(QuantizationConfig):
|
|
105
111
|
|
106
112
|
@classmethod
|
107
113
|
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
114
|
+
# Handle two different config formats:
|
115
|
+
# 1. hf_quant_config.json format: {"quantization": {"quant_algo": "FP8", ...}}
|
116
|
+
# 2. config.json quantization_config format: {"quant_algo": "FP8", ...}
|
117
|
+
# In future modelopt will deprecate hf_quant_config.json, and only keep config.json.
|
118
|
+
# For legacy reasons, we keep hf_quant_config.json for now.
|
119
|
+
|
120
|
+
# Initialize variables
|
121
|
+
kv_cache_quant_method = None
|
122
|
+
exclude_modules = None
|
123
|
+
|
124
|
+
# Try flat format first (config.json quantization_config - preferred format)
|
125
|
+
quant_method = config.get("quant_algo")
|
126
|
+
if quant_method is not None:
|
127
|
+
# Flat format (config.json quantization_config)
|
128
|
+
# For kv_cache, check if kv_cache_scheme exists and extract algo
|
129
|
+
kv_cache_scheme = config.get("kv_cache_scheme")
|
130
|
+
if (
|
131
|
+
kv_cache_scheme
|
132
|
+
and kv_cache_scheme.get("type") == "float"
|
133
|
+
and kv_cache_scheme.get("num_bits") == 8
|
134
|
+
):
|
135
|
+
kv_cache_quant_method = "FP8"
|
115
136
|
|
137
|
+
# Map 'ignore' field to 'exclude_modules'
|
138
|
+
exclude_modules = config.get("ignore")
|
139
|
+
else:
|
140
|
+
# Fall back to nested format (hf_quant_config.json - legacy format)
|
141
|
+
try:
|
142
|
+
quantization_section = cls.get_from_keys(config, ["quantization"])
|
143
|
+
quant_method = quantization_section.get("quant_algo")
|
144
|
+
kv_cache_quant_method = quantization_section.get("kv_cache_quant_algo")
|
145
|
+
exclude_modules = quantization_section.get("exclude_modules")
|
146
|
+
except ValueError:
|
147
|
+
raise ValueError(
|
148
|
+
"Cannot find 'quant_algo' in the model's quantization config. "
|
149
|
+
"Expected either flat format (config.json) or nested format (hf_quant_config.json)."
|
150
|
+
)
|
151
|
+
if quant_method is None:
|
152
|
+
raise ValueError(
|
153
|
+
"Cannot find 'quant_algo' in the model's quantization config. "
|
154
|
+
)
|
116
155
|
if "FP8" not in quant_method:
|
117
156
|
raise ValueError(
|
118
|
-
"
|
119
|
-
"
|
157
|
+
"ModelOptFp8Config only supports static FP8 quantization in SGLang. "
|
158
|
+
"For FP4 quantization, use ModelOptFp4Config. "
|
159
|
+
"Check the quantization config for your model's configuration."
|
120
160
|
)
|
121
161
|
|
122
162
|
return cls(
|
@@ -422,12 +462,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
422
462
|
layer: torch.nn.Module,
|
423
463
|
x: torch.Tensor,
|
424
464
|
topk_output: TopKOutput,
|
425
|
-
|
426
|
-
activation: str = "silu",
|
427
|
-
apply_router_weight_on_input: bool = False,
|
428
|
-
inplace: bool = True,
|
429
|
-
no_combine: bool = False,
|
430
|
-
routed_scaling_factor: Optional[float] = None,
|
465
|
+
moe_runner_config: MoeRunnerConfig,
|
431
466
|
) -> torch.Tensor:
|
432
467
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
433
468
|
|
@@ -436,15 +471,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
436
471
|
layer.w13_weight,
|
437
472
|
layer.w2_weight,
|
438
473
|
topk_output=topk_output,
|
439
|
-
|
440
|
-
activation=activation,
|
474
|
+
moe_runner_config=moe_runner_config,
|
441
475
|
use_fp8_w8a8=True,
|
442
476
|
per_channel_quant=False, # ModelOpt uses per-tensor quantization
|
443
477
|
w1_scale=layer.w13_weight_scale,
|
444
478
|
w2_scale=layer.w2_weight_scale,
|
445
479
|
a1_scale=layer.w13_input_scale,
|
446
480
|
a2_scale=layer.w2_input_scale,
|
447
|
-
no_combine=no_combine,
|
448
481
|
)
|
449
482
|
|
450
483
|
|
@@ -486,22 +519,63 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
486
519
|
|
487
520
|
@classmethod
|
488
521
|
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
|
489
|
-
|
490
|
-
|
522
|
+
# Handle two different config formats:
|
523
|
+
# 1. hf_quant_config.json format: {"quantization": {"quant_algo": "NVFP4", ...}}
|
524
|
+
# 2. config.json quantization_config format: {"quant_algo": "NVFP4", ...}
|
525
|
+
# In future modelopt will deprecate hf_quant_config.json, and only keep config.json.
|
526
|
+
# For legacy reasons, we keep hf_quant_config.json for now.
|
527
|
+
|
528
|
+
# Initialize variables
|
529
|
+
kv_cache_quant_algo = None
|
530
|
+
group_size = None
|
531
|
+
exclude_modules = []
|
532
|
+
|
533
|
+
# Try flat format first (config.json quantization_config - preferred format)
|
534
|
+
quant_method = config.get("quant_algo")
|
535
|
+
if quant_method is not None:
|
536
|
+
# Flat format (config.json quantization_config)
|
537
|
+
# Note: FP4 models in config.json format may not have all the detailed fields
|
538
|
+
# that are present in hf_quant_config.json, so we need to handle defaults
|
539
|
+
kv_cache_quant_algo = config.get("kv_cache_quant_algo")
|
540
|
+
if not kv_cache_quant_algo:
|
541
|
+
# For config.json format, derive from kv_cache_scheme if available
|
542
|
+
kv_cache_scheme = config.get("kv_cache_scheme")
|
543
|
+
if (
|
544
|
+
kv_cache_scheme
|
545
|
+
and kv_cache_scheme.get("type") == "float"
|
546
|
+
and kv_cache_scheme.get("num_bits") == 8
|
547
|
+
):
|
548
|
+
kv_cache_quant_algo = "FP8"
|
549
|
+
else:
|
550
|
+
kv_cache_quant_algo = "auto"
|
551
|
+
|
552
|
+
group_size = config.get("group_size")
|
553
|
+
exclude_modules = config.get("ignore", [])
|
554
|
+
else:
|
555
|
+
# Fall back to nested format (hf_quant_config.json - legacy format)
|
556
|
+
try:
|
557
|
+
quant_config = cls.get_from_keys(config, ["quantization"])
|
558
|
+
quant_method = quant_config["quant_algo"]
|
559
|
+
kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
|
560
|
+
if not kv_cache_quant_algo:
|
561
|
+
kv_cache_quant_algo = "auto"
|
562
|
+
group_size = quant_config.get("group_size")
|
563
|
+
exclude_modules = quant_config.get("exclude_modules", [])
|
564
|
+
except (ValueError, KeyError):
|
565
|
+
raise ValueError(
|
566
|
+
"Cannot find 'quant_algo' in the model's quantization config. "
|
567
|
+
"Expected either flat format (config.json) or nested format (hf_quant_config.json)."
|
568
|
+
)
|
569
|
+
|
491
570
|
if not quant_method in ["FP8", "NVFP4"]:
|
492
571
|
raise ValueError(
|
493
572
|
f"ModelOpt currently only supports: FP8, NVFP4"
|
494
573
|
" quantizations in sglang. Please check the "
|
495
|
-
"
|
496
|
-
"quant configuration."
|
574
|
+
"quantization config for your model's configuration."
|
497
575
|
)
|
498
576
|
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
499
|
-
|
500
|
-
if not kv_cache_quant_algo:
|
501
|
-
kv_cache_quant_algo = "auto"
|
502
|
-
group_size = quant_config["group_size"]
|
503
|
-
exclude_modules = quant_config["exclude_modules"]
|
504
|
-
if not (group_size and kv_cache_quant_algo and exclude_modules):
|
577
|
+
|
578
|
+
if not (group_size and kv_cache_quant_algo) or exclude_modules is None:
|
505
579
|
logger.warning(
|
506
580
|
f"group_size: {group_size},"
|
507
581
|
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
|
@@ -509,8 +583,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
509
583
|
)
|
510
584
|
raise ValueError(
|
511
585
|
"NVFP4 quantization requires group size and "
|
512
|
-
"kv_cache_quant_algo specified in "
|
513
|
-
"hf_quant_config.json"
|
586
|
+
"kv_cache_quant_algo specified in the quantization config"
|
514
587
|
)
|
515
588
|
return cls(
|
516
589
|
is_checkpoint_nvfp4_serialized,
|
@@ -741,8 +814,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
741
814
|
|
742
815
|
@property
|
743
816
|
def enable_flashinfer_cutlass_moe(self) -> bool:
|
817
|
+
from sglang.srt.layers.moe import get_moe_runner_backend
|
818
|
+
|
744
819
|
"""Access the global enable_flashinfer_cutlass_moe setting."""
|
745
|
-
return
|
820
|
+
return get_moe_runner_backend().is_flashinfer_cutlass()
|
746
821
|
|
747
822
|
def create_weights(
|
748
823
|
self,
|
@@ -811,6 +886,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
811
886
|
)
|
812
887
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
813
888
|
|
889
|
+
# Only use `swizzle_blockscale` for shapes, not for real content
|
890
|
+
layer.w13_blockscale_swizzled = Parameter(
|
891
|
+
self.swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
|
892
|
+
)
|
893
|
+
|
814
894
|
w2_weight_scale = ModelWeightParameter(
|
815
895
|
data=torch.empty(
|
816
896
|
layer.num_local_experts,
|
@@ -825,6 +905,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
825
905
|
)
|
826
906
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
827
907
|
|
908
|
+
layer.w2_blockscale_swizzled = Parameter(
|
909
|
+
self.swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
|
910
|
+
)
|
911
|
+
|
828
912
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
829
913
|
|
830
914
|
extra_weight_attrs.update(
|
@@ -1128,16 +1212,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1128
1212
|
|
1129
1213
|
# Process w13 weights
|
1130
1214
|
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
|
1131
|
-
layer.w13_blockscale_swizzled
|
1132
|
-
w13_blockscale_swizzled, requires_grad=False
|
1133
|
-
)
|
1215
|
+
layer.w13_blockscale_swizzled.data.copy_(w13_blockscale_swizzled)
|
1134
1216
|
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
1135
1217
|
|
1136
1218
|
# Process w2 weights
|
1137
1219
|
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
|
1138
|
-
layer.w2_blockscale_swizzled
|
1139
|
-
w2_blockscale_swizzled, requires_grad=False
|
1140
|
-
)
|
1220
|
+
layer.w2_blockscale_swizzled.data.copy_(w2_blockscale_swizzled)
|
1141
1221
|
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
1142
1222
|
|
1143
1223
|
# Both flashinfer cutlass and regular cutlass use same processing for w2
|
@@ -1160,21 +1240,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1160
1240
|
|
1161
1241
|
def apply(
|
1162
1242
|
self,
|
1163
|
-
layer:
|
1243
|
+
layer: FusedMoE,
|
1164
1244
|
x: torch.Tensor,
|
1165
1245
|
topk_output: TopKOutput,
|
1166
|
-
|
1167
|
-
activation: str = "silu",
|
1168
|
-
apply_router_weight_on_input: bool = False,
|
1169
|
-
inplace: bool = True,
|
1170
|
-
no_combine: bool = False,
|
1171
|
-
routed_scaling_factor: Optional[float] = None,
|
1172
|
-
ep_rank: Optional[int] = None,
|
1173
|
-
ep_size: Optional[int] = None,
|
1174
|
-
tp_rank: Optional[int] = None,
|
1175
|
-
tp_size: Optional[int] = None,
|
1246
|
+
moe_runner_config: MoeRunnerConfig,
|
1176
1247
|
) -> torch.Tensor:
|
1177
|
-
assert
|
1248
|
+
assert (
|
1249
|
+
moe_runner_config.activation == "silu"
|
1250
|
+
), "Only SiLU activation is supported."
|
1178
1251
|
|
1179
1252
|
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
|
1180
1253
|
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
|
@@ -1183,20 +1256,41 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1183
1256
|
|
1184
1257
|
if self.enable_flashinfer_cutlass_moe:
|
1185
1258
|
assert (
|
1186
|
-
not apply_router_weight_on_input
|
1259
|
+
not moe_runner_config.apply_router_weight_on_input
|
1187
1260
|
), "apply_router_weight_on_input is not supported for Flashinfer"
|
1188
1261
|
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
|
1189
1262
|
# and fp4 quantized weights loaded from the checkpoint
|
1190
|
-
|
1191
1263
|
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
|
1192
1264
|
|
1265
|
+
output_dtype = x.dtype
|
1266
|
+
x_sf = None
|
1267
|
+
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
1268
|
+
from flashinfer import fp4_quantize, nvfp4_block_scale_interleave
|
1269
|
+
|
1270
|
+
# Quantize before comm, swizzle after.
|
1271
|
+
if x.shape[0] > 0:
|
1272
|
+
x, x_sf = fp4_quantize(
|
1273
|
+
x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False
|
1274
|
+
)
|
1275
|
+
else:
|
1276
|
+
x_col = x.shape[1]
|
1277
|
+
x = torch.zeros(0, x_col // 2, dtype=torch.uint8, device=x.device)
|
1278
|
+
x_sf = torch.zeros(
|
1279
|
+
0, x_col // 16, dtype=torch.uint8, device=x.device
|
1280
|
+
)
|
1281
|
+
topk_weights, topk_ids, x, x_sf = get_tp_group().all_gatherv(
|
1282
|
+
[topk_weights, topk_ids, x, x_sf], sizes=get_dp_global_num_tokens()
|
1283
|
+
)
|
1284
|
+
x_sf = nvfp4_block_scale_interleave(x_sf)
|
1285
|
+
|
1193
1286
|
output = flashinfer_cutlass_fused_moe(
|
1194
|
-
x,
|
1195
|
-
topk_ids.to(torch.int),
|
1196
|
-
topk_weights,
|
1197
|
-
layer.w13_weight.view(torch.long),
|
1198
|
-
layer.w2_weight.view(torch.long),
|
1199
|
-
|
1287
|
+
input=x,
|
1288
|
+
token_selected_experts=topk_ids.to(torch.int),
|
1289
|
+
token_final_scales=topk_weights,
|
1290
|
+
fc1_expert_weights=layer.w13_weight.view(torch.long),
|
1291
|
+
fc2_expert_weights=layer.w2_weight.view(torch.long),
|
1292
|
+
output_dtype=output_dtype,
|
1293
|
+
input_sf=x_sf,
|
1200
1294
|
quant_scales=[
|
1201
1295
|
layer.w13_input_scale_quant,
|
1202
1296
|
layer.w13_blockscale_swizzled.view(torch.int32),
|
@@ -1205,14 +1299,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1205
1299
|
layer.w2_blockscale_swizzled.view(torch.int32),
|
1206
1300
|
layer.g2_alphas,
|
1207
1301
|
],
|
1208
|
-
ep_size=
|
1209
|
-
ep_rank=
|
1210
|
-
tp_size=
|
1211
|
-
tp_rank=
|
1302
|
+
ep_size=layer.moe_ep_size,
|
1303
|
+
ep_rank=layer.moe_ep_rank,
|
1304
|
+
tp_size=layer.moe_tp_size,
|
1305
|
+
tp_rank=layer.moe_tp_rank,
|
1212
1306
|
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
1213
1307
|
)[0]
|
1214
|
-
|
1215
|
-
|
1308
|
+
# Scale by routed_scaling_factor is fused into select_experts.
|
1309
|
+
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
1310
|
+
output, global_output = get_local_dp_buffer(), output
|
1311
|
+
get_tp_group().reduce_scatterv(
|
1312
|
+
global_output, output=output, sizes=get_dp_global_num_tokens()
|
1313
|
+
)
|
1216
1314
|
return output
|
1217
1315
|
|
1218
1316
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
@@ -1231,8 +1329,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1231
1329
|
topk_weights=topk_weights,
|
1232
1330
|
topk_ids=topk_ids,
|
1233
1331
|
params=layer.cutlass_moe_params,
|
1234
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
1332
|
+
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
|
1235
1333
|
).to(x.dtype)
|
1236
|
-
|
1237
|
-
output *= routed_scaling_factor
|
1334
|
+
# Scale by routed_scaling_factor is fused into select_experts.
|
1238
1335
|
return output
|
@@ -22,6 +22,7 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
|
|
22
22
|
logger = logging.getLogger(__name__)
|
23
23
|
|
24
24
|
if TYPE_CHECKING:
|
25
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
25
26
|
from sglang.srt.layers.moe.topk import TopKOutput
|
26
27
|
|
27
28
|
|
@@ -353,17 +354,14 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|
353
354
|
layer: torch.nn.Module,
|
354
355
|
x: torch.Tensor,
|
355
356
|
topk_output: TopKOutput,
|
356
|
-
|
357
|
-
activation: str = "silu",
|
358
|
-
apply_router_weight_on_input: bool = False,
|
359
|
-
inplace: bool = True,
|
360
|
-
no_combine: bool = False,
|
361
|
-
routed_scaling_factor: Optional[float] = None,
|
357
|
+
moe_runner_config: MoeRunnerConfig,
|
362
358
|
) -> torch.Tensor:
|
363
359
|
# avoid circular import
|
364
360
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
365
361
|
|
366
|
-
assert
|
362
|
+
assert (
|
363
|
+
moe_runner_config.activation == "silu"
|
364
|
+
), "Only SiLU activation is supported."
|
367
365
|
|
368
366
|
weight_bits = self.quant_config.weight_bits
|
369
367
|
has_zp = self.quant_config.has_zp
|
@@ -373,8 +371,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|
373
371
|
layer.w13_qweight,
|
374
372
|
layer.w2_qweight,
|
375
373
|
topk_output=topk_output,
|
376
|
-
|
377
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
374
|
+
moe_runner_config=moe_runner_config,
|
378
375
|
use_int4_w4a16=weight_bits == 4,
|
379
376
|
use_int8_w8a16=weight_bits == 8,
|
380
377
|
w1_scale=layer.w13_scales,
|
@@ -382,8 +379,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|
382
379
|
w1_zp=layer.w13_qzeros if has_zp else None,
|
383
380
|
w2_zp=layer.w2_qzeros if has_zp else None,
|
384
381
|
block_shape=[0, layer.group_size],
|
385
|
-
no_combine=no_combine,
|
386
|
-
routed_scaling_factor=routed_scaling_factor,
|
387
382
|
)
|
388
383
|
|
389
384
|
@staticmethod
|
@@ -486,16 +481,16 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|
486
481
|
)
|
487
482
|
|
488
483
|
if "w13_qzeros" in weight_name:
|
489
|
-
tensor = loaded_weight.view(
|
490
|
-
|
491
|
-
]
|
484
|
+
tensor = loaded_weight.view(
|
485
|
+
layer.moe_tp_size, -1, loaded_weight.size(1)
|
486
|
+
)[tp_rank]
|
492
487
|
if shard_id == "w1":
|
493
488
|
param.data[expert_id, : shard_size // 2] = tensor
|
494
489
|
else:
|
495
490
|
param.data[expert_id, shard_size // 2 :] = tensor
|
496
491
|
elif "w2_qzeros" in weight_name:
|
497
492
|
param.data[expert_id] = loaded_weight.view(
|
498
|
-
loaded_weight.size(0), layer.
|
493
|
+
loaded_weight.size(0), layer.moe_tp_size, -1
|
499
494
|
)[:, tp_rank]
|
500
495
|
else:
|
501
496
|
weight_loader(param, loaded_weight, weight_name, shard_id, expert_id)
|