sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +220 -378
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +143 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +681 -259
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +224 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +44 -18
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +94 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +208 -28
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +136 -52
  181. sglang/srt/speculative/build_eagle_tree.py +2 -8
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  183. sglang/srt/speculative/eagle_utils.py +92 -58
  184. sglang/srt/speculative/eagle_worker.py +186 -94
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -26,12 +26,20 @@ from sglang.srt.distributed import (
26
26
  get_tensor_model_parallel_world_size,
27
27
  tensor_model_parallel_all_gather,
28
28
  )
29
+ from sglang.srt.layers.dp_attention import (
30
+ dp_gather,
31
+ dp_scatter,
32
+ get_attention_dp_rank,
33
+ get_attention_dp_size,
34
+ )
29
35
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
36
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
30
37
  from sglang.srt.model_executor.forward_batch_info import (
31
38
  CaptureHiddenMode,
32
39
  ForwardBatch,
33
40
  ForwardMode,
34
41
  )
42
+ from sglang.srt.utils import dump_to_file
35
43
 
36
44
  logger = logging.getLogger(__name__)
37
45
 
@@ -51,13 +59,19 @@ class LogitsProcessorOutput:
51
59
  # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
52
60
  next_token_top_logprobs_val: Optional[List] = None
53
61
  next_token_top_logprobs_idx: Optional[List] = None
62
+ # The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
63
+ next_token_token_ids_logprobs_val: Optional[List] = None
64
+ next_token_token_ids_logprobs_idx: Optional[List] = None
54
65
 
55
66
  ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
56
67
  # The logprobs of input tokens. shape: [#token]
57
- input_token_logprobs: torch.Tensor = None
68
+ input_token_logprobs: Optional[torch.Tensor] = None
58
69
  # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
59
70
  input_top_logprobs_val: List = None
60
71
  input_top_logprobs_idx: List = None
72
+ # The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
73
+ input_token_ids_logprobs_val: Optional[List] = None
74
+ input_token_ids_logprobs_idx: Optional[List] = None
61
75
 
62
76
 
63
77
  @dataclasses.dataclass
@@ -67,43 +81,114 @@ class LogitsMetadata:
67
81
 
68
82
  extend_return_logprob: bool = False
69
83
  extend_return_top_logprob: bool = False
84
+ extend_token_ids_logprob: bool = False
70
85
  extend_seq_lens: Optional[torch.Tensor] = None
71
86
  extend_seq_lens_cpu: Optional[List[int]] = None
72
87
  extend_logprob_start_lens_cpu: Optional[List[int]] = None
73
88
  extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
74
89
  top_logprobs_nums: Optional[List[int]] = None
90
+ extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
91
+ token_ids_logprobs: Optional[List[List[int]]] = None
92
+
93
+ # logits and logprobs post processing
94
+ temp_scaled_logprobs: bool = False
95
+ temperature: torch.Tensor = None
96
+ top_p_normalized_logprobs: bool = False
97
+ top_p: torch.Tensor = None
98
+
99
+ # DP attention metadata. Not needed when DP attention is not used.
100
+ # Number of tokens in the request.
101
+ global_num_tokens_gpu: Optional[torch.Tensor] = None
102
+ # The start position of local hidden states.
103
+ dp_local_start_pos: Optional[torch.Tensor] = None
104
+ dp_local_num_tokens: Optional[torch.Tensor] = None
105
+ gathered_buffer: Optional[torch.Tensor] = None
106
+ # Buffer to gather logits from all ranks.
107
+ forward_batch_gathered_buffer: Optional[torch.Tensor] = None
108
+ # Number of tokens to sample per DP rank
109
+ global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
110
+ global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
111
+
112
+ # for padding
113
+ padded_static_len: int = -1
75
114
 
76
115
  @classmethod
77
116
  def from_forward_batch(cls, forward_batch: ForwardBatch):
78
- if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
79
- extend_return_logprob = True
117
+ if (
118
+ forward_batch.forward_mode.is_extend()
119
+ and forward_batch.return_logprob
120
+ and not forward_batch.forward_mode.is_target_verify()
121
+ ):
80
122
  extend_return_top_logprob = any(
81
123
  x > 0 for x in forward_batch.top_logprobs_nums
82
124
  )
