sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +0 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +62 -6
- sglang/srt/disaggregation/mini_lb.py +5 -1
- sglang/srt/disaggregation/mooncake/conn.py +32 -62
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/prefill.py +40 -4
- sglang/srt/disaggregation/utils.py +15 -0
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +114 -71
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -57
- sglang/srt/layers/quantization/fp8_utils.py +187 -262
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +3 -2
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +1 -0
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +2 -4
- sglang/srt/managers/scheduler.py +12 -71
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +7 -2
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/model_runner.py +20 -27
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +289 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +29 -201
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +34 -32
- sglang/srt/speculative/eagle_worker.py +4 -7
- sglang/srt/utils.py +16 -1
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -34,22 +34,31 @@ from sglang.srt.utils import (
|
|
34
34
|
supports_custom_op,
|
35
35
|
)
|
36
36
|
|
37
|
-
_enable_jit_deepgemm = False
|
38
|
-
|
39
37
|
_is_hip = is_hip()
|
40
|
-
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
41
|
-
|
42
38
|
_is_cuda = is_cuda()
|
39
|
+
_fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
40
|
+
if _is_hip:
|
41
|
+
fp8_max = 224.0
|
42
|
+
else:
|
43
|
+
fp8_max = torch.finfo(_fp8_type).max
|
44
|
+
fp8_min = -fp8_max
|
45
|
+
|
46
|
+
_enable_jit_deepgemm = False
|
47
|
+
_enable_jit_deepgemm_bmm = False
|
43
48
|
if _is_cuda:
|
44
49
|
import deep_gemm
|
45
|
-
from sgl_kernel import
|
50
|
+
from sgl_kernel import (
|
51
|
+
sgl_per_tensor_quant_fp8,
|
52
|
+
sgl_per_token_group_quant_fp8,
|
53
|
+
sgl_per_token_quant_fp8,
|
54
|
+
)
|
46
55
|
|
47
56
|
sm_version = get_device_sm()
|
48
|
-
if sm_version == 90
|
49
|
-
"SGL_ENABLE_JIT_DEEPGEMM", default="false"
|
50
|
-
|
51
|
-
|
52
|
-
|
57
|
+
if sm_version == 90:
|
58
|
+
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
|
59
|
+
_enable_jit_deepgemm = True
|
60
|
+
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM_BMM", default="false"):
|
61
|
+
_enable_jit_deepgemm_bmm = True
|
53
62
|
|
54
63
|
logger = logging.getLogger(__name__)
|
55
64
|
|
@@ -179,7 +188,6 @@ def per_token_group_quant_fp8(
|
|
179
188
|
x: torch.Tensor,
|
180
189
|
group_size: int,
|
181
190
|
eps: float = 1e-10,
|
182
|
-
dtype: torch.dtype = fp8_type_,
|
183
191
|
column_major_scales: bool = False,
|
184
192
|
scale_tma_aligned: bool = False,
|
185
193
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
@@ -192,7 +200,6 @@ def per_token_group_quant_fp8(
|
|
192
200
|
x: The input tenosr with ndim >= 2.
|
193
201
|
group_size: The group size used for quantization.
|
194
202
|
eps: The minimum to avoid dividing zero.
|
195
|
-
dtype: The dype of output tensor.
|
196
203
|
|
197
204
|
Returns:
|
198
205
|
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
@@ -202,15 +209,7 @@ def per_token_group_quant_fp8(
|
|
202
209
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
203
210
|
assert x.is_contiguous(), "`x` is not contiguous"
|
204
211
|
|
205
|
-
|
206
|
-
fp8_max = finfo.max
|
207
|
-
|
208
|
-
if _is_hip:
|
209
|
-
fp8_max = 224.0
|
210
|
-
|
211
|
-
fp8_min = -fp8_max
|
212
|
-
|
213
|
-
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
212
|
+
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
214
213
|
M = x.numel() // group_size
|
215
214
|
N = group_size
|
216
215
|
if column_major_scales:
|
@@ -276,26 +275,36 @@ def sglang_per_token_group_quant_fp8(
|
|
276
275
|
x: torch.Tensor,
|
277
276
|
group_size: int,
|
278
277
|
eps: float = 1e-10,
|
279
|
-
|
278
|
+
column_major_scales: bool = False,
|
279
|
+
scale_tma_aligned: bool = False,
|
280
280
|
):
|
281
281
|
assert (
|
282
282
|
x.shape[-1] % group_size == 0
|
283
283
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
284
284
|
assert x.is_contiguous(), "`x` is not contiguous"
|
285
285
|
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
286
|
+
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
287
|
+
if column_major_scales:
|
288
|
+
if scale_tma_aligned:
|
289
|
+
# aligned to 4 * sizeof(float)
|
290
|
+
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
291
|
+
x_s = torch.empty(
|
292
|
+
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
|
293
|
+
device=x.device,
|
294
|
+
dtype=torch.float32,
|
295
|
+
).permute(-1, -2)[: x.shape[-2], :]
|
296
|
+
else:
|
297
|
+
x_s = torch.empty(
|
298
|
+
(x.shape[-1] // group_size,) + x.shape[:-1],
|
299
|
+
device=x.device,
|
300
|
+
dtype=torch.float32,
|
301
|
+
).permute(-1, -2)
|
302
|
+
else:
|
303
|
+
x_s = torch.empty(
|
304
|
+
x.shape[:-1] + (x.shape[-1] // group_size,),
|
305
|
+
device=x.device,
|
306
|
+
dtype=torch.float32,
|
307
|
+
)
|
299
308
|
|
300
309
|
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
301
310
|
|
@@ -304,7 +313,7 @@ def sglang_per_token_group_quant_fp8(
|
|
304
313
|
|
305
314
|
def sglang_per_token_quant_fp8(
|
306
315
|
x: torch.Tensor,
|
307
|
-
dtype: torch.dtype =
|
316
|
+
dtype: torch.dtype = _fp8_type,
|
308
317
|
):
|
309
318
|
assert x.is_contiguous(), "`x` is not contiguous"
|
310
319
|
|
@@ -368,7 +377,6 @@ def static_quant_fp8(
|
|
368
377
|
x: torch.Tensor,
|
369
378
|
x_s: torch.Tensor,
|
370
379
|
repeat_scale: bool = False,
|
371
|
-
dtype: torch.dtype = fp8_type_,
|
372
380
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
373
381
|
"""Function to perform static quantization using the given scale on an input tensor `x`.
|
374
382
|
|
@@ -386,15 +394,8 @@ def static_quant_fp8(
|
|
386
394
|
"""
|
387
395
|
assert x.is_contiguous(), "`x` is not contiguous"
|
388
396
|
assert x_s.numel() == 1, "only supports per-tensor scale"
|
389
|
-
finfo = torch.finfo(dtype)
|
390
|
-
fp8_max = finfo.max
|
391
|
-
|
392
|
-
if _is_hip:
|
393
|
-
fp8_max = 224.0
|
394
397
|
|
395
|
-
|
396
|
-
|
397
|
-
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
398
|
+
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
398
399
|
M = x.numel() // x.shape[-1]
|
399
400
|
N = x.shape[-1]
|
400
401
|
if repeat_scale:
|
@@ -896,22 +897,20 @@ def _per_tensor_quant_mla_fp8_stage2(
|
|
896
897
|
|
897
898
|
|
898
899
|
def per_tensor_quant_mla_fp8(
|
899
|
-
x: torch.Tensor, eps: float = 1e-12
|
900
|
+
x: torch.Tensor, x_s_out: torch.Tensor, eps: float = 1e-12
|
900
901
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
901
902
|
"""
|
902
903
|
This function quantizes input values to float8 values with tensor-wise quantization
|
903
904
|
and specialized for mla absorbed case.
|
904
905
|
"""
|
905
906
|
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
907
|
+
assert (
|
908
|
+
x_s_out.shape == (1,)
|
909
|
+
and x_s_out.dtype == torch.float32
|
910
|
+
and x_s_out.device == x.device
|
911
|
+
)
|
906
912
|
|
907
|
-
|
908
|
-
fp8_max = finfo.max
|
909
|
-
if _is_hip:
|
910
|
-
dtype = torch.float8_e4m3fnuz
|
911
|
-
fp8_max = 224.0
|
912
|
-
|
913
|
-
x_q = x.new_empty(x.size(), dtype=dtype)
|
914
|
-
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
|
913
|
+
x_q = x.new_empty(x.size(), dtype=_fp8_type)
|
915
914
|
|
916
915
|
num_head, num_seq, head_size = x.shape
|
917
916
|
BLOCK_SIZE = triton.next_power_of_2(head_size)
|
@@ -919,7 +918,7 @@ def per_tensor_quant_mla_fp8(
|
|
919
918
|
|
920
919
|
_per_tensor_quant_mla_fp8_stage1[grid](
|
921
920
|
x,
|
922
|
-
|
921
|
+
x_s_out,
|
923
922
|
head_size,
|
924
923
|
x.stride(0),
|
925
924
|
x.stride(1),
|
@@ -929,15 +928,172 @@ def per_tensor_quant_mla_fp8(
|
|
929
928
|
)
|
930
929
|
_per_tensor_quant_mla_fp8_stage2[grid](
|
931
930
|
x,
|
932
|
-
|
931
|
+
x_s_out,
|
933
932
|
x_q,
|
934
933
|
num_seq,
|
935
934
|
head_size,
|
936
935
|
x.stride(0),
|
937
936
|
x.stride(1),
|
937
|
+
fp8_min,
|
938
|
+
fp8_max,
|
939
|
+
BLOCK_SIZE,
|
940
|
+
)
|
941
|
+
|
942
|
+
return x_q, x_s_out
|
943
|
+
|
944
|
+
|
945
|
+
@triton.jit
|
946
|
+
def _per_token_group_quant_mla_deep_gemm_masked_fp8(
|
947
|
+
y_ptr,
|
948
|
+
y_q_ptr,
|
949
|
+
y_s_ptr,
|
950
|
+
masked_m_ptr,
|
951
|
+
group_size,
|
952
|
+
y_stride_b,
|
953
|
+
y_stride_t,
|
954
|
+
y_q_stride_b,
|
955
|
+
y_q_stride_t,
|
956
|
+
y_s_stride_b,
|
957
|
+
y_s_stride_g,
|
958
|
+
eps,
|
959
|
+
fp8_min,
|
960
|
+
fp8_max,
|
961
|
+
NUM_GROUP: tl.constexpr,
|
962
|
+
BLOCK: tl.constexpr,
|
963
|
+
):
|
964
|
+
"""A Triton-accelerated function to perform per-token-group
|
965
|
+
quantization on a tensor for deep_gemm grouped_gemm_masked.
|
966
|
+
This function converts the tensor values into float8 values.
|
967
|
+
y and y_q: (b, t, k)
|
968
|
+
y_s: (b, k//group_size, t)
|
969
|
+
"""
|
970
|
+
t_id = tl.program_id(0)
|
971
|
+
b_id = tl.program_id(1)
|
972
|
+
|
973
|
+
y_ptr += b_id * y_stride_b + t_id * y_stride_t
|
974
|
+
y_q_ptr += b_id * y_q_stride_b + t_id * y_q_stride_t
|
975
|
+
y_s_ptr += b_id * y_s_stride_b + t_id
|
976
|
+
|
977
|
+
if t_id == 0:
|
978
|
+
tl.store(masked_m_ptr + b_id, tl.num_programs(0))
|
979
|
+
|
980
|
+
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
|
981
|
+
mask = cols < group_size
|
982
|
+
|
983
|
+
for gid in range(NUM_GROUP):
|
984
|
+
y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to(
|
985
|
+
tl.float32
|
986
|
+
)
|
987
|
+
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
988
|
+
y_s = _absmax / fp8_max
|
989
|
+
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
990
|
+
|
991
|
+
tl.store(y_q_ptr + gid * group_size + cols, y_q, mask=mask)
|
992
|
+
tl.store(y_s_ptr + gid * y_s_stride_g, y_s)
|
993
|
+
|
994
|
+
|
995
|
+
def per_tensor_quant_mla_deep_gemm_masked_fp8(
|
996
|
+
x: torch.Tensor,
|
997
|
+
group_size: int = 128,
|
998
|
+
eps: float = 1e-12,
|
999
|
+
dtype: torch.dtype = torch.float8_e4m3fn,
|
1000
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1001
|
+
"""
|
1002
|
+
This function quantizes input values to float8 values with per-token-group-quantization
|
1003
|
+
for deep_gemm grouped_gemm_masked and specialized for mla absorbed case.
|
1004
|
+
"""
|
1005
|
+
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
1006
|
+
|
1007
|
+
finfo = torch.finfo(dtype)
|
1008
|
+
fp8_max = finfo.max
|
1009
|
+
if _is_hip:
|
1010
|
+
dtype = torch.float8_e4m3fnuz
|
1011
|
+
fp8_max = 224.0
|
1012
|
+
|
1013
|
+
b, m, k = x.shape
|
1014
|
+
aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel
|
1015
|
+
num_tiles_k = k // group_size
|
1016
|
+
assert num_tiles_k * group_size == k, f"k % {group_size} must be zero"
|
1017
|
+
|
1018
|
+
x_q = x.new_empty((b, aligned_m, k), dtype=dtype)
|
1019
|
+
x_s = x.new_empty((b, num_tiles_k, aligned_m), dtype=torch.float32)
|
1020
|
+
masked_m = x.new_empty((b,), dtype=torch.int32)
|
1021
|
+
|
1022
|
+
BLOCK_SIZE = triton.next_power_of_2(group_size)
|
1023
|
+
grid = (m, b)
|
1024
|
+
|
1025
|
+
_per_token_group_quant_mla_deep_gemm_masked_fp8[grid](
|
1026
|
+
x,
|
1027
|
+
x_q,
|
1028
|
+
x_s,
|
1029
|
+
masked_m,
|
1030
|
+
group_size,
|
1031
|
+
x.stride(0),
|
1032
|
+
x.stride(1),
|
1033
|
+
x_q.stride(0),
|
1034
|
+
x_q.stride(1),
|
1035
|
+
x_s.stride(0),
|
1036
|
+
x_s.stride(1),
|
1037
|
+
eps,
|
938
1038
|
-fp8_max,
|
939
1039
|
fp8_max,
|
1040
|
+
num_tiles_k,
|
940
1041
|
BLOCK_SIZE,
|
941
1042
|
)
|
942
1043
|
|
943
|
-
return x_q, x_s
|
1044
|
+
return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m
|
1045
|
+
|
1046
|
+
|
1047
|
+
def scaled_fp8_quant(
|
1048
|
+
input: torch.Tensor,
|
1049
|
+
scale: Optional[torch.Tensor] = None,
|
1050
|
+
num_token_padding: Optional[int] = None,
|
1051
|
+
use_per_token_if_dynamic: bool = False,
|
1052
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
1053
|
+
"""
|
1054
|
+
Quantize input tensor to FP8 (8-bit floating point) format.
|
1055
|
+
|
1056
|
+
Args:
|
1057
|
+
input (torch.Tensor): Input tensor to be quantized
|
1058
|
+
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
1059
|
+
If None, scales will be computed dynamically.
|
1060
|
+
num_token_padding (Optional[int]): If specified, pad the first dimension
|
1061
|
+
of the output to at least this value.
|
1062
|
+
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
1063
|
+
determines the quantization granularity:
|
1064
|
+
- True: compute scale per token
|
1065
|
+
- False: compute single scale per tensor
|
1066
|
+
|
1067
|
+
Returns:
|
1068
|
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
1069
|
+
- quantized_tensor: The FP8 quantized version of input
|
1070
|
+
- scale_tensor: The scaling factors used for quantization
|
1071
|
+
|
1072
|
+
Raises:
|
1073
|
+
AssertionError: If input is not 2D or if static scale's numel != 1
|
1074
|
+
"""
|
1075
|
+
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
1076
|
+
shape = input.shape
|
1077
|
+
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
1078
|
+
if num_token_padding:
|
1079
|
+
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
1080
|
+
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
1081
|
+
|
1082
|
+
if scale is None:
|
1083
|
+
# Dynamic scaling
|
1084
|
+
if use_per_token_if_dynamic:
|
1085
|
+
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
|
1086
|
+
sgl_per_token_quant_fp8(input, output, scale)
|
1087
|
+
else:
|
1088
|
+
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
1089
|
+
sgl_per_tensor_quant_fp8(
|
1090
|
+
input, output, scale, is_static=False
|
1091
|
+
) # False for dynamic
|
1092
|
+
else:
|
1093
|
+
# Static scaling
|
1094
|
+
assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}"
|
1095
|
+
sgl_per_tensor_quant_fp8(
|
1096
|
+
input, output, scale, is_static=True
|
1097
|
+
) # True for static
|
1098
|
+
|
1099
|
+
return output, scale
|