sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +3 -6
  4. sglang/compile_deep_gemm.py +136 -0
  5. sglang/lang/backend/anthropic.py +0 -4
  6. sglang/lang/backend/base_backend.py +1 -1
  7. sglang/lang/backend/openai.py +6 -2
  8. sglang/lang/backend/runtime_endpoint.py +5 -1
  9. sglang/lang/backend/vertexai.py +0 -1
  10. sglang/lang/compiler.py +1 -7
  11. sglang/lang/tracer.py +3 -7
  12. sglang/srt/_custom_ops.py +0 -2
  13. sglang/srt/configs/model_config.py +4 -1
  14. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  15. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  16. sglang/srt/constrained/xgrammar_backend.py +27 -4
  17. sglang/srt/custom_op.py +0 -62
  18. sglang/srt/disaggregation/decode.py +105 -6
  19. sglang/srt/disaggregation/mini_lb.py +74 -9
  20. sglang/srt/disaggregation/mooncake/conn.py +33 -63
  21. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  22. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  23. sglang/srt/disaggregation/nixl/conn.py +622 -0
  24. sglang/srt/disaggregation/prefill.py +137 -17
  25. sglang/srt/disaggregation/utils.py +32 -0
  26. sglang/srt/entrypoints/engine.py +4 -0
  27. sglang/srt/entrypoints/http_server.py +3 -7
  28. sglang/srt/entrypoints/verl_engine.py +7 -5
  29. sglang/srt/function_call_parser.py +60 -0
  30. sglang/srt/layers/activation.py +6 -8
  31. sglang/srt/layers/attention/flashattention_backend.py +883 -209
  32. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  33. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  34. sglang/srt/layers/attention/triton_backend.py +6 -0
  35. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  36. sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
  37. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  38. sglang/srt/layers/dp_attention.py +1 -1
  39. sglang/srt/layers/layernorm.py +20 -5
  40. sglang/srt/layers/linear.py +17 -3
  41. sglang/srt/layers/moe/ep_moe/layer.py +17 -29
  42. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  45. sglang/srt/layers/moe/topk.py +27 -30
  46. sglang/srt/layers/parameter.py +0 -2
  47. sglang/srt/layers/quantization/__init__.py +1 -0
  48. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  49. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  52. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  53. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  54. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  55. sglang/srt/layers/quantization/fp8.py +115 -132
  56. sglang/srt/layers/quantization/fp8_kernel.py +213 -88
  57. sglang/srt/layers/quantization/fp8_utils.py +189 -264
  58. sglang/srt/layers/quantization/gptq.py +13 -7
  59. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/utils.py +5 -11
  62. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  63. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  64. sglang/srt/layers/radix_attention.py +15 -0
  65. sglang/srt/layers/rotary_embedding.py +9 -8
  66. sglang/srt/layers/sampler.py +7 -12
  67. sglang/srt/lora/backend/base_backend.py +18 -2
  68. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  69. sglang/srt/lora/backend/triton_backend.py +1 -1
  70. sglang/srt/lora/layers.py +1 -1
  71. sglang/srt/lora/lora.py +1 -1
  72. sglang/srt/lora/lora_manager.py +1 -1
  73. sglang/srt/managers/data_parallel_controller.py +7 -1
  74. sglang/srt/managers/detokenizer_manager.py +0 -1
  75. sglang/srt/managers/io_struct.py +15 -3
  76. sglang/srt/managers/mm_utils.py +4 -3
  77. sglang/srt/managers/multimodal_processor.py +0 -2
  78. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  79. sglang/srt/managers/schedule_batch.py +15 -4
  80. sglang/srt/managers/scheduler.py +28 -77
  81. sglang/srt/managers/tokenizer_manager.py +116 -29
  82. sglang/srt/managers/tp_worker.py +1 -0
  83. sglang/srt/mem_cache/hiradix_cache.py +41 -29
  84. sglang/srt/mem_cache/memory_pool.py +38 -15
  85. sglang/srt/model_executor/cuda_graph_runner.py +15 -10
  86. sglang/srt/model_executor/model_runner.py +39 -31
  87. sglang/srt/models/bert.py +398 -0
  88. sglang/srt/models/deepseek.py +1 -1
  89. sglang/srt/models/deepseek_nextn.py +74 -70
  90. sglang/srt/models/deepseek_v2.py +292 -348
  91. sglang/srt/models/llama.py +5 -5
  92. sglang/srt/models/minicpm3.py +31 -203
  93. sglang/srt/models/minicpmo.py +17 -6
  94. sglang/srt/models/qwen2.py +4 -1
  95. sglang/srt/models/qwen2_moe.py +14 -13
  96. sglang/srt/models/qwen3.py +335 -0
  97. sglang/srt/models/qwen3_moe.py +423 -0
  98. sglang/srt/openai_api/adapter.py +71 -4
  99. sglang/srt/openai_api/protocol.py +6 -1
  100. sglang/srt/reasoning_parser.py +0 -1
  101. sglang/srt/sampling/sampling_batch_info.py +2 -3
  102. sglang/srt/server_args.py +86 -72
  103. sglang/srt/speculative/build_eagle_tree.py +2 -2
  104. sglang/srt/speculative/eagle_utils.py +2 -2
  105. sglang/srt/speculative/eagle_worker.py +6 -14
  106. sglang/srt/utils.py +62 -6
  107. sglang/test/runners.py +5 -1
  108. sglang/test/test_block_fp8.py +167 -0
  109. sglang/test/test_custom_ops.py +1 -1
  110. sglang/test/test_utils.py +3 -1
  111. sglang/version.py +1 -1
  112. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
  113. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
  114. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
  115. sglang/lang/__init__.py +0 -0
  116. sglang/srt/lora/backend/__init__.py +0 -25
  117. sglang/srt/server.py +0 -18
  118. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -16,40 +16,41 @@ import functools
