sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,186 @@
1
+ """
2
+ Memory-efficient attention for decoding.
3
+ It supports page size = 1.
4
+ """
5
+
6
+ import functools
7
+ import logging
8
+
9
+ from wave_lang.kernel.lang.global_symbols import *
10
+ from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
11
+ from wave_lang.kernel.wave.constraints import GenericDot, MMAOperand, MMAType
12
+ from wave_lang.kernel.wave.templates.paged_decode_attention import (
13
+ get_paged_decode_attention_kernels,
14
+ get_paged_decode_intermediate_arrays_shapes,
15
+ paged_decode_attention_shape,
16
+ )
17
+ from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
18
+ from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
19
+
20
+ logger = logging.getLogger(__name__)
21
+ import os
22
+
23
+ dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
24
+
25
+
26
+ @functools.lru_cache(maxsize=4096)
27
+ def get_wave_kernel(
28
+ shape: paged_decode_attention_shape,
29
+ max_kv_splits,
30
+ input_dtype,
31
+ output_dtype,
32
+ logit_cap,
33
+ ):
34
+ mha = (shape.num_query_heads // shape.num_kv_heads) == 1
35
+
36
+ # Get the kernels (either compile or load from cache).
37
+ if mha:
38
+ mfma_variant = (
39
+ GenericDot(along_dim=MMAOperand.M, k_vec_size=4, k_mult=1),
40
+ GenericDot(along_dim=MMAOperand.M, k_vec_size=1, k_mult=64),
41
+ )
42
+ else:
43
+ mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)
44
+
45
+ (
46
+ phase_0,
47
+ phase_1,
48
+ hyperparams_0,
49
+ hyperparams_1,
50
+ dynamic_symbols_0,
51
+ dynamic_symbols_1,
52
+ ) = get_paged_decode_attention_kernels(
53
+ shape,
54
+ mfma_variant,
55
+ max_kv_splits,
56
+ input_dtype=input_dtype,
57
+ output_dtype=output_dtype,
58
+ logit_cap=logit_cap,
59
+ )
60
+ hyperparams_0.update(get_default_scheduling_params())
61
+ hyperparams_1.update(get_default_scheduling_params())
62
+
63
+ options = WaveCompileOptions(
64
+ subs=hyperparams_0,
65
+ canonicalize=True,
66
+ run_bench=False,
67
+ use_buffer_load_ops=True,
68
+ use_buffer_store_ops=True,
69
+ waves_per_eu=2,
70
+ dynamic_symbols=dynamic_symbols_0,
71
+ wave_runtime=True,
72
+ )
73
+ options = set_default_run_config(options)
74
+ phase_0 = wave_compile(options, phase_0)
75
+
76
+ options = WaveCompileOptions(
77
+ subs=hyperparams_1,
78
+ canonicalize=True,
79
+ run_bench=False,
80
+ use_buffer_load_ops=False,
81
+ use_buffer_store_ops=False,
82
+ waves_per_eu=4,
83
+ dynamic_symbols=dynamic_symbols_1,
84
+ wave_runtime=True,
85
+ )
86
+ options = set_default_run_config(options)
87
+ phase_1 = wave_compile(options, phase_1)
88
+
89
+ return phase_0, phase_1
90
+
91
+
92
+ def decode_attention_intermediate_arrays_shapes(
93
+ num_seqs, head_size_kv, num_query_heads, max_kv_splits
94
+ ):
95
+ # Not all fields are used, but we need to pass them to the function
96
+ shape = paged_decode_attention_shape(
97
+ num_query_heads=num_query_heads,
98
+ num_kv_heads=0,
99
+ head_size=0,
100
+ head_size_kv=head_size_kv,
101
+ block_size=0,
102
+ num_seqs=num_seqs,
103
+ )
104
+ return get_paged_decode_intermediate_arrays_shapes(shape, max_kv_splits)
105
+
106
+
107
+ def decode_attention_wave(
108
+ q,
109
+ k_buffer,
110
+ v_buffer,
111
+ o,
112
+ b_req_idx,
113
+ req_to_token,
114
+ attn_logits,
115
+ attn_logits_max,
116
+ num_kv_splits,
117
+ max_kv_splits,
118
+ sm_scale,
119
+ logit_cap,
120
+ ):
121
+ num_seqs, num_query_heads, head_size = q.shape
122
+ _, num_kv_heads, _ = k_buffer.shape
123
+ _, _, head_size_kv = v_buffer.shape
124
+ block_size = 32
125
+ shape = paged_decode_attention_shape(
126
+ num_query_heads,
127
+ num_kv_heads,
128
+ head_size,
129
+ head_size_kv,
130
+ block_size,
131
+ num_seqs,
132
+ )
133
+
134
+ phase_0, phase_1 = get_wave_kernel(
135
+ shape, max_kv_splits, q.dtype, o.dtype, logit_cap
136
+ )
137
+
138
+ mb_qk = phase_0(
139
+ q,
140
+ k_buffer,
141
+ v_buffer,
142
+ b_req_idx,
143
+ req_to_token,
144
+ attn_logits,
145
+ attn_logits_max,
146
+ )
147
+ if dump_generated_mlir:
148
+ filename = f"wave_decode_attention_phase0_{'x'.join(map(str, shape))}.mlir"
149
+ with open(filename, "w") as f:
150
+ f.write(mb_qk.module_op.get_asm())
151
+
152
+ mb_sv = phase_1(attn_logits, attn_logits_max, b_req_idx, o)
153
+ if dump_generated_mlir:
154
+ filename = f"wave_decode_attention_phase1_{'x'.join(map(str, shape))}.mlir"
155
+ with open(filename, "w") as f:
156
+ f.write(mb_sv.module_op.get_asm())
157
+
158
+
159
+ def decode_attention_fwd(
160
+ q,
161
+ k_buffer,
162
+ v_buffer,
163
+ o,
164
+ b_req_idx,
165
+ req_to_token,
166
+ attn_logits,
167
+ attn_logits_max,
168
+ num_kv_splits,
169
+ max_kv_splits,
170
+ sm_scale,
171
+ logit_cap=0.0,
172
+ ):
173
+ decode_attention_wave(
174
+ q,
175
+ k_buffer,
176
+ v_buffer,
177
+ o,
178
+ b_req_idx,
179
+ req_to_token,
180
+ attn_logits,
181
+ attn_logits_max,
182
+ num_kv_splits,
183
+ max_kv_splits,
184
+ sm_scale,
185
+ logit_cap,
186
+ )
@@ -0,0 +1,149 @@
1
+ """
2
+ Memory-efficient attention for prefill.
3
+ It support page size = 1.
4
+ """
5
+
6
+ import functools
7
+ import os
8
+
9
+ import torch
10
+ from wave_lang.kernel.lang.global_symbols import *
11
+ from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
12
+ from wave_lang.kernel.wave.constraints import MMAType
13
+ from wave_lang.kernel.wave.scheduling.schedule import SchedulingType
14
+ from wave_lang.kernel.wave.templates.attention_common import AttentionShape
15
+ from wave_lang.kernel.wave.templates.extend_attention import get_extend_attention_kernel
16
+ from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
17
+ from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
18
+
19
+ dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
20
+
21
+
22
+ @functools.lru_cache
23
+ def get_wave_kernel(
24
+ shape: AttentionShape,
25
+ q_shape: tuple[int],
26
+ k_shape: tuple[int],
27
+ v_shape: tuple[int],
28
+ k_cache_shape: tuple[int],
29
+ v_cache_shape: tuple[int],
30
+ o_shape: tuple[int],
31
+ input_dtype: torch.dtype,
32
+ output_dtype: torch.dtype,
33
+ size_dtype: torch.dtype,
34
+ is_causal: bool,
35
+ logit_cap: float,
36
+ layer_scaling: float,
37
+ ):
38
+ assert shape.num_query_heads % shape.num_kv_heads == 0
39
+
40
+ mfma_variant = (MMAType.F32_16x16x32_K8_F16, MMAType.F32_16x16x16_F16)
41
+ (
42
+ extend_attention,
43
+ hyperparams,
44
+ dynamic_symbols,
45
+ ) = get_extend_attention_kernel(
46
+ shape,
47
+ mfma_variant,
48
+ q_shape,
49
+ k_shape,
50
+ v_shape,
51
+ k_cache_shape,
52
+ v_cache_shape,
53
+ o_shape,
54
+ input_dtype=input_dtype,
55
+ output_dtype=output_dtype,
56
+ size_dtype=size_dtype,
57
+ is_causal=is_causal,
58
+ layer_scaling=layer_scaling,
59
+ logit_cap=logit_cap,
60
+ )
61
+
62
+ hyperparams.update(get_default_scheduling_params())
63
+ options = WaveCompileOptions(
64
+ subs=hyperparams,
65
+ canonicalize=True,
66
+ run_bench=False,
67
+ schedule=SchedulingType.NONE,
68
+ use_scheduling_barriers=False,
69
+ dynamic_symbols=dynamic_symbols,
70
+ use_buffer_load_ops=True,
71
+ use_buffer_store_ops=True,
72
+ waves_per_eu=2,
73
+ denorm_fp_math_f32="preserve-sign",
74
+ gpu_native_math_precision=True,
75
+ wave_runtime=True,
76
+ )
77
+ options = set_default_run_config(options)
78
+ extend_attention = wave_compile(options, extend_attention)
79
+
80
+ return extend_attention
81
+
82
+
83
+ def extend_attention_wave(
84
+ q_extend,
85
+ k_extend,
86
+ v_extend,
87
+ k_buffer,
88
+ v_buffer,
89
+ qo_indptr,
90
+ kv_indptr,
91
+ kv_indices,
92
+ custom_mask,
93
+ mask_indptr,
94
+ max_seq_len,
95
+ output,
96
+ is_causal=True,
97
+ layer_scaling=None,
98
+ logit_cap=0,
99
+ ):
100
+ shape = AttentionShape(
101
+ num_query_heads=q_extend.shape[1],
102
+ num_kv_heads=k_extend.shape[1],
103
+ head_size=q_extend.shape[2],
104
+ head_size_kv=k_extend.shape[2],
105
+ num_seqs=kv_indptr.shape[0] - 1,
106
+ max_seq_len=max_seq_len,
107
+ )
108
+
109
+ # Run the wave kernel.
110
+ extend_attention = get_wave_kernel(
111
+ shape,
112
+ q_extend.shape,
113
+ k_extend.shape,
114
+ v_extend.shape,
115
+ k_buffer.shape,
116
+ v_buffer.shape,
117
+ output.shape,
118
+ input_dtype=q_extend.dtype,
119
+ output_dtype=output.dtype,
120
+ size_dtype=qo_indptr.dtype,
121
+ is_causal=is_causal,
122
+ layer_scaling=layer_scaling,
123
+ logit_cap=logit_cap,
124
+ )
125
+
126
+ mb = extend_attention(
127
+ q_extend,
128
+ k_extend,
129
+ v_extend,
130
+ k_buffer,
131
+ v_buffer,
132
+ qo_indptr,
133
+ kv_indptr,
134
+ kv_indices,
135
+ max_seq_len,
136
+ output,
137
+ )
138
+
139
+ if dump_generated_mlir:
140
+ shape_list = [
141
+ q_extend.shape[0],
142
+ q_extend.shape[1],
143
+ k_extend.shape[1],
144
+ q_extend.shape[2],
145
+ k_extend.shape[2],
146
+ ]
147
+ filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir"
148
+ with open(filename, "w") as f:
149
+ f.write(mb.module_op.get_asm())
@@ -0,0 +1,79 @@
1
+ """
2
+ Memory-efficient attention for prefill.
3
+ It support page size = 1.
4
+ """
5
+
6
+ import math
7
+ import os
8
+
9
+ from wave_lang.kernel.lang.global_symbols import *
10
+ from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
11
+ from wave_lang.kernel.wave.constraints import MMAType
12
+ from wave_lang.kernel.wave.templates.attention_common import AttentionShape
13
+ from wave_lang.kernel.wave.templates.prefill_attention import (
14
+ get_prefill_attention_kernel,
15
+ )
16
+ from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
17
+ from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
18
+
19
+ dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
20
+
21
+
22
+ def prefill_attention_wave(
23
+ q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=True
24
+ ):
25
+
26
+ shape = AttentionShape(
27
+ num_query_heads=q.shape[1],
28
+ num_kv_heads=k.shape[1],
29
+ head_size=q.shape[2],
30
+ head_size_kv=k.shape[2],
31
+ num_seqs=b_seq_len.shape[0],
32
+ max_seq_len=max_seq_len,
33
+ total_seq_len=q.shape[0],
34
+ )
35
+
36
+ assert shape.num_query_heads % shape.num_kv_heads == 0
37
+
38
+ output_shape = (shape.total_seq_len, shape.num_query_heads, shape.head_size_kv)
39
+ # Run the wave kernel.
40
+ mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)
41
+ (prefill, hyperparams) = get_prefill_attention_kernel(
42
+ shape,
43
+ mfma_variant,
44
+ q.shape,
45
+ k.shape,
46
+ v.shape,
47
+ output_shape,
48
+ input_dtype=q.dtype,
49
+ output_dtype=o.dtype,
50
+ size_dtype=b_seq_len.dtype,
51
+ )
52
+
53
+ hyperparams.update(get_default_scheduling_params())
54
+
55
+ log2e = 1.44269504089
56
+ dk_sqrt = math.sqrt(1.0 / shape.head_size)
57
+
58
+ options = WaveCompileOptions(
59
+ subs=hyperparams,
60
+ canonicalize=True,
61
+ run_bench=False,
62
+ use_scheduling_barriers=False,
63
+ )
64
+ options = set_default_run_config(options)
65
+ prefill = wave_compile(options, prefill)
66
+
67
+ mb = prefill(
68
+ q * dk_sqrt * log2e,
69
+ k,
70
+ v,
71
+ b_start_loc,
72
+ b_seq_len,
73
+ o,
74
+ )
75
+ if dump_generated_mlir:
76
+ shape_list = [q.shape[0], q.shape[1], k.shape[1], q.shape[2], k.shape[2]]
77
+ filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir"
78
+ with open(filename, "w") as f:
79
+ f.write(mb.module_op.get_asm())
@@ -27,6 +27,7 @@ from sglang.srt.layers.dp_attention import (
27
27
  attn_tp_all_gather_into_tensor,
28
28
  attn_tp_reduce_scatter_tensor,
29
29
  dp_gather_partial,
30
+ dp_reduce_scatter_tensor,
30
31
  dp_scatter,
31
32
  get_attention_dp_size,
32
33
  get_attention_tp_rank,
@@ -149,10 +150,13 @@ class LayerCommunicator:
149
150
  layer_scatter_modes: LayerScatterModes,
150
151
  input_layernorm: torch.nn.Module,
151
152
  post_attention_layernorm: torch.nn.Module,
153
+ # Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
154
+ allow_reduce_scatter: bool = False,
152
155
  ):
