sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -52,6 +52,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
52
52
|
apply_w8a8_block_fp8_linear,
|
53
53
|
cutlass_fp8_supported,
|
54
54
|
input_to_float8,
|
55
|
+
is_sm100_supported,
|
55
56
|
normalize_e4m3fn_to_e4m3fnuz,
|
56
57
|
)
|
57
58
|
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
@@ -235,7 +236,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
235
236
|
f"{input_size_per_partition} is not divisible by "
|
236
237
|
f"weight quantization block_k = {block_k}."
|
237
238
|
)
|
238
|
-
# Required by
|
239
|
+
# Required by column parallel or enabling merged weights
|
239
240
|
if (
|
240
241
|
tp_size > 1 and output_size // output_size_per_partition == tp_size
|
241
242
|
) or len(output_partition_sizes) > 1:
|
@@ -470,6 +471,7 @@ class Fp8MoEMethod:
|
|
470
471
|
def __init__(self, quant_config):
|
471
472
|
self.quant_config = quant_config
|
472
473
|
self.block_quant = self.quant_config.weight_block_size is not None
|
474
|
+
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
473
475
|
|
474
476
|
def create_weights(
|
475
477
|
self,
|
@@ -491,7 +493,7 @@ class Fp8MoEMethod:
|
|
491
493
|
self.quant_config.weight_block_size[1],
|
492
494
|
)
|
493
495
|
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
494
|
-
# Required by
|
496
|
+
# Required by column parallel or enabling merged weights
|
495
497
|
if intermediate_size % block_n != 0:
|
496
498
|
raise ValueError(
|
497
499
|
f"The output_size of gate's and up's weight = "
|
@@ -568,6 +570,63 @@ class Fp8MoEMethod:
|
|
568
570
|
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
569
571
|
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
570
572
|
assert self.quant_config.activation_scheme == "dynamic"
|
573
|
+
if (
|
574
|
+
get_bool_env_var("CUTLASS_MOE")
|
575
|
+
and self.cutlass_fp8_supported
|
576
|
+
and is_sm100_supported()
|
577
|
+
):
|
578
|
+
self.ab_strides1 = torch.full(
|
579
|
+
(num_experts,),
|
580
|
+
hidden_size,
|
581
|
+
device=w13_weight.device,
|
582
|
+
dtype=torch.int64,
|
583
|
+
)
|
584
|
+
self.c_strides1 = torch.full(
|
585
|
+
(num_experts,),
|
586
|
+
2 * intermediate_size,
|
587
|
+
device=w13_weight.device,
|
588
|
+
dtype=torch.int64,
|
589
|
+
)
|
590
|
+
self.ab_strides2 = torch.full(
|
591
|
+
(num_experts,),
|
592
|
+
intermediate_size,
|
593
|
+
device=w2_weight.device,
|
594
|
+
dtype=torch.int64,
|
595
|
+
)
|
596
|
+
self.c_strides2 = torch.full(
|
597
|
+
(num_experts,),
|
598
|
+
hidden_size,
|
599
|
+
device=w2_weight.device,
|
600
|
+
dtype=torch.int64,
|
601
|
+
)
|
602
|
+
self.workspace = torch.empty(
|
603
|
+
90000, device=w13_weight.device, dtype=torch.uint8
|
604
|
+
)
|
605
|
+
self.a_ptr = torch.empty(
|
606
|
+
num_experts, device=w13_weight.device, dtype=torch.int64
|
607
|
+
)
|
608
|
+
self.b_ptr = torch.empty(
|
609
|
+
num_experts, device=w13_weight.device, dtype=torch.int64
|
610
|
+
)
|
611
|
+
self.out_ptr = torch.empty(
|
612
|
+
num_experts, device=w13_weight.device, dtype=torch.int64
|
613
|
+
)
|
614
|
+
self.a_scales_ptr = torch.empty(
|
615
|
+
num_experts, device=w13_weight.device, dtype=torch.int64
|
616
|
+
)
|
617
|
+
self.b_scales_ptr = torch.empty(
|
618
|
+
num_experts, device=w13_weight.device, dtype=torch.int64
|
619
|
+
)
|
620
|
+
self.expert_offsets = torch.empty(
|
621
|
+
num_experts + 1, device=w13_weight.device, dtype=torch.int32
|
622
|
+
)
|
623
|
+
self.problem_sizes1 = torch.empty(
|
624
|
+
num_experts, 3, device=w13_weight.device, dtype=torch.int32
|
625
|
+
)
|
626
|
+
self.problem_sizes2 = torch.empty(
|
627
|
+
num_experts, 3, device=w13_weight.device, dtype=torch.int32
|
628
|
+
)
|
629
|
+
|
571
630
|
else:
|
572
631
|
# Allocate 2 scales for w1 and w3 respectively.
|
573
632
|
# They will be combined to a single scale after weight loading.
|
@@ -913,6 +972,37 @@ class Fp8MoEMethod:
|
|
913
972
|
if ret is not None:
|
914
973
|
return ret
|
915
974
|
|
975
|
+
if (
|
976
|
+
get_bool_env_var("CUTLASS_MOE")
|
977
|
+
and self.cutlass_fp8_supported
|
978
|
+
and self.block_quant
|
979
|
+
and is_sm100_supported()
|
980
|
+
):
|
981
|
+
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts
|
982
|
+
|
983
|
+
return cutlass_fused_experts(
|
984
|
+
x,
|
985
|
+
layer.w13_weight.transpose(1, 2),
|
986
|
+
layer.w2_weight.transpose(1, 2),
|
987
|
+
layer.w13_weight_scale_inv.transpose(1, 2),
|
988
|
+
layer.w2_weight_scale_inv.transpose(1, 2),
|
989
|
+
topk_weights,
|
990
|
+
topk_ids,
|
991
|
+
self.ab_strides1,
|
992
|
+
self.c_strides1,
|
993
|
+
self.ab_strides2,
|
994
|
+
self.c_strides2,
|
995
|
+
self.workspace,
|
996
|
+
self.a_ptr,
|
997
|
+
self.b_ptr,
|
998
|
+
self.out_ptr,
|
999
|
+
self.a_scales_ptr,
|
1000
|
+
self.b_scales_ptr,
|
1001
|
+
self.expert_offsets,
|
1002
|
+
self.problem_sizes1,
|
1003
|
+
self.problem_sizes2,
|
1004
|
+
use_fp8_blockscale=True,
|
1005
|
+
)
|
916
1006
|
# Expert fusion with FP8 quantization
|
917
1007
|
return fused_experts(
|
918
1008
|
x,
|
@@ -104,7 +104,7 @@ def _per_token_group_quant_fp8(
|
|
104
104
|
y_s_ptr,
|
105
105
|
# Stride of input
|
106
106
|
y_stride,
|
107
|
-
#
|
107
|
+
# Columns of input
|
108
108
|
N,
|
109
109
|
# Avoid to divide zero
|
110
110
|
eps,
|
@@ -342,7 +342,7 @@ def _static_quant_fp8(
|
|
342
342
|
y_s_repeat_ptr,
|
343
343
|
# Stride of input
|
344
344
|
y_stride,
|
345
|
-
#
|
345
|
+
# Columns of input
|
346
346
|
N,
|
347
347
|
# Information for float8
|
348
348
|
fp8_min,
|
@@ -794,7 +794,7 @@ def w8a8_block_fp8_matmul(
|
|
794
794
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
795
795
|
else:
|
796
796
|
# Default config
|
797
|
-
# Block-wise quant: BLOCK_SIZE_K must be
|
797
|
+
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
|
798
798
|
config = {
|
799
799
|
"BLOCK_SIZE_M": 64,
|
800
800
|
"BLOCK_SIZE_N": block_size[0],
|
@@ -80,6 +80,12 @@ def cutlass_fp8_supported():
|
|
80
80
|
return False
|
81
81
|
|
82
82
|
|
83
|
+
def is_sm100_supported(device=None) -> bool:
|
84
|
+
return (torch.cuda.get_device_capability(device)[0] == 10) and (
|
85
|
+
torch.version.cuda >= "12.8"
|
86
|
+
)
|
87
|
+
|
88
|
+
|
83
89
|
def normalize_e4m3fn_to_e4m3fnuz(
|
84
90
|
weight: torch.Tensor,
|
85
91
|
weight_scale: torch.Tensor,
|
@@ -1,21 +1,28 @@
|
|
1
1
|
import logging
|
2
2
|
from fractions import Fraction
|
3
|
-
from typing import Any, Dict, List, Optional, Union
|
3
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
4
4
|
|
5
5
|
import torch
|
6
6
|
|
7
|
-
from sglang.srt.layers.linear import LinearBase
|
8
|
-
from sglang.srt.layers.quantization.base_config import
|
7
|
+
from sglang.srt.layers.linear import LinearBase, set_weight_attrs
|
8
|
+
from sglang.srt.layers.quantization.base_config import (
|
9
|
+
QuantizationConfig,
|
10
|
+
QuantizeMethodBase,
|
11
|
+
)
|
12
|
+
from sglang.srt.layers.quantization.utils import replace_parameter
|
9
13
|
from sglang.srt.utils import is_cuda
|
10
14
|
|
11
15
|
_is_cuda = is_cuda()
|
12
16
|
|
13
17
|
try:
|
14
|
-
from vllm
|
18
|
+
from vllm import _custom_ops as ops
|
15
19
|
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
16
20
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
21
|
+
FusedMoE,
|
22
|
+
FusedMoEMethodBase,
|
23
|
+
FusedMoeWeightScaleSupported,
|
17
24
|
GPTQMarlinLinearMethod,
|
18
|
-
|
25
|
+
marlin_moe_permute_scales,
|
19
26
|
)
|
20
27
|
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
21
28
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
@@ -27,7 +34,9 @@ try:
|
|
27
34
|
except ImportError:
|
28
35
|
VLLM_AVAILABLE = False
|
29
36
|
|
30
|
-
GPTQLinearMethod = MarlinLinearMethod =
|
37
|
+
GPTQLinearMethod = MarlinLinearMethod = Any
|
38
|
+
|
39
|
+
FusedMoEMethodBase = QuantizeMethodBase
|
31
40
|
|
32
41
|
class scalar_types:
|
33
42
|
uint4b8 = "uint4b8"
|
@@ -437,3 +446,286 @@ class MarlinConfig(QuantizationConfig):
|
|
437
446
|
):
|
438
447
|
return MarlinLinearMethod(self)
|
439
448
|
return None
|
449
|
+
|
450
|
+
|
451
|
+
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
452
|
+
"""MoE Marlin method with quantization."""
|
453
|
+
|
454
|
+
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
455
|
+
self.quant_config = quant_config
|
456
|
+
|
457
|
+
def create_weights(
|
458
|
+
self,
|
459
|
+
layer: torch.nn.Module,
|
460
|
+
num_experts: int,
|
461
|
+
hidden_size: int,
|
462
|
+
intermediate_size_per_partition: int,
|
463
|
+
params_dtype: torch.dtype,
|
464
|
+
**extra_weight_attrs,
|
465
|
+
):
|
466
|
+
intermediate_size = extra_weight_attrs.pop("intermediate_size")
|
467
|
+
|
468
|
+
self.is_k_full = (not self.quant_config.desc_act) or (
|
469
|
+
intermediate_size_per_partition == intermediate_size
|
470
|
+
)
|
471
|
+
|
472
|
+
if self.quant_config.group_size != -1:
|
473
|
+
scales_size13 = hidden_size // self.quant_config.group_size
|
474
|
+
w2_scales_size = (
|
475
|
+
intermediate_size
|
476
|
+
if self.quant_config.desc_act
|
477
|
+
else intermediate_size_per_partition
|
478
|
+
)
|
479
|
+
scales_size2 = w2_scales_size // self.quant_config.group_size
|
480
|
+
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
481
|
+
else:
|
482
|
+
scales_size13 = 1
|
483
|
+
scales_size2 = 1
|
484
|
+
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
|
485
|
+
|
486
|
+
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True})
|
487
|
+
# Fused gate_up_proj (column parallel)
|
488
|
+
w13_qweight = torch.nn.Parameter(
|
489
|
+
torch.empty(
|
490
|
+
num_experts,
|
491
|
+
hidden_size // self.quant_config.pack_factor,
|
492
|
+
2 * intermediate_size_per_partition,
|
493
|
+
dtype=torch.int32,
|
494
|
+
),
|
495
|
+
requires_grad=False,
|
496
|
+
)
|
497
|
+
layer.register_parameter("w13_qweight", w13_qweight)
|
498
|
+
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
499
|
+
# down_proj (row parallel)
|
500
|
+
w2_qweight = torch.nn.Parameter(
|
501
|
+
torch.empty(
|
502
|
+
num_experts,
|
503
|
+
intermediate_size_per_partition // self.quant_config.pack_factor,
|
504
|
+
hidden_size,
|
505
|
+
dtype=torch.int32,
|
506
|
+
),
|
507
|
+
requires_grad=False,
|
508
|
+
)
|
509
|
+
layer.register_parameter("w2_qweight", w2_qweight)
|
510
|
+
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
511
|
+
# up_proj scales
|
512
|
+
w13_scales = torch.nn.Parameter(
|
513
|
+
torch.empty(
|
514
|
+
num_experts,
|
515
|
+
scales_size13,
|
516
|
+
2 * intermediate_size_per_partition,
|
517
|
+
dtype=torch.half,
|
518
|
+
),
|
519
|
+
requires_grad=False,
|
520
|
+
)
|
521
|
+
layer.register_parameter("w13_scales", w13_scales)
|
522
|
+
set_weight_attrs(w13_scales, extra_weight_attrs)
|
523
|
+
# down_proj scales
|
524
|
+
w2_scales = torch.nn.Parameter(
|
525
|
+
torch.empty(num_experts, scales_size2, hidden_size, dtype=torch.half),
|
526
|
+
requires_grad=False,
|
527
|
+
)
|
528
|
+
layer.register_parameter("w2_scales", w2_scales)
|
529
|
+
set_weight_attrs(w2_scales, extra_weight_attrs)
|
530
|
+
# dont shard the w2 scales when running act order
|
531
|
+
set_weight_attrs(w2_scales, {"load_full_w2": self.quant_config.desc_act})
|
532
|
+
# up_proj scales
|
533
|
+
w13_qzeros = torch.nn.Parameter(
|
534
|
+
torch.empty(
|
535
|
+
num_experts,
|
536
|
+
scales_size13,
|
537
|
+
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
|
538
|
+
dtype=params_dtype,
|
539
|
+
),
|
540
|
+
requires_grad=False,
|
541
|
+
)
|
542
|
+
layer.register_parameter("w13_qzeros", w13_qzeros)
|
543
|
+
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
544
|
+
# down_proj scales
|
545
|
+
w2_qzeros = torch.nn.Parameter(
|
546
|
+
torch.empty(
|
547
|
+
num_experts,
|
548
|
+
scales_size2,
|
549
|
+
hidden_size // self.quant_config.pack_factor,
|
550
|
+
dtype=params_dtype,
|
551
|
+
),
|
552
|
+
requires_grad=False,
|
553
|
+
)
|
554
|
+
layer.register_parameter("w2_qzeros", w2_qzeros)
|
555
|
+
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
556
|
+
# dont shard the w2 scales when running act order
|
557
|
+
set_weight_attrs(w2_qzeros, {"load_full_w2": self.quant_config.desc_act})
|
558
|
+
w13_g_idx = torch.nn.Parameter(
|
559
|
+
torch.empty(
|
560
|
+
num_experts,
|
561
|
+
hidden_size,
|
562
|
+
dtype=torch.int32,
|
563
|
+
),
|
564
|
+
requires_grad=False,
|
565
|
+
)
|
566
|
+
layer.register_parameter("w13_g_idx", w13_g_idx)
|
567
|
+
set_weight_attrs(w13_g_idx, extra_weight_attrs)
|
568
|
+
w2_g_idx = torch.nn.Parameter(
|
569
|
+
torch.empty(
|
570
|
+
num_experts,
|
571
|
+
intermediate_size_per_partition,
|
572
|
+
dtype=torch.int32,
|
573
|
+
),
|
574
|
+
requires_grad=False,
|
575
|
+
)
|
576
|
+
layer.register_parameter("w2_g_idx", w2_g_idx)
|
577
|
+
set_weight_attrs(w2_g_idx, extra_weight_attrs)
|
578
|
+
w13_g_idx_sort_indices = torch.nn.Parameter(
|
579
|
+
torch.empty(
|
580
|
+
num_experts,
|
581
|
+
hidden_size,
|
582
|
+
dtype=torch.int32,
|
583
|
+
),
|
584
|
+
requires_grad=False,
|
585
|
+
)
|
586
|
+
layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
587
|
+
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
|
588
|
+
w2_g_idx_sort_indices = torch.nn.Parameter(
|
589
|
+
torch.empty(
|
590
|
+
num_experts,
|
591
|
+
intermediate_size_per_partition,
|
592
|
+
dtype=torch.int32,
|
593
|
+
),
|
594
|
+
requires_grad=False,
|
595
|
+
)
|
596
|
+
layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
597
|
+
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
598
|
+
|
599
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
600
|
+
|
601
|
+
# Process act_order
|
602
|
+
if self.quant_config.desc_act:
|
603
|
+
# Get sorting based on g_idx
|
604
|
+
num_experts = layer.w13_g_idx.shape[0]
|
605
|
+
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
|
606
|
+
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
|
607
|
+
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
|
608
|
+
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
|
609
|
+
for e in range(num_experts):
|
610
|
+
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
|
611
|
+
torch.int32
|
612
|
+
)
|
613
|
+
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
|
614
|
+
torch.int32
|
615
|
+
)
|
616
|
+
w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
|
617
|
+
w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
|
618
|
+
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
|
619
|
+
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
|
620
|
+
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
621
|
+
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
622
|
+
else:
|
623
|
+
# Reset g_idx related tensors
|
624
|
+
num_experts = layer.w13_g_idx.shape[0]
|
625
|
+
device = layer.w13_g_idx.device
|
626
|
+
layer.w13_g_idx = torch.nn.Parameter(
|
627
|
+
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
628
|
+
requires_grad=False,
|
629
|
+
)
|
630
|
+
layer.w2_g_idx = torch.nn.Parameter(
|
631
|
+
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
632
|
+
requires_grad=False,
|
633
|
+
)
|
634
|
+
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
635
|
+
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
636
|
+
requires_grad=False,
|
637
|
+
)
|
638
|
+
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
639
|
+
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
640
|
+
requires_grad=False,
|
641
|
+
)
|
642
|
+
# Repack weights
|
643
|
+
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
644
|
+
layer.w13_qweight,
|
645
|
+
layer.w13_g_idx_sort_indices,
|
646
|
+
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
|
647
|
+
layer.w13_qweight.shape[2],
|
648
|
+
self.quant_config.quant_type.size_bits,
|
649
|
+
)
|
650
|
+
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
651
|
+
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
652
|
+
layer.w2_qweight,
|
653
|
+
layer.w2_g_idx_sort_indices,
|
654
|
+
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
|
655
|
+
layer.w2_qweight.shape[2],
|
656
|
+
self.quant_config.quant_type.size_bits,
|
657
|
+
)
|
658
|
+
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
659
|
+
# Repack scales
|
660
|
+
marlin_w13_scales = marlin_moe_permute_scales(
|
661
|
+
s=layer.w13_scales,
|
662
|
+
size_k=layer.intermediate_size_per_partition,
|
663
|
+
size_n=layer.w13_scales.shape[2],
|
664
|
+
group_size=self.quant_config.group_size,
|
665
|
+
)
|
666
|
+
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
667
|
+
marlin_w2_scales = marlin_moe_permute_scales(
|
668
|
+
s=layer.w2_scales,
|
669
|
+
size_k=layer.w2_scales.shape[1]
|
670
|
+
* (
|
671
|
+
self.quant_config.group_size
|
672
|
+
if self.quant_config.group_size != -1
|
673
|
+
else self.quant_config.pack_factor
|
674
|
+
),
|
675
|
+
size_n=layer.w2_scales.shape[2],
|
676
|
+
group_size=self.quant_config.group_size,
|
677
|
+
)
|
678
|
+
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
679
|
+
|
680
|
+
def apply(
|
681
|
+
self,
|
682
|
+
layer: torch.nn.Module,
|
683
|
+
x: torch.Tensor,
|
684
|
+
router_logits: torch.Tensor,
|
685
|
+
top_k: int,
|
686
|
+
renormalize: bool,
|
687
|
+
use_grouped_topk: bool = False,
|
688
|
+
topk_group: Optional[int] = None,
|
689
|
+
num_expert_group: Optional[int] = None,
|
690
|
+
global_num_experts: int = -1,
|
691
|
+
expert_map: Optional[torch.Tensor] = None,
|
692
|
+
custom_routing_function: Optional[Callable] = None,
|
693
|
+
scoring_func: str = "softmax",
|
694
|
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
695
|
+
activation: str = "silu",
|
696
|
+
) -> torch.Tensor:
|
697
|
+
assert activation == "silu", "Only SiLU activation is supported."
|
698
|
+
|
699
|
+
# The input must currently be float16
|
700
|
+
orig_dtype = x.dtype
|
701
|
+
x = x.half()
|
702
|
+
|
703
|
+
topk_weights, topk_ids = FusedMoE.select_experts(
|
704
|
+
hidden_states=x,
|
705
|
+
router_logits=router_logits,
|
706
|
+
use_grouped_topk=use_grouped_topk,
|
707
|
+
top_k=top_k,
|
708
|
+
renormalize=renormalize,
|
709
|
+
topk_group=topk_group,
|
710
|
+
num_expert_group=num_expert_group,
|
711
|
+
custom_routing_function=custom_routing_function,
|
712
|
+
scoring_func=scoring_func,
|
713
|
+
e_score_correction_bias=e_score_correction_bias,
|
714
|
+
)
|
715
|
+
|
716
|
+
return torch.ops.vllm.fused_marlin_moe(
|
717
|
+
x,
|
718
|
+
layer.w13_qweight,
|
719
|
+
layer.w2_qweight,
|
720
|
+
layer.w13_scales,
|
721
|
+
layer.w2_scales,
|
722
|
+
router_logits,
|
723
|
+
topk_weights,
|
724
|
+
topk_ids,
|
725
|
+
g_idx1=layer.w13_g_idx,
|
726
|
+
g_idx2=layer.w2_g_idx,
|
727
|
+
sort_indices1=layer.w13_g_idx_sort_indices,
|
728
|
+
sort_indices2=layer.w2_g_idx_sort_indices,
|
729
|
+
num_bits=self.quant_config.quant_type.size_bits,
|
730
|
+
is_k_full=self.is_k_full,
|
731
|
+
).to(orig_dtype)
|
@@ -22,9 +22,11 @@ def _per_token_quant_int8(
|
|
22
22
|
x_ptr,
|
23
23
|
xq_ptr,
|
24
24
|
scale_ptr,
|
25
|
+
x_sum_ptr,
|
25
26
|
stride_x,
|
26
27
|
stride_xq,
|
27
28
|
N,
|
29
|
+
CAL_SUM: tl.constexpr,
|
28
30
|
BLOCK: tl.constexpr,
|
29
31
|
):
|
30
32
|
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
|
@@ -38,16 +40,23 @@ def _per_token_quant_int8(
|
|
38
40
|
scale_x = absmax / 127
|
39
41
|
x_q = x * (127 / absmax)
|
40
42
|
x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
|
43
|
+
if CAL_SUM:
|
44
|
+
x_sum = tl.sum(x, axis=0)
|
45
|
+
tl.store(x_sum_ptr + row_id, x_sum.to(x_sum_ptr.dtype.element_ty))
|
41
46
|
|
42
47
|
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
|
43
|
-
tl.store(scale_ptr + row_id, scale_x)
|
48
|
+
tl.store(scale_ptr + row_id, scale_x.to(scale_ptr.dtype.element_ty))
|
44
49
|
|
45
50
|
|
46
|
-
def per_token_quant_int8(x):
|
51
|
+
def per_token_quant_int8(x, scale_dtype=torch.float32, cal_sum=False):
|
47
52
|
M = x.numel() // x.shape[-1]
|
48
53
|
N = x.shape[-1]
|
49
54
|
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
|
50
|
-
scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=
|
55
|
+
scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=scale_dtype)
|
56
|
+
if cal_sum:
|
57
|
+
x_sum = torch.empty(x.shape[:-1], device=x.device, dtype=x.dtype)
|
58
|
+
else:
|
59
|
+
x_sum = None
|
51
60
|
BLOCK = triton.next_power_of_2(N)
|
52
61
|
# heuristics for number of warps
|
53
62
|
num_warps = min(max(BLOCK // 256, 1), 8)
|
@@ -57,15 +66,19 @@ def per_token_quant_int8(x):
|
|
57
66
|
x,
|
58
67
|
x_q,
|
59
68
|
scales,
|
69
|
+
x_sum,
|
60
70
|
stride_x=x.stride(-2),
|
61
71
|
stride_xq=x_q.stride(-2),
|
62
72
|
N=N,
|
73
|
+
CAL_SUM=cal_sum,
|
63
74
|
BLOCK=BLOCK,
|
64
75
|
num_warps=num_warps,
|
65
76
|
num_stages=1,
|
66
77
|
)
|
67
|
-
|
68
|
-
|
78
|
+
if cal_sum:
|
79
|
+
return x_q, scales, x_sum
|
80
|
+
else:
|
81
|
+
return x_q, scales
|
69
82
|
|
70
83
|
|
71
84
|
@triton.jit
|
@@ -76,7 +89,7 @@ def _per_token_group_quant_int8(
|
|
76
89
|
y_s_ptr,
|
77
90
|
# Stride of input
|
78
91
|
y_stride,
|
79
|
-
#
|
92
|
+
# Columns of input
|
80
93
|
N,
|
81
94
|
# Avoid to divide zero
|
82
95
|
eps,
|
@@ -370,7 +383,7 @@ def w8a8_block_int8_matmul(
|
|
370
383
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
371
384
|
else:
|
372
385
|
# Default config
|
373
|
-
# Block-wise quant: BLOCK_SIZE_K must be
|
386
|
+
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
|
374
387
|
config = {
|
375
388
|
"BLOCK_SIZE_M": 64,
|
376
389
|
"BLOCK_SIZE_N": block_size[0],
|