sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -49,6 +49,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
49
49
|
)
|
50
50
|
from sglang.srt.layers.quantization.fp8_utils import (
|
51
51
|
apply_fp8_linear,
|
52
|
+
can_auto_enable_marlin_fp8,
|
52
53
|
cutlass_fp8_supported,
|
53
54
|
dispatch_w8a8_block_fp8_linear,
|
54
55
|
input_to_float8,
|
@@ -79,6 +80,7 @@ from sglang.srt.utils import (
|
|
79
80
|
)
|
80
81
|
|
81
82
|
if TYPE_CHECKING:
|
83
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
82
84
|
from sglang.srt.layers.moe.topk import TopKOutput
|
83
85
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
84
86
|
|
@@ -208,17 +210,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
208
210
|
|
209
211
|
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
210
212
|
# kernel for fast weight-only FP8 quantization
|
211
|
-
self.use_marlin =
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
self.use_marlin = False
|
213
|
+
self.use_marlin = False
|
214
|
+
if _is_cuda and MARLIN_FP8_AVAILABLE:
|
215
|
+
force_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
|
216
|
+
auto_enable = can_auto_enable_marlin_fp8()
|
217
|
+
self.use_marlin = force_marlin or auto_enable
|
217
218
|
|
218
219
|
self.block_quant = self.quant_config.weight_block_size is not None
|
219
|
-
if self.block_quant:
|
220
|
-
# Marlin doesn't support block-wise fp8
|
221
|
-
self.use_marlin = False
|
222
220
|
|
223
221
|
self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear()
|
224
222
|
|
@@ -331,7 +329,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
331
329
|
layer.register_parameter("input_scale", None)
|
332
330
|
|
333
331
|
def process_weights_after_loading(self, layer: Module) -> None:
|
334
|
-
# Block quant doesn't need to process weights after loading
|
335
332
|
if self.block_quant:
|
336
333
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
337
334
|
if _is_fp8_fnuz:
|
@@ -341,7 +338,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
341
338
|
weight_scale=layer.weight_scale_inv,
|
342
339
|
input_scale=None,
|
343
340
|
)
|
344
|
-
|
345
341
|
layer.input_scale = None
|
346
342
|
elif _is_cpu:
|
347
343
|
assert (
|
@@ -351,90 +347,94 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
351
347
|
return
|
352
348
|
else:
|
353
349
|
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
354
|
-
layer.weight =
|
355
|
-
layer.weight_scale_inv =
|
356
|
-
|
357
|
-
)
|
358
|
-
return
|
350
|
+
layer.weight = Parameter(weight, requires_grad=False)
|
351
|
+
layer.weight_scale_inv = Parameter(weight_scale, requires_grad=False)
|
352
|
+
else:
|
353
|
+
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
359
354
|
|
360
|
-
|
355
|
+
# If checkpoint not serialized fp8, quantize the weights.
|
356
|
+
if not self.quant_config.is_checkpoint_fp8_serialized:
|
357
|
+
if self.cutlass_fp8_supported or self.use_marlin:
|
358
|
+
# apply per-channel quantization default as
|
359
|
+
# cutlass sgl-kernel and marlin only support per-channel scale
|
360
|
+
qweight, weight_scale = per_token_group_quant_fp8(
|
361
|
+
layer.weight, layer.weight.shape[-1]
|
362
|
+
)
|
363
|
+
weight_scale = weight_scale.t().contiguous()
|
364
|
+
else:
|
365
|
+
# per-tensor quantization
|
366
|
+
qweight, weight_scale = input_to_float8(layer.weight)
|
367
|
+
|
368
|
+
# Update the layer with the new values.
|
369
|
+
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
370
|
+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
371
|
+
layer.input_scale = None
|
361
372
|
|
362
|
-
|
363
|
-
|
364
|
-
if self.cutlass_fp8_supported or self.use_marlin:
|
365
|
-
# apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
|
366
|
-
qweight, weight_scale = per_token_group_quant_fp8(
|
367
|
-
layer.weight, layer.weight.shape[-1]
|
368
|
-
)
|
369
|
-
weight_scale = weight_scale.t().contiguous()
|
373
|
+
# If checkpoint is fp8, handle that there are N scales for N
|
374
|
+
# shards in a fused module
|
370
375
|
else:
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
# Update the layer with the new values.
|
375
|
-
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
376
|
-
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
377
|
-
layer.input_scale = None
|
378
|
-
|
379
|
-
# If checkpoint is fp8, handle that there are N scales for N
|
380
|
-
# shards in a fused module
|
381
|
-
else:
|
382
|
-
layer.weight_scale = torch.nn.Parameter(
|
383
|
-
layer.weight_scale.data, requires_grad=False
|
384
|
-
)
|
385
|
-
if (
|
386
|
-
hasattr(self.quant_config, "activation_scheme")
|
387
|
-
and self.quant_config.activation_scheme == "static"
|
388
|
-
) or (
|
389
|
-
hasattr(self.quant_config, "linear_activation_scheme")
|
390
|
-
and self.quant_config.linear_activation_scheme == "static"
|
391
|
-
):
|
392
|
-
layer.input_scale = torch.nn.Parameter(
|
393
|
-
layer.input_scale.data, requires_grad=False
|
376
|
+
layer.weight_scale = Parameter(
|
377
|
+
layer.weight_scale.data, requires_grad=False
|
394
378
|
)
|
379
|
+
if (
|
380
|
+
hasattr(self.quant_config, "activation_scheme")
|
381
|
+
and self.quant_config.activation_scheme == "static"
|
382
|
+
) or (
|
383
|
+
hasattr(self.quant_config, "linear_activation_scheme")
|
384
|
+
and self.quant_config.linear_activation_scheme == "static"
|
385
|
+
):
|
386
|
+
layer.input_scale = Parameter(
|
387
|
+
layer.input_scale.data, requires_grad=False
|
388
|
+
)
|
395
389
|
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
390
|
+
# cutlass sgl-kernel and marlin only support per-channel scale
|
391
|
+
if self.cutlass_fp8_supported or self.use_marlin:
|
392
|
+
weight = layer.weight
|
393
|
+
weight_scale = convert_to_channelwise(
|
394
|
+
layer.weight_scale, layer.logical_widths
|
395
|
+
)
|
396
|
+
else:
|
397
|
+
# Dequant -> Quant with max scale so we can run per tensor.
|
398
|
+
weight = layer.weight
|
399
|
+
weight_scale = layer.weight_scale
|
400
|
+
# If ROCm, normalize the weights and scales to e4m3fnuz
|
401
|
+
if _is_fp8_fnuz:
|
402
|
+
weight, weight_scale, input_scale = (
|
403
|
+
normalize_e4m3fn_to_e4m3fnuz(
|
404
|
+
weight=weight,
|
405
|
+
weight_scale=weight_scale,
|
406
|
+
input_scale=layer.input_scale,
|
407
|
+
)
|
408
|
+
)
|
409
|
+
if input_scale is not None:
|
410
|
+
layer.input_scale = Parameter(
|
411
|
+
input_scale, requires_grad=False
|
412
|
+
)
|
413
|
+
|
414
|
+
weight_scale, weight = requantize_with_max_scale(
|
409
415
|
weight=weight,
|
410
416
|
weight_scale=weight_scale,
|
411
|
-
|
417
|
+
logical_widths=layer.logical_widths,
|
412
418
|
)
|
413
|
-
if input_scale is not None:
|
414
|
-
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
415
|
-
|
416
|
-
weight_scale, weight = requantize_with_max_scale(
|
417
|
-
weight=weight,
|
418
|
-
weight_scale=weight_scale,
|
419
|
-
logical_widths=layer.logical_widths,
|
420
|
-
)
|
421
419
|
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
420
|
+
# Update layer with new values.
|
421
|
+
layer.weight = Parameter(weight.t(), requires_grad=False)
|
422
|
+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
423
|
+
if (
|
424
|
+
hasattr(self.quant_config, "activation_scheme")
|
425
|
+
and self.quant_config.activation_scheme == "static"
|
426
|
+
) or (
|
427
|
+
hasattr(self.quant_config, "linear_activation_scheme")
|
428
|
+
and self.quant_config.linear_activation_scheme == "static"
|
429
|
+
):
|
430
|
+
layer.input_scale = Parameter(
|
431
|
+
layer.input_scale.max(), requires_grad=False
|
432
|
+
)
|
435
433
|
|
436
434
|
if self.use_marlin:
|
437
|
-
|
435
|
+
if self.block_quant:
|
436
|
+
layer.weight_block_size = self.quant_config.weight_block_size
|
437
|
+
prepare_fp8_layer_for_marlin(layer, not self.block_quant)
|
438
438
|
# Activations not quantized for marlin.
|
439
439
|
del layer.input_scale
|
440
440
|
|
@@ -444,7 +444,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
444
444
|
x: torch.Tensor,
|
445
445
|
bias: Optional[torch.Tensor] = None,
|
446
446
|
) -> torch.Tensor:
|
447
|
-
|
448
447
|
if self.use_marlin:
|
449
448
|
return apply_fp8_marlin_linear(
|
450
449
|
input=x,
|
@@ -515,6 +514,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
515
514
|
self.quant_config = quant_config
|
516
515
|
self.block_quant = self.quant_config.weight_block_size is not None
|
517
516
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
517
|
+
self.use_cutlass_fused_experts_fp8 = (
|
518
|
+
get_bool_env_var("SGLANG_CUTLASS_MOE")
|
519
|
+
and self.cutlass_fp8_supported
|
520
|
+
and self.block_quant
|
521
|
+
and (is_sm100_supported() or is_sm90_supported())
|
522
|
+
)
|
518
523
|
|
519
524
|
def create_weights(
|
520
525
|
self,
|
@@ -961,6 +966,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
961
966
|
requires_grad=False,
|
962
967
|
)
|
963
968
|
torch.cuda.empty_cache()
|
969
|
+
|
964
970
|
# ROCm (_use_aiter): using column-wise scaling
|
965
971
|
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
966
972
|
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
@@ -982,12 +988,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
982
988
|
layer: torch.nn.Module,
|
983
989
|
x: torch.Tensor,
|
984
990
|
topk_output: TopKOutput,
|
985
|
-
|
986
|
-
activation: str = "silu",
|
987
|
-
apply_router_weight_on_input: bool = False,
|
988
|
-
inplace: bool = True,
|
989
|
-
no_combine: bool = False,
|
990
|
-
routed_scaling_factor: Optional[float] = None,
|
991
|
+
moe_runner_config: MoeRunnerConfig,
|
991
992
|
) -> torch.Tensor:
|
992
993
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
993
994
|
|
@@ -996,7 +997,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
996
997
|
|
997
998
|
topk_weights, topk_ids, _ = topk_output
|
998
999
|
x, topk_weights = apply_topk_weights_cpu(
|
999
|
-
apply_router_weight_on_input, topk_weights, x
|
1000
|
+
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
1000
1001
|
)
|
1001
1002
|
|
1002
1003
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
@@ -1021,18 +1022,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1021
1022
|
layer,
|
1022
1023
|
x,
|
1023
1024
|
topk_output,
|
1024
|
-
activation,
|
1025
|
-
no_combine,
|
1025
|
+
moe_runner_config.activation,
|
1026
|
+
moe_runner_config.no_combine,
|
1026
1027
|
)
|
1027
1028
|
if ret is not None:
|
1028
1029
|
return ret
|
1029
1030
|
|
1030
|
-
if
|
1031
|
-
get_bool_env_var("SGLANG_CUTLASS_MOE")
|
1032
|
-
and self.cutlass_fp8_supported
|
1033
|
-
and self.block_quant
|
1034
|
-
and (is_sm100_supported() or is_sm90_supported())
|
1035
|
-
):
|
1031
|
+
if self.use_cutlass_fused_experts_fp8:
|
1036
1032
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
1037
1033
|
|
1038
1034
|
topk_weights, topk_ids, _ = topk_output
|
@@ -1059,9 +1055,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1059
1055
|
self.problem_sizes2,
|
1060
1056
|
use_fp8_blockscale=True,
|
1061
1057
|
)
|
1062
|
-
#
|
1063
|
-
if routed_scaling_factor is not None:
|
1064
|
-
output *= routed_scaling_factor
|
1058
|
+
# Scale by routed_scaling_factor is fused into select_experts.
|
1065
1059
|
return output
|
1066
1060
|
# Expert fusion with FP8 quantization
|
1067
1061
|
return fused_experts(
|
@@ -1069,9 +1063,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1069
1063
|
layer.w13_weight,
|
1070
1064
|
layer.w2_weight,
|
1071
1065
|
topk_output=topk_output,
|
1072
|
-
|
1073
|
-
activation=activation,
|
1074
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
1066
|
+
moe_runner_config=moe_runner_config,
|
1075
1067
|
use_fp8_w8a8=True,
|
1076
1068
|
w1_scale=(
|
1077
1069
|
layer.w13_weight_scale_inv
|
@@ -1084,30 +1076,44 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1084
1076
|
a1_scale=layer.w13_input_scale,
|
1085
1077
|
a2_scale=layer.w2_input_scale,
|
1086
1078
|
block_shape=self.quant_config.weight_block_size,
|
1087
|
-
no_combine=no_combine,
|
1088
|
-
routed_scaling_factor=routed_scaling_factor,
|
1089
1079
|
)
|
1090
1080
|
|
1091
1081
|
def apply_with_router_logits(
|
1092
1082
|
self,
|
1093
1083
|
layer: torch.nn.Module,
|
1094
1084
|
x: torch.Tensor,
|
1095
|
-
|
1096
|
-
|
1097
|
-
activation: str = "silu",
|
1098
|
-
routed_scaling_factor: Optional[float] = None,
|
1085
|
+
topk_output: TopKOutput,
|
1086
|
+
moe_runner_config: MoeRunnerConfig,
|
1099
1087
|
) -> torch.Tensor:
|
1088
|
+
activation = moe_runner_config.activation
|
1089
|
+
routed_scaling_factor = moe_runner_config.routed_scaling_factor
|
1090
|
+
|
1091
|
+
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
1092
|
+
|
1093
|
+
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
1094
|
+
|
1095
|
+
assert TopKOutputChecker.format_is_bypassed(topk_output)
|
1096
|
+
router_logits = topk_output.router_logits
|
1097
|
+
topk_config = topk_output.topk_config
|
1100
1098
|
assert (
|
1101
1099
|
activation == "silu"
|
1102
1100
|
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
1103
1101
|
a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
|
1104
1102
|
# NOTE: scales of hidden states have to be transposed!
|
1105
1103
|
a_sf_t = a_sf.t().contiguous()
|
1106
|
-
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
1107
1104
|
|
1105
|
+
assert (
|
1106
|
+
topk_config.num_expert_group is not None
|
1107
|
+
and topk_config.topk_group is not None
|
1108
|
+
), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None"
|
1109
|
+
|
1110
|
+
if topk_config.correction_bias is None:
|
1111
|
+
correction_bias = topk_config.correction_bias.to(x.dtype)
|
1112
|
+
else:
|
1113
|
+
correction_bias = None
|
1108
1114
|
return trtllm_fp8_block_scale_moe(
|
1109
1115
|
routing_logits=router_logits.to(torch.float32),
|
1110
|
-
routing_bias=
|
1116
|
+
routing_bias=correction_bias,
|
1111
1117
|
hidden_states=a_q,
|
1112
1118
|
hidden_states_scale=a_sf_t,
|
1113
1119
|
gemm1_weights=layer.w13_weight,
|
@@ -1115,15 +1121,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1115
1121
|
gemm2_weights=layer.w2_weight,
|
1116
1122
|
gemm2_weights_scale=layer.w2_weight_scale_inv,
|
1117
1123
|
num_experts=layer.num_experts,
|
1118
|
-
top_k=
|
1119
|
-
n_group=
|
1120
|
-
topk_group=
|
1124
|
+
top_k=topk_config.top_k,
|
1125
|
+
n_group=topk_config.num_expert_group,
|
1126
|
+
topk_group=topk_config.topk_group,
|
1121
1127
|
intermediate_size=layer.w2_weight.shape[2],
|
1122
1128
|
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
|
1123
1129
|
local_num_experts=layer.num_local_experts,
|
1124
|
-
routed_scaling_factor=
|
1130
|
+
routed_scaling_factor=(
|
1131
|
+
routed_scaling_factor if routed_scaling_factor is not None else 1.0
|
1132
|
+
),
|
1125
1133
|
tile_tokens_dim=get_tile_tokens_dim(
|
1126
|
-
x.shape[0],
|
1134
|
+
x.shape[0], topk_config.top_k, layer.num_experts
|
1127
1135
|
),
|
1128
1136
|
routing_method_type=2, # DeepSeek-styled routing method
|
1129
1137
|
use_shuffled_weight=False,
|