16
16
  import json
17
17
  import logging
18
18
  import os
19
- from contextlib import contextmanager
20
19
  from typing import Any, Dict, List, Optional, Tuple
21
20
 
22
21
  import torch
23
22
  import triton
24
23
  import triton.language as tl
25
24
 
25
+ from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
26
26
  from sglang.srt.utils import (
27
27
  direct_register_custom_op,
28
- get_bool_env_var,
29
28
  get_device_core_count,
30
29
  get_device_name,
31
- get_device_sm,
32
30
  is_cuda,
33
31
  is_hip,
34
32
  supports_custom_op,
35
33
  )
36
34
 
37
- _enable_jit_deepgemm = False
38
-
39
35
  _is_hip = is_hip()
40
- fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
41
-
42
36
  _is_cuda = is_cuda()
43
- if _is_cuda:
44
- import deep_gemm
45
- from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
37
+ _fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
38
+ if _is_hip:
39
+ fp8_max = 224.0
40
+ else:
41
+ fp8_max = torch.finfo(_fp8_type).max
42
+ fp8_min = -fp8_max
46
43
 
47
- sm_version = get_device_sm()
48
- if sm_version == 90 and get_bool_env_var(
49
- "SGL_ENABLE_JIT_DEEPGEMM", default="false"
50
- ):
51
- _enable_jit_deepgemm = True
44
+ if _is_cuda:
45
+ from sgl_kernel import (
46
+ sgl_per_tensor_quant_fp8,
47
+ sgl_per_token_group_quant_fp8,
48
+ sgl_per_token_quant_fp8,
49
+ )
52
50
 
51
+ from sglang.srt.layers.quantization.deep_gemm import (
52
+ gemm_nt_f8f8bf16 as deep_gemm_gemm_nt_f8f8bf16,
53
+ )
53
54
 
54
55
  logger = logging.getLogger(__name__)
55
56
 
@@ -62,10 +63,7 @@ if supports_custom_op():
62
63
  Bs: torch.Tensor,
63
64
  C: torch.Tensor,
64
65
  ) -> None:
65
- M, K = A.shape
66
- N, _ = B.shape
67
- with _log_jit_build(M, N, K):
68
- deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
66
+ deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
69
67
 