153
156
  self.layer_scatter_modes = layer_scatter_modes
154
157
  self.input_layernorm = input_layernorm
155
158
  self.post_attention_layernorm = post_attention_layernorm
159
+ self.allow_reduce_scatter = allow_reduce_scatter
156
160
 
157
161
  self._context = CommunicateContext.init_new()
158
162
  self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
@@ -239,6 +243,15 @@ class LayerCommunicator:
239
243
  residual=residual,
240
244
  forward_batch=forward_batch,
241
245
  context=self._context,
246
+ allow_reduce_scatter=self.allow_reduce_scatter,
247
+ )
248
+
249
+ def should_use_reduce_scatter(self, forward_batch: ForwardBatch):
250
+ return (
251
+ self.allow_reduce_scatter
252
+ and self._communicate_summable_tensor_pair_fn
253
+ is CommunicateSummableTensorPairFn._scatter_hidden_states
254
+ and forward_batch.dp_padding_mode.is_max_len()
242
255
  )
243
256
 
244
257
 
@@ -395,9 +408,9 @@ class CommunicateWithAllReduceAndLayerNormFn:
395
408
  ):
396
409
  if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
397
410
  residual, local_residual = (
398
- forward_batch.gathered_buffer[
399
- : forward_batch.input_ids.shape[0]
400
- ].clone(),
411
+ torch.empty_like(
412
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]]
413
+ ),
401
414
  residual,
