sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,186 @@
1
+ """
2
+ Memory-efficient attention for decoding.
3
+ It supports page size = 1.
4
+ """
5
+
6
+ import functools
7
+ import logging
8
+
9
+ from wave_lang.kernel.lang.global_symbols import *
10
+ from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
11
+ from wave_lang.kernel.wave.constraints import GenericDot, MMAOperand, MMAType
12
+ from wave_lang.kernel.wave.templates.paged_decode_attention import (
13
+ get_paged_decode_attention_kernels,
14
+ get_paged_decode_intermediate_arrays_shapes,
15
+ paged_decode_attention_shape,
16
+ )
17
+ from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
18
+ from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
19
+
20
+ logger = logging.getLogger(__name__)
21
+ import os
22
+
23
+ dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
24
+
25
+
26
+ @functools.lru_cache(maxsize=4096)
27
+ def get_wave_kernel(
28
+ shape: paged_decode_attention_shape,
29
+ max_kv_splits,
30
+ input_dtype,
31
+ output_dtype,
32
+ logit_cap,
33
+ ):
34
+ mha = (shape.num_query_heads // shape.num_kv_heads) == 1
35
+
36
+ # Get the kernels (either compile or load from cache).
37
+ if mha:
38
+ mfma_variant = (
39
+ GenericDot(along_dim=MMAOperand.M, k_vec_size=4, k_mult=1),
40
+ GenericDot(along_dim=MMAOperand.M, k_vec_size=1, k_mult=64),
41
+ )
42
+ else:
43
+ mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)
44
+
45
+ (
46
+ phase_0,
47
+ phase_1,
48
+ hyperparams_0,
49
+ hyperparams_1,
50
+ dynamic_symbols_0,
51
+ dynamic_symbols_1,
52
+ ) = get_paged_decode_attention_kernels(
53
+ shape,
54
+ mfma_variant,
55
+ max_kv_splits,
56
+ input_dtype=input_dtype,
57
+ output_dtype=output_dtype,
58
+ logit_cap=logit_cap,
59
+ )
60
+ hyperparams_0.update(get_default_scheduling_params())
61
+ hyperparams_1.update(get_default_scheduling_params())
62
+
63
+ options = WaveCompileOptions(
64
+ subs=hyperparams_0,
65
+ canonicalize=True,
66
+ run_bench=False,
67
+ use_buffer_load_ops=True,
68
+ use_buffer_store_ops=True,
69
+ waves_per_eu=2,
70
+ dynamic_symbols=dynamic_symbols_0,
71
+ wave_runtime=True,
72
+ )
73
+ options = set_default_run_config(options)
74
+ phase_0 = wave_compile(options, phase_0)
75
+
76
+ options = WaveCompileOptions(
77
+ subs=hyperparams_1,
78
+ canonicalize=True,
79
+ run_bench=False,
80
+ use_buffer_load_ops=False,
81
+ use_buffer_store_ops=False,
82
+ waves_per_eu=4,
83
+ dynamic_symbols=dynamic_symbols_1,
84
+ wave_runtime=True,
85
+ )
86
+ options = set_default_run_config(options)
87
+ phase_1 = wave_compile(options, phase_1)
88
+
89
+ return phase_0, phase_1
90
+
91
+
92
+ def decode_attention_intermediate_arrays_shapes(
93
+ num_seqs, head_size_kv, num_query_heads, max_kv_splits
94
+ ):
95
+ # Not all fields are used, but we need to pass them to the function
96
+ shape = paged_decode_attention_shape(
97
+ num_query_heads=num_query_heads,
98
+ num_kv_heads=0,
99
+ head_size=0,
100
+ head_size_kv=head_size_kv,
101
+ block_size=0,
102
+ num_seqs=num_seqs,
103
+ )
104
+ return get_paged_decode_intermediate_arrays_shapes(shape, max_kv_splits)
105
+
106
+
107
+ def decode_attention_wave(
108
+ q,
109
+ k_buffer,
110
+ v_buffer,
111
+ o,
112
+ b_req_idx,
113
+ req_to_token,
114
+ attn_logits,
115
+ attn_logits_max,
116
+ num_kv_splits,
117
+ max_kv_splits,
118
+ sm_scale,
119
+ logit_cap,
120
+ ):
121
+ num_seqs, num_query_heads, head_size = q.shape
122
+ _, num_kv_heads, _ = k_buffer.shape
123
+ _, _, head_size_kv = v_buffer.shape
124
+ block_size = 32
125
+ shape = paged_decode_attention_shape(
126
+ num_query_heads,
127
+ num_kv_heads,
128
+ head_size,
129
+ head_size_kv,
130
+ block_size,
131
+ num_seqs,
132
+ )
133
+
134
+ phase_0, phase_1 = get_wave_kernel(
135
+ shape, max_kv_splits, q.dtype, o.dtype, logit_cap
136
+ )
137
+
138
+ mb_qk = phase_0(
139
+ q,
140
+ k_buffer,
141
+ v_buffer,
142
+ b_req_idx,
143
+ req_to_token,
144
+ attn_logits,
145
+ attn_logits_max,
146
+ )
147
+ if dump_generated_mlir:
148
+ filename = f"wave_decode_attention_phase0_{'x'.join(map(str, shape))}.mlir"
149
+ with open(filename, "w") as f:
150
+ f.write(mb_qk.module_op.get_asm())
151
+
152
+ mb_sv = phase_1(attn_logits, attn_logits_max, b_req_idx, o)
153
+ if dump_generated_mlir:
154
+ filename = f"wave_decode_attention_phase1_{'x'.join(map(str, shape))}.mlir"
155
+ with open(filename, "w") as f:
156
+ f.write(mb_sv.module_op.get_asm())
157
+
158
+
159
+ def decode_attention_fwd(
160
+ q,
161
+ k_buffer,
162
+ v_buffer,
163
+ o,
164
+ b_req_idx,
165
+ req_to_token,
166
+ attn_logits,
167
+ attn_logits_max,
168
+ num_kv_splits,
169
+ max_kv_splits,
170
+ sm_scale,
171
+ logit_cap=0.0,
172
+ ):
173
+ decode_attention_wave(
174
+ q,
175
+ k_buffer,
176
+ v_buffer,
177
+ o,
178
+ b_req_idx,
179
+ req_to_token,
180
+ attn_logits,
181
+ attn_logits_max,
182
+ num_kv_splits,
183
+ max_kv_splits,
184
+ sm_scale,
185
+ logit_cap,
186
+ )
@@ -0,0 +1,149 @@
1
+ """
2
+ Memory-efficient attention for prefill.
3
+ It support page size = 1.
4
+ """
5
+
6
+ import functools
7
+ import os
8
+
9
+ import torch
10
+ from wave_lang.kernel.lang.global_symbols import *
11
+ from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
12
+ from wave_lang.kernel.wave.constraints import MMAType
13
+ from wave_lang.kernel.wave.scheduling.schedule import SchedulingType
14
+ from wave_lang.kernel.wave.templates.attention_common import AttentionShape
15
+ from wave_lang.kernel.wave.templates.extend_attention import get_extend_attention_kernel
16
+ from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
17
+ from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
18
+
19
+ dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
20
+
21
+
22
+ @functools.lru_cache
23
+ def get_wave_kernel(
24
+ shape: AttentionShape,
25
+ q_shape: tuple[int],
26
+ k_shape: tuple[int],
27
+ v_shape: tuple[int],
28
+ k_cache_shape: tuple[int],
29
+ v_cache_shape: tuple[int],
30
+ o_shape: tuple[int],
31
+ input_dtype: torch.dtype,
32
+ output_dtype: torch.dtype,
33
+ size_dtype: torch.dtype,
34
+ is_causal: bool,
35
+ logit_cap: float,
36
+ layer_scaling: float,
37
+ ):
38
+ assert shape.num_query_heads % shape.num_kv_heads == 0
39
+
40
+ mfma_variant = (MMAType.F32_16x16x32_K8_F16, MMAType.F32_16x16x16_F16)
41
+ (
42
+ extend_attention,
43
+ hyperparams,
44
+ dynamic_symbols,
45
+ ) = get_extend_attention_kernel(
46
+ shape,
47
+ mfma_variant,
48
+ q_shape,
49
+ k_shape,
50
+ v_shape,
51
+ k_cache_shape,
52
+ v_cache_shape,
53
+ o_shape,
54
+ input_dtype=input_dtype,
55
+ output_dtype=output_dtype,
56
+ size_dtype=size_dtype,
57
+ is_causal=is_causal,
58
+ layer_scaling=layer_scaling,
59
+ logit_cap=logit_cap,
60
+ )
61
+
62
+ hyperparams.update(get_default_scheduling_params())
63
+ options = WaveCompileOptions(
64
+ subs=hyperparams,
65
+ canonicalize=True,
66
+ run_bench=False,
67
+ schedule=SchedulingType.NONE,
68
+ use_scheduling_barriers=False,
69
+ dynamic_symbols=dynamic_symbols,
70
+ use_buffer_load_ops=True,
71
+ use_buffer_store_ops=True,
72
+ waves_per_eu=2,
73
+ denorm_fp_math_f32="preserve-sign",
74
+ gpu_native_math_precision=True,
75
+ wave_runtime=True,
76
+ )
77
+ options = set_default_run_config(options)
78
+ extend_attention = wave_compile(options, extend_attention)
79
+
80
+ return extend_attention
81
+
82
+
83
+ def extend_attention_wave(
84
+ q_extend,
85
+ k_extend,
86
+ v_extend,
87
+ k_buffer,
88
+ v_buffer,
89
+ qo_indptr,
90
+ kv_indptr,
91
+ kv_indices,
92
+ custom_mask,
93
+ mask_indptr,
94
+ max_seq_len,
95
+ output,
96
+ is_causal=True,
97
+ layer_scaling=None,
98
+ logit_cap=0,
99
+ ):
100
+ shape = AttentionShape(
101
+ num_query_heads=q_extend.shape[1],
102
+ num_kv_heads=k_extend.shape[1],
103
+ head_size=q_extend.shape[2],
104
+ head_size_kv=k_extend.shape[2],
105
+ num_seqs=kv_indptr.shape[0] - 1,
106
+ max_seq_len=max_seq_len,
107
+ )
108
+
109
+ # Run the wave kernel.
110
+ extend_attention = get_wave_kernel(
111
+ shape,
112
+ q_extend.shape,
113
+ k_extend.shape,
114
+ v_extend.shape,
115
+ k_buffer.shape,
116
+ v_buffer.shape,
117
+ output.shape,
118
+ input_dtype=q_extend.dtype,
119
+ output_dtype=output.dtype,
120
+ size_dtype=qo_indptr.dtype,
121
+ is_causal=is_causal,
122
+ layer_scaling=layer_scaling,
123
+ logit_cap=logit_cap,
124
+ )
125
+
126
+ mb = extend_attention(
127
+ q_extend,
128
+ k_extend,
129
+ v_extend,
130
+ k_buffer,
131
+ v_buffer,
132
+ qo_indptr,
133
+ kv_indptr,
134
+ kv_indices,
135
+ max_seq_len,
136
+ output,
137
+ )
138
+
139
+ if dump_generated_mlir:
140
+ shape_list = [
141
+ q_extend.shape[0],
142
+ q_extend.shape[1],
143
+ k_extend.shape[1],
144
+ q_extend.shape[2],
145
+ k_extend.shape[2],
146
+ ]
147
+ filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir"
148
+ with open(filename, "w") as f:
149
+ f.write(mb.module_op.get_asm())
@@ -0,0 +1,79 @@
1
+ """
2
+ Memory-efficient attention for prefill.
3
+ It support page size = 1.
4
+ """
5
+
6
+ import math
7
+ import os
8
+
9
+ from wave_lang.kernel.lang.global_symbols import *
10
+ from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
11
+ from wave_lang.kernel.wave.constraints import MMAType
12
+ from wave_lang.kernel.wave.templates.attention_common import AttentionShape
13
+ from wave_lang.kernel.wave.templates.prefill_attention import (
14
+ get_prefill_attention_kernel,
15
+ )
16
+ from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
17
+ from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
18
+
19
+ dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
20
+
21
+
22
+ def prefill_attention_wave(
23
+ q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=True
24
+ ):
25
+
26
+ shape = AttentionShape(
27
+ num_query_heads=q.shape[1],
28
+ num_kv_heads=k.shape[1],
29
+ head_size=q.shape[2],
30
+ head_size_kv=k.shape[2],
31
+ num_seqs=b_seq_len.shape[0],
32
+ max_seq_len=max_seq_len,
33
+ total_seq_len=q.shape[0],
34
+ )
35
+
36
+ assert shape.num_query_heads % shape.num_kv_heads == 0
37
+
38
+ output_shape = (shape.total_seq_len, shape.num_query_heads, shape.head_size_kv)
39
+ # Run the wave kernel.
40
+ mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)
41
+ (prefill, hyperparams) = get_prefill_attention_kernel(
42
+ shape,
43
+ mfma_variant,
44
+ q.shape,
45
+ k.shape,
46
+ v.shape,
47
+ output_shape,
48
+ input_dtype=q.dtype,
49
+ output_dtype=o.dtype,
50
+ size_dtype=b_seq_len.dtype,
51
+ )
52
+
53
+ hyperparams.update(get_default_scheduling_params())
54
+
55
+ log2e = 1.44269504089
56
+ dk_sqrt = math.sqrt(1.0 / shape.head_size)
57
+
58
+ options = WaveCompileOptions(
59
+ subs=hyperparams,
60
+ canonicalize=True,
61
+ run_bench=False,
62
+ use_scheduling_barriers=False,
63
+ )
64
+ options = set_default_run_config(options)
65
+ prefill = wave_compile(options, prefill)
66
+
67
+ mb = prefill(
68
+ q * dk_sqrt * log2e,
69
+ k,
70
+ v,
71
+ b_start_loc,
72
+ b_seq_len,
73
+ o,
74
+ )
75
+ if dump_generated_mlir:
76
+ shape_list = [q.shape[0], q.shape[1], k.shape[1], q.shape[2], k.shape[2]]
77
+ filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir"
78
+ with open(filename, "w") as f:
79
+ f.write(mb.module_op.get_asm())
@@ -32,6 +32,8 @@ from sglang.srt.layers.dp_attention import (
32
32
  get_attention_dp_size,
33
33
  get_attention_tp_rank,
34
34
  get_attention_tp_size,
35
+ get_global_dp_buffer,
36
+ get_local_dp_buffer,
35
37
  )
