sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -113,7 +113,7 @@ if supports_custom_op():
|
|
113
113
|
|
114
114
|
|
115
115
|
@triton.jit
|
116
|
-
def
|
116
|
+
def _per_token_group_quant_8bit(
|
117
117
|
# Pointers to inputs and output
|
118
118
|
y_ptr,
|
119
119
|
y_q_ptr,
|
@@ -125,8 +125,8 @@ def _per_token_group_quant_fp8(
|
|
125
125
|
# Avoid to divide zero
|
126
126
|
eps,
|
127
127
|
# Information for float8
|
128
|
-
|
129
|
-
|
128
|
+
bit8_min,
|
129
|
+
bit8_max,
|
130
130
|
# Meta-parameters
|
131
131
|
BLOCK: tl.constexpr,
|
132
132
|
):
|
@@ -147,16 +147,16 @@ def _per_token_group_quant_fp8(
|
|
147
147
|
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
148
148
|
# Quant
|
149
149
|
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
150
|
-
y_s = _absmax /
|
150
|
+
y_s = _absmax / bit8_max
|
151
151
|
y_s_inv = 1.0 / y_s
|
152
|
-
y_q = tl.clamp(y * y_s_inv,
|
152
|
+
y_q = tl.clamp(y * y_s_inv, bit8_min, bit8_max).to(y_q_ptr.dtype.element_ty)
|
153
153
|
|
154
154
|
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
155
155
|
tl.store(y_s_ptr, y_s)
|
156
156
|
|
157
157
|
|
158
158
|
@triton.jit
|
159
|
-
def
|
159
|
+
def _per_token_group_quant_8bit_colmajor(
|
160
160
|
# Pointers to inputs and output
|
161
161
|
y_ptr,
|
162
162
|
y_q_ptr,
|
@@ -169,8 +169,8 @@ def _per_token_group_quant_fp8_colmajor(
|
|
169
169
|
# Avoid to divide zero
|
170
170
|
eps,
|
171
171
|
# Information for float8
|
172
|
-
|
173
|
-
|
172
|
+
bit8_min,
|
173
|
+
bit8_max,
|
174
174
|
# Meta-parameters
|
175
175
|
BLOCK: tl.constexpr,
|
176
176
|
SCALE_UE8M0: tl.constexpr,
|
@@ -197,19 +197,20 @@ def _per_token_group_quant_fp8_colmajor(
|
|
197
197
|
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
198
198
|
# Quant
|
199
199
|
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
200
|
-
y_s = _absmax /
|
200
|
+
y_s = _absmax / bit8_max
|
201
201
|
if SCALE_UE8M0:
|
202
202
|
y_s = tl.exp2(tl.ceil(tl.log2(tl.abs(y_s))))
|
203
|
-
y_q = tl.clamp(y / y_s,
|
203
|
+
y_q = tl.clamp(y / y_s, bit8_min, bit8_max).to(y_q_ptr.dtype.element_ty)
|
204
204
|
|
205
205
|
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
206
206
|
tl.store(y_s_ptr, y_s)
|
207
207
|
|
208
208
|
|
209
|
-
def
|
209
|
+
def _per_token_group_quant_8bit_raw(
|
210
210
|
x: torch.Tensor,
|
211
211
|
group_size: int,
|
212
212
|
eps: float = 1e-10,
|
213
|
+
dtype: torch.dtype = fp8_dtype,
|
213
214
|
column_major_scales: bool = False,
|
214
215
|
scale_tma_aligned: bool = False,
|
215
216
|
scale_ue8m0: bool = False,
|
@@ -223,6 +224,7 @@ def per_token_group_quant_fp8(
|
|
223
224
|
x: The input tenosr with ndim >= 2.
|
224
225
|
group_size: The group size used for quantization.
|
225
226
|
eps: The minimum to avoid dividing zero.
|
227
|
+
dtype: The dype of output tensor.
|
226
228
|
|
227
229
|
Returns:
|
228
230
|
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
@@ -232,7 +234,21 @@ def per_token_group_quant_fp8(
|
|
232
234
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
233
235
|
assert x.is_contiguous(), "`x` is not contiguous"
|
234
236
|
|
235
|
-
|
237
|
+
if _is_hip:
|
238
|
+
if dtype == torch.int8:
|
239
|
+
bit8_max = 127.0
|
240
|
+
else:
|
241
|
+
bit8_max = 224.0
|
242
|
+
bit8_min = -bit8_max # TODO incorrect for int8
|
243
|
+
else:
|
244
|
+
if dtype == torch.int8:
|
245
|
+
info = torch.iinfo(dtype)
|
246
|
+
else:
|
247
|
+
info = torch.finfo(dtype)
|
248
|
+
bit8_max = info.max
|
249
|
+
bit8_min = info.min
|
250
|
+
|
251
|
+
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
236
252
|
x_s = create_per_token_group_quant_fp8_output_scale(
|
237
253
|
x_shape=x.shape,
|
238
254
|
device=x.device,
|
@@ -250,7 +266,7 @@ def per_token_group_quant_fp8(
|
|
250
266
|
num_warps = min(max(BLOCK // 256, 1), 8)
|
251
267
|
num_stages = 1
|
252
268
|
if column_major_scales:
|
253
|
-
|
269
|
+
_per_token_group_quant_8bit_colmajor[(M,)](
|
254
270
|
x,
|
255
271
|
x_q,
|
256
272
|
x_s,
|
@@ -258,8 +274,8 @@ def per_token_group_quant_fp8(
|
|
258
274
|
x.shape[1],
|
259
275
|
x_s.stride(1),
|
260
276
|
eps,
|
261
|
-
|
262
|
-
|
277
|
+
bit8_min=bit8_min,
|
278
|
+
bit8_max=bit8_max,
|
263
279
|
BLOCK=BLOCK,
|
264
280
|
num_warps=num_warps,
|
265
281
|
num_stages=num_stages,
|
@@ -267,15 +283,15 @@ def per_token_group_quant_fp8(
|
|
267
283
|
)
|
268
284
|
else:
|
269
285
|
assert not scale_ue8m0
|
270
|
-
|
286
|
+
_per_token_group_quant_8bit[(M,)](
|
271
287
|
x,
|
272
288
|
x_q,
|
273
289
|
x_s,
|
274
290
|
group_size,
|
275
291
|
N,
|
276
292
|
eps,
|
277
|
-
|
278
|
-
|
293
|
+
bit8_min=bit8_min,
|
294
|
+
bit8_max=bit8_max,
|
279
295
|
BLOCK=BLOCK,
|
280
296
|
num_warps=num_warps,
|
281
297
|
num_stages=num_stages,
|
@@ -297,6 +313,117 @@ def per_token_group_quant_fp8(
|
|
297
313
|
return x_q, x_s
|
298
314
|
|
299
315
|
|
316
|
+
# backward compatibility
|
317
|
+
per_token_group_quant_fp8 = _per_token_group_quant_8bit_raw
|
318
|
+
|
319
|
+
|
320
|
+
def _per_token_group_quant_8bit_fuse_silu_and_mul(
|
321
|
+
x: torch.Tensor,
|
322
|
+
group_size: int,
|
323
|
+
dst_dtype: torch.dtype,
|
324
|
+
column_major_scales: bool,
|
325
|
+
scale_tma_aligned: bool,
|
326
|
+
scale_ue8m0: bool,
|
327
|
+
masked_m: Optional[torch.Tensor],
|
328
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
329
|
+
# Another way to implement (can be used in e.g. comparison tests)
|
330
|
+
# from sgl_kernel import silu_and_mul
|
331
|
+
# x_after_silu_and_mul = silu_and_mul(x)
|
332
|
+
# return per_token_group_quant_fp8(
|
333
|
+
# x_after_silu_and_mul,
|
334
|
+
# group_size=group_size,
|
335
|
+
# eps=eps,
|
336
|
+
# column_major_scales=column_major_scales,
|
337
|
+
# scale_tma_aligned=scale_tma_aligned,
|
338
|
+
# scale_ue8m0=scale_ue8m0,
|
339
|
+
# )
|
340
|
+
|
341
|
+
from deep_gemm.utils.layout import transform_sf_into_required_layout
|
342
|
+
|
343
|
+
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
|
344
|
+
|
345
|
+
assert column_major_scales
|
346
|
+
assert scale_tma_aligned
|
347
|
+
assert scale_ue8m0
|
348
|
+
|
349
|
+
needs_unsqueeze = x.dim() == 2
|
350
|
+
if needs_unsqueeze:
|
351
|
+
num_tokens, _ = x.shape
|
352
|
+
x = x.unsqueeze(0)
|
353
|
+
assert masked_m is None
|
354
|
+
masked_m = torch.tensor([num_tokens], device=x.device, dtype=torch.int32)
|
355
|
+
|
356
|
+
# Use `zeros` for easier testing
|
357
|
+
output = torch.zeros(
|
358
|
+
(*x.shape[:-1], x.shape[-1] // 2),
|
359
|
+
device=x.device,
|
360
|
+
dtype=dst_dtype,
|
361
|
+
)
|
362
|
+
# Use `zeros` for easier testing
|
363
|
+
output_scale_for_kernel = torch.zeros(
|
364
|
+
(*x.shape[:-1], x.shape[-1] // 2 // group_size),
|
365
|
+
device=x.device,
|
366
|
+
dtype=torch.float32,
|
367
|
+
)
|
368
|
+
silu_and_mul_masked_post_quant_fwd(
|
369
|
+
input=x,
|
370
|
+
output=output,
|
371
|
+
output_scale=output_scale_for_kernel,
|
372
|
+
quant_group_size=group_size,
|
373
|
+
masked_m=masked_m,
|
374
|
+
scale_ue8m0=scale_ue8m0,
|
375
|
+
)
|
376
|
+
|
377
|
+
assert group_size == 128
|
378
|
+
output_scale = transform_sf_into_required_layout(
|
379
|
+
output_scale_for_kernel,
|
380
|
+
num_groups=output.shape[0],
|
381
|
+
mn=output.shape[-2],
|
382
|
+
k=output.shape[-1],
|
383
|
+
recipe=(1, group_size, group_size),
|
384
|
+
is_sfa=True,
|
385
|
+
)
|
386
|
+
|
387
|
+
if needs_unsqueeze:
|
388
|
+
output = output.squeeze(0)
|
389
|
+
output_scale = output_scale.squeeze(0)
|
390
|
+
|
391
|
+
return output, output_scale
|
392
|
+
|
393
|
+
|
394
|
+
def per_token_group_quant_8bit(
|
395
|
+
x: torch.Tensor,
|
396
|
+
group_size: int,
|
397
|
+
dst_dtype: torch.dtype,
|
398
|
+
eps: float = 1e-10,
|
399
|
+
column_major_scales: bool = False,
|
400
|
+
scale_tma_aligned: bool = False,
|
401
|
+
scale_ue8m0: bool = False,
|
402
|
+
fuse_silu_and_mul: bool = False,
|
403
|
+
masked_m: Optional[torch.Tensor] = None,
|
404
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
405
|
+
if fuse_silu_and_mul:
|
406
|
+
return _per_token_group_quant_8bit_fuse_silu_and_mul(
|
407
|
+
x=x,
|
408
|
+
group_size=group_size,
|
409
|
+
dst_dtype=dst_dtype,
|
410
|
+
column_major_scales=column_major_scales,
|
411
|
+
scale_tma_aligned=scale_tma_aligned,
|
412
|
+
scale_ue8m0=scale_ue8m0,
|
413
|
+
masked_m=masked_m,
|
414
|
+
)
|
415
|
+
else:
|
416
|
+
return _per_token_group_quant_8bit_raw(
|
417
|
+
x=x,
|
418
|
+
group_size=group_size,
|
419
|
+
eps=eps,
|
420
|
+
column_major_scales=column_major_scales,
|
421
|
+
scale_tma_aligned=scale_tma_aligned,
|
422
|
+
scale_ue8m0=scale_ue8m0,
|
423
|
+
dtype=dst_dtype,
|
424
|
+
)
|
425
|
+
|
426
|
+
|
300
427
|
def create_per_token_group_quant_fp8_output_scale(
|
301
428
|
x_shape,
|
302
429
|
device,
|
@@ -307,16 +434,16 @@ def create_per_token_group_quant_fp8_output_scale(
|
|
307
434
|
):
|
308
435
|
if scale_ue8m0:
|
309
436
|
assert column_major_scales and scale_tma_aligned
|
310
|
-
x_q_mn, x_q_k = x_shape
|
437
|
+
*x_batch, x_q_mn, x_q_k = x_shape
|
311
438
|
x_s_mn, x_s_k = x_q_mn, x_q_k // 128
|
312
439
|
aligned_mn = align(x_s_mn, 4)
|
313
440
|
aligned_k = align(x_s_k, 4)
|
314
441
|
# TODO(FIXME): Fix cuda kernel and recover here to empty.
|
315
|
-
return torch.
|
316
|
-
(aligned_k // 4, aligned_mn),
|
442
|
+
return torch.empty(
|
443
|
+
(*x_batch, aligned_k // 4, aligned_mn),
|
317
444
|
device=device,
|
318
445
|
dtype=torch.int,
|
319
|
-
).transpose(
|
446
|
+
).transpose(-1, -2)[..., :x_s_mn, :]
|
320
447
|
elif column_major_scales:
|
321
448
|
if scale_tma_aligned:
|
322
449
|
# TODO extract "align" function
|
@@ -348,15 +475,19 @@ def sglang_per_token_group_quant_fp8(
|
|
348
475
|
column_major_scales: bool = False,
|
349
476
|
scale_tma_aligned: bool = False,
|
350
477
|
scale_ue8m0: bool = False,
|
478
|
+
fuse_silu_and_mul: bool = False,
|
479
|
+
masked_m: Optional[torch.Tensor] = None,
|
351
480
|
):
|
352
481
|
assert (
|
353
482
|
x.shape[-1] % group_size == 0
|
354
483
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
355
484
|
assert x.is_contiguous(), "`x` is not contiguous"
|
356
485
|
|
357
|
-
|
486
|
+
out_shape = (*x.shape[:-1], x.shape[-1] // (2 if fuse_silu_and_mul else 1))
|
487
|
+
|
488
|
+
x_q = torch.empty(out_shape, device=x.device, dtype=fp8_dtype)
|
358
489
|
x_s = create_per_token_group_quant_fp8_output_scale(
|
359
|
-
x_shape=
|
490
|
+
x_shape=out_shape,
|
360
491
|
device=x.device,
|
361
492
|
group_size=group_size,
|
362
493
|
column_major_scales=column_major_scales,
|
@@ -372,6 +503,46 @@ def sglang_per_token_group_quant_fp8(
|
|
372
503
|
return x_q, x_s
|
373
504
|
|
374
505
|
|
506
|
+
# TODO maybe unify int8 and fp8 code later
|
507
|
+
def sglang_per_token_group_quant_8bit(
|
508
|
+
x: torch.Tensor,
|
509
|
+
group_size: int,
|
510
|
+
dst_dtype: torch.dtype,
|
511
|
+
eps: float = 1e-10,
|
512
|
+
column_major_scales: bool = False,
|
513
|
+
scale_tma_aligned: bool = False,
|
514
|
+
scale_ue8m0: bool = False,
|
515
|
+
fuse_silu_and_mul: bool = False,
|
516
|
+
masked_m: Optional[torch.Tensor] = None,
|
517
|
+
):
|
518
|
+
from sglang.srt.layers.quantization.int8_kernel import (
|
519
|
+
sglang_per_token_group_quant_int8,
|
520
|
+
)
|
521
|
+
|
522
|
+
if dst_dtype == torch.int8:
|
523
|
+
assert not column_major_scales
|
524
|
+
assert not scale_tma_aligned
|
525
|
+
assert not fuse_silu_and_mul
|
526
|
+
assert masked_m is None
|
527
|
+
return sglang_per_token_group_quant_int8(
|
528
|
+
x=x,
|
529
|
+
group_size=group_size,
|
530
|
+
eps=eps,
|
531
|
+
dtype=dst_dtype,
|
532
|
+
)
|
533
|
+
|
534
|
+
return sglang_per_token_group_quant_fp8(
|
535
|
+
x=x,
|
536
|
+
group_size=group_size,
|
537
|
+
eps=eps,
|
538
|
+
column_major_scales=column_major_scales,
|
539
|
+
scale_tma_aligned=scale_tma_aligned,
|
540
|
+
scale_ue8m0=scale_ue8m0,
|
541
|
+
fuse_silu_and_mul=fuse_silu_and_mul,
|
542
|
+
masked_m=masked_m,
|
543
|
+
)
|
544
|
+
|
545
|
+
|
375
546
|
def sglang_per_token_quant_fp8(
|
376
547
|
x: torch.Tensor,
|
377
548
|
dtype: torch.dtype = fp8_dtype,
|
@@ -53,6 +53,7 @@ if _is_cuda:
|
|
53
53
|
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
|
54
54
|
|
55
55
|
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
|
56
|
+
use_triton_w8a8_fp8_kernel = get_bool_env_var("USE_TRITON_W8A8_FP8_KERNEL")
|
56
57
|
|
57
58
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
58
59
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
@@ -113,6 +114,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
|
113
114
|
return weight, weight_scale, input_scale
|
114
115
|
|
115
116
|
|
117
|
+
# TODO(ch-wan): define these backends in --moe-runner-backend
|
116
118
|
def cutlass_block_fp8_supported() -> bool:
|
117
119
|
if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"):
|
118
120
|
return False
|
@@ -555,7 +557,10 @@ def apply_fp8_linear(
|
|
555
557
|
# We also don't pad when using torch.compile,
|
556
558
|
# as it breaks with dynamic shapes.
|
557
559
|
if pad_output is None:
|
558
|
-
pad_output =
|
560
|
+
pad_output = (
|
561
|
+
not get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
|
562
|
+
and not cutlass_fp8_supported
|
563
|
+
)
|
559
564
|
output_padding = 17 if pad_output else None
|
560
565
|
|
561
566
|
# View input as 2D matrix for fp8 methods
|
@@ -591,7 +596,7 @@ def apply_fp8_linear(
|
|
591
596
|
cutlass_compatible_b = (
|
592
597
|
weight.shape[0] % 16 == 0 and weight.shape[1] % 16 == 0
|
593
598
|
)
|
594
|
-
if not cutlass_compatible_b:
|
599
|
+
if not cutlass_compatible_b or use_triton_w8a8_fp8_kernel:
|
595
600
|
# Massage the input to be 2D
|
596
601
|
qinput = qinput.view(-1, qinput.shape[-1])
|
597
602
|
output = triton_scaled_mm(
|
@@ -734,14 +739,25 @@ def apply_fp8_linear(
|
|
734
739
|
assert (
|
735
740
|
weight_scale.numel() == weight.shape[1]
|
736
741
|
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
737
|
-
|
738
|
-
|
739
|
-
weight
|
740
|
-
x_scale,
|
741
|
-
weight_scale,
|
742
|
-
out_dtype=input.dtype,
|
743
|
-
bias=bias,
|
742
|
+
|
743
|
+
cutlass_compatible_b = (
|
744
|
+
weight.shape[0] % 16 == 0 and weight.shape[1] % 16 == 0
|
744
745
|
)
|
746
|
+
if not cutlass_compatible_b or use_triton_w8a8_fp8_kernel:
|
747
|
+
# Massage the input to be 2D
|
748
|
+
qinput = qinput.view(-1, qinput.shape[-1])
|
749
|
+
output = triton_scaled_mm(
|
750
|
+
qinput, weight, x_scale, weight_scale, input.dtype, bias
|
751
|
+
)
|
752
|
+
else:
|
753
|
+
output = fp8_scaled_mm(
|
754
|
+
qinput,
|
755
|
+
weight,
|
756
|
+
x_scale,
|
757
|
+
weight_scale,
|
758
|
+
out_dtype=input.dtype,
|
759
|
+
bias=bias,
|
760
|
+
)
|
745
761
|
return output.view(*output_shape)
|
746
762
|
except (ImportError, NameError, AttributeError):
|
747
763
|
pass
|
@@ -788,3 +804,12 @@ def apply_fp8_linear(
|
|
788
804
|
bias,
|
789
805
|
input.dtype,
|
790
806
|
)
|
807
|
+
|
808
|
+
|
809
|
+
def can_auto_enable_marlin_fp8() -> bool:
|
810
|
+
try:
|
811
|
+
major, minor = get_device_capability()
|
812
|
+
sm = major * 10 + minor
|
813
|
+
return 80 <= sm < 89
|
814
|
+
except Exception:
|
815
|
+
return False
|
@@ -0,0 +1,203 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
from __future__ import annotations
|
3
|
+
|
4
|
+
import logging
|
5
|
+
from typing import Any, Optional
|
6
|
+
|
7
|
+
import torch
|
8
|
+
from torch.nn import Module
|
9
|
+
from torch.nn.parameter import Parameter
|
10
|
+
|
11
|
+
from sglang.srt.layers.linear import LinearBase
|
12
|
+
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
13
|
+
from sglang.srt.layers.quantization.base_config import (
|
14
|
+
FusedMoEMethodBase,
|
15
|
+
LinearMethodBase,
|
16
|
+
QuantizationConfig,
|
17
|
+
QuantizeMethodBase,
|
18
|
+
)
|
19
|
+
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
20
|
+
from sglang.srt.layers.quantization.fp8_utils import (
|
21
|
+
apply_fp8_linear,
|
22
|
+
can_auto_enable_marlin_fp8,
|
23
|
+
cutlass_fp8_supported,
|
24
|
+
normalize_e4m3fn_to_e4m3fnuz,
|
25
|
+
)
|
26
|
+
from sglang.srt.layers.quantization.marlin_utils_fp8 import (
|
27
|
+
apply_fp8_marlin_linear,
|
28
|
+
prepare_fp8_layer_for_marlin,
|
29
|
+
)
|
30
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
31
|
+
from sglang.srt.layers.quantization.utils import is_layer_skipped, replace_parameter
|
32
|
+
from sglang.srt.utils import get_bool_env_var, is_cuda
|
33
|
+
|
34
|
+
_is_cuda = is_cuda()
|
35
|
+
_is_fp8_fnuz = is_fp8_fnuz()
|
36
|
+
|
37
|
+
logger = logging.getLogger(__name__)
|
38
|
+
|
39
|
+
|
40
|
+
class FBGEMMFp8Config(QuantizationConfig):
|
41
|
+
"""Config class for FBGEMM Fp8."""
|
42
|
+
|
43
|
+
def __init__(self, ignore_list: list[str], input_scale_ub: float):
|
44
|
+
super().__init__()
|
45
|
+
self.ignore_list = ignore_list if ignore_list else []
|
46
|
+
self.input_scale_ub = input_scale_ub
|
47
|
+
|
48
|
+
# For GPUs that lack FP8 hardware suspport, we can leverage the Marlin
|
49
|
+
# kernel for fast weight-only FP8 quantization
|
50
|
+
# self.use_marlin = not marlin_fp8_supported()
|
51
|
+
self.use_marlin = False
|
52
|
+
if _is_cuda:
|
53
|
+
force_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
|
54
|
+
auto_enable = can_auto_enable_marlin_fp8()
|
55
|
+
self.use_marlin = force_marlin or auto_enable
|
56
|
+
|
57
|
+
@classmethod
|
58
|
+
def get_name(cls) -> str:
|
59
|
+
return "fbgemm_fp8"
|
60
|
+
|
61
|
+
@classmethod
|
62
|
+
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
63
|
+
return [torch.bfloat16, torch.float16]
|
64
|
+
|
65
|
+
@classmethod
|
66
|
+
def get_min_capability(cls) -> int:
|
67
|
+
return 80
|
68
|
+
|
69
|
+
@classmethod
|
70
|
+
def get_config_filenames(cls) -> list[str]:
|
71
|
+
return []
|
72
|
+
|
73
|
+
@classmethod
|
74
|
+
def from_config(cls, config: dict[str, Any]) -> FBGEMMFp8Config:
|
75
|
+
ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
|
76
|
+
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
|
77
|
+
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
|
78
|
+
|
79
|
+
def get_quant_method(
|
80
|
+
self, layer: torch.nn.Module, prefix: str
|
81
|
+
) -> Optional[QuantizeMethodBase]:
|
82
|
+
if isinstance(layer, LinearBase):
|
83
|
+
if is_layer_skipped(
|
84
|
+
prefix=prefix,
|
85
|
+
ignored_layers=self.ignore_list,
|
86
|
+
fused_mapping=self.packed_modules_mapping,
|
87
|
+
):
|
88
|
+
return UnquantizedLinearMethod()
|
89
|
+
return FBGEMMFp8LinearMethod(self)
|
90
|
+
return None
|
91
|
+
|
92
|
+
def get_scaled_act_names(self) -> List[str]:
|
93
|
+
return []
|
94
|
+
|
95
|
+
|
96
|
+
class FBGEMMFp8LinearMethod(LinearMethodBase):
|
97
|
+
|
98
|
+
def __init__(self, quant_config: FBGEMMFp8Config):
|
99
|
+
self.quant_config = quant_config
|
100
|
+
# self.fp8_linear = Fp8LinearOp(
|
101
|
+
# act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN)
|
102
|
+
self.out_dtype = torch.get_default_dtype()
|
103
|
+
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
104
|
+
|
105
|
+
def create_weights(
|
106
|
+
self,
|
107
|
+
layer: torch.nn.Module,
|
108
|
+
input_size_per_partition: int,
|
109
|
+
output_partition_sizes: list[int],
|
110
|
+
input_size: int,
|
111
|
+
output_size: int,
|
112
|
+
params_dtype: torch.dtype,
|
113
|
+
**extra_weight_attrs,
|
114
|
+
):
|
115
|
+
# maybe_create_device_identity()
|
116
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
117
|
+
del input_size, output_size
|
118
|
+
output_size_per_partition = sum(output_partition_sizes)
|
119
|
+
|
120
|
+
layer.logical_widths = output_partition_sizes
|
121
|
+
|
122
|
+
layer.input_size_per_partition = input_size_per_partition
|
123
|
+
layer.output_size_per_partition = output_size_per_partition
|
124
|
+
layer.orig_dtype = params_dtype
|
125
|
+
|
126
|
+
# WEIGHT
|
127
|
+
weight = ModelWeightParameter(
|
128
|
+
data=torch.empty(
|
129
|
+
output_size_per_partition,
|
130
|
+
input_size_per_partition,
|
131
|
+
dtype=torch.float8_e4m3fn,
|
132
|
+
),
|
133
|
+
input_dim=1,
|
134
|
+
output_dim=0,
|
135
|
+
weight_loader=weight_loader,
|
136
|
+
)
|
137
|
+
layer.register_parameter("weight", weight)
|
138
|
+
|
139
|
+
# WEIGHT SCALE
|
140
|
+
weight_scale = ChannelQuantScaleParameter(
|
141
|
+
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
142
|
+
output_dim=0,
|
143
|
+
weight_loader=weight_loader,
|
144
|
+
)
|
145
|
+
weight_scale[:] = torch.finfo(torch.float32).min
|
146
|
+
layer.register_parameter("weight_scale", weight_scale)
|
147
|
+
|
148
|
+
# INPUT SCALE UPPER BOUND
|
149
|
+
input_scale_ub = torch.nn.Parameter(
|
150
|
+
torch.tensor((self.quant_config.input_scale_ub), dtype=torch.float32),
|
151
|
+
requires_grad=False,
|
152
|
+
)
|
153
|
+
layer.input_scale_ub = input_scale_ub
|
154
|
+
|
155
|
+
def process_weights_after_loading(self, layer: Module) -> None:
|
156
|
+
# required by torch.compile
|
157
|
+
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
158
|
+
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
159
|
+
|
160
|
+
weight = layer.weight
|
161
|
+
|
162
|
+
if _is_fp8_fnuz:
|
163
|
+
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
164
|
+
weight=weight, weight_scale=layer.weight_scale, input_scale=None
|
165
|
+
)
|
166
|
+
if input_scale is not None:
|
167
|
+
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
168
|
+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
169
|
+
|
170
|
+
layer.weight = Parameter(weight.t(), requires_grad=False)
|
171
|
+
if self.quant_config.use_marlin:
|
172
|
+
prepare_fp8_layer_for_marlin(layer)
|
173
|
+
# Activations not quantized for marlin.
|
174
|
+
del layer.input_scale_ub
|
175
|
+
|
176
|
+
def apply(
|
177
|
+
self,
|
178
|
+
layer: torch.nn.Module,
|
179
|
+
x: torch.Tensor,
|
180
|
+
bias: Optional[torch.Tensor] = None,
|
181
|
+
) -> torch.Tensor:
|
182
|
+
|
183
|
+
if self.quant_config.use_marlin:
|
184
|
+
return apply_fp8_marlin_linear(
|
185
|
+
input=x,
|
186
|
+
weight=layer.weight,
|
187
|
+
weight_scale=layer.weight_scale,
|
188
|
+
workspace=layer.workspace,
|
189
|
+
size_n=layer.output_size_per_partition,
|
190
|
+
size_k=layer.input_size_per_partition,
|
191
|
+
bias=bias,
|
192
|
+
)
|
193
|
+
|
194
|
+
return apply_fp8_linear(
|
195
|
+
input=x,
|
196
|
+
weight=layer.weight,
|
197
|
+
weight_scale=layer.weight_scale,
|
198
|
+
input_scale=None,
|
199
|
+
input_scale_ub=layer.input_scale_ub,
|
200
|
+
bias=bias,
|
201
|
+
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
202
|
+
use_per_token_if_dynamic=False,
|
203
|
+
)
|
@@ -44,6 +44,7 @@ from sglang.srt.layers.quantization.utils import (
|
|
44
44
|
)
|
45
45
|
|
46
46
|
if TYPE_CHECKING:
|
47
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
47
48
|
from sglang.srt.layers.moe.topk import TopKOutput
|
48
49
|
|
49
50
|
from sglang.srt.utils import is_cuda
|
@@ -1056,13 +1057,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
1056
1057
|
layer: torch.nn.Module,
|
1057
1058
|
x: torch.Tensor,
|
1058
1059
|
topk_output: TopKOutput,
|
1059
|
-
|
1060
|
-
activation: str = "silu",
|
1061
|
-
**kwargs,
|
1060
|
+
moe_runner_config: MoeRunnerConfig,
|
1062
1061
|
) -> torch.Tensor:
|
1063
1062
|
# Delay the import to avoid circular dependency
|
1064
1063
|
|
1065
|
-
assert
|
1064
|
+
assert (
|
1065
|
+
moe_runner_config.activation == "silu"
|
1066
|
+
), "Only SiLU activation is supported."
|
1066
1067
|
|
1067
1068
|
# The input must currently be float16
|
1068
1069
|
orig_dtype = x.dtype
|
@@ -28,6 +28,7 @@ from sglang.srt.utils import get_device_capability, is_cuda
|
|
28
28
|
|
29
29
|
if TYPE_CHECKING:
|
30
30
|
from sglang.srt.layers.linear import LinearBase
|
31
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
31
32
|
|
32
33
|
try:
|
33
34
|
from vllm import _custom_ops as ops
|
@@ -216,13 +217,13 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
|
|
216
217
|
)[0]
|
217
218
|
|
218
219
|
|
219
|
-
def check_moe_marlin_supports_layer(layer:
|
220
|
+
def check_moe_marlin_supports_layer(layer: FusedMoE, group_size: int) -> bool:
|
220
221
|
hidden_size = layer.hidden_size
|
221
222
|
intermediate_size_per_partition = layer.intermediate_size_per_partition
|
222
223
|
# apply_router_weight_on_input is not supported for moe marlin
|
223
|
-
supports_router_weight = not layer.apply_router_weight_on_input
|
224
|
+
supports_router_weight = not layer.moe_runner_config.apply_router_weight_on_input
|
224
225
|
# moe marlin requires the activation to be silu
|
225
|
-
supports_activation = layer.activation == "silu"
|
226
|
+
supports_activation = layer.moe_runner_config.activation == "silu"
|
226
227
|
|
227
228
|
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
|
228
229
|
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
|
@@ -305,6 +306,13 @@ def marlin_permute_scales(
|
|
305
306
|
return s
|
306
307
|
|
307
308
|
|
309
|
+
def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor:
|
310
|
+
origin_shape = s.shape
|
311
|
+
_, scale_perm_single = get_scale_perms()
|
312
|
+
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
313
|
+
return s.reshape(*origin_shape).contiguous()
|
314
|
+
|
315
|
+
|
308
316
|
def marlin_moe_permute_scales(
|
309
317
|
s: torch.Tensor,
|
310
318
|
size_k: int,
|