402
415
  )
403
416
  attn_tp_all_gather_into_tensor(residual, local_residual)
@@ -407,13 +420,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
407
420
 
408
421
  # Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
409
422
  use_layer_norm_before_gather = context.attn_tp_size == 1
410
- if use_layer_norm_before_gather:
411
- residual.copy_(hidden_states)
412
- if hidden_states.shape[0] != 0:
413
- hidden_states = layernorm(hidden_states)
414
-
423
+ if use_layer_norm_before_gather and hidden_states.shape[0] != 0:
424
+ residual = hidden_states
425
+ hidden_states = layernorm(hidden_states)
415
426
  hidden_states, local_hidden_states = (
416
- forward_batch.gathered_buffer,
427
+ torch.empty_like(forward_batch.gathered_buffer),
417
428
  hidden_states,
418
429
  )
419
430
  dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
@@ -430,7 +441,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
430
441
  and _is_flashinfer_available
431
442
  and hasattr(layernorm, "forward_with_allreduce_fusion")
432
443
  and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
433
- and hidden_states.shape[0] <= 128
444
+ and hidden_states.shape[0] <= 2048
434
445
  ):
435
446
  hidden_states, residual = layernorm.forward_with_allreduce_fusion(
436
447
  hidden_states, residual
@@ -524,6 +535,7 @@ class CommunicateSummableTensorPairFn:
524
535
  residual: torch.Tensor,
525
536
  forward_batch: ForwardBatch,
526
537
  context: CommunicateContext,
538
+ **kwargs,
527
539
  ):
