sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ import functools
16
16
  import json
17
17
  import logging
18
18
  import os
19
+ from contextlib import contextmanager
19
20
  from typing import Any, Dict, List, Optional, Tuple
20
21
 
21
22
  import torch
@@ -33,20 +34,31 @@ from sglang.srt.utils import (
33
34
  supports_custom_op,
34
35
  )
35
36
 
36
- _enable_jit_deepgemm = False
37
-
38
37
  _is_hip = is_hip()
39
- fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
40
-
41
38
  _is_cuda = is_cuda()
39
+ _fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
40
+ if _is_hip:
41
+ fp8_max = 224.0
42
+ else:
43
+ fp8_max = torch.finfo(_fp8_type).max
44
+ fp8_min = -fp8_max
45
+
46
+ _enable_jit_deepgemm = False
47
+ _enable_jit_deepgemm_bmm = False
42
48
  if _is_cuda:
43
- import deep_gemm # `pip install "sgl-kernel>=0.0.4.post3"`
44
- from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
49
+ import deep_gemm
50
+ from sgl_kernel import (
51
+ sgl_per_tensor_quant_fp8,
52
+ sgl_per_token_group_quant_fp8,
53
+ sgl_per_token_quant_fp8,
54
+ )
45
55
 
46
56
  sm_version = get_device_sm()
47
- if sm_version >= 90 and get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
48
- _enable_jit_deepgemm = True
49
-
57
+ if sm_version == 90:
58
+ if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
59
+ _enable_jit_deepgemm = True
60
+ if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM_BMM", default="false"):
61
+ _enable_jit_deepgemm_bmm = True
50
62
 
51
63
  logger = logging.getLogger(__name__)
52
64
 
@@ -59,7 +71,10 @@ if supports_custom_op():
59
71
  Bs: torch.Tensor,
60
72
  C: torch.Tensor,
61
73
  ) -> None:
62
- deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
74
+ M, K = A.shape
75
+ N, _ = B.shape
76
+ with _log_jit_build(M, N, K):
77
+ deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
63
78
 
