sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,133 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import torch
|
17
|
+
|
18
|
+
|
19
|
+
# https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/quantization/qtensor/mxfp4_tensor.py
|
20
|
+
class MXFP4QuantizeUtil:
|
21
|
+
E2M1_max = 6.0
|
22
|
+
|
23
|
+
E2M1_values = [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
24
|
+
E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
|
25
|
+
|
26
|
+
@classmethod
|
27
|
+
def quantize(cls, input: torch.Tensor, block_size: int | None) -> tuple:
|
28
|
+
"""Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported.
|
29
|
+
Args:
|
30
|
+
input (torch.Tensor): The input tensor to be quantized.
|
31
|
+
block_sizes (dict | None): The block sizes for quantization.
|
32
|
+
"""
|
33
|
+
|
34
|
+
def cast_fp4(x):
|
35
|
+
sign = torch.sign(x)
|
36
|
+
sign_bit = (2 - sign) // 2
|
37
|
+
ord_ = torch.sum(
|
38
|
+
(x.abs().unsqueeze(-1) - cls.E2M1_bounds.to(x.device)) > 0, dim=-1
|
39
|
+
)
|
40
|
+
fp4_val = (sign_bit * 0b1000 + ord_).to(torch.uint8)
|
41
|
+
return fp4_val
|
42
|
+
|
43
|
+
def fuse_uint4_to_uint8(x):
|
44
|
+
# If the last dimension is odd, pad with zeros
|
45
|
+
# If this behavior is not desired, please modify the code accordingly
|
46
|
+
left_side = x[..., 0::2] # Even indices (0, 2, 4...)
|
47
|
+
right_side = x[..., 1::2] # Odd indices (1, 3, 5...)
|
48
|
+
new_data = (
|
49
|
+
right_side.clone() << 4
|
50
|
+
) # Put odd indices (higher addresses) in high bits
|
51
|
+
new_data[
|
52
|
+
..., : left_side.shape[-1]
|
53
|
+
] += left_side # Put even indices in low bits
|
54
|
+
return new_data
|
55
|
+
|
56
|
+
if block_size is None:
|
57
|
+
block_size = 32
|
58
|
+
|
59
|
+
original_shape = input.shape
|
60
|
+
original_dtype = input.dtype
|
61
|
+
input = input.view(-1, block_size)
|
62
|
+
# get scales
|
63
|
+
input_amax = input.abs().max(dim=-1, keepdim=True).values
|
64
|
+
descale = input_amax / cls.E2M1_max
|
65
|
+
min_value = torch.tensor(-127.0, device=descale.device)
|
66
|
+
e8m0_scale = torch.ceil(torch.maximum(torch.log2(descale), min_value))
|
67
|
+
|
68
|
+
input = (input / torch.exp2(e8m0_scale)).view(original_shape)
|
69
|
+
input_q = cast_fp4(input)
|
70
|
+
input_q = fuse_uint4_to_uint8(input_q)
|
71
|
+
e8m0_scale = (e8m0_scale + 127).to(torch.uint8)
|
72
|
+
return cls(original_shape, original_dtype, input_q), e8m0_scale
|
73
|
+
|
74
|
+
@classmethod
|
75
|
+
def dequantize(cls, quantized_data, dtype: torch.dtype, scale, block_sizes):
|
76
|
+
"""Dequantze MXFP4 packed tensor to a target dtype."""
|
77
|
+
|
78
|
+
def unfuse_uint8_to_uint4(x):
|
79
|
+
"""Unfuse uint8 values back to uint4 values.
|
80
|
+
This is the inverse operation of fuse_uint4_to_uint8.
|
81
|
+
"""
|
82
|
+
# Extract the lower 4 bits (even indices)
|
83
|
+
left_side = x & 0x0F
|
84
|
+
|
85
|
+
# Extract the upper 4 bits (odd indices)
|
86
|
+
right_side = (x >> 4) & 0x0F
|
87
|
+
|
88
|
+
# Create a new tensor with alternating values
|
89
|
+
shape = list(x.shape)
|
90
|
+
shape[-1] = shape[-1] * 2
|
91
|
+
result = torch.zeros(shape, dtype=torch.uint8, device=x.device)
|
92
|
+
|
93
|
+
# Fill in the values - even indices get low bits, odd indices get high bits
|
94
|
+
result[..., 0::2] = left_side # Even indices from low bits
|
95
|
+
result[..., 1::2] = right_side # Odd indices from high bits
|
96
|
+
|
97
|
+
return result
|
98
|
+
|
99
|
+
e8m0_scale = scale
|
100
|
+
block_size = block_sizes[-1]
|
101
|
+
|
102
|
+
# Unfuse the uint8 values back to uint4
|
103
|
+
x_unfused = unfuse_uint8_to_uint4(quantized_data)
|
104
|
+
# Extract sign and magnitude
|
105
|
+
sign = 1 - 2 * ((x_unfused & 0b1000) >> 3).to(
|
106
|
+
torch.float32
|
107
|
+
) # Extract sign bit and convert to +1/-1
|
108
|
+
magnitude = x_unfused & 0b0111 # Extract magnitude bits
|
109
|
+
magnitude = magnitude.to(torch.long)
|
110
|
+
|
111
|
+
# Create a tensor with the E2M1 values
|
112
|
+
values = torch.tensor(cls.E2M1_values, device=quantized_data.device)
|
113
|
+
|
114
|
+
# Use gather to index the values tensor properly
|
115
|
+
# We need to reshape magnitude to match the dimensions we want to gather along
|
116
|
+
original_shape = magnitude.shape
|
117
|
+
x_float = values[magnitude.reshape(-1)].reshape(original_shape)
|
118
|
+
|
119
|
+
# Apply sign and scale
|
120
|
+
x_float = sign.float() * x_float
|
121
|
+
|
122
|
+
# Reshape to apply block-wise scaling
|
123
|
+
x_float = x_float.reshape(-1, block_size)
|
124
|
+
|
125
|
+
# Apply the E8M0 scale
|
126
|
+
scale_factor = torch.exp2(e8m0_scale.float() - 127)
|
127
|
+
scale_factor = scale_factor.reshape(-1, 1) # Reshape for proper broadcasting
|
128
|
+
|
129
|
+
# Apply scaling and reshape back to original shape
|
130
|
+
x_float = x_float * scale_factor
|
131
|
+
|
132
|
+
# Reshape back to the original shape
|
133
|
+
return x_float.reshape(original_shape).to(dtype)
|
@@ -0,0 +1,55 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
__all__ = ["QuarkScheme"]
|
9
|
+
|
10
|
+
|
11
|
+
class QuarkScheme(ABC):
|
12
|
+
"""
|
13
|
+
Abstract class used to describe the weight creation and forward pass
|
14
|
+
of different quantization schemes supported by Quark.
|
15
|
+
"""
|
16
|
+
|
17
|
+
@classmethod
|
18
|
+
@abstractmethod
|
19
|
+
def get_min_capability(cls) -> int:
|
20
|
+
"""
|
21
|
+
Get minimum device capability.
|
22
|
+
"""
|
23
|
+
raise NotImplementedError
|
24
|
+
|
25
|
+
@abstractmethod
|
26
|
+
def create_weights(self, *args, **kwargs):
|
27
|
+
"""
|
28
|
+
Weight creation for the particular scheme. Inputs to this function
|
29
|
+
|
30
|
+
"""
|
31
|
+
raise NotImplementedError
|
32
|
+
|
33
|
+
@abstractmethod
|
34
|
+
def apply_weights(
|
35
|
+
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
|
36
|
+
):
|
37
|
+
"""
|
38
|
+
Run the forward pass for the particular scheme. This is where
|
39
|
+
scheme-specific dequant/quant steps/kernels should be applied.
|
40
|
+
|
41
|
+
:param layer: torch.nn.Module with the registered weights and
|
42
|
+
other parameters relevant to the particular scheme.
|
43
|
+
:param x: input to the layer
|
44
|
+
:param bias: bias parameter
|
45
|
+
|
46
|
+
"""
|
47
|
+
raise NotImplementedError
|
48
|
+
|
49
|
+
@abstractmethod
|
50
|
+
def process_weights_after_loading(self, layer: torch.nn.Module):
|
51
|
+
"""
|
52
|
+
Called after weight loading is complete for any cleanup that
|
53
|
+
needs to occur.
|
54
|
+
"""
|
55
|
+
raise NotImplementedError
|
@@ -0,0 +1,118 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
from typing import Any, Callable, Optional
|
4
|
+
|
5
|
+
import aiter
|
6
|
+
import torch
|
7
|
+
import torch.nn.functional as F
|
8
|
+
from aiter.ops.gemm_op_a4w4 import gemm_a4w4
|
9
|
+
from aiter.ops.shuffle import shuffle_weight
|
10
|
+
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
11
|
+
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
12
|
+
from aiter.utility import dtypes
|
13
|
+
from aiter.utility.fp4_utils import e8m0_shuffle
|
14
|
+
|
15
|
+
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
16
|
+
from sglang.srt.layers.quantization.quark.schemes import QuarkScheme
|
17
|
+
from sglang.srt.utils import get_bool_env_var
|
18
|
+
|
19
|
+
__all__ = ["QuarkW4A4MXFP4"]
|
20
|
+
|
21
|
+
OCP_MX_BLOCK_SIZE = 32
|
22
|
+
|
23
|
+
|
24
|
+
class QuarkW4A4MXFP4(QuarkScheme):
|
25
|
+
|
26
|
+
def __init__(
|
27
|
+
self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
|
28
|
+
):
|
29
|
+
self.out_dtype = torch.get_default_dtype()
|
30
|
+
self.qscheme = "per_group"
|
31
|
+
self.weight_quant_spec = weight_quant_spec
|
32
|
+
self.input_quant_spec = input_quant_spec
|
33
|
+
|
34
|
+
@classmethod
|
35
|
+
def get_min_capability(cls) -> int:
|
36
|
+
return 70
|
37
|
+
|
38
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
39
|
+
return
|
40
|
+
|
41
|
+
# for aiter implement
|
42
|
+
# wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16))
|
43
|
+
# w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0)
|
44
|
+
|
45
|
+
# layer.weight = torch.nn.Parameter(wshuffle,
|
46
|
+
# requires_grad=False)
|
47
|
+
# layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
|
48
|
+
# requires_grad=False)
|
49
|
+
|
50
|
+
def create_weights(
|
51
|
+
self,
|
52
|
+
layer: torch.nn.Module,
|
53
|
+
output_partition_sizes: list[int],
|
54
|
+
input_size_per_partition: int,
|
55
|
+
params_dtype: torch.dtype,
|
56
|
+
weight_loader: Callable,
|
57
|
+
**kwargs
|
58
|
+
):
|
59
|
+
output_size_per_partition = sum(output_partition_sizes)
|
60
|
+
layer.logical_widths = output_partition_sizes
|
61
|
+
|
62
|
+
# WEIGHT
|
63
|
+
weight = PackedvLLMParameter(
|
64
|
+
data=torch.empty(
|
65
|
+
output_size_per_partition,
|
66
|
+
input_size_per_partition // 2,
|
67
|
+
dtype=torch.uint8,
|
68
|
+
),
|
69
|
+
input_dim=1,
|
70
|
+
output_dim=0,
|
71
|
+
packed_dim=1,
|
72
|
+
packed_factor=2,
|
73
|
+
weight_loader=weight_loader,
|
74
|
+
)
|
75
|
+
layer.register_parameter("weight", weight)
|
76
|
+
|
77
|
+
# WEIGHT SCALE
|
78
|
+
weight_scale = GroupQuantScaleParameter(
|
79
|
+
data=torch.empty(
|
80
|
+
output_size_per_partition,
|
81
|
+
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
82
|
+
dtype=torch.uint8,
|
83
|
+
),
|
84
|
+
input_dim=1,
|
85
|
+
output_dim=0,
|
86
|
+
weight_loader=weight_loader,
|
87
|
+
)
|
88
|
+
layer.register_parameter("weight_scale", weight_scale)
|
89
|
+
|
90
|
+
def apply_weights(
|
91
|
+
self,
|
92
|
+
layer: torch.nn.Module,
|
93
|
+
x: torch.Tensor,
|
94
|
+
bias: Optional[torch.Tensor] = None,
|
95
|
+
) -> torch.Tensor:
|
96
|
+
|
97
|
+
out_dtype = x.dtype
|
98
|
+
# M = x.shape[0]
|
99
|
+
# N = layer.weight.shape[0]
|
100
|
+
|
101
|
+
# quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
|
102
|
+
# x, x_scales_shuffle = quant_func(x, shuffle=True)
|
103
|
+
|
104
|
+
# y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=self.out_dtype)
|
105
|
+
|
106
|
+
# out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias)
|
107
|
+
|
108
|
+
# return out[:M]
|
109
|
+
|
110
|
+
# triton implement
|
111
|
+
x_q, x_s = dynamic_mxfp4_quant(x)
|
112
|
+
y = torch.empty(
|
113
|
+
x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype
|
114
|
+
)
|
115
|
+
|
116
|
+
out = gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y)
|
117
|
+
|
118
|
+
return out
|
@@ -0,0 +1,107 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import re
|
4
|
+
from collections.abc import Iterable, Mapping
|
5
|
+
from types import MappingProxyType
|
6
|
+
from typing import Any, Optional
|
7
|
+
|
8
|
+
|
9
|
+
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
10
|
+
if type(dict1) is not type(dict2):
|
11
|
+
return False
|
12
|
+
if isinstance(dict1, dict):
|
13
|
+
if dict1.keys() != dict2.keys():
|
14
|
+
return False
|
15
|
+
return all(deep_compare(dict1[k], dict2[k]) for k in dict1)
|
16
|
+
elif isinstance(dict1, list):
|
17
|
+
return set(dict1) == set(dict2)
|
18
|
+
else:
|
19
|
+
return dict1 == dict2
|
20
|
+
|
21
|
+
|
22
|
+
def should_ignore_layer(
|
23
|
+
layer_name: Optional[str],
|
24
|
+
ignore: Iterable[str],
|
25
|
+
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
|
26
|
+
) -> bool:
|
27
|
+
if layer_name is None:
|
28
|
+
return False
|
29
|
+
|
30
|
+
# layer_name = model.layers.0.self_attn.qkv_proj
|
31
|
+
# proj_name = qkv_proj
|
32
|
+
proj_name = layer_name.split(".")[-1]
|
33
|
+
|
34
|
+
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
35
|
+
# in the safetensors checkpoint. So, we convert the name
|
36
|
+
# from the fused version to unfused + check to make sure that
|
37
|
+
# each shard of the fused layer has the same scheme.
|
38
|
+
if proj_name in fused_mapping:
|
39
|
+
shard_proj_names = fused_mapping[proj_name]
|
40
|
+
|
41
|
+
# Convert fused_name --> [shard_names]
|
42
|
+
shard_names = [
|
43
|
+
layer_name.replace(proj_name, shard_proj_name)
|
44
|
+
for shard_proj_name in shard_proj_names
|
45
|
+
]
|
46
|
+
|
47
|
+
# Layer should be ignored if shards are ignored.
|
48
|
+
should_ignore_layer = None
|
49
|
+
for shard_name in shard_names:
|
50
|
+
should_ignore_shard = check_equal_or_regex_match(
|
51
|
+
layer_name=shard_name, targets=ignore
|
52
|
+
)
|
53
|
+
|
54
|
+
# If shard_idx=0, set layer ignore to match shard.
|
55
|
+
if should_ignore_layer is None:
|
56
|
+
should_ignore_layer = should_ignore_shard
|
57
|
+
|
58
|
+
# If shard_idx=1+ confirm scheme matches prior shards.
|
59
|
+
elif should_ignore_shard != should_ignore_layer:
|
60
|
+
raise ValueError(
|
61
|
+
f"Found a different quantization schemes for "
|
62
|
+
f"{shard_proj_names} in {layer_name}. vLLM "
|
63
|
+
"requires all to use the same scheme."
|
64
|
+
)
|
65
|
+
|
66
|
+
# Unfused layers like down_proj and o_proj will match
|
67
|
+
# the safetensors checkpoint already.
|
68
|
+
else:
|
69
|
+
should_ignore_layer = check_equal_or_regex_match(
|
70
|
+
layer_name=layer_name, targets=ignore
|
71
|
+
)
|
72
|
+
|
73
|
+
assert should_ignore_layer is not None
|
74
|
+
|
75
|
+
return should_ignore_layer
|
76
|
+
|
77
|
+
|
78
|
+
def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
|
79
|
+
"""
|
80
|
+
Checks whether a layer_name is exactly equal or a regex match for
|
81
|
+
if target starts with 're:' to any target in list.
|
82
|
+
"""
|
83
|
+
for target in targets:
|
84
|
+
if _is_equal_or_regex_match(layer_name, target):
|
85
|
+
return True
|
86
|
+
return False
|
87
|
+
|
88
|
+
|
89
|
+
def _is_equal_or_regex_match(
|
90
|
+
value: str, target: str, check_contains: bool = False
|
91
|
+
) -> bool:
|
92
|
+
"""
|
93
|
+
Checks whether a value is exactly equal or a regex match for target
|
94
|
+
if target starts with 're:'. If check_contains is set to True,
|
95
|
+
additionally checks if the target string is contained within the value.
|
96
|
+
"""
|
97
|
+
|
98
|
+
if target.startswith("re:"):
|
99
|
+
pattern = target[3:]
|
100
|
+
if re.match(pattern, value):
|
101
|
+
return True
|
102
|
+
elif check_contains:
|
103
|
+
if target.lower() in value.lower():
|
104
|
+
return True
|
105
|
+
elif target == value:
|
106
|
+
return True
|
107
|
+
return False
|
@@ -129,14 +129,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
129
129
|
def __init__(self, use_triton_kernels: bool = False):
|
130
130
|
super().__init__()
|
131
131
|
self.use_triton_kernels = use_triton_kernels
|
132
|
+
self.with_bias = False
|
132
133
|
|
133
134
|
self.triton_kernel_moe_forward = None
|
135
|
+
self.triton_kernel_moe_with_bias_forward = None
|
134
136
|
if torch.cuda.is_available() and has_triton_kernels:
|
135
137
|
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
136
138
|
triton_kernel_moe_forward as _tk_forward,
|
137
139
|
)
|
140
|
+
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
141
|
+
triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
|
142
|
+
)
|
138
143
|
|
139
144
|
self.triton_kernel_moe_forward = _tk_forward
|
145
|
+
self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
|
140
146
|
|
141
147
|
def create_weights(
|
142
148
|
self,
|
@@ -145,8 +151,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
145
151
|
hidden_size: int,
|
146
152
|
intermediate_size: int,
|
147
153
|
params_dtype: torch.dtype,
|
154
|
+
with_bias: bool = False,
|
148
155
|
**extra_weight_attrs,
|
149
156
|
):
|
157
|
+
self.with_bias = with_bias
|
158
|
+
|
150
159
|
# Fused gate_up_proj (column parallel)
|
151
160
|
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
|
152
161
|
if self.use_triton_kernels:
|
@@ -158,6 +167,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
158
167
|
layer.register_parameter("w13_weight", w13_weight)
|
159
168
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
160
169
|
|
170
|
+
if self.with_bias:
|
171
|
+
w13_weight_bias = torch.nn.Parameter(
|
172
|
+
torch.empty(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
173
|
+
requires_grad=False,
|
174
|
+
)
|
175
|
+
layer.register_parameter("w13_weight_bias", w13_weight_bias)
|
176
|
+
set_weight_attrs(w13_weight_bias, extra_weight_attrs)
|
177
|
+
|
161
178
|
# down_proj (row parallel)
|
162
179
|
w2_weight_n, w2_weight_k = (
|
163
180
|
hidden_size,
|
@@ -172,6 +189,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
172
189
|
layer.register_parameter("w2_weight", w2_weight)
|
173
190
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
174
191
|
|
192
|
+
if self.with_bias:
|
193
|
+
w2_weight_bias = torch.nn.Parameter(
|
194
|
+
torch.empty(num_experts, hidden_size, dtype=torch.float32),
|
195
|
+
requires_grad=False,
|
196
|
+
)
|
197
|
+
layer.register_parameter("w2_weight_bias", w2_weight_bias)
|
198
|
+
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
|
199
|
+
|
175
200
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
176
201
|
if _use_aiter:
|
177
202
|
layer.w13_weight = torch.nn.Parameter(
|
@@ -202,7 +227,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
202
227
|
inplace: bool = True,
|
203
228
|
no_combine: bool = False,
|
204
229
|
routed_scaling_factor: Optional[float] = None,
|
230
|
+
activation_alpha: Optional[float] = None,
|
231
|
+
swiglu_limit: Optional[float] = None,
|
205
232
|
) -> torch.Tensor:
|
233
|
+
kwargs = {}
|
234
|
+
if activation_alpha is not None:
|
235
|
+
kwargs["activation_alpha"] = activation_alpha
|
236
|
+
if swiglu_limit is not None:
|
237
|
+
kwargs["swiglu_limit"] = swiglu_limit
|
206
238
|
|
207
239
|
return self.forward(
|
208
240
|
x=x,
|
@@ -213,6 +245,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
213
245
|
inplace=inplace,
|
214
246
|
no_combine=no_combine,
|
215
247
|
routed_scaling_factor=routed_scaling_factor,
|
248
|
+
**kwargs,
|
216
249
|
)
|
217
250
|
|
218
251
|
def forward_cuda(
|
@@ -226,15 +259,32 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
226
259
|
inplace: bool = True,
|
227
260
|
no_combine: bool = False,
|
228
261
|
routed_scaling_factor: Optional[float] = None,
|
262
|
+
activation_alpha: Optional[float] = None,
|
263
|
+
swiglu_limit: Optional[float] = None,
|
229
264
|
) -> torch.Tensor:
|
230
265
|
|
231
266
|
if self.use_triton_kernels:
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
267
|
+
if self.with_bias:
|
268
|
+
return self.triton_kernel_moe_with_bias_forward(
|
269
|
+
hidden_states=x,
|
270
|
+
w1=layer.w13_weight,
|
271
|
+
w2=layer.w2_weight,
|
272
|
+
b1=layer.w13_weight_bias,
|
273
|
+
b2=layer.w2_weight_bias,
|
274
|
+
topk_output=topk_output,
|
275
|
+
activation=activation,
|
276
|
+
activation_alpha=activation_alpha,
|
277
|
+
swiglu_limit=swiglu_limit,
|
278
|
+
w1_pcg=None,
|
279
|
+
w2_pcg=None,
|
280
|
+
)
|
281
|
+
else:
|
282
|
+
return self.triton_kernel_moe_forward(
|
283
|
+
hidden_states=x,
|
284
|
+
w1=layer.w13_weight,
|
285
|
+
w2=layer.w2_weight,
|
286
|
+
topk_output=topk_output,
|
287
|
+
)
|
238
288
|
else:
|
239
289
|
if _use_aiter:
|
240
290
|
assert not no_combine, "unsupported"
|
@@ -272,12 +322,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
272
322
|
hidden_states=x,
|
273
323
|
w1=layer.w13_weight,
|
274
324
|
w2=layer.w2_weight,
|
325
|
+
b1=getattr(layer, "w13_weight_bias", None),
|
326
|
+
b2=getattr(layer, "w2_weight_bias", None),
|
275
327
|
topk_output=topk_output,
|
276
328
|
inplace=inplace and not no_combine,
|
277
329
|
activation=activation,
|
278
330
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
279
331
|
no_combine=no_combine,
|
280
332
|
routed_scaling_factor=routed_scaling_factor,
|
333
|
+
activation_alpha=activation_alpha,
|
334
|
+
swiglu_limit=swiglu_limit,
|
281
335
|
)
|
282
336
|
|
283
337
|
def forward_cpu(
|
@@ -116,6 +116,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
116
116
|
params_dtype: torch.dtype,
|
117
117
|
**extra_weight_attrs,
|
118
118
|
):
|
119
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
120
|
+
|
119
121
|
assert "weight_loader" in extra_weight_attrs
|
120
122
|
|
121
123
|
# Fused gate_up_proj (column parallel)
|
@@ -144,6 +146,9 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
144
146
|
layer.register_parameter("w2_weight", w2_weight)
|
145
147
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
146
148
|
|
149
|
+
extra_weight_attrs.update(
|
150
|
+
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
|
151
|
+
)
|
147
152
|
w13_weight_scale = torch.nn.Parameter(
|
148
153
|
torch.zeros(
|
149
154
|
num_experts,
|
@@ -274,29 +279,30 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
274
279
|
def apply(
|
275
280
|
self,
|
276
281
|
layer: EPMoE,
|
277
|
-
|
282
|
+
x: torch.Tensor,
|
278
283
|
topk_output: TopKOutput,
|
284
|
+
activation: str = "silu",
|
285
|
+
apply_router_weight_on_input: bool = False,
|
286
|
+
routed_scaling_factor: Optional[float] = None,
|
279
287
|
**kwargs,
|
280
288
|
) -> torch.Tensor:
|
281
289
|
|
282
290
|
# TODO(ch-wan): move it out of this class
|
283
291
|
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
284
292
|
|
285
|
-
|
293
|
+
topk_weights, topk_ids, _ = topk_output
|
286
294
|
local_topk_ids = topk_ids
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
return cutlass_w4a8_moe(
|
295
|
+
local_topk_ids = torch.where(
|
296
|
+
topk_ids == -1,
|
297
|
+
layer.num_experts,
|
298
|
+
topk_ids,
|
299
|
+
)
|
300
|
+
|
301
|
+
output = cutlass_w4a8_moe(
|
296
302
|
layer.start_expert_id,
|
297
303
|
layer.end_expert_id,
|
298
304
|
layer.num_experts,
|
299
|
-
|
305
|
+
x,
|
300
306
|
layer.w13_weight,
|
301
307
|
layer.w2_weight,
|
302
308
|
layer.w13_weight_scale_inv,
|
@@ -318,3 +324,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|
318
324
|
layer.w13_input_scale,
|
319
325
|
layer.w2_input_scale,
|
320
326
|
)
|
327
|
+
if routed_scaling_factor is not None:
|
328
|
+
output *= routed_scaling_factor
|
329
|
+
return output
|