528
540
  return hidden_states, residual
529
541
 
@@ -533,15 +545,17 @@ class CommunicateSummableTensorPairFn:
533
545
  residual: torch.Tensor,
534
546
  forward_batch: ForwardBatch,
535
547
  context: CommunicateContext,
548
+ allow_reduce_scatter: bool = False,
536
549
  ):
537
- # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
538
- # important: forward batch.gathered_buffer is used both after scatter and after gather.
539
- # be careful about this!
540
550
  hidden_states, global_hidden_states = (
541
551
  forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
542
552
  hidden_states,
543
553
  )
544
- dp_scatter(hidden_states, global_hidden_states, forward_batch)
554
+ if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
555
+ # When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
556
+ dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
557
+ else:
558
+ dp_scatter(hidden_states, global_hidden_states, forward_batch)
545
559
  return hidden_states, residual
546
560
 
547
561
  @staticmethod
@@ -550,6 +564,7 @@ class CommunicateSummableTensorPairFn:
550
564
  residual: torch.Tensor,
551
565
  forward_batch: ForwardBatch,
552
566
  context: CommunicateContext,
567
+ **kwargs,
553
568
  ):
554
569
  hidden_states += residual
555
570
  residual = None
@@ -12,6 +12,7 @@ import triton.language as tl
12
12
 
