sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,133 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+
18
+
19
+ # https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/quantization/qtensor/mxfp4_tensor.py
20
+ class MXFP4QuantizeUtil:
21
+ E2M1_max = 6.0
22
+
23
+ E2M1_values = [0, 0.5, 1, 1.5, 2, 3, 4, 6]
24
+ E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
25
+
26
+ @classmethod
27
+ def quantize(cls, input: torch.Tensor, block_size: int | None) -> tuple:
28
+ """Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported.
29
+ Args:
30
+ input (torch.Tensor): The input tensor to be quantized.
31
+ block_sizes (dict | None): The block sizes for quantization.
32
+ """
33
+
34
+ def cast_fp4(x):
35
+ sign = torch.sign(x)
36
+ sign_bit = (2 - sign) // 2
37
+ ord_ = torch.sum(
38
+ (x.abs().unsqueeze(-1) - cls.E2M1_bounds.to(x.device)) > 0, dim=-1
39
+ )
40
+ fp4_val = (sign_bit * 0b1000 + ord_).to(torch.uint8)
41
+ return fp4_val
42
+
43
+ def fuse_uint4_to_uint8(x):
44
+ # If the last dimension is odd, pad with zeros
45
+ # If this behavior is not desired, please modify the code accordingly
46
+ left_side = x[..., 0::2] # Even indices (0, 2, 4...)
47
+ right_side = x[..., 1::2] # Odd indices (1, 3, 5...)
48
+ new_data = (
49
+ right_side.clone() << 4
50
+ ) # Put odd indices (higher addresses) in high bits
51
+ new_data[
52
+ ..., : left_side.shape[-1]
53
+ ] += left_side # Put even indices in low bits
54
+ return new_data
55
+
56
+ if block_size is None:
57
+ block_size = 32
58
+
59
+ original_shape = input.shape
60
+ original_dtype = input.dtype
61
+ input = input.view(-1, block_size)
62
+ # get scales
63
+ input_amax = input.abs().max(dim=-1, keepdim=True).values
64
+ descale = input_amax / cls.E2M1_max
65
+ min_value = torch.tensor(-127.0, device=descale.device)
66
+ e8m0_scale = torch.ceil(torch.maximum(torch.log2(descale), min_value))
67
+
68
+ input = (input / torch.exp2(e8m0_scale)).view(original_shape)
69
+ input_q = cast_fp4(input)
70
+ input_q = fuse_uint4_to_uint8(input_q)
71
+ e8m0_scale = (e8m0_scale + 127).to(torch.uint8)
72
+ return cls(original_shape, original_dtype, input_q), e8m0_scale
73
+
74
+ @classmethod
75
+ def dequantize(cls, quantized_data, dtype: torch.dtype, scale, block_sizes):
76
+ """Dequantze MXFP4 packed tensor to a target dtype."""
77
+
78
+ def unfuse_uint8_to_uint4(x):
79
+ """Unfuse uint8 values back to uint4 values.
80
+ This is the inverse operation of fuse_uint4_to_uint8.
81
+ """
82
+ # Extract the lower 4 bits (even indices)
83
+ left_side = x & 0x0F
84
+
85
+ # Extract the upper 4 bits (odd indices)
86
+ right_side = (x >> 4) & 0x0F
87
+
88
+ # Create a new tensor with alternating values
89
+ shape = list(x.shape)
90
+ shape[-1] = shape[-1] * 2
91
+ result = torch.zeros(shape, dtype=torch.uint8, device=x.device)
92
+
93
+ # Fill in the values - even indices get low bits, odd indices get high bits
94
+ result[..., 0::2] = left_side # Even indices from low bits
95
+ result[..., 1::2] = right_side # Odd indices from high bits
96
+
97
+ return result
98
+
99
+ e8m0_scale = scale
100
+ block_size = block_sizes[-1]
101
+
102
+ # Unfuse the uint8 values back to uint4
103
+ x_unfused = unfuse_uint8_to_uint4(quantized_data)
104
+ # Extract sign and magnitude
105
+ sign = 1 - 2 * ((x_unfused & 0b1000) >> 3).to(
106
+ torch.float32
107
+ ) # Extract sign bit and convert to +1/-1
108
+ magnitude = x_unfused & 0b0111 # Extract magnitude bits
109
+ magnitude = magnitude.to(torch.long)
110
+
111
+ # Create a tensor with the E2M1 values
112
+ values = torch.tensor(cls.E2M1_values, device=quantized_data.device)
113
+
114
+ # Use gather to index the values tensor properly
115
+ # We need to reshape magnitude to match the dimensions we want to gather along
116
+ original_shape = magnitude.shape
117
+ x_float = values[magnitude.reshape(-1)].reshape(original_shape)
118
+
119
+ # Apply sign and scale
120
+ x_float = sign.float() * x_float
121
+
122
+ # Reshape to apply block-wise scaling
123
+ x_float = x_float.reshape(-1, block_size)
124
+
125
+ # Apply the E8M0 scale
126
+ scale_factor = torch.exp2(e8m0_scale.float() - 127)
127
+ scale_factor = scale_factor.reshape(-1, 1) # Reshape for proper broadcasting
128
+
129
+ # Apply scaling and reshape back to original shape
130
+ x_float = x_float * scale_factor
131
+
132
+ # Reshape back to the original shape
133
+ return x_float.reshape(original_shape).to(dtype)
@@ -0,0 +1,6 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from .quark_scheme import QuarkScheme
4
+ from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4
5
+
6
+ __all__ = ["QuarkScheme", "QuarkW4A4MXFP4"]
@@ -0,0 +1,55 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Optional
5
+
6
+ import torch
7
+
8
+ __all__ = ["QuarkScheme"]
9
+
10
+
11
+ class QuarkScheme(ABC):
12
+ """
13
+ Abstract class used to describe the weight creation and forward pass
14
+ of different quantization schemes supported by Quark.
15
+ """
16
+
17
+ @classmethod
18
+ @abstractmethod
19
+ def get_min_capability(cls) -> int:
20
+ """
21
+ Get minimum device capability.
22
+ """
23
+ raise NotImplementedError
24
+
25
+ @abstractmethod
26
+ def create_weights(self, *args, **kwargs):
27
+ """
28
+ Weight creation for the particular scheme. Inputs to this function
29
+
30
+ """
31
+ raise NotImplementedError
32
+
33
+ @abstractmethod
34
+ def apply_weights(
35
+ self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
36
+ ):
37
+ """
38
+ Run the forward pass for the particular scheme. This is where
39
+ scheme-specific dequant/quant steps/kernels should be applied.
40
+
41
+ :param layer: torch.nn.Module with the registered weights and
42
+ other parameters relevant to the particular scheme.
43
+ :param x: input to the layer
44
+ :param bias: bias parameter
45
+
46
+ """
47
+ raise NotImplementedError
48
+
49
+ @abstractmethod
50
+ def process_weights_after_loading(self, layer: torch.nn.Module):
51
+ """
52
+ Called after weight loading is complete for any cleanup that
53
+ needs to occur.
54
+ """
55
+ raise NotImplementedError
@@ -0,0 +1,118 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import Any, Callable, Optional
4
+
5
+ import aiter
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from aiter.ops.gemm_op_a4w4 import gemm_a4w4
9
+ from aiter.ops.shuffle import shuffle_weight
10
+ from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
11
+ from aiter.ops.triton.quant import dynamic_mxfp4_quant
12
+ from aiter.utility import dtypes
13
+ from aiter.utility.fp4_utils import e8m0_shuffle
14
+
15
+ from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
16
+ from sglang.srt.layers.quantization.quark.schemes import QuarkScheme
17
+ from sglang.srt.utils import get_bool_env_var
18
+
19
+ __all__ = ["QuarkW4A4MXFP4"]
20
+
21
+ OCP_MX_BLOCK_SIZE = 32
22
+
23
+
24
+ class QuarkW4A4MXFP4(QuarkScheme):
25
+
26
+ def __init__(
27
+ self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
28
+ ):
29
+ self.out_dtype = torch.get_default_dtype()
30
+ self.qscheme = "per_group"
31
+ self.weight_quant_spec = weight_quant_spec
32
+ self.input_quant_spec = input_quant_spec
33
+
34
+ @classmethod
35
+ def get_min_capability(cls) -> int:
36
+ return 70
37
+
38
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
39
+ return
40
+
41
+ # for aiter implement
42
+ # wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16))
43
+ # w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0)
44
+
45
+ # layer.weight = torch.nn.Parameter(wshuffle,
46
+ # requires_grad=False)
47
+ # layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
48
+ # requires_grad=False)
49
+
50
+ def create_weights(
51
+ self,
52
+ layer: torch.nn.Module,
53
+ output_partition_sizes: list[int],
54
+ input_size_per_partition: int,
55
+ params_dtype: torch.dtype,
56
+ weight_loader: Callable,
57
+ **kwargs
58
+ ):
59
+ output_size_per_partition = sum(output_partition_sizes)
60
+ layer.logical_widths = output_partition_sizes
61
+
62
+ # WEIGHT
63
+ weight = PackedvLLMParameter(
64
+ data=torch.empty(
65
+ output_size_per_partition,
66
+ input_size_per_partition // 2,
67
+ dtype=torch.uint8,
68
+ ),
69
+ input_dim=1,
70
+ output_dim=0,
71
+ packed_dim=1,
72
+ packed_factor=2,
73
+ weight_loader=weight_loader,
74
+ )
75
+ layer.register_parameter("weight", weight)
76
+
77
+ # WEIGHT SCALE
78
+ weight_scale = GroupQuantScaleParameter(
79
+ data=torch.empty(
80
+ output_size_per_partition,
81
+ input_size_per_partition // OCP_MX_BLOCK_SIZE,
82
+ dtype=torch.uint8,
83
+ ),
84
+ input_dim=1,
85
+ output_dim=0,
86
+ weight_loader=weight_loader,
87
+ )
88
+ layer.register_parameter("weight_scale", weight_scale)
89
+
90
+ def apply_weights(
91
+ self,
92
+ layer: torch.nn.Module,
93
+ x: torch.Tensor,
94
+ bias: Optional[torch.Tensor] = None,
95
+ ) -> torch.Tensor:
96
+
97
+ out_dtype = x.dtype
98
+ # M = x.shape[0]
99
+ # N = layer.weight.shape[0]
100
+
101
+ # quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
102
+ # x, x_scales_shuffle = quant_func(x, shuffle=True)
103
+
104
+ # y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=self.out_dtype)
105
+
106
+ # out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias)
107
+
108
+ # return out[:M]
109
+
110
+ # triton implement
111
+ x_q, x_s = dynamic_mxfp4_quant(x)
112
+ y = torch.empty(
113
+ x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype
114
+ )
115
+
116
+ out = gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y)
117
+
118
+ return out
@@ -0,0 +1,107 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import re
4
+ from collections.abc import Iterable, Mapping
5
+ from types import MappingProxyType
6
+ from typing import Any, Optional
7
+
8
+
9
+ def deep_compare(dict1: Any, dict2: Any) -> bool:
10
+ if type(dict1) is not type(dict2):
11
+ return False
12
+ if isinstance(dict1, dict):
13
+ if dict1.keys() != dict2.keys():
14
+ return False
15
+ return all(deep_compare(dict1[k], dict2[k]) for k in dict1)
16
+ elif isinstance(dict1, list):
17
+ return set(dict1) == set(dict2)
18
+ else:
19
+ return dict1 == dict2
20
+
21
+
22
+ def should_ignore_layer(
23
+ layer_name: Optional[str],
24
+ ignore: Iterable[str],
25
+ fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
26
+ ) -> bool:
27
+ if layer_name is None:
28
+ return False
29
+
30
+ # layer_name = model.layers.0.self_attn.qkv_proj
31
+ # proj_name = qkv_proj
32
+ proj_name = layer_name.split(".")[-1]
33
+
34
+ # Fused layers like gate_up_proj or qkv_proj will not be fused
35
+ # in the safetensors checkpoint. So, we convert the name
36
+ # from the fused version to unfused + check to make sure that
37
+ # each shard of the fused layer has the same scheme.
38
+ if proj_name in fused_mapping:
39
+ shard_proj_names = fused_mapping[proj_name]
40
+
41
+ # Convert fused_name --> [shard_names]
42
+ shard_names = [
43
+ layer_name.replace(proj_name, shard_proj_name)
44
+ for shard_proj_name in shard_proj_names
45
+ ]
46
+
47
+ # Layer should be ignored if shards are ignored.
48
+ should_ignore_layer = None
49
+ for shard_name in shard_names:
50
+ should_ignore_shard = check_equal_or_regex_match(
51
+ layer_name=shard_name, targets=ignore
52
+ )
53
+
54
+ # If shard_idx=0, set layer ignore to match shard.
55
+ if should_ignore_layer is None:
56
+ should_ignore_layer = should_ignore_shard
57
+
58
+ # If shard_idx=1+ confirm scheme matches prior shards.
59
+ elif should_ignore_shard != should_ignore_layer:
60
+ raise ValueError(
61
+ f"Found a different quantization schemes for "
62
+ f"{shard_proj_names} in {layer_name}. vLLM "
63
+ "requires all to use the same scheme."
64
+ )
65
+
66
+ # Unfused layers like down_proj and o_proj will match
67
+ # the safetensors checkpoint already.
68
+ else:
69
+ should_ignore_layer = check_equal_or_regex_match(
70
+ layer_name=layer_name, targets=ignore
71
+ )
72
+
73
+ assert should_ignore_layer is not None
74
+
75
+ return should_ignore_layer
76
+
77
+
78
+ def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
79
+ """
80
+ Checks whether a layer_name is exactly equal or a regex match for
81
+ if target starts with 're:' to any target in list.
82
+ """
83
+ for target in targets:
84
+ if _is_equal_or_regex_match(layer_name, target):
85
+ return True
86
+ return False
87
+
88
+
89
+ def _is_equal_or_regex_match(
90
+ value: str, target: str, check_contains: bool = False
91
+ ) -> bool:
92
+ """
93
+ Checks whether a value is exactly equal or a regex match for target
94
+ if target starts with 're:'. If check_contains is set to True,
95
+ additionally checks if the target string is contained within the value.
96
+ """
97
+
98
+ if target.startswith("re:"):
99
+ pattern = target[3:]
100
+ if re.match(pattern, value):
101
+ return True
102
+ elif check_contains:
103
+ if target.lower() in value.lower():
104
+ return True
105
+ elif target == value:
106
+ return True
107
+ return False
@@ -129,14 +129,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
129
129
  def __init__(self, use_triton_kernels: bool = False):
