sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,658 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import enum
5
+ import logging
6
+ from enum import Enum
7
+ from typing import Callable, List, Optional
8
+
9
+ import torch
10
+ from compressed_tensors import CompressionFormat
11
+ from compressed_tensors.quantization import QuantizationStrategy
12
+
13
+ from sglang.srt.layers.moe.fused_moe_triton import (
14
+ FusedMoE,
15
+ FusedMoEMethodBase,
16
+ FusedMoeWeightScaleSupported,
17
+ )
18
+ from sglang.srt.layers.moe.topk import select_experts
19
+ from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
20
+ from sglang.srt.layers.quantization.utils import (
21
+ all_close_1d,
22
+ is_cuda,
23
+ is_fp8_fnuz,
24
+ per_tensor_dequantize,
25
+ replace_parameter,
26
+ )
27
+ from sglang.srt.utils import set_weight_attrs
28
+
29
+ _is_cuda = is_cuda()
30
+
31
+ if _is_cuda:
32
+ from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
33
+ else:
34
+ from vllm import _custom_ops as vllm_ops
35
+
36
+ try:
37
+ import vllm
38
+
39
+ VLLM_AVAILABLE = True
40
+ except ImportError:
41
+ VLLM_AVAILABLE = False
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ class GPTQMarlinState(Enum):
47
+ REPACK = enum.auto()
48
+ READY = enum.auto()
49
+
50
+
51
+ __all__ = [
52
+ "CompressedTensorsMoEMethod",
53
+ "CompressedTensorsW8A8Fp8MoEMethod",
54
+ "CompressedTensorsWNA16MoEMethod",
55
+ ]
56
+
57
+
58
+ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
59
+
60
+ @staticmethod
61
+ def get_moe_method(
62
+ quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
63
+ ) -> "CompressedTensorsMoEMethod":
64
+ # TODO: @dsikka: refactor this to use schemes as other kernels
65
+ # are supported + check if the layer is being ignored.
66
+ weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
67
+ input_quant = quant_config.target_scheme_map["Linear"].get("input_activations")
68
+
69
+ if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
70
+ if not VLLM_AVAILABLE:
71
+ raise ImportError(
72
+ "vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm"
73
+ )
74
+ return CompressedTensorsWNA16MoEMethod(quant_config)
75
+ elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
76
+ return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
77
+ else:
78
+ raise RuntimeError(
79
+ f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}"
80
+ )
81
+
82
+
83
+ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
84
+
85
+ def __init__(
86
+ self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
87
+ ):
88
+ self.quant_config = quant_config
89
+ self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
90
+ self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
91
+ "input_activations"
92
+ )
93
+
94
+ if not (
95
+ self.weight_quant.strategy == QuantizationStrategy.TENSOR
96
+ and self.input_quant.strategy == QuantizationStrategy.TENSOR
97
+ ):
98
+ raise ValueError(
99
+ "For FP8 Fused MoE layers, only per-tensor scales "
100
+ "for weights and activations are supported. Found "
101
+ f"{self.weight_quant}, {self.input_quant}"
102
+ )
103
+
104
+ self.static_input_scales = not self.input_quant.dynamic
105
+
106
+ def create_weights(
107
+ self,
108
+ layer: torch.nn.Module,
109
+ num_experts: int,
110
+ hidden_size: int,
111
+ intermediate_size_per_partition: int,
112
+ params_dtype: torch.dtype,
113
+ **extra_weight_attrs,
114
+ ):
115
+
116
+ params_dtype = torch.float8_e4m3fn
117
+
118
+ # WEIGHTS
119
+ w13_weight = torch.nn.Parameter(
120
+ torch.empty(
121
+ num_experts,
122
+ 2 * intermediate_size_per_partition,
123
+ hidden_size,
124
+ dtype=params_dtype,
125
+ ),
126
+ requires_grad=False,
127
+ )
128
+ layer.register_parameter("w13_weight", w13_weight)
129
+ set_weight_attrs(w13_weight, extra_weight_attrs)
130
+
131
+ w2_weight = torch.nn.Parameter(
132
+ torch.empty(
133
+ num_experts,
134
+ hidden_size,
135
+ intermediate_size_per_partition,
136
+ dtype=params_dtype,
137
+ ),
138
+ requires_grad=False,
139
+ )
140
+ layer.register_parameter("w2_weight", w2_weight)
141
+ set_weight_attrs(w2_weight, extra_weight_attrs)
142
+
143
+ # WEIGHT_SCALES
144
+ # Allocate 2 scales for w1 and w3 respectively.
145
+ # They will be combined to a single scale after weight loading.
146
+ w13_weight_scale = torch.nn.Parameter(
147
+ torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
148
+ )
149
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
150
+
151
+ w2_weight_scale = torch.nn.Parameter(
152
+ torch.ones(num_experts, dtype=torch.float32), requires_grad=False
153
+ )
154
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
155
+ # Add the quantization method used (per tensor/grouped/channel)
156
+ # to ensure the weight scales are loaded in properly
157
+ extra_weight_attrs.update(
158
+ {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
159
+ )
160
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
161
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
162
+
163
+ # INPUT_SCALES
164
+ if self.static_input_scales:
165
+ w13_input_scale = torch.nn.Parameter(
166
+ torch.ones(num_experts, dtype=torch.float32), requires_grad=False
167
+ )
168
+ layer.register_parameter("w13_input_scale", w13_input_scale)
169
+ set_weight_attrs(w13_input_scale, extra_weight_attrs)
170
+
171
+ w2_input_scale = torch.nn.Parameter(
172
+ torch.ones(num_experts, dtype=torch.float32), requires_grad=False
173
+ )
174
+ layer.register_parameter("w2_input_scale", w2_input_scale)
175
+ set_weight_attrs(w2_input_scale, extra_weight_attrs)
176
+ else:
177
+ layer.w13_input_scale = None
178
+ layer.w2_input_scale = None
179
+
180
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
181
+ # Fp8 moe kernels require a single activation scale.
182
+ # We take the max of all the scales in case they differ.
183
+ if self.static_input_scales:
184
+ if layer.w13_input_scale is None or layer.w2_input_scale is None:
185
+ raise ValueError(
186
+ "QuantConfig has static quantization, but found "
187
+ "activation scales are None."
188
+ )
189
+ if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
190
+ layer.w2_input_scale
191
+ ):
192
+ logger.warning(
193
+ "Found input_scales that are not equal for "
194
+ "fp8 MoE layer. Using the maximum across experts "
195
+ "for each layer."
196
+ )
197
+ layer.w13_input_scale = torch.nn.Parameter(
198
+ layer.w13_input_scale.max(), requires_grad=False
199
+ )
200
+ layer.w2_input_scale = torch.nn.Parameter(
201
+ layer.w2_input_scale.max(), requires_grad=False
202
+ )
203
+
204
+ if is_fp8_fnuz():
205
+ # Normalize the weights and scales
206
+ w13_weight, w13_weight_scale, w13_input_scale = (
207
+ normalize_e4m3fn_to_e4m3fnuz(
208
+ layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
209
+ )
210
+ )
211
+ w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
212
+ layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
213
+ )
214
+ # Reset the parameter
215
+ layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
216
+ layer.w13_weight_scale = torch.nn.Parameter(
217
+ w13_weight_scale, requires_grad=False
218
+ )
219
+ if w13_input_scale is not None:
220
+ layer.w13_input_scale = torch.nn.Parameter(
221
+ w13_input_scale, requires_grad=False
222
+ )
223
+ layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
224
+ layer.w2_weight_scale = torch.nn.Parameter(
225
+ w2_weight_scale, requires_grad=False
226
+ )
227
+ if w2_input_scale is not None:
228
+ layer.w2_input_scale = torch.nn.Parameter(
229
+ w2_input_scale, requires_grad=False
230
+ )
231
+
232
+ # Fp8 moe kernel needs single weight scale for w13 per expert.
233
+ # We take the max then dequant and requant each expert.
234
+ assert layer.w13_weight_scale is not None
235
+ shard_size = layer.intermediate_size_per_partition
236
+ max_w13_scales = layer.w13_weight_scale.max(dim=1).values
237
+ for expert_id in range(layer.local_num_experts):
238
+ start = 0
239
+ for shard_id in range(2):
240
+ dq_weight = per_tensor_dequantize(
241
+ layer.w13_weight[expert_id][start : start + shard_size, :],
242
+ layer.w13_weight_scale[expert_id][shard_id],
243
+ )
244
+
245
+ if _is_cuda:
246
+ layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
247
+ sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
248
+ )
249
+ else:
250
+ layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
251
+ vllm_ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
252
+ )
253
+ start += shard_size
254
+
255
+ layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
256
+
257
+ def apply(
258
+ self,
259
+ layer: torch.nn.Module,
260
+ x: torch.Tensor,
261
+ router_logits: torch.Tensor,
262
+ top_k: int,
263
+ renormalize: bool,
264
+ use_grouped_topk: bool = False,
265
+ topk_group: Optional[int] = None,
266
+ num_expert_group: Optional[int] = None,
267
+ global_num_experts: int = -1,
268
+ expert_map: Optional[torch.Tensor] = None,
269
+ custom_routing_function: Optional[Callable] = None,
270
+ scoring_func: str = "softmax",
271
+ correction_bias: Optional[torch.Tensor] = None,
272
+ activation: str = "silu",
273
+ ) -> torch.Tensor:
274
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
275
+
276
+ topk_weights, topk_ids = select_experts(
277
+ hidden_states=x,
278
+ router_logits=router_logits,
279
+ use_grouped_topk=use_grouped_topk,
280
+ top_k=top_k,
281
+ renormalize=renormalize,
282
+ topk_group=topk_group,
283
+ num_expert_group=num_expert_group,
284
+ custom_routing_function=custom_routing_function,
285
+ correction_bias=correction_bias,
286
+ )
287
+
288
+ return fused_experts(
289
+ x,
290
+ layer.w13_weight,
291
+ layer.w2_weight,
292
+ topk_weights=topk_weights,
293
+ topk_ids=topk_ids,
294
+ inplace=True,
295
+ activation=activation,
296
+ use_fp8_w8a8=True,
297
+ w1_scale=layer.w13_weight_scale,
298
+ w2_scale=layer.w2_weight_scale,
299
+ a1_scale=layer.w13_input_scale,
300
+ a2_scale=layer.w2_input_scale,
301
+ )
302
+
303
+
304
+ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
305
+
306
+ def __init__(
307
+ self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
308
+ ):
309
+ self.quant_config = quant_config
310
+ # TODO: @dsikka: refactor this to use schemes as other kernels
311
+ # are supported + check if the layer is being ignored.
312
+ config = self.quant_config.target_scheme_map["Linear"].get("weights")
313
+ self.num_bits = config.num_bits
314
+ self.packed_factor = 32 // config.num_bits
315
+ self.strategy = config.strategy
316
+ self.group_size = config.group_size
317
+ self.actorder = config.actorder
318
+ assert config.symmetric, "Only symmetric quantization is supported for MoE"
319
+
320
+ if not (
321
+ self.quant_config.quant_format == CompressionFormat.pack_quantized.value
322
+ and self.num_bits in WNA16_SUPPORTED_BITS
323
+ ):
324
+ raise ValueError(
325
+ "For Fused MoE layers, only ",
326
+ f"{CompressionFormat.pack_quantized.value} ",
327
+ "is supported for the following bits: ",
328
+ f"{WNA16_SUPPORTED_BITS}",
329
+ )
330
+
331
+ def create_weights(
332
+ self,
333
+ layer: torch.nn.Module,
334
+ num_experts: int,
335
+ hidden_size: int,
336
+ intermediate_size_per_partition: int,
337
+ params_dtype: torch.dtype,
338
+ **extra_weight_attrs,
339
+ ):
340
+
341
+ assert (
342
+ params_dtype == torch.float16
343
+ ), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
344
+
345
+ intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
346
+
347
+ # Will transpose the loaded weight along the
348
+ # intermediate and hidden dim sizes. Will
349
+ # shard for TP along the transposed dims
350
+ extra_weight_attrs.update(
351
+ {"is_transposed": True, "quant_method": self.strategy}
352
+ )
353
+ w13_weight = torch.nn.Parameter(
354
+ torch.empty(
355
+ num_experts,
356
+ hidden_size // self.packed_factor,
357
+ 2 * intermediate_size_per_partition,
358
+ dtype=torch.int32,
359
+ ),
360
+ requires_grad=False,
361
+ )
362
+ layer.register_parameter("w13_weight_packed", w13_weight)
363
+ set_weight_attrs(w13_weight, extra_weight_attrs)
364
+
365
+ w2_weight = torch.nn.Parameter(
366
+ torch.empty(
367
+ num_experts,
368
+ intermediate_size_per_partition // self.packed_factor,
369
+ hidden_size,
370
+ dtype=torch.int32,
371
+ ),
372
+ requires_grad=False,
373
+ )
374
+ layer.register_parameter("w2_weight_packed", w2_weight)
375
+ set_weight_attrs(w2_weight, extra_weight_attrs)
376
+
377
+ # In the case where we have actorder/g_idx,
378
+ # we do not partition the w2 scales
379
+ load_full_w2 = self.actorder and self.group_size != -1
380
+ w2_scales_size = (
381
+ intermediate_size_full if load_full_w2 else intermediate_size_per_partition
382
+ )
383
+
384
+ self.is_k_full = (not self.actorder) or (
385
+ intermediate_size_per_partition == intermediate_size_full
386
+ )
387
+
388
+ if self.strategy == "channel":
389
+ num_groups_w2 = num_groups_w13 = 1
390
+ self.group_size = -1
391
+ else:
392
+ num_groups_w2 = w2_scales_size // self.group_size
393
+ num_groups_w13 = hidden_size // self.group_size
394
+
395
+ w13_scale = torch.nn.Parameter(
396
+ torch.ones(
397
+ num_experts,
398
+ num_groups_w13,
399
+ 2 * intermediate_size_per_partition,
400
+ dtype=params_dtype,
401
+ ),
402
+ requires_grad=False,
403
+ )
404
+ layer.register_parameter("w13_weight_scale", w13_scale)
405
+ set_weight_attrs(w13_scale, extra_weight_attrs)
406
+
407
+ w2_scale = torch.nn.Parameter(
408
+ torch.ones(num_experts, num_groups_w2, hidden_size, dtype=params_dtype),
409
+ requires_grad=False,
410
+ )
411
+ layer.register_parameter("w2_weight_scale", w2_scale)
412
+ set_weight_attrs(w2_scale, extra_weight_attrs)
413
+ set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2})
414
+
415
+ w2_weight_shape = torch.nn.Parameter(
416
+ torch.empty(num_experts, 2), requires_grad=False
417
+ )
418
+ layer.register_parameter("w2_weight_shape", w2_weight_shape)
419
+ set_weight_attrs(w2_weight_shape, extra_weight_attrs)
420
+ w13_weight_shape = torch.nn.Parameter(
421
+ torch.empty(num_experts, 2), requires_grad=False
422
+ )
423
+
424
+ layer.register_parameter("w13_weight_shape", w13_weight_shape)
425
+ set_weight_attrs(w13_weight_shape, extra_weight_attrs)
426
+
427
+ w13_g_idx = torch.nn.Parameter(
428
+ torch.empty(
429
+ num_experts,
430
+ hidden_size,
431
+ dtype=torch.int32,
432
+ ),
433
+ requires_grad=False,
434
+ )
435
+ layer.register_parameter("w13_weight_g_idx", w13_g_idx)
436
+ set_weight_attrs(w13_g_idx, extra_weight_attrs)
437
+
438
+ w2_g_idx = torch.nn.Parameter(
439
+ torch.empty(
440
+ num_experts,
441
+ intermediate_size_per_partition,
442
+ dtype=torch.int32,
443
+ ),
444
+ requires_grad=False,
445
+ )
446
+ layer.register_parameter("w2_weight_g_idx", w2_g_idx)
447
+ set_weight_attrs(w2_g_idx, extra_weight_attrs)
448
+
449
+ w13_g_idx_sort_indices = torch.nn.Parameter(
450
+ torch.empty(
451
+ num_experts,
452
+ hidden_size,
453
+ dtype=torch.int32,
454
+ ),
455
+ requires_grad=False,
456
+ )
457
+ layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
458
+ set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
459
+
460
+ w2_g_idx_sort_indices = torch.nn.Parameter(
461
+ torch.empty(
462
+ num_experts,
463
+ intermediate_size_per_partition,
464
+ dtype=torch.int32,
465
+ ),
466
+ requires_grad=False,
467
+ )
468
+ layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
469
+ set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
470
+
471
+ layer.a13_scale = None
472
+ layer.a2_scale = None
473
+ layer.marlin_state = GPTQMarlinState.REPACK
474
+
475
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
476
+
477
+ def replace_tensor(name, new_t):
478
+ # It is important to use resize_() here since it ensures
479
+ # the same buffer is reused
480
+ getattr(layer, name).resize_(new_t.shape)
481
+ getattr(layer, name).copy_(new_t)
482
+ del new_t
483
+
484
+ def get_scale_perms(num_bits: int):
485
+ scale_perm: List[int] = []
486
+ for i in range(8):
487
+ scale_perm.extend([i + 8 * j for j in range(8)])
488
+ scale_perm_single: List[int] = []
489
+ for i in range(4):
490
+ scale_perm_single.extend(
491
+ [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]
492
+ )
493
+ return scale_perm, scale_perm_single
494
+
495
+ def marlin_permute_scales(
496
+ s: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int
497
+ ):
498
+ scale_perm, scale_perm_single = get_scale_perms(num_bits)
499
+ if group_size < size_k and group_size != -1:
500
+ s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
501
+ else:
502
+ s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
503
+ s = s.reshape((-1, size_n)).contiguous()
504
+ return s
505
+
506
+ def marlin_moe_permute_scales(
507
+ s: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int
508
+ ):
509
+ num_experts = s.shape[0]
510
+ output = torch.empty(
511
+ (num_experts, s.shape[1], s.shape[2]), device=s.device, dtype=s.dtype
512
+ )
513
+ for e in range(num_experts):
514
+ output[e] = marlin_permute_scales(
515
+ s[e], size_k, size_n, group_size, num_bits
516
+ )
517
+ return output
518
+
519
+ size_k2 = layer.w2_weight_packed.shape[2]
520
+ size_k13 = layer.w13_weight_packed.shape[2]
521
+
522
+ num_experts = layer.w13_weight_g_idx.shape[0]
523
+ device = layer.w13_weight_g_idx.device
524
+
525
+ # when running models with grouped act order,
526
+ # resort to g_idx values provided in checkpoint
527
+ if self.actorder == "group":
528
+ w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx)
529
+ w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx)
530
+ w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx)
531
+ w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx)
532
+
533
+ for e in range(num_experts):
534
+ w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_weight_g_idx[e]).to(
535
+ torch.int32
536
+ )
537
+ w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_weight_g_idx[e]).to(
538
+ torch.int32
539
+ )
540
+ w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][
541
+ w13_g_idx_sort_indices[e]
542
+ ]
543
+ w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][w2_g_idx_sort_indices[e]]
544
+
545
+ replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx)
546
+ replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx)
547
+ replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
548
+ replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
549
+
550
+ else:
551
+ layer.w13_weight_g_idx = torch.nn.Parameter(
552
+ torch.empty((num_experts, 0), dtype=torch.int32, device=device),
553
+ requires_grad=False,
554
+ )
555
+ layer.w2_weight_g_idx = torch.nn.Parameter(
556
+ torch.empty((num_experts, 0), dtype=torch.int32, device=device),
557
+ requires_grad=False,
558
+ )
559
+ layer.w13_g_idx_sort_indices = torch.nn.Parameter(
560
+ torch.empty((num_experts, 0), dtype=torch.int32, device=device),
561
+ requires_grad=False,
562
+ )
563
+ layer.w2_g_idx_sort_indices = torch.nn.Parameter(
564
+ torch.empty((num_experts, 0), dtype=torch.int32, device=device),
565
+ requires_grad=False,
566
+ )
567
+
568
+ marlin_w13_qweight = ops.gptq_marlin_moe_repack(
569
+ layer.w13_weight_packed,
570
+ layer.w13_g_idx_sort_indices,
571
+ layer.w13_weight_packed.shape[1] * self.packed_factor,
572
+ layer.w13_weight_packed.shape[2],
573
+ self.num_bits,
574
+ )
575
+ replace_tensor("w13_weight_packed", marlin_w13_qweight)
576
+ marlin_w2_qweight = ops.gptq_marlin_moe_repack(
577
+ layer.w2_weight_packed,
578
+ layer.w2_g_idx_sort_indices,
579
+ layer.w2_weight_packed.shape[1] * self.packed_factor,
580
+ layer.w2_weight_packed.shape[2],
581
+ self.num_bits,
582
+ )
583
+ replace_tensor("w2_weight_packed", marlin_w2_qweight)
584
+ # Repack scales
585
+ marlin_w13_scales = marlin_moe_permute_scales(
586
+ layer.w13_weight_scale,
587
+ size_k13,
588
+ layer.w13_weight_scale.shape[2],
589
+ self.group_size,
590
+ self.num_bits,
591
+ )
592
+ replace_tensor("w13_weight_scale", marlin_w13_scales)
593
+ marlin_w2_scales = marlin_moe_permute_scales(
594
+ layer.w2_weight_scale,
595
+ layer.w2_weight_scale.shape[1]
596
+ * (self.group_size if self.group_size != -1 else self.packed_factor),
597
+ size_k2,
598
+ self.group_size,
599
+ self.num_bits,
600
+ )
601
+ replace_tensor("w2_weight_scale", marlin_w2_scales)
602
+
603
+ def apply(
604
+ self,
605
+ layer: torch.nn.Module,
606
+ x: torch.Tensor,
607
+ router_logits: torch.Tensor,
608
+ top_k: int,
609
+ renormalize: bool,
610
+ use_grouped_topk: bool = False,
611
+ topk_group: Optional[int] = None,
612
+ num_expert_group: Optional[int] = None,
613
+ global_num_experts: int = -1,
614
+ expert_map: Optional[torch.Tensor] = None,
615
+ custom_routing_function: Optional[Callable] = None,
616
+ scoring_func: str = "softmax",
617
+ correction_bias: Optional[torch.Tensor] = None,
618
+ activation: str = "silu",
619
+ ) -> torch.Tensor:
620
+ assert activation == "silu", "Only SiLU activation is supported."
621
+ if not VLLM_AVAILABLE:
622
+ raise ImportError(
623
+ "vllm is not installed, to use fused_marlin_moe, please install vllm"
624
+ )
625
+ if expert_map is not None:
626
+ raise NotImplementedError(
627
+ "Expert Parallelism is not supported for " "fused Marlin MoE method."
628
+ )
629
+
630
+ topk_weights, topk_ids = select_experts(
631
+ hidden_states=x,
632
+ router_logits=router_logits,
633
+ use_grouped_topk=use_grouped_topk,
634
+ top_k=top_k,
635
+ renormalize=renormalize,
636
+ topk_group=topk_group,
637
+ num_expert_group=num_expert_group,
638
+ custom_routing_function=custom_routing_function,
639
+ scoring_func=scoring_func,
640
+ correction_bias=correction_bias,
641
+ )
642
+
643
+ return torch.ops.vllm.fused_marlin_moe(
644
+ x,
645
+ layer.w13_weight_packed,
646
+ layer.w2_weight_packed,
647
+ layer.w13_weight_scale,
648
+ layer.w2_weight_scale,
649
+ router_logits,
650
+ topk_weights,
651
+ topk_ids,
652
+ g_idx1=layer.w13_weight_g_idx,
653
+ g_idx2=layer.w2_weight_g_idx,
654
+ sort_indices1=layer.w13_g_idx_sort_indices,
655
+ sort_indices2=layer.w2_g_idx_sort_indices,
656
+ num_bits=self.num_bits,
657
+ is_k_full=self.is_k_full,
658
+ )
@@ -0,0 +1,9 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from .compressed_tensors_scheme import CompressedTensorsScheme
4
+ from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
5
+
6
+ __all__ = [
7
+ "CompressedTensorsScheme",
8
+ "CompressedTensorsW8A8Fp8",
9
+ ]