83
- extend_logprob_pruned_lens_cpu = [
84
- extend_len - start_len
85
- for extend_len, start_len in zip(
86
- forward_batch.extend_seq_lens_cpu,
87
- forward_batch.extend_logprob_start_lens_cpu,
88
- )
89
- ]
125
+ extend_token_ids_logprob = any(
126
+ x is not None for x in forward_batch.token_ids_logprobs
127
+ )
128
+ extend_return_logprob = False
129
+ extend_logprob_pruned_lens_cpu = []
130
+ for extend_len, start_len in zip(
131
+ forward_batch.extend_seq_lens_cpu,
132
+ forward_batch.extend_logprob_start_lens_cpu,
133
+ ):
134
+ if extend_len - start_len > 0:
135
+ extend_return_logprob = True
136
+ extend_logprob_pruned_lens_cpu.append(extend_len - start_len)
90
137
  else:
91
138
  extend_return_logprob = extend_return_top_logprob = (
92
- extend_logprob_pruned_lens_cpu
93
- ) = False
139
+ extend_token_ids_logprob
140
+ ) = extend_logprob_pruned_lens_cpu = False
94
141
 
95
142
  return cls(
96
143
  forward_mode=forward_batch.forward_mode,
97
144
  capture_hidden_mode=forward_batch.capture_hidden_mode,
98
145
  extend_return_logprob=extend_return_logprob,
99
146
  extend_return_top_logprob=extend_return_top_logprob,
147
+ extend_token_ids_logprob=extend_token_ids_logprob,
100
148
  extend_seq_lens=forward_batch.extend_seq_lens,
101
149
  extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
102
150
  extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
103
151
  extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
104
152
  top_logprobs_nums=forward_batch.top_logprobs_nums,
153
+ token_ids_logprobs=forward_batch.token_ids_logprobs,
154
+ extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
155
+ padded_static_len=forward_batch.padded_static_len,
156
+ global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
157
+ dp_local_start_pos=forward_batch.dp_local_start_pos,
158
+ dp_local_num_tokens=forward_batch.dp_local_num_tokens,
159
+ gathered_buffer=forward_batch.gathered_buffer,
160
+ forward_batch_gathered_buffer=forward_batch.gathered_buffer,
161
+ global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
162
+ global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
163
+ )
164
+
165
+ def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
166
+ if self.global_num_tokens_for_logprob_cpu is None:
167
+ # we are capturing cuda graph
168
+ return
169
+
170
+ cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
171
+ dp_rank = get_attention_dp_rank()
172
+ if dp_rank == 0:
173
+ dp_local_start_pos = torch.zeros_like(
174
+ self.global_num_tokens_for_logprob_gpu[0]
175
+ )
176
+ else:
177
+ dp_local_start_pos = cumtokens[dp_rank - 1]
178
+ dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
179
+ gathered_buffer = torch.zeros(
180
+ (
181
+ sum(self.global_num_tokens_for_logprob_cpu),
182
+ hidden_states.shape[1],
183
+ ),
184
+ dtype=hidden_states.dtype,
185
+ device=hidden_states.device,
105
186
  )
106
187
 
188
+ self.dp_local_start_pos = dp_local_start_pos
189
+ self.dp_local_num_tokens = dp_local_num_tokens
190
+ self.gathered_buffer = gathered_buffer
191
+
107
192
 
108
193
  class LogitsProcessor(nn.Module):
109
194
  def __init__(
@@ -115,6 +200,9 @@ class LogitsProcessor(nn.Module):
115
200
  self.do_tensor_parallel_all_gather = (
116
201
  not skip_all_gather and get_tensor_model_parallel_world_size() > 1
117
202
  )
203
+ self.do_tensor_parallel_all_gather_dp_attn = (
204
+ self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
205
+ )
118
206
  self.final_logit_softcapping = getattr(
119
207
  self.config, "final_logit_softcapping", None
120
208
  )
@@ -124,6 +212,10 @@ class LogitsProcessor(nn.Module):
124
212
  ):
125
213
  self.final_logit_softcapping = None
126
214
 