130
130
  super().__init__()
131
131
  self.use_triton_kernels = use_triton_kernels
132
+ self.with_bias = False
132
133
 
133
134
  self.triton_kernel_moe_forward = None
135
+ self.triton_kernel_moe_with_bias_forward = None
134
136
  if torch.cuda.is_available() and has_triton_kernels:
135
137
  from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
136
138
  triton_kernel_moe_forward as _tk_forward,
137
139
  )
140
+ from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
141
+ triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
142
+ )
138
143
 
139
144
  self.triton_kernel_moe_forward = _tk_forward
145
+ self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
140
146
 
141
147
  def create_weights(
142
148
  self,
@@ -145,8 +151,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
145
151
  hidden_size: int,
146
152
  intermediate_size: int,
147
153
  params_dtype: torch.dtype,
154
+ with_bias: bool = False,
148
155
  **extra_weight_attrs,
149
156
  ):
157
+ self.with_bias = with_bias
158
+
150
159
  # Fused gate_up_proj (column parallel)
151
160
  w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
152
161
  if self.use_triton_kernels:
@@ -158,6 +167,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
158
167
  layer.register_parameter("w13_weight", w13_weight)
159
168
  set_weight_attrs(w13_weight, extra_weight_attrs)
160
169
 
170
+ if self.with_bias:
171
+ w13_weight_bias = torch.nn.Parameter(
172
+ torch.empty(num_experts, 2 * intermediate_size, dtype=torch.float32),
173
+ requires_grad=False,
174
+ )
175
+ layer.register_parameter("w13_weight_bias", w13_weight_bias)
176
+ set_weight_attrs(w13_weight_bias, extra_weight_attrs)
177
+
161
178
  # down_proj (row parallel)