13
13
  from sglang.srt.distributed import (
14
14
  GroupCoordinator,
15
+ get_tensor_model_parallel_rank,
15
16
  get_tensor_model_parallel_world_size,
16
17
  get_tp_group,
17
18
  tensor_model_parallel_all_reduce,
@@ -355,6 +356,17 @@ def dp_scatter(
355
356
  )
356
357
 
357
358
 
359
+ def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
360
+ if get_tensor_model_parallel_world_size() == get_attention_dp_size():
361
+ get_tp_group().reduce_scatter_tensor(output, input)
362
+ else:
363
+ scattered_local_tokens = input.tensor_split(
364
+ get_tensor_model_parallel_world_size()
365
+ )[get_tensor_model_parallel_rank()]
366
+ get_tp_group().reduce_scatter_tensor(scattered_local_tokens, input)
367
+ get_attention_tp_group().all_gather_into_tensor(output, scattered_local_tokens)
368
+
369
+
358
370
  def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
359
371
  return get_attention_tp_group().reduce_scatter_tensor(output, input)
360
372
 
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Tuple
2
+ from typing import Optional, Tuple
3
3
 
4
4
  import torch
5
5
  import torch.distributed as dist
@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
92
92
 
93
93
 
94
94
  def ensure_workspace_initialized(
95
- max_token_num: int = 128, hidden_dim: int = 4096, use_fp32_lamport: bool = False
95
+ max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False
96
96
  ):
97
97
  """Ensure workspace is initialized"""
98
98
  if not is_flashinfer_available() or _flashinfer_comm is None:
@@ -124,8 +124,8 @@ def flashinfer_allreduce_residual_rmsnorm(
124
124
  residual: torch.Tensor,
125
125
  weight: torch.Tensor,
126
126
  eps: float = 1e-6,
127
- max_token_num: int = 128,
128
- use_oneshot: bool = True,
127
+ max_token_num: int = 2048,
128
+ use_oneshot: Optional[bool] = None,
129
129
  trigger_completion_at_end: bool = False,
130
130
  fp32_acc: bool = False,
131
131
  ) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -1191,11 +1191,6 @@ class RowParallelLinear(LinearBase):
1191
1191
  else self.weight_loader
1192
1192
  ),