215
+ self.debug_tensor_dump_output_folder = global_server_args_dict.get(
216
+ "debug_tensor_dump_output_folder", None
217
+ )
218
+
127
219
  def forward(
128
220
  self,
129
221
  input_ids,
@@ -141,30 +233,74 @@ class LogitsProcessor(nn.Module):
141
233
  ):
142
234
  pruned_states = hidden_states
143
235
  sample_indices = None
236
+ input_logprob_indices = None
144
237
  elif (
145
238
  logits_metadata.forward_mode.is_extend()
146
239
  and not logits_metadata.extend_return_logprob
147
240
  ):
148
241
  # Prefill without input logprobs.
149
- last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
242
+ if logits_metadata.padded_static_len < 0:
243
+ last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
244
+ else:
245
+ # If padding_static length is 5 and extended_seq_lens is [2, 3],
246
+ # then our batch looks like [t00, t01, p, p, p, t10, t11, t12, p, p]
247
+ # and this retrieves t01 and t12, which are the valid last tokens
248
+ idx = torch.arange(
249
+ len(logits_metadata.extend_seq_lens),
250
+ device=logits_metadata.extend_seq_lens.device,
251
+ )
252
+ last_index = (
253
+ idx * logits_metadata.padded_static_len
254
+ + logits_metadata.extend_seq_lens
255
+ - 1
256
+ )
150
257
  pruned_states = hidden_states[last_index]
151
258
  sample_indices = None
259
+ input_logprob_indices = None
152
260
  else:
153
- # Slice the requested tokens to compute logprob
261
+ # Input logprobs are required.
262
+ # Find 3 different indices.
263
+ # 1. pruned_states: hidden states that we want logprobs from.
264
+ # 2. sample_indices: Indices that have sampled tokens.
265
+ # 3. input_logprob_indices: Indices that have input logprob tokens.
154
266
  sample_index_pt = -1
155
267
  sample_indices = []
156
- pt, pruned_states, pruned_input_ids = 0, [], []
157
- for start_len, extend_len in zip(
268
+ input_logprob_indices_pt = 0
269
+ input_logprob_indices = []
270
+ pt, pruned_states = 0, []
271
+ for extend_logprob_start_len, extend_len in zip(
158
272
  logits_metadata.extend_logprob_start_lens_cpu,
159
273
  logits_metadata.extend_seq_lens_cpu,
160
274
  ):
275
+ # It can happen in chunked prefill. We still need to sample 1 token,
276
+ # But we don't want to include it in input logprob.
277
+ if extend_len == extend_logprob_start_len:
278
+ start_len = extend_logprob_start_len - 1
279
+ else:
280
+ start_len = extend_logprob_start_len
281
+
282
+ # We always need at least 1 token to sample because that's required
283
+ # by a caller.
284
+ assert extend_len > start_len
161
285
  pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
286
+ pt += extend_len
162
287
  sample_index_pt += extend_len - start_len
163
288
  sample_indices.append(sample_index_pt)
164
- pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
165
- pt += extend_len
289
+ input_logprob_indices.extend(
290
+ [
291
+ input_logprob_indices_pt + i
292
+ for i in range(extend_len - extend_logprob_start_len)
293
+ ]
294
+ )
295
+ input_logprob_indices_pt += extend_len - start_len
166
296
 
167
297
  pruned_states = torch.cat(pruned_states)
298
+ sample_indices = torch.tensor(
299
+ sample_indices, device=pruned_states.device, dtype=torch.int64
300
+ )
301
+ input_logprob_indices = torch.tensor(
302
+ input_logprob_indices, device=pruned_states.device, dtype=torch.int64
303
+ )
168
304
 
169
305
  # Compute logits for both input and sampled tokens.
170
306
  logits = self._get_logits(pruned_states, lm_head, logits_metadata)
@@ -172,28 +308,51 @@ class LogitsProcessor(nn.Module):
172
308
  logits[sample_indices] if sample_indices is not None else logits
173
309
  )
174
310
 
