sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/lang/chat_template.py +24 -0
- sglang/srt/configs/model_config.py +40 -4
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +29 -4
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +609 -202
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +28 -14
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +3 -0
- sglang/srt/layers/radix_attention.py +14 -0
- sglang/srt/layers/rotary_embedding.py +75 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +49 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +13 -4
- sglang/srt/models/llama4.py +487 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +227 -0
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -77,6 +77,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
77
77
|
sparsity_ignore_list: List[str],
|
78
78
|
kv_cache_scheme: Optional[Dict[str, Any]] = None,
|
79
79
|
config: Optional[Dict[str, Any]] = None,
|
80
|
+
packed_modules_mapping: Dict[str, List[str]] = {},
|
80
81
|
):
|
81
82
|
super().__init__()
|
82
83
|
self.ignore = ignore
|
@@ -87,6 +88,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
87
88
|
self.sparsity_scheme_map = sparsity_scheme_map
|
88
89
|
self.sparsity_ignore_list = sparsity_ignore_list
|
89
90
|
self.config = config
|
91
|
+
self.packed_modules_mapping = packed_modules_mapping
|
90
92
|
|
91
93
|
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
92
94
|
return CompressedTensorsLinearMethod(self)
|
@@ -136,6 +138,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
136
138
|
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
|
137
139
|
config=config
|
138
140
|
)
|
141
|
+
packed_modules_mapping = config.get("packed_modules_mapping", {})
|
139
142
|
|
140
143
|
return cls(
|
141
144
|
target_scheme_map=target_scheme_map,
|
@@ -144,6 +147,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
144
147
|
sparsity_scheme_map=sparsity_scheme_map,
|
145
148
|
sparsity_ignore_list=sparsity_ignore_list,
|
146
149
|
config=config,
|
150
|
+
packed_modules_mapping=packed_modules_mapping,
|
147
151
|
)
|
148
152
|
|
149
153
|
@classmethod
|
@@ -103,16 +103,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
103
103
|
"input_activations"
|
104
104
|
)
|
105
105
|
|
106
|
-
if not (
|
107
|
-
self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
108
|
-
and self.input_quant.strategy == QuantizationStrategy.TENSOR
|
109
|
-
):
|
110
|
-
raise ValueError(
|
111
|
-
"For FP8 Fused MoE layers, only per-tensor scales "
|
112
|
-
"for weights and activations are supported. Found "
|
113
|
-
f"{self.weight_quant}, {self.input_quant}"
|
114
|
-
)
|
115
|
-
|
116
106
|
self.static_input_scales = not self.input_quant.dynamic
|
117
107
|
|
118
108
|
def create_weights(
|
@@ -154,27 +144,50 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
154
144
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
155
145
|
|
156
146
|
# WEIGHT_SCALES
|
157
|
-
#
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
147
|
+
# per-tensor quantization
|
148
|
+
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
149
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
150
|
+
# They will be combined to a single scale after weight loading.
|
151
|
+
w13_weight_scale = torch.nn.Parameter(
|
152
|
+
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
153
|
+
)
|
154
|
+
w2_weight_scale = torch.nn.Parameter(
|
155
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
156
|
+
)
|
157
|
+
weight_quant_method = FusedMoeWeightScaleSupported.TENSOR.value
|
158
|
+
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
|
159
|
+
w13_weight_scale = torch.nn.Parameter(
|
160
|
+
torch.ones(
|
161
|
+
num_experts,
|
162
|
+
2 * intermediate_size_per_partition,
|
163
|
+
1,
|
164
|
+
dtype=torch.float32,
|
165
|
+
),
|
166
|
+
requires_grad=False,
|
167
|
+
)
|
168
|
+
w2_weight_scale = torch.nn.Parameter(
|
169
|
+
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
|
170
|
+
requires_grad=False,
|
171
|
+
)
|
172
|
+
weight_quant_method = FusedMoeWeightScaleSupported.CHANNEL.value
|
173
|
+
else:
|
174
|
+
raise ValueError(
|
175
|
+
f"Unsupported weight quantization strategy: {self.weight_quant.strategy}"
|
176
|
+
)
|
163
177
|
|
164
|
-
|
165
|
-
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
166
|
-
)
|
178
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
167
179
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
168
180
|
# Add the quantization method used (per tensor/grouped/channel)
|
169
181
|
# to ensure the weight scales are loaded in properly
|
170
|
-
extra_weight_attrs.update(
|
171
|
-
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
172
|
-
)
|
182
|
+
extra_weight_attrs.update({"quant_method": weight_quant_method})
|
173
183
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
174
184
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
175
185
|
|
176
186
|
# INPUT_SCALES
|
177
187
|
if self.static_input_scales:
|
188
|
+
assert (
|
189
|
+
self.input_quant.strategy == QuantizationStrategy.TENSOR
|
190
|
+
), "Only per-tensor quantization is supported for static input scales"
|
178
191
|
w13_input_scale = torch.nn.Parameter(
|
179
192
|
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
180
193
|
)
|
@@ -241,31 +254,37 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
241
254
|
layer.w2_input_scale = torch.nn.Parameter(
|
242
255
|
w2_input_scale, requires_grad=False
|
243
256
|
)
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
)
|
257
|
-
|
258
|
-
if _is_cuda:
|
259
|
-
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
260
|
-
sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
261
|
-
)
|
262
|
-
else:
|
263
|
-
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
264
|
-
vllm_ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
257
|
+
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
258
|
+
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
259
|
+
# We take the max then dequant and requant each expert.
|
260
|
+
assert layer.w13_weight_scale is not None
|
261
|
+
shard_size = layer.intermediate_size_per_partition
|
262
|
+
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
263
|
+
for expert_id in range(layer.local_num_experts):
|
264
|
+
start = 0
|
265
|
+
for shard_id in range(2):
|
266
|
+
dq_weight = per_tensor_dequantize(
|
267
|
+
layer.w13_weight[expert_id][start : start + shard_size, :],
|
268
|
+
layer.w13_weight_scale[expert_id][shard_id],
|
265
269
|
)
|
266
|
-
start += shard_size
|
267
270
|
|
268
|
-
|
271
|
+
if _is_cuda:
|
272
|
+
(
|
273
|
+
layer.w13_weight[expert_id][start : start + shard_size, :],
|
274
|
+
_,
|
275
|
+
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
276
|
+
else:
|
277
|
+
(
|
278
|
+
layer.w13_weight[expert_id][start : start + shard_size, :],
|
279
|
+
_,
|
280
|
+
) = vllm_ops.scaled_fp8_quant(
|
281
|
+
dq_weight, max_w13_scales[expert_id]
|
282
|
+
)
|
283
|
+
start += shard_size
|
284
|
+
|
285
|
+
layer.w13_weight_scale = torch.nn.Parameter(
|
286
|
+
max_w13_scales, requires_grad=False
|
287
|
+
)
|
269
288
|
|
270
289
|
def apply(
|
271
290
|
self,
|
@@ -285,6 +304,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
285
304
|
activation: str = "silu",
|
286
305
|
inplace: bool = True,
|
287
306
|
no_combine: bool = False,
|
307
|
+
apply_router_weight_on_input: bool = False,
|
288
308
|
) -> torch.Tensor:
|
289
309
|
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
|
290
310
|
from sglang.srt.layers.moe.topk import select_experts
|
@@ -310,10 +330,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
310
330
|
inplace=inplace,
|
311
331
|
activation=activation,
|
312
332
|
use_fp8_w8a8=True,
|
333
|
+
per_channel_quant=self.weight_quant.strategy
|
334
|
+
== QuantizationStrategy.CHANNEL,
|
313
335
|
w1_scale=layer.w13_weight_scale,
|
314
336
|
w2_scale=layer.w2_weight_scale,
|
315
337
|
a1_scale=layer.w13_input_scale,
|
316
338
|
a2_scale=layer.w2_input_scale,
|
339
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
317
340
|
)
|
318
341
|
|
319
342
|
|
@@ -71,7 +71,8 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
|
71
71
|
_is_hip = is_hip()
|
72
72
|
|
73
73
|
if _is_hip:
|
74
|
-
from aiter
|
74
|
+
from aiter import ActivationType
|
75
|
+
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
|
75
76
|
from aiter.ops.shuffle import shuffle_weight
|
76
77
|
|
77
78
|
_is_cuda = is_cuda()
|
@@ -487,7 +488,7 @@ class Fp8MoEMethod:
|
|
487
488
|
|
488
489
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
489
490
|
params_dtype = (
|
490
|
-
torch.
|
491
|
+
torch.uint32
|
491
492
|
if get_bool_env_var("USE_INT4_WEIGHT")
|
492
493
|
else torch.float8_e4m3fn
|
493
494
|
)
|
@@ -822,12 +823,14 @@ class Fp8MoEMethod:
|
|
822
823
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
823
824
|
# Weight Permutation
|
824
825
|
layer.w13_weight = torch.nn.Parameter(
|
825
|
-
permute_weight(layer.w13_weight.data),
|
826
|
+
# permute_weight(layer.w13_weight.data),
|
827
|
+
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
826
828
|
requires_grad=False,
|
827
829
|
)
|
828
830
|
torch.cuda.empty_cache()
|
829
831
|
layer.w2_weight = torch.nn.Parameter(
|
830
|
-
permute_weight(layer.w2_weight.data),
|
832
|
+
# permute_weight(layer.w2_weight.data),
|
833
|
+
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
831
834
|
requires_grad=False,
|
832
835
|
)
|
833
836
|
torch.cuda.empty_cache()
|
@@ -860,19 +863,21 @@ class Fp8MoEMethod:
|
|
860
863
|
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
|
861
864
|
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
|
862
865
|
|
863
|
-
def process_weights_hip_scale_padding(self, layer: Module
|
866
|
+
def process_weights_hip_scale_padding(self, layer: Module):
|
864
867
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
865
868
|
padding_size, # Avoid circular import
|
866
869
|
)
|
867
870
|
|
868
871
|
if get_bool_env_var("CK_MOE"):
|
869
872
|
layer.w13_weight = torch.nn.Parameter(
|
870
|
-
permute_weight(layer.w13_weight.data),
|
873
|
+
# permute_weight(layer.w13_weight.data),
|
874
|
+
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
871
875
|
requires_grad=False,
|
872
876
|
)
|
873
877
|
torch.cuda.empty_cache()
|
874
878
|
layer.w2_weight = torch.nn.Parameter(
|
875
|
-
permute_weight(layer.w2_weight.data),
|
879
|
+
# permute_weight(layer.w2_weight.data),
|
880
|
+
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
876
881
|
requires_grad=False,
|
877
882
|
)
|
878
883
|
torch.cuda.empty_cache()
|
@@ -905,6 +910,7 @@ class Fp8MoEMethod:
|
|
905
910
|
custom_routing_function: Optional[Callable] = None,
|
906
911
|
correction_bias: Optional[torch.Tensor] = None,
|
907
912
|
activation: str = "silu",
|
913
|
+
apply_router_weight_on_input: bool = False,
|
908
914
|
inplace: bool = True,
|
909
915
|
no_combine: bool = False,
|
910
916
|
) -> torch.Tensor:
|
@@ -927,7 +933,7 @@ class Fp8MoEMethod:
|
|
927
933
|
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
928
934
|
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
929
935
|
assert not no_combine, f"{no_combine=} is not supported."
|
930
|
-
return
|
936
|
+
return ck_moe_2stages_win4(
|
931
937
|
x,
|
932
938
|
layer.w13_weight,
|
933
939
|
layer.w2_weight,
|
@@ -935,15 +941,17 @@ class Fp8MoEMethod:
|
|
935
941
|
topk_ids,
|
936
942
|
layer.w13_weight_scale1,
|
937
943
|
layer.w2_weight_scale1,
|
938
|
-
activation=
|
944
|
+
activation=(
|
945
|
+
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
946
|
+
),
|
939
947
|
)
|
940
948
|
if _is_hip and get_bool_env_var("CK_MOE"):
|
941
|
-
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
|
942
|
-
assert (
|
943
|
-
activation == "silu"
|
944
|
-
), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
945
949
|
assert not no_combine, f"{no_combine=} is not supported."
|
946
950
|
if self.block_quant:
|
951
|
+
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
952
|
+
assert (
|
953
|
+
activation == "silu"
|
954
|
+
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
947
955
|
return asm_moe(
|
948
956
|
x,
|
949
957
|
layer.w13_weight,
|
@@ -956,7 +964,7 @@ class Fp8MoEMethod:
|
|
956
964
|
expert_mask=None,
|
957
965
|
)
|
958
966
|
else:
|
959
|
-
return
|
967
|
+
return ck_moe_2stages(
|
960
968
|
x,
|
961
969
|
layer.w13_weight,
|
962
970
|
layer.w2_weight,
|
@@ -964,6 +972,11 @@ class Fp8MoEMethod:
|
|
964
972
|
topk_ids,
|
965
973
|
layer.w13_weight_scale1,
|
966
974
|
layer.w2_weight_scale1,
|
975
|
+
activation=(
|
976
|
+
ActivationType.Silu
|
977
|
+
if activation == "silu"
|
978
|
+
else ActivationType.Gelu
|
979
|
+
),
|
967
980
|
)
|
968
981
|
else:
|
969
982
|
# Expert fusion with FP8 quantization
|
@@ -975,6 +988,7 @@ class Fp8MoEMethod:
|
|
975
988
|
topk_ids=topk_ids,
|
976
989
|
inplace=inplace and not no_combine,
|
977
990
|
activation=activation,
|
991
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
978
992
|
use_fp8_w8a8=True,
|
979
993
|
w1_scale=(
|
980
994
|
layer.w13_weight_scale_inv
|
@@ -16,6 +16,7 @@ import functools
|
|
16
16
|
import json
|
17
17
|
import logging
|
18
18
|
import os
|
19
|
+
from contextlib import contextmanager
|
19
20
|
from typing import Any, Dict, List, Optional, Tuple
|
20
21
|
|
21
22
|
import torch
|
@@ -40,11 +41,13 @@ fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
|
40
41
|
|
41
42
|
_is_cuda = is_cuda()
|
42
43
|
if _is_cuda:
|
43
|
-
import deep_gemm
|
44
|
+
import deep_gemm
|
44
45
|
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
45
46
|
|
46
47
|
sm_version = get_device_sm()
|
47
|
-
if sm_version
|
48
|
+
if sm_version == 90 and get_bool_env_var(
|
49
|
+
"SGL_ENABLE_JIT_DEEPGEMM", default="false"
|
50
|
+
):
|
48
51
|
_enable_jit_deepgemm = True
|
49
52
|
|
50
53
|
|
@@ -59,7 +62,10 @@ if supports_custom_op():
|
|
59
62
|
Bs: torch.Tensor,
|
60
63
|
C: torch.Tensor,
|
61
64
|
) -> None:
|
62
|
-
|
65
|
+
M, K = A.shape
|
66
|
+
N, _ = B.shape
|
67
|
+
with _log_jit_build(M, N, K):
|
68
|
+
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
63
69
|
|
64
70
|
def deep_gemm_fp8_fp8_bf16_nt_fake(
|
65
71
|
A: torch.Tensor,
|
@@ -708,6 +714,25 @@ def get_w8a8_block_fp8_configs(
|
|
708
714
|
return None
|
709
715
|
|
710
716
|
|
717
|
+
@contextmanager
|
718
|
+
def _log_jit_build(M: int, N: int, K: int):
|
719
|
+
from deep_gemm.jit.runtime import RuntimeCache
|
720
|
+
|
721
|
+
origin_func = RuntimeCache.__getitem__
|
722
|
+
|
723
|
+
def __patched_func(self, *args, **kwargs):
|
724
|
+
ret = origin_func(self, *args, **kwargs)
|
725
|
+
if ret is None:
|
726
|
+
logger.warning(
|
727
|
+
f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
|
728
|
+
)
|
729
|
+
return ret
|
730
|
+
|
731
|
+
RuntimeCache.__getitem__ = __patched_func
|
732
|
+
yield
|
733
|
+
RuntimeCache.__getitem__ = origin_func
|
734
|
+
|
735
|
+
|
711
736
|
def w8a8_block_fp8_matmul(
|
712
737
|
A: torch.Tensor,
|
713
738
|
B: torch.Tensor,
|
@@ -782,7 +807,8 @@ def w8a8_block_fp8_matmul(
|
|
782
807
|
if supports_custom_op():
|
783
808
|
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
784
809
|
else:
|
785
|
-
|
810
|
+
with _log_jit_build(M, N, K):
|
811
|
+
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
786
812
|
else:
|
787
813
|
kernel = (
|
788
814
|
_w8a8_block_fp8_matmul_unrolledx4
|
@@ -815,3 +841,103 @@ def w8a8_block_fp8_matmul(
|
|
815
841
|
)
|
816
842
|
|
817
843
|
return C
|
844
|
+
|
845
|
+
|
846
|
+
@triton.jit
|
847
|
+
def _per_tensor_quant_mla_fp8_stage1(
|
848
|
+
x_ptr,
|
849
|
+
x_s_ptr,
|
850
|
+
head_size,
|
851
|
+
x_stride_h,
|
852
|
+
x_stride_s,
|
853
|
+
eps,
|
854
|
+
fp8_max,
|
855
|
+
BLOCK_SIZE: tl.constexpr,
|
856
|
+
):
|
857
|
+
seq_id = tl.program_id(0)
|
858
|
+
head_id = tl.program_id(1)
|
859
|
+
offset = tl.arange(0, BLOCK_SIZE)
|
860
|
+
mask = offset < head_size
|
861
|
+
|
862
|
+
x_ptr += head_id * x_stride_h + seq_id * x_stride_s
|
863
|
+
x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32)
|
864
|
+
_absmax = tl.maximum(tl.max(tl.abs(x)), eps)
|
865
|
+
|
866
|
+
tl.atomic_max(x_s_ptr, _absmax / fp8_max)
|
867
|
+
|
868
|
+
|
869
|
+
@triton.jit
|
870
|
+
def _per_tensor_quant_mla_fp8_stage2(
|
871
|
+
x_ptr,
|
872
|
+
x_s_ptr,
|
873
|
+
x_q_ptr,
|
874
|
+
num_seq,
|
875
|
+
head_size,
|
876
|
+
x_stride_h,
|
877
|
+
x_stride_s,
|
878
|
+
fp8_min,
|
879
|
+
fp8_max,
|
880
|
+
BLOCK_SIZE: tl.constexpr,
|
881
|
+
):
|
882
|
+
seq_id = tl.program_id(0)
|
883
|
+
head_id = tl.program_id(1)
|
884
|
+
offset = tl.arange(0, BLOCK_SIZE)
|
885
|
+
mask = offset < head_size
|
886
|
+
|
887
|
+
x_s = tl.load(x_s_ptr)
|
888
|
+
x_s_inv = 1.0 / x_s
|
889
|
+
|
890
|
+
x_ptr += head_id * x_stride_h + seq_id * x_stride_s
|
891
|
+
x_q_ptr += head_id * num_seq * head_size + seq_id * head_size
|
892
|
+
|
893
|
+
x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32)
|
894
|
+
x_q = tl.clamp(x * x_s_inv, fp8_min, fp8_max).to(x_q_ptr.dtype.element_ty)
|
895
|
+
tl.store(x_q_ptr + offset, x_q, mask=mask)
|
896
|
+
|
897
|
+
|
898
|
+
def per_tensor_quant_mla_fp8(
|
899
|
+
x: torch.Tensor, eps: float = 1e-12, dtype: torch.dtype = torch.float8_e4m3fn
|
900
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
901
|
+
"""
|
902
|
+
This function quantizes input values to float8 values with tensor-wise quantization
|
903
|
+
and specialized for mla absorbed case.
|
904
|
+
"""
|
905
|
+
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
906
|
+
|
907
|
+
finfo = torch.finfo(dtype)
|
908
|
+
fp8_max = finfo.max
|
909
|
+
if _is_hip:
|
910
|
+
dtype = torch.float8_e4m3fnuz
|
911
|
+
fp8_max = 224.0
|
912
|
+
|
913
|
+
x_q = x.new_empty(x.size(), dtype=dtype)
|
914
|
+
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
|
915
|
+
|
916
|
+
num_head, num_seq, head_size = x.shape
|
917
|
+
BLOCK_SIZE = triton.next_power_of_2(head_size)
|
918
|
+
grid = (num_seq, num_head)
|
919
|
+
|
920
|
+
_per_tensor_quant_mla_fp8_stage1[grid](
|
921
|
+
x,
|
922
|
+
x_s,
|
923
|
+
head_size,
|
924
|
+
x.stride(0),
|
925
|
+
x.stride(1),
|
926
|
+
eps,
|
927
|
+
fp8_max,
|
928
|
+
BLOCK_SIZE,
|
929
|
+
)
|
930
|
+
_per_tensor_quant_mla_fp8_stage2[grid](
|
931
|
+
x,
|
932
|
+
x_s,
|
933
|
+
x_q,
|
934
|
+
num_seq,
|
935
|
+
head_size,
|
936
|
+
x.stride(0),
|
937
|
+
x.stride(1),
|
938
|
+
-fp8_max,
|
939
|
+
fp8_max,
|
940
|
+
BLOCK_SIZE,
|
941
|
+
)
|
942
|
+
|
943
|
+
return x_q, x_s
|
@@ -168,12 +168,13 @@ def input_to_float8(
|
|
168
168
|
"""This function quantizes input values to float8 values with tensor-wise quantization."""
|
169
169
|
finfo = torch.finfo(dtype)
|
170
170
|
min_val, max_val = x.aminmax()
|
171
|
-
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
171
|
+
amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
|
172
172
|
fp8_max = finfo.max
|
173
173
|
if _is_hip:
|
174
|
+
dtype = torch.float8_e4m3fnuz
|
174
175
|
fp8_max = 224.0
|
175
176
|
scale = fp8_max / amax
|
176
|
-
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
|
177
|
+
x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max)
|
177
178
|
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
178
179
|
|
179
180
|
|
@@ -212,7 +213,24 @@ def block_quant_to_tensor_quant(
|
|
212
213
|
for j in range(n_tiles):
|
213
214
|
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
214
215
|
|
215
|
-
x_q_tensor, scale =
|
216
|
+
x_q_tensor, scale = (
|
217
|
+
sgl_scaled_fp8_quant(x_dq_block)
|
218
|
+
if _is_cuda
|
219
|
+
else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
220
|
+
)
|
221
|
+
return x_q_tensor, scale
|
222
|
+
|
223
|
+
|
224
|
+
def channel_quant_to_tensor_quant(
|
225
|
+
x_q_channel: torch.Tensor,
|
226
|
+
x_s: torch.Tensor,
|
227
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
228
|
+
x_dq_channel = x_q_channel.to(torch.float32) * x_s
|
229
|
+
x_q_tensor, scale = (
|
230
|
+
sgl_scaled_fp8_quant(x_dq_channel)
|
231
|
+
if _is_cuda
|
232
|
+
else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
|
233
|
+
)
|
216
234
|
return x_q_tensor, scale
|
217
235
|
|
218
236
|
|
@@ -242,9 +260,19 @@ def apply_fp8_linear(
|
|
242
260
|
if _is_cuda:
|
243
261
|
qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
|
244
262
|
else:
|
245
|
-
|
246
|
-
|
247
|
-
)
|
263
|
+
# TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
|
264
|
+
# final solution should be: 1. add support to per-tensor activation scaling.
|
265
|
+
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
|
266
|
+
if _is_hip and weight_scale.numel() == 1:
|
267
|
+
qinput, x_scale = ops.scaled_fp8_quant(
|
268
|
+
input_2d,
|
269
|
+
input_scale,
|
270
|
+
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
271
|
+
)
|
272
|
+
else:
|
273
|
+
qinput, x_scale = per_token_group_quant_fp8(
|
274
|
+
input_2d, group_size=input_2d.shape[1]
|
275
|
+
)
|
248
276
|
|
249
277
|
if cutlass_fp8_supported:
|
250
278
|
try:
|