162
179
  w2_weight_n, w2_weight_k = (
163
180
  hidden_size,
@@ -172,6 +189,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
172
189
  layer.register_parameter("w2_weight", w2_weight)
173
190
  set_weight_attrs(w2_weight, extra_weight_attrs)
174
191
 
192
+ if self.with_bias:
193
+ w2_weight_bias = torch.nn.Parameter(
194
+ torch.empty(num_experts, hidden_size, dtype=torch.float32),
195
+ requires_grad=False,
196
+ )
197
+ layer.register_parameter("w2_weight_bias", w2_weight_bias)
198
+ set_weight_attrs(w2_weight_bias, extra_weight_attrs)
199
+
175
200
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
176
201
  if _use_aiter:
177
202
  layer.w13_weight = torch.nn.Parameter(
@@ -202,7 +227,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
202
227
  inplace: bool = True,
203
228
  no_combine: bool = False,
204
229
  routed_scaling_factor: Optional[float] = None,
230
+ activation_alpha: Optional[float] = None,
231
+ swiglu_limit: Optional[float] = None,
205
232
  ) -> torch.Tensor:
233
+ kwargs = {}
234
+ if activation_alpha is not None:
235
+ kwargs["activation_alpha"] = activation_alpha
236
+ if swiglu_limit is not None:
237
+ kwargs["swiglu_limit"] = swiglu_limit
206
238
 
207
239
  return self.forward(
208
240
  x=x,
@@ -213,6 +245,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
213
245
  inplace=inplace,
214
246
  no_combine=no_combine,
215
247
  routed_scaling_factor=routed_scaling_factor,
248
+ **kwargs,
216
249
  )
217
250
 
218
251
  def forward_cuda(
@@ -226,15 +259,32 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
226
259
  inplace: bool = True,
227
260
  no_combine: bool = False,
228
261
  routed_scaling_factor: Optional[float] = None,
262
+ activation_alpha: Optional[float] = None,
263
+ swiglu_limit: Optional[float] = None,
229
264
  ) -> torch.Tensor:
230
265
 
231
266
  if self.use_triton_kernels:
232
- return self.triton_kernel_moe_forward(
233
- hidden_states=x,
234
- w1=layer.w13_weight,
235
- w2=layer.w2_weight,
236
- topk_output=topk_output,
237
- )
267
+ if self.with_bias:
268
+ return self.triton_kernel_moe_with_bias_forward(
269
+ hidden_states=x,
270
+ w1=layer.w13_weight,
271
+ w2=layer.w2_weight,
272
+ b1=layer.w13_weight_bias,
273
+ b2=layer.w2_weight_bias,
274
+ topk_output=topk_output,
275
+ activation=activation,
276
+ activation_alpha=activation_alpha,
277
+ swiglu_limit=swiglu_limit,
278
+ w1_pcg=None,
279
+ w2_pcg=None,
280
+ )
281
+ else:
282
+ return self.triton_kernel_moe_forward(
283
+ hidden_states=x,
284
+ w1=layer.w13_weight,
285
+ w2=layer.w2_weight,
286
+ topk_output=topk_output,
287
+ )
238
288
  else:
239
289
  if _use_aiter:
240
290
  assert not no_combine, "unsupported"
@@ -272,12 +322,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
272
322
  hidden_states=x,
273
323
  w1=layer.w13_weight,
274
324
  w2=layer.w2_weight,
325
+ b1=getattr(layer, "w13_weight_bias", None),
326
+ b2=getattr(layer, "w2_weight_bias", None),
275
327
  topk_output=topk_output,
276
328
  inplace=inplace and not no_combine,
277
329
  activation=activation,
278
330
  apply_router_weight_on_input=apply_router_weight_on_input,
279
331
  no_combine=no_combine,
280
332
  routed_scaling_factor=routed_scaling_factor,
333
+ activation_alpha=activation_alpha,
334
+ swiglu_limit=swiglu_limit,
281
335
  )