175
- if (
176
- not logits_metadata.extend_return_logprob
177
- or logits_metadata.capture_hidden_mode.need_capture()
178
- ):
311
+ if self.debug_tensor_dump_output_folder:
312
+ assert (
313
+ not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1
314
+ ), "dp attention + sharded lm_head doesn't support full logits"
315
+ full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
316
+ dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
317
+
318
+ hidden_states_to_store: Optional[torch.Tensor] = None
319
+ if logits_metadata.capture_hidden_mode.need_capture():
320
+ if logits_metadata.capture_hidden_mode.is_full():
321
+ hidden_states_to_store = hidden_states
322
+ elif logits_metadata.capture_hidden_mode.is_last():
323
+ # Get the last token hidden states. If sample_indices is None,
324
+ # pruned states only contain the last tokens already.
325
+ hidden_states_to_store = (
326
+ pruned_states[sample_indices] if sample_indices else pruned_states
327
+ )
328
+ else:
329
+ assert False, "Should never reach"
330
+
331
+ if not logits_metadata.extend_return_logprob:
179
332
  # Decode mode or extend mode without return_logprob.
180
333
  return LogitsProcessorOutput(
181
334
  next_token_logits=sampled_logits,
182
- hidden_states=(
183
- hidden_states
184
- if logits_metadata.capture_hidden_mode.is_full()
185
- else (
186
- pruned_states
187
- if logits_metadata.capture_hidden_mode.is_last()
188
- else None
189
- )
190
- ),
335
+ hidden_states=hidden_states_to_store,
191
336
  )
192
337
  else:
193
- input_logprobs = logits
338
+ input_logprobs = logits[input_logprob_indices]
194
339
  del hidden_states, logits
195
340
 
196
341
  # Normalize the logprob w/o temperature, top-p
342
+ pruned_lens = torch.tensor(
343
+ logits_metadata.extend_logprob_pruned_lens_cpu,
344
+ device=input_logprobs.device,
345
+ )
346
+ if logits_metadata.temp_scaled_logprobs:
347
+ logits_metadata.temperature = torch.repeat_interleave(
348
+ logits_metadata.temperature.view(-1),
349
+ pruned_lens,
350
+ ).view(-1, 1)
351
+ if logits_metadata.top_p_normalized_logprobs:
352
+ logits_metadata.top_p = torch.repeat_interleave(
353
+ logits_metadata.top_p,
354
+ pruned_lens,
355
+ )
197
356
  input_logprobs = self.compute_temp_top_p_normalized_logprobs(
198
357
  input_logprobs, logits_metadata
199
358
  )
@@ -207,14 +366,18 @@ class LogitsProcessor(nn.Module):
207
366
  else:
208
367
  input_top_logprobs_val = input_top_logprobs_idx = None
209
368
 
369
+ # Get the logprob of given token id
370
+ if logits_metadata.extend_token_ids_logprob:
371
+ (
372
+ input_token_ids_logprobs_val,
373
+ input_token_ids_logprobs_idx,
374
+ ) = self.get_token_ids_logprobs(input_logprobs, logits_metadata)
375
+ else:
376
+ input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None
377
+
210
378
  input_token_logprobs = input_logprobs[
211
379
  torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
212
- torch.cat(
213
- [
214
- torch.cat(pruned_input_ids)[1:],
215
- torch.tensor([0], device=input_logprobs.device),
216
- ]
217
- ),
380
+ logits_metadata.extend_input_logprob_token_ids_gpu,
218
381
  ]
219
382
 
220
383
  return LogitsProcessorOutput(
@@ -222,6 +385,9 @@ class LogitsProcessor(nn.Module):
222
385
  input_token_logprobs=input_token_logprobs,
223
386
  input_top_logprobs_val=input_top_logprobs_val,
224
387
  input_top_logprobs_idx=input_top_logprobs_idx,
388
+ hidden_states=hidden_states_to_store,
389
+ input_token_ids_logprobs_val=input_token_ids_logprobs_val,
390
+ input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
225
391
  )
226
392
 
227
393
  def _get_logits(
@@ -231,13 +397,27 @@ class LogitsProcessor(nn.Module):
231
397
  logits_metadata: LogitsMetadata,
232
398
  embedding_bias: Optional[torch.Tensor] = None,
233
399
  ) -> torch.Tensor:
