sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +23 -2
- sglang/bench_serving.py +6 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +80 -11
- sglang/srt/disaggregation/mini_lb.py +58 -123
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +585 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
- sglang/srt/disaggregation/prefill.py +82 -22
- sglang/srt/disaggregation/utils.py +46 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +42 -13
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +430 -257
- sglang/srt/layers/attention/flashinfer_backend.py +18 -9
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +18 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +63 -45
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +131 -136
- sglang/srt/layers/quantization/fp8_kernel.py +328 -46
- sglang/srt/layers/quantization/fp8_utils.py +206 -253
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
- sglang/srt/layers/quantization/w8a8_int8.py +8 -7
- sglang/srt/layers/radix_attention.py +28 -1
- sglang/srt/layers/rotary_embedding.py +15 -3
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +255 -97
- sglang/srt/managers/mm_utils.py +7 -5
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +64 -25
- sglang/srt/managers/scheduler.py +80 -82
- sglang/srt/managers/tokenizer_manager.py +18 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +21 -3
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +67 -35
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +494 -366
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +6 -5
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +30 -200
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +5 -1
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +15 -13
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +55 -19
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +10 -9
- sglang/srt/utils.py +136 -10
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +224 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/disaggregation/conn.py +0 -81
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ import functools
|
|
16
16
|
import json
|
17
17
|
import logging
|
18
18
|
import os
|
19
|
+
from contextlib import contextmanager
|
19
20
|
from typing import Any, Dict, List, Optional, Tuple
|
20
21
|
|
21
22
|
import torch
|
@@ -33,20 +34,31 @@ from sglang.srt.utils import (
|
|
33
34
|
supports_custom_op,
|
34
35
|
)
|
35
36
|
|
36
|
-
_enable_jit_deepgemm = False
|
37
|
-
|
38
37
|
_is_hip = is_hip()
|
39
|
-
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
40
|
-
|
41
38
|
_is_cuda = is_cuda()
|
39
|
+
_fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
40
|
+
if _is_hip:
|
41
|
+
fp8_max = 224.0
|
42
|
+
else:
|
43
|
+
fp8_max = torch.finfo(_fp8_type).max
|
44
|
+
fp8_min = -fp8_max
|
45
|
+
|
46
|
+
_enable_jit_deepgemm = False
|
47
|
+
_enable_jit_deepgemm_bmm = False
|
42
48
|
if _is_cuda:
|
43
|
-
import deep_gemm
|
44
|
-
from sgl_kernel import
|
49
|
+
import deep_gemm
|
50
|
+
from sgl_kernel import (
|
51
|
+
sgl_per_tensor_quant_fp8,
|
52
|
+
sgl_per_token_group_quant_fp8,
|
53
|
+
sgl_per_token_quant_fp8,
|
54
|
+
)
|
45
55
|
|
46
56
|
sm_version = get_device_sm()
|
47
|
-
if sm_version
|
48
|
-
|
49
|
-
|
57
|
+
if sm_version == 90:
|
58
|
+
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
|
59
|
+
_enable_jit_deepgemm = True
|
60
|
+
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM_BMM", default="false"):
|
61
|
+
_enable_jit_deepgemm_bmm = True
|
50
62
|
|
51
63
|
logger = logging.getLogger(__name__)
|
52
64
|
|
@@ -59,7 +71,10 @@ if supports_custom_op():
|
|
59
71
|
Bs: torch.Tensor,
|
60
72
|
C: torch.Tensor,
|
61
73
|
) -> None:
|
62
|
-
|
74
|
+
M, K = A.shape
|
75
|
+
N, _ = B.shape
|
76
|
+
with _log_jit_build(M, N, K):
|
77
|
+
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
63
78
|
|
64
79
|
def deep_gemm_fp8_fp8_bf16_nt_fake(
|
65
80
|
A: torch.Tensor,
|
@@ -173,7 +188,6 @@ def per_token_group_quant_fp8(
|
|
173
188
|
x: torch.Tensor,
|
174
189
|
group_size: int,
|
175
190
|
eps: float = 1e-10,
|
176
|
-
dtype: torch.dtype = fp8_type_,
|
177
191
|
column_major_scales: bool = False,
|
178
192
|
scale_tma_aligned: bool = False,
|
179
193
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
@@ -186,7 +200,6 @@ def per_token_group_quant_fp8(
|
|
186
200
|
x: The input tenosr with ndim >= 2.
|
187
201
|
group_size: The group size used for quantization.
|
188
202
|
eps: The minimum to avoid dividing zero.
|
189
|
-
dtype: The dype of output tensor.
|
190
203
|
|
191
204
|
Returns:
|
192
205
|
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
@@ -196,15 +209,7 @@ def per_token_group_quant_fp8(
|
|
196
209
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
197
210
|
assert x.is_contiguous(), "`x` is not contiguous"
|
198
211
|
|
199
|
-
|
200
|
-
fp8_max = finfo.max
|
201
|
-
|
202
|
-
if _is_hip:
|
203
|
-
fp8_max = 224.0
|
204
|
-
|
205
|
-
fp8_min = -fp8_max
|
206
|
-
|
207
|
-
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
212
|
+
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
208
213
|
M = x.numel() // group_size
|
209
214
|
N = group_size
|
210
215
|
if column_major_scales:
|
@@ -270,26 +275,36 @@ def sglang_per_token_group_quant_fp8(
|
|
270
275
|
x: torch.Tensor,
|
271
276
|
group_size: int,
|
272
277
|
eps: float = 1e-10,
|
273
|
-
|
278
|
+
column_major_scales: bool = False,
|
279
|
+
scale_tma_aligned: bool = False,
|
274
280
|
):
|
275
281
|
assert (
|
276
282
|
x.shape[-1] % group_size == 0
|
277
283
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
278
284
|
assert x.is_contiguous(), "`x` is not contiguous"
|
279
285
|
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
286
|
+
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
287
|
+
if column_major_scales:
|
288
|
+
if scale_tma_aligned:
|
289
|
+
# aligned to 4 * sizeof(float)
|
290
|
+
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
291
|
+
x_s = torch.empty(
|
292
|
+
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
|
293
|
+
device=x.device,
|
294
|
+
dtype=torch.float32,
|
295
|
+
).permute(-1, -2)[: x.shape[-2], :]
|
296
|
+
else:
|
297
|
+
x_s = torch.empty(
|
298
|
+
(x.shape[-1] // group_size,) + x.shape[:-1],
|
299
|
+
device=x.device,
|
300
|
+
dtype=torch.float32,
|
301
|
+
).permute(-1, -2)
|
302
|
+
else:
|
303
|
+
x_s = torch.empty(
|
304
|
+
x.shape[:-1] + (x.shape[-1] // group_size,),
|
305
|
+
device=x.device,
|
306
|
+
dtype=torch.float32,
|
307
|
+
)
|
293
308
|
|
294
309
|
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
295
310
|
|
@@ -298,7 +313,7 @@ def sglang_per_token_group_quant_fp8(
|
|
298
313
|
|
299
314
|
def sglang_per_token_quant_fp8(
|
300
315
|
x: torch.Tensor,
|
301
|
-
dtype: torch.dtype =
|
316
|
+
dtype: torch.dtype = _fp8_type,
|
302
317
|
):
|
303
318
|
assert x.is_contiguous(), "`x` is not contiguous"
|
304
319
|
|
@@ -362,7 +377,6 @@ def static_quant_fp8(
|
|
362
377
|
x: torch.Tensor,
|
363
378
|
x_s: torch.Tensor,
|
364
379
|
repeat_scale: bool = False,
|
365
|
-
dtype: torch.dtype = fp8_type_,
|
366
380
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
367
381
|
"""Function to perform static quantization using the given scale on an input tensor `x`.
|
368
382
|
|
@@ -380,15 +394,8 @@ def static_quant_fp8(
|
|
380
394
|
"""
|
381
395
|
assert x.is_contiguous(), "`x` is not contiguous"
|
382
396
|
assert x_s.numel() == 1, "only supports per-tensor scale"
|
383
|
-
finfo = torch.finfo(dtype)
|
384
|
-
fp8_max = finfo.max
|
385
397
|
|
386
|
-
|
387
|
-
fp8_max = 224.0
|
388
|
-
|
389
|
-
fp8_min = -fp8_max
|
390
|
-
|
391
|
-
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
398
|
+
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
392
399
|
M = x.numel() // x.shape[-1]
|
393
400
|
N = x.shape[-1]
|
394
401
|
if repeat_scale:
|
@@ -708,6 +715,25 @@ def get_w8a8_block_fp8_configs(
|
|
708
715
|
return None
|
709
716
|
|
710
717
|
|
718
|
+
@contextmanager
|
719
|
+
def _log_jit_build(M: int, N: int, K: int):
|
720
|
+
from deep_gemm.jit.runtime import RuntimeCache
|
721
|
+
|
722
|
+
origin_func = RuntimeCache.__getitem__
|
723
|
+
|
724
|
+
def __patched_func(self, *args, **kwargs):
|
725
|
+
ret = origin_func(self, *args, **kwargs)
|
726
|
+
if ret is None:
|
727
|
+
logger.warning(
|
728
|
+
f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
|
729
|
+
)
|
730
|
+
return ret
|
731
|
+
|
732
|
+
RuntimeCache.__getitem__ = __patched_func
|
733
|
+
yield
|
734
|
+
RuntimeCache.__getitem__ = origin_func
|
735
|
+
|
736
|
+
|
711
737
|
def w8a8_block_fp8_matmul(
|
712
738
|
A: torch.Tensor,
|
713
739
|
B: torch.Tensor,
|
@@ -782,7 +808,8 @@ def w8a8_block_fp8_matmul(
|
|
782
808
|
if supports_custom_op():
|
783
809
|
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
784
810
|
else:
|
785
|
-
|
811
|
+
with _log_jit_build(M, N, K):
|
812
|
+
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
786
813
|
else:
|
787
814
|
kernel = (
|
788
815
|
_w8a8_block_fp8_matmul_unrolledx4
|
@@ -815,3 +842,258 @@ def w8a8_block_fp8_matmul(
|
|
815
842
|
)
|
816
843
|
|
817
844
|
return C
|
845
|
+
|
846
|
+
|
847
|
+
@triton.jit
|
848
|
+
def _per_tensor_quant_mla_fp8_stage1(
|
849
|
+
x_ptr,
|
850
|
+
x_s_ptr,
|
851
|
+
head_size,
|
852
|
+
x_stride_h,
|
853
|
+
x_stride_s,
|
854
|
+
eps,
|
855
|
+
fp8_max,
|
856
|
+
BLOCK_SIZE: tl.constexpr,
|
857
|
+
):
|
858
|
+
seq_id = tl.program_id(0)
|
859
|
+
head_id = tl.program_id(1)
|
860
|
+
offset = tl.arange(0, BLOCK_SIZE)
|
861
|
+
mask = offset < head_size
|
862
|
+
|
863
|
+
x_ptr += head_id * x_stride_h + seq_id * x_stride_s
|
864
|
+
x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32)
|
865
|
+
_absmax = tl.maximum(tl.max(tl.abs(x)), eps)
|
866
|
+
|
867
|
+
tl.atomic_max(x_s_ptr, _absmax / fp8_max)
|
868
|
+
|
869
|
+
|
870
|
+
@triton.jit
|
871
|
+
def _per_tensor_quant_mla_fp8_stage2(
|
872
|
+
x_ptr,
|
873
|
+
x_s_ptr,
|
874
|
+
x_q_ptr,
|
875
|
+
num_seq,
|
876
|
+
head_size,
|
877
|
+
x_stride_h,
|
878
|
+
x_stride_s,
|
879
|
+
fp8_min,
|
880
|
+
fp8_max,
|
881
|
+
BLOCK_SIZE: tl.constexpr,
|
882
|
+
):
|
883
|
+
seq_id = tl.program_id(0)
|
884
|
+
head_id = tl.program_id(1)
|
885
|
+
offset = tl.arange(0, BLOCK_SIZE)
|
886
|
+
mask = offset < head_size
|
887
|
+
|
888
|
+
x_s = tl.load(x_s_ptr)
|
889
|
+
x_s_inv = 1.0 / x_s
|
890
|
+
|
891
|
+
x_ptr += head_id * x_stride_h + seq_id * x_stride_s
|
892
|
+
x_q_ptr += head_id * num_seq * head_size + seq_id * head_size
|
893
|
+
|
894
|
+
x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32)
|
895
|
+
x_q = tl.clamp(x * x_s_inv, fp8_min, fp8_max).to(x_q_ptr.dtype.element_ty)
|
896
|
+
tl.store(x_q_ptr + offset, x_q, mask=mask)
|
897
|
+
|
898
|
+
|
899
|
+
def per_tensor_quant_mla_fp8(
|
900
|
+
x: torch.Tensor, x_s_out: torch.Tensor, eps: float = 1e-12
|
901
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
902
|
+
"""
|
903
|
+
This function quantizes input values to float8 values with tensor-wise quantization
|
904
|
+
and specialized for mla absorbed case.
|
905
|
+
"""
|
906
|
+
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
907
|
+
assert (
|
908
|
+
x_s_out.shape == (1,)
|
909
|
+
and x_s_out.dtype == torch.float32
|
910
|
+
and x_s_out.device == x.device
|
911
|
+
)
|
912
|
+
|
913
|
+
x_q = x.new_empty(x.size(), dtype=_fp8_type)
|
914
|
+
|
915
|
+
num_head, num_seq, head_size = x.shape
|
916
|
+
BLOCK_SIZE = triton.next_power_of_2(head_size)
|
917
|
+
grid = (num_seq, num_head)
|
918
|
+
|
919
|
+
_per_tensor_quant_mla_fp8_stage1[grid](
|
920
|
+
x,
|
921
|
+
x_s_out,
|
922
|
+
head_size,
|
923
|
+
x.stride(0),
|
924
|
+
x.stride(1),
|
925
|
+
eps,
|
926
|
+
fp8_max,
|
927
|
+
BLOCK_SIZE,
|
928
|
+
)
|
929
|
+
_per_tensor_quant_mla_fp8_stage2[grid](
|
930
|
+
x,
|
931
|
+
x_s_out,
|
932
|
+
x_q,
|
933
|
+
num_seq,
|
934
|
+
head_size,
|
935
|
+
x.stride(0),
|
936
|
+
x.stride(1),
|
937
|
+
fp8_min,
|
938
|
+
fp8_max,
|
939
|
+
BLOCK_SIZE,
|
940
|
+
)
|
941
|
+
|
942
|
+
return x_q, x_s_out
|
943
|
+
|
944
|
+
|
945
|
+
@triton.jit
|
946
|
+
def _per_token_group_quant_mla_deep_gemm_masked_fp8(
|
947
|
+
y_ptr,
|
948
|
+
y_q_ptr,
|
949
|
+
y_s_ptr,
|
950
|
+
masked_m_ptr,
|
951
|
+
group_size,
|
952
|
+
y_stride_b,
|
953
|
+
y_stride_t,
|
954
|
+
y_q_stride_b,
|
955
|
+
y_q_stride_t,
|
956
|
+
y_s_stride_b,
|
957
|
+
y_s_stride_g,
|
958
|
+
eps,
|
959
|
+
fp8_min,
|
960
|
+
fp8_max,
|
961
|
+
NUM_GROUP: tl.constexpr,
|
962
|
+
BLOCK: tl.constexpr,
|
963
|
+
):
|
964
|
+
"""A Triton-accelerated function to perform per-token-group
|
965
|
+
quantization on a tensor for deep_gemm grouped_gemm_masked.
|
966
|
+
This function converts the tensor values into float8 values.
|
967
|
+
y and y_q: (b, t, k)
|
968
|
+
y_s: (b, k//group_size, t)
|
969
|
+
"""
|
970
|
+
t_id = tl.program_id(0)
|
971
|
+
b_id = tl.program_id(1)
|
972
|
+
|
973
|
+
y_ptr += b_id * y_stride_b + t_id * y_stride_t
|
974
|
+
y_q_ptr += b_id * y_q_stride_b + t_id * y_q_stride_t
|
975
|
+
y_s_ptr += b_id * y_s_stride_b + t_id
|
976
|
+
|
977
|
+
if t_id == 0:
|
978
|
+
tl.store(masked_m_ptr + b_id, tl.num_programs(0))
|
979
|
+
|
980
|
+
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
|
981
|
+
mask = cols < group_size
|
982
|
+
|
983
|
+
for gid in range(NUM_GROUP):
|
984
|
+
y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to(
|
985
|
+
tl.float32
|
986
|
+
)
|
987
|
+
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
988
|
+
y_s = _absmax / fp8_max
|
989
|
+
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
990
|
+
|
991
|
+
tl.store(y_q_ptr + gid * group_size + cols, y_q, mask=mask)
|
992
|
+
tl.store(y_s_ptr + gid * y_s_stride_g, y_s)
|
993
|
+
|
994
|
+
|
995
|
+
def per_tensor_quant_mla_deep_gemm_masked_fp8(
|
996
|
+
x: torch.Tensor,
|
997
|
+
group_size: int = 128,
|
998
|
+
eps: float = 1e-12,
|
999
|
+
dtype: torch.dtype = torch.float8_e4m3fn,
|
1000
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1001
|
+
"""
|
1002
|
+
This function quantizes input values to float8 values with per-token-group-quantization
|
1003
|
+
for deep_gemm grouped_gemm_masked and specialized for mla absorbed case.
|
1004
|
+
"""
|
1005
|
+
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
1006
|
+
|
1007
|
+
finfo = torch.finfo(dtype)
|
1008
|
+
fp8_max = finfo.max
|
1009
|
+
if _is_hip:
|
1010
|
+
dtype = torch.float8_e4m3fnuz
|
1011
|
+
fp8_max = 224.0
|
1012
|
+
|
1013
|
+
b, m, k = x.shape
|
1014
|
+
aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel
|
1015
|
+
num_tiles_k = k // group_size
|
1016
|
+
assert num_tiles_k * group_size == k, f"k % {group_size} must be zero"
|
1017
|
+
|
1018
|
+
x_q = x.new_empty((b, aligned_m, k), dtype=dtype)
|
1019
|
+
x_s = x.new_empty((b, num_tiles_k, aligned_m), dtype=torch.float32)
|
1020
|
+
masked_m = x.new_empty((b,), dtype=torch.int32)
|
1021
|
+
|
1022
|
+
BLOCK_SIZE = triton.next_power_of_2(group_size)
|
1023
|
+
grid = (m, b)
|
1024
|
+
|
1025
|
+
_per_token_group_quant_mla_deep_gemm_masked_fp8[grid](
|
1026
|
+
x,
|
1027
|
+
x_q,
|
1028
|
+
x_s,
|
1029
|
+
masked_m,
|
1030
|
+
group_size,
|
1031
|
+
x.stride(0),
|
1032
|
+
x.stride(1),
|
1033
|
+
x_q.stride(0),
|
1034
|
+
x_q.stride(1),
|
1035
|
+
x_s.stride(0),
|
1036
|
+
x_s.stride(1),
|
1037
|
+
eps,
|
1038
|
+
-fp8_max,
|
1039
|
+
fp8_max,
|
1040
|
+
num_tiles_k,
|
1041
|
+
BLOCK_SIZE,
|
1042
|
+
)
|
1043
|
+
|
1044
|
+
return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m
|
1045
|
+
|
1046
|
+
|
1047
|
+
def scaled_fp8_quant(
|
1048
|
+
input: torch.Tensor,
|
1049
|
+
scale: Optional[torch.Tensor] = None,
|
1050
|
+
num_token_padding: Optional[int] = None,
|
1051
|
+
use_per_token_if_dynamic: bool = False,
|
1052
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
1053
|
+
"""
|
1054
|
+
Quantize input tensor to FP8 (8-bit floating point) format.
|
1055
|
+
|
1056
|
+
Args:
|
1057
|
+
input (torch.Tensor): Input tensor to be quantized
|
1058
|
+
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
1059
|
+
If None, scales will be computed dynamically.
|
1060
|
+
num_token_padding (Optional[int]): If specified, pad the first dimension
|
1061
|
+
of the output to at least this value.
|
1062
|
+
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
1063
|
+
determines the quantization granularity:
|
1064
|
+
- True: compute scale per token
|
1065
|
+
- False: compute single scale per tensor
|
1066
|
+
|
1067
|
+
Returns:
|
1068
|
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
1069
|
+
- quantized_tensor: The FP8 quantized version of input
|
1070
|
+
- scale_tensor: The scaling factors used for quantization
|
1071
|
+
|
1072
|
+
Raises:
|
1073
|
+
AssertionError: If input is not 2D or if static scale's numel != 1
|
1074
|
+
"""
|
1075
|
+
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
1076
|
+
shape = input.shape
|
1077
|
+
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
1078
|
+
if num_token_padding:
|
1079
|
+
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
1080
|
+
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
1081
|
+
|
1082
|
+
if scale is None:
|
1083
|
+
# Dynamic scaling
|
1084
|
+
if use_per_token_if_dynamic:
|
1085
|
+
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
|
1086
|
+
sgl_per_token_quant_fp8(input, output, scale)
|
1087
|
+
else:
|
1088
|
+
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
1089
|
+
sgl_per_tensor_quant_fp8(
|
1090
|
+
input, output, scale, is_static=False
|
1091
|
+
) # False for dynamic
|
1092
|
+
else:
|
1093
|
+
# Static scaling
|
1094
|
+
assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}"
|
1095
|
+
sgl_per_tensor_quant_fp8(
|
1096
|
+
input, output, scale, is_static=True
|
1097
|
+
) # True for static
|
1098
|
+
|
1099
|
+
return output, scale
|