282
336
 
283
337
  def forward_cpu(
@@ -116,6 +116,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
116
116
  params_dtype: torch.dtype,
117
117
  **extra_weight_attrs,
118
118
  ):
119
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
120
+
119
121
  assert "weight_loader" in extra_weight_attrs
120
122
 
121
123
  # Fused gate_up_proj (column parallel)
@@ -144,6 +146,9 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
144
146
  layer.register_parameter("w2_weight", w2_weight)
145
147
  set_weight_attrs(w2_weight, extra_weight_attrs)
146
148
 
149
+ extra_weight_attrs.update(
150
+ {"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
151
+ )
147
152
  w13_weight_scale = torch.nn.Parameter(
148
153
  torch.zeros(
149
154
  num_experts,
@@ -274,29 +279,30 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
274
279
  def apply(
275
280
  self,
276
281
  layer: EPMoE,
277
- hidden_states: torch.Tensor,
282
+ x: torch.Tensor,
278
283
  topk_output: TopKOutput,
284
+ activation: str = "silu",
285
+ apply_router_weight_on_input: bool = False,
286
+ routed_scaling_factor: Optional[float] = None,
279
287
  **kwargs,
280
288
  ) -> torch.Tensor:
281
289
 
282
290
  # TODO(ch-wan): move it out of this class
283
291
  from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
284
292
 
285
- topk_ids, topk_weights, _ = topk_output
293
+ topk_weights, topk_ids, _ = topk_output
286
294
  local_topk_ids = topk_ids
287
- if layer.expert_map is not None:
288
- "Translate info from expert_map to topk_ids"
289
- local_topk_ids = torch.where(
290
- layer.expert_map[topk_ids] != layer.num_experts,
291
- layer.expert_map[topk_ids],
292
- layer.num_experts,
293
- )
294
-
295
- return cutlass_w4a8_moe(
295
+ local_topk_ids = torch.where(
296
+ topk_ids == -1,
297
+ layer.num_experts,
298
+ topk_ids,
299
+ )
300
+
301
+ output = cutlass_w4a8_moe(
296
302
  layer.start_expert_id,
297
303
  layer.end_expert_id,
298
304
  layer.num_experts,
299
- hidden_states,
305
+ x,
300
306
  layer.w13_weight,
301
307
  layer.w2_weight,
302
308
  layer.w13_weight_scale_inv,
@@ -318,3 +324,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
318
324
  layer.w13_input_scale,
319
325
  layer.w2_input_scale,
320
326
  )
327
+ if routed_scaling_factor is not None:
328
+ output *= routed_scaling_factor
329
+ return output