234
- """Get logits from hidden_states."""
400
+ """Get logits from hidden_states.
401
+
402
+ If sampled_logits_only is True, it means hidden_states only contain the
403
+ last position (e.g., extend without input logprobs). The caller should
404
+ guarantee the given hidden_states follow this constraint.
405
+ """
406
+ if self.do_tensor_parallel_all_gather_dp_attn:
407
+ logits_metadata.compute_dp_attention_metadata(hidden_states)
408
+ hidden_states, local_hidden_states = (
409
+ logits_metadata.gathered_buffer,
410
+ hidden_states.clone(),
411
+ )
412
+ dp_gather(hidden_states, local_hidden_states, logits_metadata, "embedding")
235
413
 
236
414
  if hasattr(lm_head, "weight"):
237
- logits = torch.matmul(hidden_states, lm_head.weight.T)
415
+ logits = torch.matmul(
416
+ hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
417
+ )
238
418
  else:
239
419
  # GGUF models
240
- logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
420
+ logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
241
421
 
242
422
  if self.logit_scale is not None:
243
423
  logits.mul_(self.logit_scale)
@@ -245,6 +425,17 @@ class LogitsProcessor(nn.Module):
245
425
  if self.do_tensor_parallel_all_gather:
246
426
  logits = tensor_model_parallel_all_gather(logits)
247
427
 
428
+ if self.do_tensor_parallel_all_gather_dp_attn:
429
+ logits, global_logits = (
430
+ torch.empty(
431
+ (local_hidden_states.shape[0], logits.shape[1]),
432
+ device=logits.device,
433
+ dtype=logits.dtype,
434
+ ),
435
+ logits,
436
+ )
437
+ dp_scatter(logits, global_logits, logits_metadata)
438
+
248
439
  logits = logits[:, : self.config.vocab_size].float()
249
440
 
250
441
  if self.final_logit_softcapping:
@@ -272,21 +463,66 @@ class LogitsProcessor(nn.Module):
272
463
  continue
273
464
 
274
465
  input_top_logprobs_val.append(
275
- [values[pt + j][:k] for j in range(pruned_len - 1)]
466
+ [values[pt + j][:k] for j in range(pruned_len)]
276
467
  )
277
468
  input_top_logprobs_idx.append(
278
- [indices[pt + j][:k] for j in range(pruned_len - 1)]
469
+ [indices[pt + j][:k] for j in range(pruned_len)]
279
470
  )
280
471
  pt += pruned_len
281
472
 
282
473
  return input_top_logprobs_val, input_top_logprobs_idx
283
474
 
475
+ @staticmethod
476
+ def get_token_ids_logprobs(
477
+ all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
478
+ ):
479
+ input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
480
+ pt = 0
481
+ for token_ids, pruned_len in zip(
482
+ logits_metadata.token_ids_logprobs,
483
+ logits_metadata.extend_logprob_pruned_lens_cpu,
484
+ ):
485
+ if pruned_len <= 0:
486
+ input_token_ids_logprobs_val.append([])
487
+ input_token_ids_logprobs_idx.append([])
488
+ continue
489
+
490
+ input_token_ids_logprobs_val.append(
491
+ [all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
492
+ )
493
+ input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
494
+ pt += pruned_len
495
+
496
+ return input_token_ids_logprobs_val, input_token_ids_logprobs_idx
497
+
284
498
  @staticmethod
285
499
  def compute_temp_top_p_normalized_logprobs(
286
500
  last_logits: torch.Tensor, logits_metadata: LogitsMetadata
287
501
  ) -> torch.Tensor:
288
- # TODO: Implement the temp and top-p normalization
289
- return torch.nn.functional.log_softmax(last_logits, dim=-1)
502
+ """
503
+ compute logprobs for the output token from the given logits.
504
+
505
+ Returns:
506
+ torch.Tensor: logprobs from logits
507
+ """
508
+ # Scale logits if temperature scaling is enabled
509
+ if logits_metadata.temp_scaled_logprobs:
510
+ last_logits = last_logits / logits_metadata.temperature
511
+
512
+ # Normalize logprobs if top_p normalization is enabled
513
+ # NOTE: only normalize logprobs when top_p is set and not equal to 1.0
514
+ if (
515
+ logits_metadata.top_p_normalized_logprobs
516
+ and (logits_metadata.top_p != 1.0).any()
517
+ ):
518
+ from sglang.srt.layers.sampler import top_p_normalize_probs_torch
519
+
520
+ probs = torch.softmax(last_logits, dim=-1)
521
+ del last_logits
522
+ probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p)
523
+ return torch.log(probs)
524
+ else:
525
+ return torch.nn.functional.log_softmax(last_logits, dim=-1)
290
526
 