64
79
  def deep_gemm_fp8_fp8_bf16_nt_fake(
65
80
  A: torch.Tensor,
@@ -173,7 +188,6 @@ def per_token_group_quant_fp8(
173
188
  x: torch.Tensor,
174
189
  group_size: int,
175
190
  eps: float = 1e-10,
176
- dtype: torch.dtype = fp8_type_,
177
191
  column_major_scales: bool = False,
178
192
  scale_tma_aligned: bool = False,
179
193
  ) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -186,7 +200,6 @@ def per_token_group_quant_fp8(
186
200
  x: The input tenosr with ndim >= 2.
187
201
  group_size: The group size used for quantization.
188
202
  eps: The minimum to avoid dividing zero.
189
- dtype: The dype of output tensor.
190
203
 
191
204
  Returns:
192
205
  Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
@@ -196,15 +209,7 @@ def per_token_group_quant_fp8(
196
209
  ), "the last dimension of `x` cannot be divisible by `group_size`"
197
210
  assert x.is_contiguous(), "`x` is not contiguous"
198
211
 
199
- finfo = torch.finfo(dtype)
200
- fp8_max = finfo.max
201
-
202
- if _is_hip:
203
- fp8_max = 224.0
204
-
205
- fp8_min = -fp8_max
206
-
207
- x_q = torch.empty_like(x, device=x.device, dtype=dtype)
212
+ x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
208
213
  M = x.numel() // group_size
209
214
  N = group_size
210
215
  if column_major_scales:
@@ -270,26 +275,36 @@ def sglang_per_token_group_quant_fp8(
270
275
  x: torch.Tensor,
271
276
  group_size: int,
272
277
  eps: float = 1e-10,
273
- dtype: torch.dtype = fp8_type_,
278
+ column_major_scales: bool = False,
279
+ scale_tma_aligned: bool = False,
274
280
  ):
275
281
  assert (
276
282
  x.shape[-1] % group_size == 0
277
283
  ), "the last dimension of `x` cannot be divisible by `group_size`"
278
284
  assert x.is_contiguous(), "`x` is not contiguous"
279
285
 
280
- finfo = torch.finfo(dtype)
281
- fp8_max = finfo.max
282
-
283
- fp8_min = -fp8_max
284
-
285
- x_q = torch.empty_like(x, device=x.device, dtype=dtype)
286
- M = x.numel() // group_size
287
- N = group_size
288
- x_s = torch.empty(
289
- x.shape[:-1] + (x.shape[-1] // group_size,),
290
- device=x.device,
291
- dtype=torch.float32,
292
- )
286
+ x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
287
+ if column_major_scales:
288
+ if scale_tma_aligned:
289
+ # aligned to 4 * sizeof(float)
290
+ aligned_size = (x.shape[-2] + 3) // 4 * 4
291
+ x_s = torch.empty(
292
+ x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
293
+ device=x.device,
294
+ dtype=torch.float32,
295
+ ).permute(-1, -2)[: x.shape[-2], :]
296
+ else:
297
+ x_s = torch.empty(
298
+ (x.shape[-1] // group_size,) + x.shape[:-1],
299
+ device=x.device,
300
+ dtype=torch.float32,
301
+ ).permute(-1, -2)
302
+ else:
303
+ x_s = torch.empty(
304
+ x.shape[:-1] + (x.shape[-1] // group_size,),
305
+ device=x.device,
306
+ dtype=torch.float32,
307
+ )
293
308
 
294
309
  sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
295
310
 
@@ -298,7 +313,7 @@ def sglang_per_token_group_quant_fp8(
298
313
 
299
314
  def sglang_per_token_quant_fp8(
300
315
  x: torch.Tensor,
301
- dtype: torch.dtype = fp8_type_,
316
+ dtype: torch.dtype = _fp8_type,
302
317
  ):
303
318
  assert x.is_contiguous(), "`x` is not contiguous"
304
319
 
@@ -362,7 +377,6 @@ def static_quant_fp8(
362
377
  x: torch.Tensor,
363
378
  x_s: torch.Tensor,
364
379
  repeat_scale: bool = False,
365
- dtype: torch.dtype = fp8_type_,
366
380
  ) -> Tuple[torch.Tensor, torch.Tensor]:
367
381
  """Function to perform static quantization using the given scale on an input tensor `x`.
368
382
 
@@ -380,15 +394,8 @@ def static_quant_fp8(
380
394
  """
381
395
  assert x.is_contiguous(), "`x` is not contiguous"
382
396
  assert x_s.numel() == 1, "only supports per-tensor scale"
383
- finfo = torch.finfo(dtype)
384
- fp8_max = finfo.max
385
397
 
386
- if _is_hip:
387
- fp8_max = 224.0
388
-
389
- fp8_min = -fp8_max
390
-
391
- x_q = torch.empty_like(x, device=x.device, dtype=dtype)
398
+ x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
392
399
  M = x.numel() // x.shape[-1]
393
400
  N = x.shape[-1]
394
401
  if repeat_scale:
@@ -708,6 +715,25 @@ def get_w8a8_block_fp8_configs(
708
715
  return None
709
716
 
710
717
 
718
+ @contextmanager
719
+ def _log_jit_build(M: int, N: int, K: int):
720
+ from deep_gemm.jit.runtime import RuntimeCache
721
+
722
+ origin_func = RuntimeCache.__getitem__
723
+
724
+ def __patched_func(self, *args, **kwargs):
725
+ ret = origin_func(self, *args, **kwargs)
726
+ if ret is None:
727
+ logger.warning(
728
+ f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
729
+ )
730
+ return ret
731
+
732
+ RuntimeCache.__getitem__ = __patched_func
733
+ yield
734
+ RuntimeCache.__getitem__ = origin_func
735
+
736
+
711
737
  def w8a8_block_fp8_matmul(
712
738
  A: torch.Tensor,
713
739
  B: torch.Tensor,
@@ -782,7 +808,8 @@ def w8a8_block_fp8_matmul(
782
808
  if supports_custom_op():
783
809
  torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
784
810
  else:
785
- deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
811
+ with _log_jit_build(M, N, K):
812
+ deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
786
813
  else:
787
814
  kernel = (
788
815
  _w8a8_block_fp8_matmul_unrolledx4
@@ -815,3 +842,258 @@ def w8a8_block_fp8_matmul(
815
842
  )
816
843
 
817
844
  return C
845
+
846
+
847
+ @triton.jit
848
+ def _per_tensor_quant_mla_fp8_stage1(
849
+ x_ptr,
850
+ x_s_ptr,
851
+ head_size,
852
+ x_stride_h,
853
+ x_stride_s,
854
+ eps,
855
+ fp8_max,
856
+ BLOCK_SIZE: tl.constexpr,
857
+ ):
858
+ seq_id = tl.program_id(0)
859
+ head_id = tl.program_id(1)
860
+ offset = tl.arange(0, BLOCK_SIZE)
861
+ mask = offset < head_size
862
+
863
+ x_ptr += head_id * x_stride_h + seq_id * x_stride_s
864
+ x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32)
865
+ _absmax = tl.maximum(tl.max(tl.abs(x)), eps)
866
+
867
+ tl.atomic_max(x_s_ptr, _absmax / fp8_max)
868
+
869
+
870
+ @triton.jit
871
+ def _per_tensor_quant_mla_fp8_stage2(
872
+ x_ptr,
873
+ x_s_ptr,
874
+ x_q_ptr,
875
+ num_seq,
876
+ head_size,
877
+ x_stride_h,
878
+ x_stride_s,
879
+ fp8_min,
880
+ fp8_max,
881
+ BLOCK_SIZE: tl.constexpr,
882
+ ):
883
+ seq_id = tl.program_id(0)
884
+ head_id = tl.program_id(1)
885
+ offset = tl.arange(0, BLOCK_SIZE)
886
+ mask = offset < head_size
887
+
888
+ x_s = tl.load(x_s_ptr)
889
+ x_s_inv = 1.0 / x_s
890
+
891
+ x_ptr += head_id * x_stride_h + seq_id * x_stride_s
892
+ x_q_ptr += head_id * num_seq * head_size + seq_id * head_size
893
+
894
+ x = tl.load(x_ptr + offset, mask=mask, other=0.0).to(tl.float32)
895
+ x_q = tl.clamp(x * x_s_inv, fp8_min, fp8_max).to(x_q_ptr.dtype.element_ty)
896
+ tl.store(x_q_ptr + offset, x_q, mask=mask)
897
+
898
+
899
+ def per_tensor_quant_mla_fp8(
900
+ x: torch.Tensor, x_s_out: torch.Tensor, eps: float = 1e-12
901
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
902
+ """
903
+ This function quantizes input values to float8 values with tensor-wise quantization
904
+ and specialized for mla absorbed case.
905
+ """
906
+ assert x.dim() == 3, "`x` is not a 3d-tensor"
907
+ assert (
908
+ x_s_out.shape == (1,)
909
+ and x_s_out.dtype == torch.float32
910
+ and x_s_out.device == x.device
911
+ )
912
+
913
+ x_q = x.new_empty(x.size(), dtype=_fp8_type)
914
+
915
+ num_head, num_seq, head_size = x.shape
916
+ BLOCK_SIZE = triton.next_power_of_2(head_size)
917
+ grid = (num_seq, num_head)
918
+
919
+ _per_tensor_quant_mla_fp8_stage1[grid](
920
+ x,
921
+ x_s_out,
922
+ head_size,
923
+ x.stride(0),
924
+ x.stride(1),
925
+ eps,
926
+ fp8_max,
927
+ BLOCK_SIZE,
928
+ )
929
+ _per_tensor_quant_mla_fp8_stage2[grid](
930
+ x,
931
+ x_s_out,
932
+ x_q,
933
+ num_seq,
934
+ head_size,
935
+ x.stride(0),
936
+ x.stride(1),
937
+ fp8_min,
938
+ fp8_max,
939
+ BLOCK_SIZE,
940
+ )
941
+
942
+ return x_q, x_s_out
943
+
944
+
945
+ @triton.jit
946
+ def _per_token_group_quant_mla_deep_gemm_masked_fp8(
947
+ y_ptr,
948
+ y_q_ptr,
949
+ y_s_ptr,
950
+ masked_m_ptr,
951
+ group_size,
952
+ y_stride_b,
953
+ y_stride_t,
954
+ y_q_stride_b,
955
+ y_q_stride_t,
956
+ y_s_stride_b,
957
+ y_s_stride_g,
958
+ eps,
959
+ fp8_min,
960
+ fp8_max,
961
+ NUM_GROUP: tl.constexpr,
962
+ BLOCK: tl.constexpr,
963
+ ):
964
+ """A Triton-accelerated function to perform per-token-group
965
+ quantization on a tensor for deep_gemm grouped_gemm_masked.
966
+ This function converts the tensor values into float8 values.
967
+ y and y_q: (b, t, k)
968
+ y_s: (b, k//group_size, t)
969
+ """
970
+ t_id = tl.program_id(0)
971
+ b_id = tl.program_id(1)
972
+
973
+ y_ptr += b_id * y_stride_b + t_id * y_stride_t
974
+ y_q_ptr += b_id * y_q_stride_b + t_id * y_q_stride_t
975
+ y_s_ptr += b_id * y_s_stride_b + t_id
976
+
977
+ if t_id == 0:
978
+ tl.store(masked_m_ptr + b_id, tl.num_programs(0))
979
+
980
+ cols = tl.arange(0, BLOCK) # group_size <= BLOCK
981
+ mask = cols < group_size
982
+
983
+ for gid in range(NUM_GROUP):
984
+ y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to(
985
+ tl.float32
986
+ )
987
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
988
+ y_s = _absmax / fp8_max
989
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
990
+
991
+ tl.store(y_q_ptr + gid * group_size + cols, y_q, mask=mask)
992
+ tl.store(y_s_ptr + gid * y_s_stride_g, y_s)
993
+
994
+
995
+ def per_tensor_quant_mla_deep_gemm_masked_fp8(
996
+ x: torch.Tensor,
997
+ group_size: int = 128,
998
+ eps: float = 1e-12,
999
+ dtype: torch.dtype = torch.float8_e4m3fn,
1000
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1001
+ """
1002
+ This function quantizes input values to float8 values with per-token-group-quantization
1003
+ for deep_gemm grouped_gemm_masked and specialized for mla absorbed case.
1004
+ """
1005
+ assert x.dim() == 3, "`x` is not a 3d-tensor"
1006
+
1007
+ finfo = torch.finfo(dtype)
1008
+ fp8_max = finfo.max
1009
+ if _is_hip:
1010
+ dtype = torch.float8_e4m3fnuz
1011
+ fp8_max = 224.0
1012
+
1013
+ b, m, k = x.shape
1014
+ aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel
1015
+ num_tiles_k = k // group_size
1016
+ assert num_tiles_k * group_size == k, f"k % {group_size} must be zero"
1017
+
1018
+ x_q = x.new_empty((b, aligned_m, k), dtype=dtype)
1019
+ x_s = x.new_empty((b, num_tiles_k, aligned_m), dtype=torch.float32)
1020
+ masked_m = x.new_empty((b,), dtype=torch.int32)
1021
+
1022
+ BLOCK_SIZE = triton.next_power_of_2(group_size)
1023
+ grid = (m, b)
1024
+
1025
+ _per_token_group_quant_mla_deep_gemm_masked_fp8[grid](
1026
+ x,
1027
+ x_q,
1028
+ x_s,
1029
+ masked_m,
1030
+ group_size,
1031
+ x.stride(0),
1032
+ x.stride(1),
1033
+ x_q.stride(0),
1034
+ x_q.stride(1),
1035
+ x_s.stride(0),
1036
+ x_s.stride(1),
1037
+ eps,
1038
+ -fp8_max,
1039
+ fp8_max,
1040
+ num_tiles_k,
1041
+ BLOCK_SIZE,
1042
+ )
1043
+
1044
+ return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m
1045
+
1046
+
1047
+ def scaled_fp8_quant(
1048
+ input: torch.Tensor,
1049
+ scale: Optional[torch.Tensor] = None,
1050
+ num_token_padding: Optional[int] = None,
1051
+ use_per_token_if_dynamic: bool = False,
1052
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1053
+ """
1054
+ Quantize input tensor to FP8 (8-bit floating point) format.
1055
+
1056
+ Args:
1057
+ input (torch.Tensor): Input tensor to be quantized
1058
+ scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
1059
+ If None, scales will be computed dynamically.
1060
+ num_token_padding (Optional[int]): If specified, pad the first dimension
1061
+ of the output to at least this value.
1062
+ use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
1063
+ determines the quantization granularity:
1064
+ - True: compute scale per token
1065
+ - False: compute single scale per tensor
1066
+
1067
+ Returns:
1068
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
1069
+ - quantized_tensor: The FP8 quantized version of input
1070
+ - scale_tensor: The scaling factors used for quantization
1071
+
1072
+ Raises:
1073
+ AssertionError: If input is not 2D or if static scale's numel != 1
1074
+ """
1075
+ assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
1076
+ shape = input.shape
1077
+ out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
1078
+ if num_token_padding:
1079
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
1080
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
1081
+
1082
+ if scale is None:
1083
+ # Dynamic scaling
1084
+ if use_per_token_if_dynamic:
1085
+ scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
1086
+ sgl_per_token_quant_fp8(input, output, scale)
1087
+ else:
1088
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
1089
+ sgl_per_tensor_quant_fp8(
1090
+ input, output, scale, is_static=False
1091
+ ) # False for dynamic
1092
+ else:
1093
+ # Static scaling
1094
+ assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}"
1095
+ sgl_per_tensor_quant_fp8(
1096
+ input, output, scale, is_static=True
1097
+ ) # True for static
1098
+
1099
+ return output, scale