70
68
  def deep_gemm_fp8_fp8_bf16_nt_fake(
71
69
  A: torch.Tensor,
@@ -179,7 +177,6 @@ def per_token_group_quant_fp8(
179
177
  x: torch.Tensor,
180
178
  group_size: int,
181
179
  eps: float = 1e-10,
182
- dtype: torch.dtype = fp8_type_,
183
180
  column_major_scales: bool = False,
184
181
  scale_tma_aligned: bool = False,
185
182
  ) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -192,7 +189,6 @@ def per_token_group_quant_fp8(
192
189
  x: The input tenosr with ndim >= 2.
193
190
  group_size: The group size used for quantization.
194
191
  eps: The minimum to avoid dividing zero.
195
- dtype: The dype of output tensor.
196
192
 
197
193
  Returns:
198
194
  Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
@@ -202,15 +198,7 @@ def per_token_group_quant_fp8(
202
198
  ), "the last dimension of `x` cannot be divisible by `group_size`"
203
199
  assert x.is_contiguous(), "`x` is not contiguous"
204
200
 
205
- finfo = torch.finfo(dtype)
206
- fp8_max = finfo.max
207
-
208
- if _is_hip:
209
- fp8_max = 224.0
210
-
211
- fp8_min = -fp8_max
212
-
213
- x_q = torch.empty_like(x, device=x.device, dtype=dtype)
201
+ x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
214
202
  M = x.numel() // group_size
215
203
  N = group_size
216
204
  if column_major_scales:
@@ -276,26 +264,36 @@ def sglang_per_token_group_quant_fp8(
276
264
  x: torch.Tensor,
277
265
  group_size: int,
278
266
  eps: float = 1e-10,
279
- dtype: torch.dtype = fp8_type_,
267
+ column_major_scales: bool = False,
268
+ scale_tma_aligned: bool = False,
280
269
  ):
281
270
  assert (
282
271
  x.shape[-1] % group_size == 0
283
272
  ), "the last dimension of `x` cannot be divisible by `group_size`"
284
273
  assert x.is_contiguous(), "`x` is not contiguous"
285
274
 
286
- finfo = torch.finfo(dtype)
287
- fp8_max = finfo.max
288
-
289
- fp8_min = -fp8_max
290
-
291
- x_q = torch.empty_like(x, device=x.device, dtype=dtype)
292
- M = x.numel() // group_size
293
- N = group_size
294
- x_s = torch.empty(
295
- x.shape[:-1] + (x.shape[-1] // group_size,),
296
- device=x.device,
297
- dtype=torch.float32,
298
- )
275
+ x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
276
+ if column_major_scales:
277
+ if scale_tma_aligned:
278
+ # aligned to 4 * sizeof(float)
279
+ aligned_size = (x.shape[-2] + 3) // 4 * 4
280
+ x_s = torch.empty(
281
+ x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
282
+ device=x.device,
283
+ dtype=torch.float32,
284
+ ).permute(-1, -2)[: x.shape[-2], :]
285
+ else:
286
+ x_s = torch.empty(
287
+ (x.shape[-1] // group_size,) + x.shape[:-1],
288
+ device=x.device,
289
+ dtype=torch.float32,
290
+ ).permute(-1, -2)
291
+ else:
292
+ x_s = torch.empty(
293
+ x.shape[:-1] + (x.shape[-1] // group_size,),
294
+ device=x.device,
295
+ dtype=torch.float32,
296
+ )
299
297
 
300
298
  sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
301
299
 
@@ -304,7 +302,7 @@ def sglang_per_token_group_quant_fp8(
304
302
 
305
303
  def sglang_per_token_quant_fp8(
306
304
  x: torch.Tensor,
307
- dtype: torch.dtype = fp8_type_,
305
+ dtype: torch.dtype = _fp8_type,
308
306
  ):
309
307
  assert x.is_contiguous(), "`x` is not contiguous"
310
308
 
@@ -368,7 +366,6 @@ def static_quant_fp8(
368
366
  x: torch.Tensor,
369
367
  x_s: torch.Tensor,
370
368
  repeat_scale: bool = False,
371
- dtype: torch.dtype = fp8_type_,
372
369
  ) -> Tuple[torch.Tensor, torch.Tensor]:
373
370
  """Function to perform static quantization using the given scale on an input tensor `x`.
374
371
 
@@ -386,15 +383,8 @@ def static_quant_fp8(
386
383
  """
387
384
  assert x.is_contiguous(), "`x` is not contiguous"
388
385
  assert x_s.numel() == 1, "only supports per-tensor scale"
389
- finfo = torch.finfo(dtype)
390
- fp8_max = finfo.max
391
386
 
392
- if _is_hip:
393
- fp8_max = 224.0
394
-
395
- fp8_min = -fp8_max
396
-
397
- x_q = torch.empty_like(x, device=x.device, dtype=dtype)
387
+ x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
398
388
  M = x.numel() // x.shape[-1]
399
389
  N = x.shape[-1]
400
390
  if repeat_scale:
@@ -714,25 +704,6 @@ def get_w8a8_block_fp8_configs(
714
704
  return None
715
705
 
716
706
 
717
- @contextmanager
718
- def _log_jit_build(M: int, N: int, K: int):
719
- from deep_gemm.jit.runtime import RuntimeCache
720
-
721
- origin_func = RuntimeCache.__getitem__
722
-
723
- def __patched_func(self, *args, **kwargs):
724
- ret = origin_func(self, *args, **kwargs)
725
- if ret is None:
726
- logger.warning(
727
- f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
728
- )
729
- return ret
730
-
731
- RuntimeCache.__getitem__ = __patched_func
732
- yield
733
- RuntimeCache.__getitem__ = origin_func
734
-
735
-
736
707
  def w8a8_block_fp8_matmul(
737
708
  A: torch.Tensor,
738
709
  B: torch.Tensor,
@@ -803,12 +774,11 @@ def w8a8_block_fp8_matmul(
803
774
  )
804
775
 
805
776
  # deepgemm only support bf16
806
- if C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
777
+ if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
807
778
  if supports_custom_op():
808
779
  torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
809
780
  else:
810
- with _log_jit_build(M, N, K):
811
- deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
781
+ deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
812
782
  else:
813
783
  kernel = (
814
784
  _w8a8_block_fp8_matmul_unrolledx4
@@ -896,22 +866,20 @@ def _per_tensor_quant_mla_fp8_stage2(
896
866
 
897
867
 
898
868
  def per_tensor_quant_mla_fp8(
899
- x: torch.Tensor, eps: float = 1e-12, dtype: torch.dtype = torch.float8_e4m3fn
869
+ x: torch.Tensor, x_s_out: torch.Tensor, eps: float = 1e-12
900
870
  ) -> Tuple[torch.Tensor, torch.Tensor]:
901
871
  """
902
872
  This function quantizes input values to float8 values with tensor-wise quantization
903
873
  and specialized for mla absorbed case.
904
874
  """
905
875
  assert x.dim() == 3, "`x` is not a 3d-tensor"
876
+ assert (
877
+ x_s_out.shape == (1,)
878
+ and x_s_out.dtype == torch.float32
879
+ and x_s_out.device == x.device
880
+ )
906
881
 
907
- finfo = torch.finfo(dtype)
908
- fp8_max = finfo.max
909
- if _is_hip:
910
- dtype = torch.float8_e4m3fnuz
911
- fp8_max = 224.0
912
-
913
- x_q = x.new_empty(x.size(), dtype=dtype)
914
- x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
882
+ x_q = x.new_empty(x.size(), dtype=_fp8_type)
915
883
 
916
884
  num_head, num_seq, head_size = x.shape
917
885
  BLOCK_SIZE = triton.next_power_of_2(head_size)
@@ -919,7 +887,7 @@ def per_tensor_quant_mla_fp8(
919
887
 
920
888
  _per_tensor_quant_mla_fp8_stage1[grid](
921
889
  x,
922
- x_s,
890
+ x_s_out,
923
891
  head_size,
924
892
  x.stride(0),
925
893
  x.stride(1),
@@ -929,15 +897,172 @@ def per_tensor_quant_mla_fp8(
929
897
  )
930
898
  _per_tensor_quant_mla_fp8_stage2[grid](
931
899
  x,
932
- x_s,
900
+ x_s_out,
933
901
  x_q,
934
902
  num_seq,
935
903
  head_size,
936
904
  x.stride(0),
937
905
  x.stride(1),
906
+ fp8_min,
907
+ fp8_max,
908
+ BLOCK_SIZE,
909
+ )
910
+
911
+ return x_q, x_s_out
912
+
913
+
914
+ @triton.jit
915
+ def _per_token_group_quant_mla_deep_gemm_masked_fp8(
916
+ y_ptr,
917
+ y_q_ptr,
918
+ y_s_ptr,
919
+ masked_m_ptr,
920
+ group_size,
921
+ y_stride_b,
922
+ y_stride_t,
923
+ y_q_stride_b,
924
+ y_q_stride_t,
925
+ y_s_stride_b,
926
+ y_s_stride_g,
927
+ eps,
928
+ fp8_min,
929
+ fp8_max,
930
+ NUM_GROUP: tl.constexpr,
931
+ BLOCK: tl.constexpr,
932
+ ):
933
+ """A Triton-accelerated function to perform per-token-group
934
+ quantization on a tensor for deep_gemm grouped_gemm_masked.
935
+ This function converts the tensor values into float8 values.
936
+ y and y_q: (b, t, k)
937
+ y_s: (b, k//group_size, t)
938
+ """
939
+ t_id = tl.program_id(0)
940
+ b_id = tl.program_id(1)
941
+
942
+ y_ptr += b_id * y_stride_b + t_id * y_stride_t
943
+ y_q_ptr += b_id * y_q_stride_b + t_id * y_q_stride_t
944
+ y_s_ptr += b_id * y_s_stride_b + t_id
945
+
946
+ if t_id == 0:
947
+ tl.store(masked_m_ptr + b_id, tl.num_programs(0))
948
+
949
+ cols = tl.arange(0, BLOCK) # group_size <= BLOCK
950
+ mask = cols < group_size
951
+
952
+ for gid in range(NUM_GROUP):
953
+ y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to(
954
+ tl.float32
955
+ )
956
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
957
+ y_s = _absmax / fp8_max
958
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
959
+
960
+ tl.store(y_q_ptr + gid * group_size + cols, y_q, mask=mask)
961
+ tl.store(y_s_ptr + gid * y_s_stride_g, y_s)
962
+
963
+
964
+ def per_tensor_quant_mla_deep_gemm_masked_fp8(
965
+ x: torch.Tensor,
966
+ group_size: int = 128,
967
+ eps: float = 1e-12,
968
+ dtype: torch.dtype = torch.float8_e4m3fn,
969
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
970
+ """
971
+ This function quantizes input values to float8 values with per-token-group-quantization
972
+ for deep_gemm grouped_gemm_masked and specialized for mla absorbed case.
973
+ """
974
+ assert x.dim() == 3, "`x` is not a 3d-tensor"
975
+
976
+ finfo = torch.finfo(dtype)
977
+ fp8_max = finfo.max
978
+ if _is_hip:
979
+ dtype = torch.float8_e4m3fnuz
980
+ fp8_max = 224.0
981
+
982
+ b, m, k = x.shape
983
+ aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel
984
+ num_tiles_k = k // group_size
985
+ assert num_tiles_k * group_size == k, f"k % {group_size} must be zero"
986
+
987
+ x_q = x.new_empty((b, aligned_m, k), dtype=dtype)
988
+ x_s = x.new_empty((b, num_tiles_k, aligned_m), dtype=torch.float32)
989
+ masked_m = x.new_empty((b,), dtype=torch.int32)
990
+
991
+ BLOCK_SIZE = triton.next_power_of_2(group_size)
992
+ grid = (m, b)
993
+
994
+ _per_token_group_quant_mla_deep_gemm_masked_fp8[grid](
995
+ x,
996
+ x_q,
997
+ x_s,
998
+ masked_m,
999
+ group_size,
1000
+ x.stride(0),
1001
+ x.stride(1),
1002
+ x_q.stride(0),
1003
+ x_q.stride(1),
1004
+ x_s.stride(0),
1005
+ x_s.stride(1),
1006
+ eps,
938
1007
  -fp8_max,
939
1008
  fp8_max,
1009
+ num_tiles_k,
940
1010
  BLOCK_SIZE,
941
1011
  )
942
1012
 
943
- return x_q, x_s
1013
+ return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m
1014
+
1015
+
1016
+ def scaled_fp8_quant(
1017
+ input: torch.Tensor,
1018
+ scale: Optional[torch.Tensor] = None,
1019
+ num_token_padding: Optional[int] = None,
1020
+ use_per_token_if_dynamic: bool = False,
1021
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1022
+ """
1023
+ Quantize input tensor to FP8 (8-bit floating point) format.
1024
+
1025
+ Args:
1026
+ input (torch.Tensor): Input tensor to be quantized
1027
+ scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
1028
+ If None, scales will be computed dynamically.
1029
+ num_token_padding (Optional[int]): If specified, pad the first dimension
1030
+ of the output to at least this value.
1031
+ use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
1032
+ determines the quantization granularity:
1033
+ - True: compute scale per token
1034
+ - False: compute single scale per tensor
1035
+
1036
+ Returns:
1037
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
1038
+ - quantized_tensor: The FP8 quantized version of input
1039
+ - scale_tensor: The scaling factors used for quantization
1040
+
1041
+ Raises:
1042
+ AssertionError: If input is not 2D or if static scale's numel != 1
1043
+ """
1044
+ assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
1045
+ shape = input.shape
1046
+ out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
1047
+ if num_token_padding:
1048
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
1049
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
1050
+
1051
+ if scale is None:
1052
+ # Dynamic scaling
1053
+ if use_per_token_if_dynamic:
1054
+ scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
1055
+ sgl_per_token_quant_fp8(input, output, scale)
1056
+ else:
1057
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
1058
+ sgl_per_tensor_quant_fp8(
1059
+ input, output, scale, is_static=False
1060
+ ) # False for dynamic
1061
+ else:
1062
+ # Static scaling
1063
+ assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}"
1064
+ sgl_per_tensor_quant_fp8(
1065
+ input, output, scale, is_static=True
1066
+ ) # True for static
1067
+
1068
+ return output, scale