1193
1193
  )
1194
- if not reduce_results and (bias and not skip_bias_add):
1195
- raise ValueError(
1196
- "When not reduce the results, adding bias to the "
1197
- "results can lead to incorrect results"
1198
- )
1199
1194
 
1200
1195
  if bias:
1201
1196
  self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
@@ -1282,7 +1277,7 @@ class RowParallelLinear(LinearBase):
1282
1277
  # It does not support additional parameters.
1283
1278
  param.load_row_parallel_weight(loaded_weight)
1284
1279
 
1285
- def forward(self, input_, can_fuse_mlp_allreduce=False):
1280
+ def forward(self, input_, skip_all_reduce=False):
1286
1281
  if self.input_is_parallel:
1287
1282
  input_parallel = input_
1288
1283
  else:
@@ -1299,7 +1294,8 @@ class RowParallelLinear(LinearBase):
1299
1294
  with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
1300
1295
  output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
1301
1296
  sm.tag(output_parallel)
1302
- if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
1297
+
1298
+ if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
1303
1299
  output = tensor_model_parallel_all_reduce(output_parallel)
1304
1300
  else:
1305
1301
  output = output_parallel
@@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
9
9
  import torch
10
10
 
11
11
  from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
12
+ from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
12
13
  from sglang.srt.utils import is_cuda
13
14
 
14
15
  _is_cuda = is_cuda()
@@ -123,6 +124,8 @@ def cutlass_fused_experts_fp8(
123
124
 
124
125
  if is_cuda:
125
126
  from sglang.srt.layers.quantization.fp8_kernel import (
127
+ per_group_transpose,
128
+ per_token_group_quant_fp8_hopper_moe_mn_major,
126
129
  sglang_per_token_group_quant_fp8,
127
130
  )
128
131
 
@@ -133,9 +136,7 @@ def cutlass_fused_experts_fp8(
133
136
  n = w2_q.size(1)
134
137
 
135
138
  topk = topk_ids.size(1)
136
-
137
- a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
138
- device = a_q.device
139
+ device = a.device
139
140
 
140
141
  a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
141
142
  c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
@@ -152,9 +153,14 @@ def cutlass_fused_experts_fp8(
152
153
  k,
153
154
  )
154
155
 
156
+ a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
155
157
  rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
156
158
  rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
157
159
 
160
+ if not is_sm100_supported():
161
+ rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets)
162
+ w1_scale = w1_scale.contiguous()
163
+
158
164
  c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
159
165
  c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
160
166
 
@@ -186,6 +192,9 @@ def cutlass_fused_experts_fp8(
186
192
  silu_and_mul(c1, intermediate)
187
193
 
188
194
  intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
195
+ if not is_sm100_supported():
196
+ a2_scale = per_group_transpose(a2_scale, expert_offsets)
197
+ w2_scale = w2_scale.contiguous()
189
198
 
190
199
  fp8_blockwise_scaled_grouped_mm(
191
200
  c2,
@@ -11,7 +11,7 @@ from sgl_kernel import (
11
11
  )
12
12
 
13
13
  from sglang.srt.layers.moe.ep_moe.kernels import (
14
- post_reorder_triton_kernel,
14
+ post_reorder_triton_kernel_for_cutlass_moe,
15
15
  pre_reorder_triton_kernel_for_cutlass_moe,
16
16
  run_cutlass_moe_ep_preproess,
17
17
  )
@@ -199,14 +199,13 @@ def cutlass_w4a8_moe(
199
199
  )
200
200
 
201
201
  output = torch.empty_like(a)
202
- post_reorder_triton_kernel[(m,)](
202
+ post_reorder_triton_kernel_for_cutlass_moe[(m,)](
203
203
  c2,
204
204
  output,
205
205
  src2dst,
206
- topk_ids_,
206
+ local_topk_ids,
207
207
  topk_weights,
208
- start_expert_id,
209
- end_expert_id,
208
+ num_experts,
210
209
  topk,
211
210
  k,
212
211
  0,