36
38
  from sglang.srt.layers.utils import is_sm100_supported
37
39
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -319,7 +321,7 @@ class CommunicateSimpleFn:
319
321
  context: CommunicateContext,
320
322
  ) -> torch.Tensor:
321
323
  hidden_states, local_hidden_states = (
322
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
324
+ get_local_dp_buffer(),
323
325
  hidden_states,
324
326
  )
325
327
  attn_tp_all_gather_into_tensor(
@@ -408,9 +410,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
408
410
  ):
409
411
  if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
410
412
  residual, local_residual = (
411
- forward_batch.gathered_buffer[
412
- : forward_batch.input_ids.shape[0]
413
- ].clone(),
413
+ get_local_dp_buffer(),
414
414
  residual,
415
415
  )
416
416
  attn_tp_all_gather_into_tensor(residual, local_residual)
@@ -420,13 +420,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
420
420
 
421
421
  # Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
422
422
  use_layer_norm_before_gather = context.attn_tp_size == 1
423
- if use_layer_norm_before_gather:
424
- residual.copy_(hidden_states)
425
- if hidden_states.shape[0] != 0:
426
- hidden_states = layernorm(hidden_states)
427
-
423
+ if use_layer_norm_before_gather and hidden_states.shape[0] != 0:
424
+ residual = hidden_states
425
+ hidden_states = layernorm(hidden_states)
428
426
  hidden_states, local_hidden_states = (
429
- forward_batch.gathered_buffer,
427
+ get_global_dp_buffer(),
430
428
  hidden_states,
431
429
  )
432
430
  dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
@@ -443,7 +441,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
443
441
  and _is_flashinfer_available
444
442
  and hasattr(layernorm, "forward_with_allreduce_fusion")
445
443
  and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
446
- and hidden_states.shape[0] <= 128
444
+ and hidden_states.shape[0] <= 2048
447
445
  ):
448
446
  hidden_states, residual = layernorm.forward_with_allreduce_fusion(
449
447
  hidden_states, residual
@@ -550,7 +548,7 @@ class CommunicateSummableTensorPairFn:
550
548
  allow_reduce_scatter: bool = False,
551
549
  ):
552
550
  hidden_states, global_hidden_states = (
553
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
551
+ get_local_dp_buffer(),
554
552
  hidden_states,
555
553
  )
556
554
  if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
@@ -571,7 +569,7 @@ class CommunicateSummableTensorPairFn:
571
569
  hidden_states += residual
572
570
  residual = None
573
571
  hidden_states, local_hidden_states = (
574
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
572
+ get_local_dp_buffer(),
575
573
  hidden_states,
576
574
  )
577
575
  attn_tp_all_gather_into_tensor(