sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +3 -6
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +6 -2
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +105 -6
- sglang/srt/disaggregation/mini_lb.py +74 -9
- sglang/srt/disaggregation/mooncake/conn.py +33 -63
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +137 -17
- sglang/srt/disaggregation/utils.py +32 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +883 -209
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +20 -5
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +17 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -88
- sglang/srt/layers/quantization/fp8_utils.py +189 -264
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +9 -8
- sglang/srt/layers/sampler.py +7 -12
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +15 -4
- sglang/srt/managers/scheduler.py +28 -77
- sglang/srt/managers/tokenizer_manager.py +116 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +41 -29
- sglang/srt/mem_cache/memory_pool.py +38 -15
- sglang/srt/model_executor/cuda_graph_runner.py +15 -10
- sglang/srt/model_executor/model_runner.py +39 -31
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +292 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +31 -203
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +86 -72
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +6 -14
- sglang/srt/utils.py +62 -6
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -16,40 +16,41 @@ import functools
|
|
16
16
|
import json
|
17
17
|
import logging
|
18
18
|
import os
|
19
|
-
from contextlib import contextmanager
|
20
19
|
from typing import Any, Dict, List, Optional, Tuple
|
21
20
|
|
22
21
|
import torch
|
23
22
|
import triton
|
24
23
|
import triton.language as tl
|
25
24
|
|
25
|
+
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
26
26
|
from sglang.srt.utils import (
|
27
27
|
direct_register_custom_op,
|
28
|
-
get_bool_env_var,
|
29
28
|
get_device_core_count,
|
30
29
|
get_device_name,
|
31
|
-
get_device_sm,
|
32
30
|
is_cuda,
|
33
31
|
is_hip,
|
34
32
|
supports_custom_op,
|
35
33
|
)
|
36
34
|
|
37
|
-
_enable_jit_deepgemm = False
|
38
|
-
|
39
35
|
_is_hip = is_hip()
|
40
|
-
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
41
|
-
|
42
36
|
_is_cuda = is_cuda()
|
43
|
-
if
|
44
|
-
|
45
|
-
|
37
|
+
_fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
38
|
+
if _is_hip:
|
39
|
+
fp8_max = 224.0
|
40
|
+
else:
|
41
|
+
fp8_max = torch.finfo(_fp8_type).max
|
42
|
+
fp8_min = -fp8_max
|
46
43
|
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
44
|
+
if _is_cuda:
|
45
|
+
from sgl_kernel import (
|
46
|
+
sgl_per_tensor_quant_fp8,
|
47
|
+
sgl_per_token_group_quant_fp8,
|
48
|
+
sgl_per_token_quant_fp8,
|
49
|
+
)
|
52
50
|
|
51
|
+
from sglang.srt.layers.quantization.deep_gemm import (
|
52
|
+
gemm_nt_f8f8bf16 as deep_gemm_gemm_nt_f8f8bf16,
|
53
|
+
)
|
53
54
|
|
54
55
|
logger = logging.getLogger(__name__)
|
55
56
|
|
@@ -62,10 +63,7 @@ if supports_custom_op():
|
|
62
63
|
Bs: torch.Tensor,
|
63
64
|
C: torch.Tensor,
|
64
65
|
) -> None:
|
65
|
-
|
66
|
-
N, _ = B.shape
|
67
|
-
with _log_jit_build(M, N, K):
|
68
|
-
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
66
|
+
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
69
67
|
|
70
68
|
def deep_gemm_fp8_fp8_bf16_nt_fake(
|
71
69
|
A: torch.Tensor,
|
@@ -179,7 +177,6 @@ def per_token_group_quant_fp8(
|
|
179
177
|
x: torch.Tensor,
|
180
178
|
group_size: int,
|
181
179
|
eps: float = 1e-10,
|
182
|
-
dtype: torch.dtype = fp8_type_,
|
183
180
|
column_major_scales: bool = False,
|
184
181
|
scale_tma_aligned: bool = False,
|
185
182
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
@@ -192,7 +189,6 @@ def per_token_group_quant_fp8(
|
|
192
189
|
x: The input tenosr with ndim >= 2.
|
193
190
|
group_size: The group size used for quantization.
|
194
191
|
eps: The minimum to avoid dividing zero.
|
195
|
-
dtype: The dype of output tensor.
|
196
192
|
|
197
193
|
Returns:
|
198
194
|
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
@@ -202,15 +198,7 @@ def per_token_group_quant_fp8(
|
|
202
198
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
203
199
|
assert x.is_contiguous(), "`x` is not contiguous"
|
204
200
|
|
205
|
-
|
206
|
-
fp8_max = finfo.max
|
207
|
-
|
208
|
-
if _is_hip:
|
209
|
-
fp8_max = 224.0
|
210
|
-
|
211
|
-
fp8_min = -fp8_max
|
212
|
-
|
213
|
-
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
201
|
+
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
214
202
|
M = x.numel() // group_size
|
215
203
|
N = group_size
|
216
204
|
if column_major_scales:
|
@@ -276,26 +264,36 @@ def sglang_per_token_group_quant_fp8(
|
|
276
264
|
x: torch.Tensor,
|
277
265
|
group_size: int,
|
278
266
|
eps: float = 1e-10,
|
279
|
-
|
267
|
+
column_major_scales: bool = False,
|
268
|
+
scale_tma_aligned: bool = False,
|
280
269
|
):
|
281
270
|
assert (
|
282
271
|
x.shape[-1] % group_size == 0
|
283
272
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
284
273
|
assert x.is_contiguous(), "`x` is not contiguous"
|
285
274
|
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
275
|
+
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
276
|
+
if column_major_scales:
|
277
|
+
if scale_tma_aligned:
|
278
|
+
# aligned to 4 * sizeof(float)
|
279
|
+
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
280
|
+
x_s = torch.empty(
|
281
|
+
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
|
282
|
+
device=x.device,
|
283
|
+
dtype=torch.float32,
|
284
|
+
).permute(-1, -2)[: x.shape[-2], :]
|
285
|
+
else:
|
286
|
+
x_s = torch.empty(
|
287
|
+
(x.shape[-1] // group_size,) + x.shape[:-1],
|
288
|
+
device=x.device,
|
289
|
+
dtype=torch.float32,
|
290
|
+
).permute(-1, -2)
|
291
|
+
else:
|
292
|
+
x_s = torch.empty(
|
293
|
+
x.shape[:-1] + (x.shape[-1] // group_size,),
|
294
|
+
device=x.device,
|
295
|
+
dtype=torch.float32,
|
296
|
+
)
|
299
297
|
|
300
298
|
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
301
299
|
|
@@ -304,7 +302,7 @@ def sglang_per_token_group_quant_fp8(
|
|
304
302
|
|
305
303
|
def sglang_per_token_quant_fp8(
|
306
304
|
x: torch.Tensor,
|
307
|
-
dtype: torch.dtype =
|
305
|
+
dtype: torch.dtype = _fp8_type,
|
308
306
|
):
|
309
307
|
assert x.is_contiguous(), "`x` is not contiguous"
|
310
308
|
|
@@ -368,7 +366,6 @@ def static_quant_fp8(
|
|
368
366
|
x: torch.Tensor,
|
369
367
|
x_s: torch.Tensor,
|
370
368
|
repeat_scale: bool = False,
|
371
|
-
dtype: torch.dtype = fp8_type_,
|
372
369
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
373
370
|
"""Function to perform static quantization using the given scale on an input tensor `x`.
|
374
371
|
|
@@ -386,15 +383,8 @@ def static_quant_fp8(
|
|
386
383
|
"""
|
387
384
|
assert x.is_contiguous(), "`x` is not contiguous"
|
388
385
|
assert x_s.numel() == 1, "only supports per-tensor scale"
|
389
|
-
finfo = torch.finfo(dtype)
|
390
|
-
fp8_max = finfo.max
|
391
386
|
|
392
|
-
|
393
|
-
fp8_max = 224.0
|
394
|
-
|
395
|
-
fp8_min = -fp8_max
|
396
|
-
|
397
|
-
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
387
|
+
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
398
388
|
M = x.numel() // x.shape[-1]
|
399
389
|
N = x.shape[-1]
|
400
390
|
if repeat_scale:
|
@@ -714,25 +704,6 @@ def get_w8a8_block_fp8_configs(
|
|
714
704
|
return None
|
715
705
|
|
716
706
|
|
717
|
-
@contextmanager
|
718
|
-
def _log_jit_build(M: int, N: int, K: int):
|
719
|
-
from deep_gemm.jit.runtime import RuntimeCache
|
720
|
-
|
721
|
-
origin_func = RuntimeCache.__getitem__
|
722
|
-
|
723
|
-
def __patched_func(self, *args, **kwargs):
|
724
|
-
ret = origin_func(self, *args, **kwargs)
|
725
|
-
if ret is None:
|
726
|
-
logger.warning(
|
727
|
-
f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
|
728
|
-
)
|
729
|
-
return ret
|
730
|
-
|
731
|
-
RuntimeCache.__getitem__ = __patched_func
|
732
|
-
yield
|
733
|
-
RuntimeCache.__getitem__ = origin_func
|
734
|
-
|
735
|
-
|
736
707
|
def w8a8_block_fp8_matmul(
|
737
708
|
A: torch.Tensor,
|
738
709
|
B: torch.Tensor,
|
@@ -803,12 +774,11 @@ def w8a8_block_fp8_matmul(
|
|
803
774
|
)
|
804
775
|
|
805
776
|
# deepgemm only support bf16
|
806
|
-
if C.dtype == torch.bfloat16 and
|
777
|
+
if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
|
807
778
|
if supports_custom_op():
|
808
779
|
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
809
780
|
else:
|
810
|
-
|
811
|
-
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
781
|
+
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
812
782
|
else:
|
813
783
|
kernel = (
|
814
784
|
_w8a8_block_fp8_matmul_unrolledx4
|
@@ -896,22 +866,20 @@ def _per_tensor_quant_mla_fp8_stage2(
|
|
896
866
|
|
897
867
|
|
898
868
|
def per_tensor_quant_mla_fp8(
|
899
|
-
x: torch.Tensor, eps: float = 1e-12
|
869
|
+
x: torch.Tensor, x_s_out: torch.Tensor, eps: float = 1e-12
|
900
870
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
901
871
|
"""
|
902
872
|
This function quantizes input values to float8 values with tensor-wise quantization
|
903
873
|
and specialized for mla absorbed case.
|
904
874
|
"""
|
905
875
|
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
876
|
+
assert (
|
877
|
+
x_s_out.shape == (1,)
|
878
|
+
and x_s_out.dtype == torch.float32
|
879
|
+
and x_s_out.device == x.device
|
880
|
+
)
|
906
881
|
|
907
|
-
|
908
|
-
fp8_max = finfo.max
|
909
|
-
if _is_hip:
|
910
|
-
dtype = torch.float8_e4m3fnuz
|
911
|
-
fp8_max = 224.0
|
912
|
-
|
913
|
-
x_q = x.new_empty(x.size(), dtype=dtype)
|
914
|
-
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
|
882
|
+
x_q = x.new_empty(x.size(), dtype=_fp8_type)
|
915
883
|
|
916
884
|
num_head, num_seq, head_size = x.shape
|
917
885
|
BLOCK_SIZE = triton.next_power_of_2(head_size)
|
@@ -919,7 +887,7 @@ def per_tensor_quant_mla_fp8(
|
|
919
887
|
|
920
888
|
_per_tensor_quant_mla_fp8_stage1[grid](
|
921
889
|
x,
|
922
|
-
|
890
|
+
x_s_out,
|
923
891
|
head_size,
|
924
892
|
x.stride(0),
|
925
893
|
x.stride(1),
|
@@ -929,15 +897,172 @@ def per_tensor_quant_mla_fp8(
|
|
929
897
|
)
|
930
898
|
_per_tensor_quant_mla_fp8_stage2[grid](
|
931
899
|
x,
|
932
|
-
|
900
|
+
x_s_out,
|
933
901
|
x_q,
|
934
902
|
num_seq,
|
935
903
|
head_size,
|
936
904
|
x.stride(0),
|
937
905
|
x.stride(1),
|
906
|
+
fp8_min,
|
907
|
+
fp8_max,
|
908
|
+
BLOCK_SIZE,
|
909
|
+
)
|
910
|
+
|
911
|
+
return x_q, x_s_out
|
912
|
+
|
913
|
+
|
914
|
+
@triton.jit
|
915
|
+
def _per_token_group_quant_mla_deep_gemm_masked_fp8(
|
916
|
+
y_ptr,
|
917
|
+
y_q_ptr,
|
918
|
+
y_s_ptr,
|
919
|
+
masked_m_ptr,
|
920
|
+
group_size,
|
921
|
+
y_stride_b,
|
922
|
+
y_stride_t,
|
923
|
+
y_q_stride_b,
|
924
|
+
y_q_stride_t,
|
925
|
+
y_s_stride_b,
|
926
|
+
y_s_stride_g,
|
927
|
+
eps,
|
928
|
+
fp8_min,
|
929
|
+
fp8_max,
|
930
|
+
NUM_GROUP: tl.constexpr,
|
931
|
+
BLOCK: tl.constexpr,
|
932
|
+
):
|
933
|
+
"""A Triton-accelerated function to perform per-token-group
|
934
|
+
quantization on a tensor for deep_gemm grouped_gemm_masked.
|
935
|
+
This function converts the tensor values into float8 values.
|
936
|
+
y and y_q: (b, t, k)
|
937
|
+
y_s: (b, k//group_size, t)
|
938
|
+
"""
|
939
|
+
t_id = tl.program_id(0)
|
940
|
+
b_id = tl.program_id(1)
|
941
|
+
|
942
|
+
y_ptr += b_id * y_stride_b + t_id * y_stride_t
|
943
|
+
y_q_ptr += b_id * y_q_stride_b + t_id * y_q_stride_t
|
944
|
+
y_s_ptr += b_id * y_s_stride_b + t_id
|
945
|
+
|
946
|
+
if t_id == 0:
|
947
|
+
tl.store(masked_m_ptr + b_id, tl.num_programs(0))
|
948
|
+
|
949
|
+
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
|
950
|
+
mask = cols < group_size
|
951
|
+
|
952
|
+
for gid in range(NUM_GROUP):
|
953
|
+
y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to(
|
954
|
+
tl.float32
|
955
|
+
)
|
956
|
+
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
957
|
+
y_s = _absmax / fp8_max
|
958
|
+
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
959
|
+
|
960
|
+
tl.store(y_q_ptr + gid * group_size + cols, y_q, mask=mask)
|
961
|
+
tl.store(y_s_ptr + gid * y_s_stride_g, y_s)
|
962
|
+
|
963
|
+
|
964
|
+
def per_tensor_quant_mla_deep_gemm_masked_fp8(
|
965
|
+
x: torch.Tensor,
|
966
|
+
group_size: int = 128,
|
967
|
+
eps: float = 1e-12,
|
968
|
+
dtype: torch.dtype = torch.float8_e4m3fn,
|
969
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
970
|
+
"""
|
971
|
+
This function quantizes input values to float8 values with per-token-group-quantization
|
972
|
+
for deep_gemm grouped_gemm_masked and specialized for mla absorbed case.
|
973
|
+
"""
|
974
|
+
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
975
|
+
|
976
|
+
finfo = torch.finfo(dtype)
|
977
|
+
fp8_max = finfo.max
|
978
|
+
if _is_hip:
|
979
|
+
dtype = torch.float8_e4m3fnuz
|
980
|
+
fp8_max = 224.0
|
981
|
+
|
982
|
+
b, m, k = x.shape
|
983
|
+
aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel
|
984
|
+
num_tiles_k = k // group_size
|
985
|
+
assert num_tiles_k * group_size == k, f"k % {group_size} must be zero"
|
986
|
+
|
987
|
+
x_q = x.new_empty((b, aligned_m, k), dtype=dtype)
|
988
|
+
x_s = x.new_empty((b, num_tiles_k, aligned_m), dtype=torch.float32)
|
989
|
+
masked_m = x.new_empty((b,), dtype=torch.int32)
|
990
|
+
|
991
|
+
BLOCK_SIZE = triton.next_power_of_2(group_size)
|
992
|
+
grid = (m, b)
|
993
|
+
|
994
|
+
_per_token_group_quant_mla_deep_gemm_masked_fp8[grid](
|
995
|
+
x,
|
996
|
+
x_q,
|
997
|
+
x_s,
|
998
|
+
masked_m,
|
999
|
+
group_size,
|
1000
|
+
x.stride(0),
|
1001
|
+
x.stride(1),
|
1002
|
+
x_q.stride(0),
|
1003
|
+
x_q.stride(1),
|
1004
|
+
x_s.stride(0),
|
1005
|
+
x_s.stride(1),
|
1006
|
+
eps,
|
938
1007
|
-fp8_max,
|
939
1008
|
fp8_max,
|
1009
|
+
num_tiles_k,
|
940
1010
|
BLOCK_SIZE,
|
941
1011
|
)
|
942
1012
|
|
943
|
-
return x_q, x_s
|
1013
|
+
return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m
|
1014
|
+
|
1015
|
+
|
1016
|
+
def scaled_fp8_quant(
|
1017
|
+
input: torch.Tensor,
|
1018
|
+
scale: Optional[torch.Tensor] = None,
|
1019
|
+
num_token_padding: Optional[int] = None,
|
1020
|
+
use_per_token_if_dynamic: bool = False,
|
1021
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
1022
|
+
"""
|
1023
|
+
Quantize input tensor to FP8 (8-bit floating point) format.
|
1024
|
+
|
1025
|
+
Args:
|
1026
|
+
input (torch.Tensor): Input tensor to be quantized
|
1027
|
+
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
1028
|
+
If None, scales will be computed dynamically.
|
1029
|
+
num_token_padding (Optional[int]): If specified, pad the first dimension
|
1030
|
+
of the output to at least this value.
|
1031
|
+
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
1032
|
+
determines the quantization granularity:
|
1033
|
+
- True: compute scale per token
|
1034
|
+
- False: compute single scale per tensor
|
1035
|
+
|
1036
|
+
Returns:
|
1037
|
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
1038
|
+
- quantized_tensor: The FP8 quantized version of input
|
1039
|
+
- scale_tensor: The scaling factors used for quantization
|
1040
|
+
|
1041
|
+
Raises:
|
1042
|
+
AssertionError: If input is not 2D or if static scale's numel != 1
|
1043
|
+
"""
|
1044
|
+
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
1045
|
+
shape = input.shape
|
1046
|
+
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
1047
|
+
if num_token_padding:
|
1048
|
+
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
1049
|
+
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
1050
|
+
|
1051
|
+
if scale is None:
|
1052
|
+
# Dynamic scaling
|
1053
|
+
if use_per_token_if_dynamic:
|
1054
|
+
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
|
1055
|
+
sgl_per_token_quant_fp8(input, output, scale)
|
1056
|
+
else:
|
1057
|
+
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
1058
|
+
sgl_per_tensor_quant_fp8(
|
1059
|
+
input, output, scale, is_static=False
|
1060
|
+
) # False for dynamic
|
1061
|
+
else:
|
1062
|
+
# Static scaling
|
1063
|
+
assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}"
|
1064
|
+
sgl_per_tensor_quant_fp8(
|
1065
|
+
input, output, scale, is_static=True
|
1066
|
+
) # True for static
|
1067
|
+
|
1068
|
+
return output, scale
|