291
527
 
292
528
  @triton.jit
@@ -1,10 +1,17 @@
1
1
  import logging
2
- from typing import Optional
2
+ from typing import List, Optional
3
3
 
4
4
  import torch
5
5
  import triton
6
6
  import triton.language as tl
7
7
 
8
+ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
9
+
10
+ _is_cuda = torch.cuda.is_available() and torch.version.cuda
11
+ if _is_cuda:
12
+ from sglang.srt.layers.quantization.fp8_kernel import (
13
+ sglang_per_token_group_quant_fp8,
14
+ )
8
15
  logger = logging.getLogger(__name__)
9
16
 
10
17
 
@@ -137,6 +144,73 @@ def silu_and_mul_triton_kernel(
137
144
  tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
138
145
 
139
146
 
147
+ @triton.jit
148
+ def tanh(x):
149
+ return 2 * tl.sigmoid(2 * x) - 1
150
+
151
+
152
+ @triton.jit
153
+ def gelu_and_mul_triton_kernel(
154
+ gateup_output,
155
+ down_input,
156
+ hidden_size,
157
+ reorder_topk_ids,
158
+ scales,
159
+ start_expert_id,
160
+ end_expert_id,
161
+ BLOCK_SIZE: tl.constexpr,
162
+ ):
163
+ InDtype = gateup_output.dtype.element_ty
164
+ OutDtype = down_input.dtype.element_ty
165
+
166
+ half_hidden_size = hidden_size // 2
167
+
168
+ pid = tl.program_id(0)
169
+ expert_id = tl.load(reorder_topk_ids + pid)
170
+ if expert_id >= start_expert_id and expert_id <= end_expert_id:
171
+ gateup_output_ptr = gateup_output + pid * hidden_size
172
+ gate_output_ptr = gateup_output_ptr
173
+ up_output_ptr = gateup_output_ptr + half_hidden_size
174
+ down_input_ptr = down_input + pid * half_hidden_size
175
+
176
+ if scales is not None:
177
+ scale = tl.load(scales + expert_id - start_expert_id)
178
+ scale = (1 / scale).to(InDtype)
179
+ else:
180
+ scale = 1
181
+
182
+ for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
183
+ offset = start_offset + tl.arange(0, BLOCK_SIZE)
184
+ mask = offset < half_hidden_size
185
+
186
+ gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
187
+ up_output = tl.load(up_output_ptr + offset, mask=mask)
188
+
189
+ # gelu & mul & quantize
190
+ # https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
191
+ # sqrt(2/pi)
192
+ kAlpha = 0.7978845608028654
193
+ gate_output = (
194
+ 0.5
195
+ * gate_output
196
+ * (
197
+ 1
198
+ + tanh(
199
+ kAlpha
200
+ * (
201
+ gate_output
202
+ + 0.044715 * gate_output * gate_output * gate_output
203
+ )
204
+ )
205
+ )
206
+ )
207
+ gate_output = gate_output.to(InDtype)
208
+
209
+ gelu_mul_output = gate_output * up_output * scale
210
+ gelu_mul_output = gelu_mul_output.to(OutDtype)
211
+ tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
212
+
213
+
140
214
  @triton.jit
141
215
  def post_reorder_triton_kernel(
142
216
  down_output_ptr,
@@ -218,12 +292,19 @@ def grouped_gemm_triton_kernel(
218
292
  seg_indptr,
219
293
  weight_indices,
220
294
  m_num_tiles_indptr,
221
- use_fp8_w8a8,
222
295
  scale_a,
223
296
  scale_b,
297
+ use_fp8_w8a8: tl.constexpr,
298
+ group_n: tl.constexpr,
299
+ group_k: tl.constexpr,
224
300
  a_stride_0: tl.constexpr,
225
301
  b_stride_0: tl.constexpr,
226
302
  b_stride_1: tl.constexpr,
303
+ as_stride_0: tl.constexpr,
304
+ as_stride_1: tl.constexpr,
305
+ bs_stride_0: tl.constexpr,
306
+ bs_stride_2: tl.constexpr,
307
+ bs_stride_1: tl.constexpr,
227
308
  BLOCK_SIZE_M: tl.constexpr,
228
309
  BLOCK_SIZE_N: tl.constexpr,
229
310
  BLOCK_SIZE_K: tl.constexpr,
@@ -260,6 +341,12 @@ def grouped_gemm_triton_kernel(
260
341
  + (n_range_start + offs_bn[:, None]) * b_stride_1
261
342
  + offs_k[None, :]
262
343
  )
344
+
345
+ if group_k > 0 and group_n > 0:
346
+ a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0
347
+ offs_bsn = (n_range_start + offs_bn) // group_n
348
+ b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1
349
+
263
350
  accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
264
351
  for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
265
352
  a_tile = tl.load(
@@ -268,14 +355,23 @@ def grouped_gemm_triton_kernel(
268
355
  b_tile = tl.load(
269
356
  b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
270
357
  )
271
- accumulator = tl.dot(a_tile, b_tile.T, accumulator)
358
+
359
+ if group_k > 0 and group_n > 0:
360
+ k_start = k * BLOCK_SIZE_K
361
+ offs_ks = k_start // group_k
362
+ a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1)
363
+ b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2)
364
+ accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :]
365
+ else:
366
+ accumulator = tl.dot(a_tile, b_tile.T, accumulator)
272
367
  a_ptr += BLOCK_SIZE_K
273
368
  b_ptr += BLOCK_SIZE_K
274
369
 
275
- if use_fp8_w8a8:
370
+ if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
276
371
  scale_a_value = tl.load(scale_a + expert_id)
277
372
  scale_b_value = tl.load(scale_b + expert_id)
278
373
  accumulator *= scale_a_value * scale_b_value
374
+
279
375
  c_tile = accumulator.to(c_dtype)
280
376
 
281
377
  offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
@@ -307,14 +403,29 @@ def grouped_gemm_triton(
307
403
  use_fp8_w8a8: bool = False,
308
404
  scale_a: torch.Tensor = None,
309
405
  scale_b: torch.Tensor = None,
406
+ block_shape: Optional[List[int]] = None,
310
407
  ):
311
408
  assert weight_column_major == True # TODO: more
312
- if use_fp8_w8a8:
409
+ if use_fp8_w8a8 and block_shape is None:
313
410
  assert scale_a is not None and scale_b is not None
314
411
 
412
+ if block_shape is not None:
413
+ assert len(block_shape) == 2
414
+ block_n, block_k = block_shape[0], block_shape[1]
415
+ if _is_cuda:
416
+ a, scale_a = sglang_per_token_group_quant_fp8(a, block_k)
417
+ else:
418
+ a, scale_a = per_token_group_quant_fp8(a, block_k)
419
+
420
+ assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
421
+ assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
422
+ assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
423
+
424
+ # TODO: adjust config or tune kernel
425
+ # Reduce block size to prevent L40 shared memory overflow.
315
426
  config = {
316
- "BLOCK_SIZE_M": 128,
317
- "BLOCK_SIZE_N": 128,
427
+ "BLOCK_SIZE_M": 64,
428
+ "BLOCK_SIZE_N": 32,
318
429
  "BLOCK_SIZE_K": 128,
319
430
  }
320
431
 
@@ -338,12 +449,19 @@ def grouped_gemm_triton(
338
449
  seg_indptr,
339
450
  weight_indices,
340
451
  m_num_tiles_indptr,
341
- use_fp8_w8a8,
342
452
  scale_a,
343
453
  scale_b,
454
+ use_fp8_w8a8,
455
+ 0 if block_shape is None else block_shape[0],
456
+ 0 if block_shape is None else block_shape[1],
344
457
  a.stride(0),
345
458
  b.stride(0),
346
459
  b.stride(1),
460
+ scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0,
461
+ scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0,
462
+ scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0,
463
+ scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0,
464
+ scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0,
347
465
  **config,
348
466
  )
349
467
  return c