sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +23 -2
- sglang/bench_serving.py +6 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +80 -11
- sglang/srt/disaggregation/mini_lb.py +58 -123
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +585 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
- sglang/srt/disaggregation/prefill.py +82 -22
- sglang/srt/disaggregation/utils.py +46 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +42 -13
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +430 -257
- sglang/srt/layers/attention/flashinfer_backend.py +18 -9
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +18 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +63 -45
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +131 -136
- sglang/srt/layers/quantization/fp8_kernel.py +328 -46
- sglang/srt/layers/quantization/fp8_utils.py +206 -253
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
- sglang/srt/layers/quantization/w8a8_int8.py +8 -7
- sglang/srt/layers/radix_attention.py +28 -1
- sglang/srt/layers/rotary_embedding.py +15 -3
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +255 -97
- sglang/srt/managers/mm_utils.py +7 -5
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +64 -25
- sglang/srt/managers/scheduler.py +80 -82
- sglang/srt/managers/tokenizer_manager.py +18 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +21 -3
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +67 -35
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +494 -366
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +6 -5
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +30 -200
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +5 -1
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +15 -13
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +55 -19
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +10 -9
- sglang/srt/utils.py +136 -10
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +224 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/disaggregation/conn.py +0 -81
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -8,15 +8,6 @@ import torch.nn.functional as F
|
|
8
8
|
from torch.nn import Module
|
9
9
|
from torch.nn.parameter import Parameter
|
10
10
|
|
11
|
-
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
12
|
-
from sglang.srt.layers.quantization.utils import (
|
13
|
-
all_close_1d,
|
14
|
-
convert_to_channelwise,
|
15
|
-
is_layer_skipped,
|
16
|
-
per_tensor_dequantize,
|
17
|
-
requantize_with_max_scale,
|
18
|
-
)
|
19
|
-
|
20
11
|
try:
|
21
12
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
22
13
|
apply_fp8_marlin_linear,
|
@@ -27,11 +18,12 @@ try:
|
|
27
18
|
except ImportError:
|
28
19
|
MARLIN_FP8_AVAILABLE = False
|
29
20
|
|
30
|
-
def
|
31
|
-
raise ImportError(
|
21
|
+
def dummy_func(*args, **kwargs):
|
22
|
+
raise ImportError(
|
23
|
+
"marlin FP8 requires some operators from vllm. Please install vllm."
|
24
|
+
)
|
32
25
|
|
33
|
-
|
34
|
-
raise ImportError("vllm is not installed")
|
26
|
+
apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func
|
35
27
|
|
36
28
|
|
37
29
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
@@ -49,7 +41,10 @@ from sglang.srt.layers.quantization.base_config import (
|
|
49
41
|
QuantizationConfig,
|
50
42
|
QuantizeMethodBase,
|
51
43
|
)
|
52
|
-
from sglang.srt.layers.quantization.fp8_kernel import
|
44
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
45
|
+
per_token_group_quant_fp8,
|
46
|
+
scaled_fp8_quant,
|
47
|
+
)
|
53
48
|
from sglang.srt.layers.quantization.fp8_utils import (
|
54
49
|
apply_fp8_linear,
|
55
50
|
apply_w8a8_block_fp8_linear,
|
@@ -57,29 +52,35 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
57
52
|
input_to_float8,
|
58
53
|
normalize_e4m3fn_to_e4m3fnuz,
|
59
54
|
)
|
55
|
+
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
56
|
+
from sglang.srt.layers.quantization.utils import (
|
57
|
+
all_close_1d,
|
58
|
+
convert_to_channelwise,
|
59
|
+
is_layer_skipped,
|
60
|
+
per_tensor_dequantize,
|
61
|
+
requantize_with_max_scale,
|
62
|
+
)
|
60
63
|
from sglang.srt.utils import (
|
61
64
|
get_bool_env_var,
|
62
65
|
is_cuda,
|
63
66
|
is_hip,
|
64
|
-
permute_weight,
|
65
67
|
print_warning_once,
|
66
68
|
set_weight_attrs,
|
67
69
|
)
|
68
70
|
|
69
|
-
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
70
|
-
|
71
71
|
_is_hip = is_hip()
|
72
|
+
_is_cuda = is_cuda()
|
72
73
|
|
73
74
|
if _is_hip:
|
74
|
-
from aiter
|
75
|
+
from aiter import ActivationType
|
76
|
+
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
|
75
77
|
from aiter.ops.shuffle import shuffle_weight
|
76
78
|
|
77
|
-
|
79
|
+
if not _is_cuda:
|
80
|
+
from vllm._custom_ops import scaled_fp8_quant
|
78
81
|
|
79
|
-
|
80
|
-
|
81
|
-
else:
|
82
|
-
from vllm import _custom_ops as vllm_ops
|
82
|
+
|
83
|
+
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
83
84
|
|
84
85
|
logger = logging.getLogger(__name__)
|
85
86
|
|
@@ -242,7 +243,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
242
243
|
)
|
243
244
|
|
244
245
|
layer.logical_widths = output_partition_sizes
|
245
|
-
|
246
246
|
layer.input_size_per_partition = input_size_per_partition
|
247
247
|
layer.output_size_per_partition = output_size_per_partition
|
248
248
|
layer.orig_dtype = params_dtype
|
@@ -326,7 +326,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
326
326
|
layer.weight_scale_inv.data, requires_grad=False
|
327
327
|
)
|
328
328
|
return
|
329
|
+
|
329
330
|
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
331
|
+
|
330
332
|
# If checkpoint not serialized fp8, quantize the weights.
|
331
333
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
332
334
|
if self.cutlass_fp8_supported or self.use_marlin:
|
@@ -390,12 +392,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
390
392
|
)
|
391
393
|
|
392
394
|
if self.use_marlin:
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
del layer.input_scale
|
397
|
-
except ImportError:
|
398
|
-
self.use_marlin = False
|
395
|
+
prepare_fp8_layer_for_marlin(layer)
|
396
|
+
# Activations not quantized for marlin.
|
397
|
+
del layer.input_scale
|
399
398
|
|
400
399
|
def apply(
|
401
400
|
self,
|
@@ -405,18 +404,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
405
404
|
) -> torch.Tensor:
|
406
405
|
|
407
406
|
if self.use_marlin:
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
)
|
418
|
-
except ImportError:
|
419
|
-
self.use_marlin = False
|
407
|
+
return apply_fp8_marlin_linear(
|
408
|
+
input=x,
|
409
|
+
weight=layer.weight,
|
410
|
+
weight_scale=layer.weight_scale,
|
411
|
+
workspace=layer.workspace,
|
412
|
+
size_n=layer.output_size_per_partition,
|
413
|
+
size_k=layer.input_size_per_partition,
|
414
|
+
bias=bias,
|
415
|
+
)
|
420
416
|
|
421
417
|
if self.block_quant:
|
422
418
|
return apply_w8a8_block_fp8_linear(
|
@@ -487,7 +483,7 @@ class Fp8MoEMethod:
|
|
487
483
|
|
488
484
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
489
485
|
params_dtype = (
|
490
|
-
torch.
|
486
|
+
torch.uint32
|
491
487
|
if get_bool_env_var("USE_INT4_WEIGHT")
|
492
488
|
else torch.float8_e4m3fn
|
493
489
|
)
|
@@ -515,7 +511,7 @@ class Fp8MoEMethod:
|
|
515
511
|
)
|
516
512
|
|
517
513
|
# WEIGHTS
|
518
|
-
if get_bool_env_var("USE_INT4_WEIGHT"):
|
514
|
+
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
519
515
|
# INT4 MoE weight - INT32 packed
|
520
516
|
w13_weight = torch.nn.Parameter(
|
521
517
|
torch.empty(
|
@@ -616,7 +612,7 @@ class Fp8MoEMethod:
|
|
616
612
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
617
613
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
618
614
|
|
619
|
-
if get_bool_env_var("USE_INT4_WEIGHT"):
|
615
|
+
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
620
616
|
extra_weight_attrs.update(
|
621
617
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
622
618
|
)
|
@@ -648,7 +644,7 @@ class Fp8MoEMethod:
|
|
648
644
|
layer.w2_input_scale = None
|
649
645
|
|
650
646
|
def process_weights_after_loading(self, layer: Module) -> None:
|
651
|
-
if get_bool_env_var("USE_INT4_WEIGHT"):
|
647
|
+
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
652
648
|
self.process_weights_hip_int4(layer)
|
653
649
|
return
|
654
650
|
|
@@ -705,20 +701,12 @@ class Fp8MoEMethod:
|
|
705
701
|
requires_grad=False,
|
706
702
|
)
|
707
703
|
for expert in range(layer.num_experts):
|
708
|
-
|
709
|
-
w13_weight[expert, :, :]
|
710
|
-
|
711
|
-
|
712
|
-
w2_weight[expert, :, :]
|
713
|
-
|
714
|
-
)
|
715
|
-
else:
|
716
|
-
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
717
|
-
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
718
|
-
)
|
719
|
-
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
720
|
-
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
721
|
-
)
|
704
|
+
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
705
|
+
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
706
|
+
)
|
707
|
+
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
708
|
+
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
709
|
+
)
|
722
710
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
723
711
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
724
712
|
|
@@ -795,18 +783,10 @@ class Fp8MoEMethod:
|
|
795
783
|
layer.w13_weight[expert_id][start : start + shard_size, :],
|
796
784
|
layer.w13_weight_scale[expert_id][shard_id],
|
797
785
|
)
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
803
|
-
else:
|
804
|
-
(
|
805
|
-
layer.w13_weight[expert_id][start : start + shard_size, :],
|
806
|
-
_,
|
807
|
-
) = vllm_ops.scaled_fp8_quant(
|
808
|
-
dq_weight, max_w13_scales[expert_id]
|
809
|
-
)
|
786
|
+
(
|
787
|
+
layer.w13_weight[expert_id][start : start + shard_size, :],
|
788
|
+
_,
|
789
|
+
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
810
790
|
start += shard_size
|
811
791
|
|
812
792
|
layer.w13_weight_scale = torch.nn.Parameter(
|
@@ -822,12 +802,14 @@ class Fp8MoEMethod:
|
|
822
802
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
823
803
|
# Weight Permutation
|
824
804
|
layer.w13_weight = torch.nn.Parameter(
|
825
|
-
permute_weight(layer.w13_weight.data),
|
805
|
+
# permute_weight(layer.w13_weight.data),
|
806
|
+
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
826
807
|
requires_grad=False,
|
827
808
|
)
|
828
809
|
torch.cuda.empty_cache()
|
829
810
|
layer.w2_weight = torch.nn.Parameter(
|
830
|
-
permute_weight(layer.w2_weight.data),
|
811
|
+
# permute_weight(layer.w2_weight.data),
|
812
|
+
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
831
813
|
requires_grad=False,
|
832
814
|
)
|
833
815
|
torch.cuda.empty_cache()
|
@@ -867,12 +849,14 @@ class Fp8MoEMethod:
|
|
867
849
|
|
868
850
|
if get_bool_env_var("CK_MOE"):
|
869
851
|
layer.w13_weight = torch.nn.Parameter(
|
870
|
-
permute_weight(layer.w13_weight.data),
|
852
|
+
# permute_weight(layer.w13_weight.data),
|
853
|
+
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
871
854
|
requires_grad=False,
|
872
855
|
)
|
873
856
|
torch.cuda.empty_cache()
|
874
857
|
layer.w2_weight = torch.nn.Parameter(
|
875
|
-
permute_weight(layer.w2_weight.data),
|
858
|
+
# permute_weight(layer.w2_weight.data),
|
859
|
+
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
876
860
|
requires_grad=False,
|
877
861
|
)
|
878
862
|
torch.cuda.empty_cache()
|
@@ -908,6 +892,7 @@ class Fp8MoEMethod:
|
|
908
892
|
apply_router_weight_on_input: bool = False,
|
909
893
|
inplace: bool = True,
|
910
894
|
no_combine: bool = False,
|
895
|
+
routed_scaling_factor: Optional[float] = None,
|
911
896
|
) -> torch.Tensor:
|
912
897
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
913
898
|
from sglang.srt.layers.moe.topk import select_experts
|
@@ -923,41 +908,14 @@ class Fp8MoEMethod:
|
|
923
908
|
num_expert_group=num_expert_group,
|
924
909
|
custom_routing_function=custom_routing_function,
|
925
910
|
correction_bias=correction_bias,
|
911
|
+
routed_scaling_factor=routed_scaling_factor,
|
926
912
|
)
|
927
913
|
|
928
|
-
if _is_hip
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
layer.w13_weight,
|
934
|
-
layer.w2_weight,
|
935
|
-
topk_weights,
|
936
|
-
topk_ids,
|
937
|
-
layer.w13_weight_scale1,
|
938
|
-
layer.w2_weight_scale1,
|
939
|
-
activation=activation,
|
940
|
-
)
|
941
|
-
if _is_hip and get_bool_env_var("CK_MOE"):
|
942
|
-
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
|
943
|
-
assert (
|
944
|
-
activation == "silu"
|
945
|
-
), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
946
|
-
assert not no_combine, f"{no_combine=} is not supported."
|
947
|
-
if self.block_quant:
|
948
|
-
return asm_moe(
|
949
|
-
x,
|
950
|
-
layer.w13_weight,
|
951
|
-
layer.w2_weight,
|
952
|
-
topk_weights,
|
953
|
-
topk_ids,
|
954
|
-
layer.w13_weight_scale_inv,
|
955
|
-
layer.w2_weight_scale_inv,
|
956
|
-
block_shape=tuple(self.quant_config.weight_block_size),
|
957
|
-
expert_mask=None,
|
958
|
-
)
|
959
|
-
else:
|
960
|
-
return asm_moe(
|
914
|
+
if _is_hip:
|
915
|
+
if get_bool_env_var("USE_INT4_WEIGHT"):
|
916
|
+
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
917
|
+
assert not no_combine, f"{no_combine=} is not supported."
|
918
|
+
return ck_moe_2stages_win4(
|
961
919
|
x,
|
962
920
|
layer.w13_weight,
|
963
921
|
layer.w2_weight,
|
@@ -965,34 +923,71 @@ class Fp8MoEMethod:
|
|
965
923
|
topk_ids,
|
966
924
|
layer.w13_weight_scale1,
|
967
925
|
layer.w2_weight_scale1,
|
926
|
+
activation=(
|
927
|
+
ActivationType.Silu
|
928
|
+
if activation == "silu"
|
929
|
+
else ActivationType.Gelu
|
930
|
+
),
|
968
931
|
)
|
969
|
-
|
970
|
-
|
971
|
-
|
972
|
-
|
973
|
-
|
974
|
-
|
975
|
-
|
976
|
-
|
977
|
-
|
978
|
-
|
979
|
-
|
980
|
-
|
981
|
-
|
982
|
-
|
983
|
-
|
984
|
-
|
985
|
-
|
986
|
-
|
987
|
-
|
988
|
-
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
993
|
-
|
994
|
-
|
995
|
-
|
932
|
+
|
933
|
+
if get_bool_env_var("CK_MOE"):
|
934
|
+
assert not no_combine, f"{no_combine=} is not supported."
|
935
|
+
if self.block_quant:
|
936
|
+
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
937
|
+
assert (
|
938
|
+
activation == "silu"
|
939
|
+
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
940
|
+
return asm_moe(
|
941
|
+
x,
|
942
|
+
layer.w13_weight,
|
943
|
+
layer.w2_weight,
|
944
|
+
topk_weights,
|
945
|
+
topk_ids,
|
946
|
+
layer.w13_weight_scale_inv,
|
947
|
+
layer.w2_weight_scale_inv,
|
948
|
+
block_shape=tuple(self.quant_config.weight_block_size),
|
949
|
+
expert_mask=None,
|
950
|
+
)
|
951
|
+
else:
|
952
|
+
return ck_moe_2stages(
|
953
|
+
x,
|
954
|
+
layer.w13_weight,
|
955
|
+
layer.w2_weight,
|
956
|
+
topk_weights,
|
957
|
+
topk_ids,
|
958
|
+
layer.w13_weight_scale1,
|
959
|
+
layer.w2_weight_scale1,
|
960
|
+
activation=(
|
961
|
+
ActivationType.Silu
|
962
|
+
if activation == "silu"
|
963
|
+
else ActivationType.Gelu
|
964
|
+
),
|
965
|
+
)
|
966
|
+
|
967
|
+
# Expert fusion with FP8 quantization
|
968
|
+
return fused_experts(
|
969
|
+
x,
|
970
|
+
layer.w13_weight,
|
971
|
+
layer.w2_weight,
|
972
|
+
topk_weights=topk_weights,
|
973
|
+
topk_ids=topk_ids,
|
974
|
+
inplace=inplace and not no_combine,
|
975
|
+
activation=activation,
|
976
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
977
|
+
use_fp8_w8a8=True,
|
978
|
+
w1_scale=(
|
979
|
+
layer.w13_weight_scale_inv
|
980
|
+
if self.block_quant
|
981
|
+
else layer.w13_weight_scale
|
982
|
+
),
|
983
|
+
w2_scale=(
|
984
|
+
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
985
|
+
),
|
986
|
+
a1_scale=layer.w13_input_scale,
|
987
|
+
a2_scale=layer.w2_input_scale,
|
988
|
+
block_shape=self.quant_config.weight_block_size,
|
989
|
+
no_combine=no_combine,
|
990
|
+
)
|
996
991
|
|
997
992
|
|
998
993
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|