sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -0
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +7 -7
  6. sglang/srt/disaggregation/decode.py +8 -3
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +4 -5
  14. sglang/srt/entrypoints/openai/protocol.py +0 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +59 -265
  16. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  17. sglang/srt/function_call/ebnf_composer.py +1 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  20. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  21. sglang/srt/function_call/kimik2_detector.py +3 -3
  22. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  23. sglang/srt/jinja_template_utils.py +6 -0
  24. sglang/srt/layers/attention/aiter_backend.py +370 -107
  25. sglang/srt/layers/attention/ascend_backend.py +3 -0
  26. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  27. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  28. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  29. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  30. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  31. sglang/srt/layers/attention/vision.py +9 -1
  32. sglang/srt/layers/attention/wave_backend.py +627 -0
  33. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  34. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  35. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  36. sglang/srt/layers/communicator.py +8 -10
  37. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  38. sglang/srt/layers/linear.py +1 -0
  39. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  41. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  42. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  43. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  46. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  47. sglang/srt/layers/moe/topk.py +4 -1
  48. sglang/srt/layers/quantization/__init__.py +5 -3
  49. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  50. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  51. sglang/srt/layers/quantization/modelopt_quant.py +6 -11
  52. sglang/srt/layers/quantization/mxfp4.py +4 -1
  53. sglang/srt/layers/quantization/w4afp8.py +20 -11
  54. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  55. sglang/srt/layers/rotary_embedding.py +281 -2
  56. sglang/srt/lora/backend/base_backend.py +3 -23
  57. sglang/srt/lora/layers.py +60 -114
  58. sglang/srt/lora/lora.py +17 -62
  59. sglang/srt/lora/lora_manager.py +12 -48
  60. sglang/srt/lora/lora_registry.py +20 -9
  61. sglang/srt/lora/mem_pool.py +20 -63
  62. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  63. sglang/srt/lora/utils.py +25 -58
  64. sglang/srt/managers/cache_controller.py +21 -29
  65. sglang/srt/managers/detokenizer_manager.py +1 -1
  66. sglang/srt/managers/io_struct.py +6 -6
  67. sglang/srt/managers/mm_utils.py +1 -2
  68. sglang/srt/managers/multimodal_processor.py +1 -1
  69. sglang/srt/managers/schedule_batch.py +35 -20
  70. sglang/srt/managers/schedule_policy.py +6 -6
  71. sglang/srt/managers/scheduler.py +15 -7
  72. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  73. sglang/srt/managers/tokenizer_manager.py +25 -26
  74. sglang/srt/mem_cache/allocator.py +61 -87
  75. sglang/srt/mem_cache/hicache_storage.py +1 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  77. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  78. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  79. sglang/srt/mem_cache/radix_cache.py +2 -5
  80. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  81. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  82. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  83. sglang/srt/model_executor/cuda_graph_runner.py +22 -3
  84. sglang/srt/model_executor/forward_batch_info.py +26 -5
  85. sglang/srt/model_executor/model_runner.py +129 -35
  86. sglang/srt/model_loader/loader.py +18 -6
  87. sglang/srt/models/deepseek_v2.py +74 -35
  88. sglang/srt/models/gemma2.py +0 -34
  89. sglang/srt/models/gemma3n_mm.py +8 -9
  90. sglang/srt/models/glm4.py +6 -0
  91. sglang/srt/models/glm4_moe.py +9 -9
  92. sglang/srt/models/glm4v.py +589 -0
  93. sglang/srt/models/glm4v_moe.py +400 -0
  94. sglang/srt/models/gpt_oss.py +136 -19
  95. sglang/srt/models/granite.py +0 -25
  96. sglang/srt/models/llama.py +0 -25
  97. sglang/srt/models/llama4.py +1 -1
  98. sglang/srt/models/qwen2_5_vl.py +7 -3
  99. sglang/srt/models/qwen2_audio.py +10 -9
  100. sglang/srt/models/qwen3.py +0 -24
  101. sglang/srt/models/registry.py +1 -1
  102. sglang/srt/models/torch_native_llama.py +0 -24
  103. sglang/srt/multimodal/processors/base_processor.py +23 -13
  104. sglang/srt/multimodal/processors/glm4v.py +132 -0
  105. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  106. sglang/srt/reasoning_parser.py +316 -0
  107. sglang/srt/server_args.py +115 -139
  108. sglang/srt/speculative/eagle_worker.py +16 -0
  109. sglang/srt/two_batch_overlap.py +12 -4
  110. sglang/srt/utils.py +3 -3
  111. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  112. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  113. sglang/test/doc_patch.py +59 -0
  114. sglang/test/few_shot_gsm8k.py +1 -1
  115. sglang/test/few_shot_gsm8k_engine.py +1 -1
  116. sglang/test/run_eval.py +4 -1
  117. sglang/test/simple_eval_common.py +6 -0
  118. sglang/test/simple_eval_gpqa.py +2 -0
  119. sglang/test/test_fp4_moe.py +118 -36
  120. sglang/utils.py +1 -1
  121. sglang/version.py +1 -1
  122. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
  123. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
  124. sglang/lang/backend/__init__.py +0 -0
  125. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  126. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  127. /sglang/{api.py → lang/api.py} +0 -0
  128. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  129. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,186 @@
1
+ """
2
+ Memory-efficient attention for decoding.
3
+ It supports page size = 1.
4
+ """
5
+
6
+ import functools
7
+ import logging
8
+
9
+ from wave_lang.kernel.lang.global_symbols import *
10
+ from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
11
+ from wave_lang.kernel.wave.constraints import GenericDot, MMAOperand, MMAType
12
+ from wave_lang.kernel.wave.templates.paged_decode_attention import (
13
+ get_paged_decode_attention_kernels,
14
+ get_paged_decode_intermediate_arrays_shapes,
15
+ paged_decode_attention_shape,
16
+ )
17
+ from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
18
+ from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
19
+
20
+ logger = logging.getLogger(__name__)
21
+ import os
22
+
23
+ dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
24
+
25
+
26
+ @functools.lru_cache(maxsize=4096)
27
+ def get_wave_kernel(
28
+ shape: paged_decode_attention_shape,
29
+ max_kv_splits,
30
+ input_dtype,
31
+ output_dtype,
32
+ logit_cap,
33
+ ):
34
+ mha = (shape.num_query_heads // shape.num_kv_heads) == 1
35
+
36
+ # Get the kernels (either compile or load from cache).
37
+ if mha:
38
+ mfma_variant = (
39
+ GenericDot(along_dim=MMAOperand.M, k_vec_size=4, k_mult=1),
40
+ GenericDot(along_dim=MMAOperand.M, k_vec_size=1, k_mult=64),
41
+ )
42
+ else:
43
+ mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)
44
+
45
+ (
46
+ phase_0,
47
+ phase_1,
48
+ hyperparams_0,
49
+ hyperparams_1,
50
+ dynamic_symbols_0,
51
+ dynamic_symbols_1,
52
+ ) = get_paged_decode_attention_kernels(
53
+ shape,
54
+ mfma_variant,
55
+ max_kv_splits,
56
+ input_dtype=input_dtype,
57
+ output_dtype=output_dtype,
58
+ logit_cap=logit_cap,
59
+ )
60
+ hyperparams_0.update(get_default_scheduling_params())
61
+ hyperparams_1.update(get_default_scheduling_params())
62
+
63
+ options = WaveCompileOptions(
64
+ subs=hyperparams_0,
65
+ canonicalize=True,
66
+ run_bench=False,
67
+ use_buffer_load_ops=True,
68
+ use_buffer_store_ops=True,
69
+ waves_per_eu=2,
70
+ dynamic_symbols=dynamic_symbols_0,
71
+ wave_runtime=True,
72
+ )
73
+ options = set_default_run_config(options)
74
+ phase_0 = wave_compile(options, phase_0)
75
+
76
+ options = WaveCompileOptions(
77
+ subs=hyperparams_1,
78
+ canonicalize=True,
79
+ run_bench=False,
80
+ use_buffer_load_ops=False,
81
+ use_buffer_store_ops=False,
82
+ waves_per_eu=4,
83
+ dynamic_symbols=dynamic_symbols_1,
84
+ wave_runtime=True,
85
+ )
86
+ options = set_default_run_config(options)
87
+ phase_1 = wave_compile(options, phase_1)
88
+
89
+ return phase_0, phase_1
90
+
91
+
92
+ def decode_attention_intermediate_arrays_shapes(
93
+ num_seqs, head_size_kv, num_query_heads, max_kv_splits
94
+ ):
95
+ # Not all fields are used, but we need to pass them to the function
96
+ shape = paged_decode_attention_shape(
97
+ num_query_heads=num_query_heads,
98
+ num_kv_heads=0,
99
+ head_size=0,
100
+ head_size_kv=head_size_kv,
101
+ block_size=0,
102
+ num_seqs=num_seqs,
103
+ )
104
+ return get_paged_decode_intermediate_arrays_shapes(shape, max_kv_splits)
105
+
106
+
107
+ def decode_attention_wave(
108
+ q,
109
+ k_buffer,
110
+ v_buffer,
111
+ o,
112
+ b_req_idx,
113
+ req_to_token,
114
+ attn_logits,
115
+ attn_logits_max,
116
+ num_kv_splits,
117
+ max_kv_splits,
118
+ sm_scale,
119
+ logit_cap,
120
+ ):
121
+ num_seqs, num_query_heads, head_size = q.shape
122
+ _, num_kv_heads, _ = k_buffer.shape
123
+ _, _, head_size_kv = v_buffer.shape
124
+ block_size = 32
125
+ shape = paged_decode_attention_shape(
126
+ num_query_heads,
127
+ num_kv_heads,
128
+ head_size,
129
+ head_size_kv,
130
+ block_size,
131
+ num_seqs,
132
+ )
133
+
134
+ phase_0, phase_1 = get_wave_kernel(
135
+ shape, max_kv_splits, q.dtype, o.dtype, logit_cap
136
+ )
137
+
138
+ mb_qk = phase_0(
139
+ q,
140
+ k_buffer,
141
+ v_buffer,
142
+ b_req_idx,
143
+ req_to_token,
144
+ attn_logits,
145
+ attn_logits_max,
146
+ )
147
+ if dump_generated_mlir:
148
+ filename = f"wave_decode_attention_phase0_{'x'.join(map(str, shape))}.mlir"
149
+ with open(filename, "w") as f:
150
+ f.write(mb_qk.module_op.get_asm())
151
+
152
+ mb_sv = phase_1(attn_logits, attn_logits_max, b_req_idx, o)
153
+ if dump_generated_mlir:
154
+ filename = f"wave_decode_attention_phase1_{'x'.join(map(str, shape))}.mlir"
155
+ with open(filename, "w") as f:
156
+ f.write(mb_sv.module_op.get_asm())
157
+
158
+
159
+ def decode_attention_fwd(
160
+ q,
161
+ k_buffer,
162
+ v_buffer,
163
+ o,
164
+ b_req_idx,
165
+ req_to_token,
166
+ attn_logits,
167
+ attn_logits_max,
168
+ num_kv_splits,
169
+ max_kv_splits,
170
+ sm_scale,
171
+ logit_cap=0.0,
172
+ ):
173
+ decode_attention_wave(
174
+ q,
175
+ k_buffer,
176
+ v_buffer,
177
+ o,
178
+ b_req_idx,
179
+ req_to_token,
180
+ attn_logits,
181
+ attn_logits_max,
182
+ num_kv_splits,
183
+ max_kv_splits,
184
+ sm_scale,
185
+ logit_cap,
186
+ )
@@ -0,0 +1,149 @@
1
+ """
2
+ Memory-efficient attention for prefill.
3
+ It support page size = 1.
4
+ """
5
+
6
+ import functools
7
+ import os
8
+
9
+ import torch
10
+ from wave_lang.kernel.lang.global_symbols import *
11
+ from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
12
+ from wave_lang.kernel.wave.constraints import MMAType
13
+ from wave_lang.kernel.wave.scheduling.schedule import SchedulingType
14
+ from wave_lang.kernel.wave.templates.attention_common import AttentionShape
15
+ from wave_lang.kernel.wave.templates.extend_attention import get_extend_attention_kernel
16
+ from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
17
+ from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
18
+
19
+ dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
20
+
21
+
22
+ @functools.lru_cache
23
+ def get_wave_kernel(
24
+ shape: AttentionShape,
25
+ q_shape: tuple[int],
26
+ k_shape: tuple[int],
27
+ v_shape: tuple[int],
28
+ k_cache_shape: tuple[int],
29
+ v_cache_shape: tuple[int],
30
+ o_shape: tuple[int],
31
+ input_dtype: torch.dtype,
32
+ output_dtype: torch.dtype,
33
+ size_dtype: torch.dtype,
34
+ is_causal: bool,
35
+ logit_cap: float,
36
+ layer_scaling: float,
37
+ ):
38
+ assert shape.num_query_heads % shape.num_kv_heads == 0
39
+
40
+ mfma_variant = (MMAType.F32_16x16x32_K8_F16, MMAType.F32_16x16x16_F16)
41
+ (
42
+ extend_attention,
43
+ hyperparams,
44
+ dynamic_symbols,
45
+ ) = get_extend_attention_kernel(
46
+ shape,
47
+ mfma_variant,
48
+ q_shape,
49
+ k_shape,
50
+ v_shape,
51
+ k_cache_shape,
52
+ v_cache_shape,
53
+ o_shape,
54
+ input_dtype=input_dtype,
55
+ output_dtype=output_dtype,
56
+ size_dtype=size_dtype,
57
+ is_causal=is_causal,
58
+ layer_scaling=layer_scaling,
59
+ logit_cap=logit_cap,
60
+ )
61
+
62
+ hyperparams.update(get_default_scheduling_params())
63
+ options = WaveCompileOptions(
64
+ subs=hyperparams,
65
+ canonicalize=True,
66
+ run_bench=False,
67
+ schedule=SchedulingType.NONE,
68
+ use_scheduling_barriers=False,
69
+ dynamic_symbols=dynamic_symbols,
70
+ use_buffer_load_ops=True,
71
+ use_buffer_store_ops=True,
72
+ waves_per_eu=2,
73
+ denorm_fp_math_f32="preserve-sign",
74
+ gpu_native_math_precision=True,
75
+ wave_runtime=True,
76
+ )
77
+ options = set_default_run_config(options)
78
+ extend_attention = wave_compile(options, extend_attention)
79
+
80
+ return extend_attention
81
+
82
+
83
+ def extend_attention_wave(
84
+ q_extend,
85
+ k_extend,
86
+ v_extend,
87
+ k_buffer,
88
+ v_buffer,
89
+ qo_indptr,
90
+ kv_indptr,
91
+ kv_indices,
92
+ custom_mask,
93
+ mask_indptr,
94
+ max_seq_len,
95
+ output,
96
+ is_causal=True,
97
+ layer_scaling=None,
98
+ logit_cap=0,
99
+ ):
100
+ shape = AttentionShape(
101
+ num_query_heads=q_extend.shape[1],
102
+ num_kv_heads=k_extend.shape[1],
103
+ head_size=q_extend.shape[2],
104
+ head_size_kv=k_extend.shape[2],
105
+ num_seqs=kv_indptr.shape[0] - 1,
106
+ max_seq_len=max_seq_len,
107
+ )
108
+
109
+ # Run the wave kernel.
110
+ extend_attention = get_wave_kernel(
111
+ shape,
112
+ q_extend.shape,
113
+ k_extend.shape,
114
+ v_extend.shape,
115
+ k_buffer.shape,
116
+ v_buffer.shape,
117
+ output.shape,
118
+ input_dtype=q_extend.dtype,
119
+ output_dtype=output.dtype,
120
+ size_dtype=qo_indptr.dtype,
121
+ is_causal=is_causal,
122
+ layer_scaling=layer_scaling,
123
+ logit_cap=logit_cap,
124
+ )
125
+
126
+ mb = extend_attention(
127
+ q_extend,
128
+ k_extend,
129
+ v_extend,
130
+ k_buffer,
131
+ v_buffer,
132
+ qo_indptr,
133
+ kv_indptr,
134
+ kv_indices,
135
+ max_seq_len,
136
+ output,
137
+ )
138
+
139
+ if dump_generated_mlir:
140
+ shape_list = [
141
+ q_extend.shape[0],
142
+ q_extend.shape[1],
143
+ k_extend.shape[1],
144
+ q_extend.shape[2],
145
+ k_extend.shape[2],
146
+ ]
147
+ filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir"
148
+ with open(filename, "w") as f:
149
+ f.write(mb.module_op.get_asm())
@@ -0,0 +1,79 @@
1
+ """
2
+ Memory-efficient attention for prefill.
3
+ It support page size = 1.
4
+ """
5
+
6
+ import math
7
+ import os
8
+
9
+ from wave_lang.kernel.lang.global_symbols import *
10
+ from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
11
+ from wave_lang.kernel.wave.constraints import MMAType
12
+ from wave_lang.kernel.wave.templates.attention_common import AttentionShape
13
+ from wave_lang.kernel.wave.templates.prefill_attention import (
14
+ get_prefill_attention_kernel,
15
+ )
16
+ from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
17
+ from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
18
+
19
+ dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
20
+
21
+
22
+ def prefill_attention_wave(
23
+ q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=True
24
+ ):
25
+
26
+ shape = AttentionShape(
27
+ num_query_heads=q.shape[1],
28
+ num_kv_heads=k.shape[1],
29
+ head_size=q.shape[2],
30
+ head_size_kv=k.shape[2],
31
+ num_seqs=b_seq_len.shape[0],
32
+ max_seq_len=max_seq_len,
33
+ total_seq_len=q.shape[0],
34
+ )
35
+
36
+ assert shape.num_query_heads % shape.num_kv_heads == 0
37
+
38
+ output_shape = (shape.total_seq_len, shape.num_query_heads, shape.head_size_kv)
39
+ # Run the wave kernel.
40
+ mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)
41
+ (prefill, hyperparams) = get_prefill_attention_kernel(
42
+ shape,
43
+ mfma_variant,
44
+ q.shape,
45
+ k.shape,
46
+ v.shape,
47
+ output_shape,
48
+ input_dtype=q.dtype,
49
+ output_dtype=o.dtype,
50
+ size_dtype=b_seq_len.dtype,
51
+ )
52
+
53
+ hyperparams.update(get_default_scheduling_params())
54
+
55
+ log2e = 1.44269504089
56
+ dk_sqrt = math.sqrt(1.0 / shape.head_size)
57
+
58
+ options = WaveCompileOptions(
59
+ subs=hyperparams,
60
+ canonicalize=True,
61
+ run_bench=False,
62
+ use_scheduling_barriers=False,
63
+ )
64
+ options = set_default_run_config(options)
65
+ prefill = wave_compile(options, prefill)
66
+
67
+ mb = prefill(
68
+ q * dk_sqrt * log2e,
69
+ k,
70
+ v,
71
+ b_start_loc,
72
+ b_seq_len,
73
+ o,
74
+ )
75
+ if dump_generated_mlir:
76
+ shape_list = [q.shape[0], q.shape[1], k.shape[1], q.shape[2], k.shape[2]]
77
+ filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir"
78
+ with open(filename, "w") as f:
79
+ f.write(mb.module_op.get_asm())
@@ -408,9 +408,9 @@ class CommunicateWithAllReduceAndLayerNormFn:
408
408
  ):
409
409
  if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
410
410
  residual, local_residual = (
411
- forward_batch.gathered_buffer[
412
- : forward_batch.input_ids.shape[0]
413
- ].clone(),
411
+ torch.empty_like(
412
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]]
413
+ ),
414
414
  residual,
415
415
  )
416
416
  attn_tp_all_gather_into_tensor(residual, local_residual)
@@ -420,13 +420,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
420
420
 
421
421
  # Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
422
422
  use_layer_norm_before_gather = context.attn_tp_size == 1
423
- if use_layer_norm_before_gather:
424
- residual.copy_(hidden_states)
425
- if hidden_states.shape[0] != 0:
426
- hidden_states = layernorm(hidden_states)
427
-
423
+ if use_layer_norm_before_gather and hidden_states.shape[0] != 0:
424
+ residual = hidden_states
425
+ hidden_states = layernorm(hidden_states)
428
426
  hidden_states, local_hidden_states = (
429
- forward_batch.gathered_buffer,
427
+ torch.empty_like(forward_batch.gathered_buffer),
430
428
  hidden_states,
431
429
  )
432
430
  dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
@@ -443,7 +441,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
443
441
  and _is_flashinfer_available
444
442
  and hasattr(layernorm, "forward_with_allreduce_fusion")
445
443
  and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
446
- and hidden_states.shape[0] <= 128
444
+ and hidden_states.shape[0] <= 2048
447
445
  ):
448
446
  hidden_states, residual = layernorm.forward_with_allreduce_fusion(
449
447
  hidden_states, residual
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Tuple
2
+ from typing import Optional, Tuple
3
3
 
4
4
  import torch
5
5
  import torch.distributed as dist
@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
92
92
 
93
93
 
94
94
  def ensure_workspace_initialized(
95
- max_token_num: int = 128, hidden_dim: int = 4096, use_fp32_lamport: bool = False
95
+ max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False
96
96
  ):
97
97
  """Ensure workspace is initialized"""
98
98
  if not is_flashinfer_available() or _flashinfer_comm is None:
@@ -124,8 +124,8 @@ def flashinfer_allreduce_residual_rmsnorm(
124
124
  residual: torch.Tensor,
125
125
  weight: torch.Tensor,
126
126
  eps: float = 1e-6,
127
- max_token_num: int = 128,
128
- use_oneshot: bool = True,
127
+ max_token_num: int = 2048,
128
+ use_oneshot: Optional[bool] = None,
129
129
  trigger_completion_at_end: bool = False,
130
130
  fp32_acc: bool = False,
131
131
  ) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -1294,6 +1294,7 @@ class RowParallelLinear(LinearBase):
1294
1294
  with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
1295
1295
  output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
1296
1296
  sm.tag(output_parallel)
1297
+
1297
1298
  if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
1298
1299
  output = tensor_model_parallel_all_reduce(output_parallel)
1299
1300
  else:
@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
9
9
  import torch
10
10
 
11
11
  from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
12
- from sglang.srt.layers.utils import is_sm100_supported
12
+ from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
13
13
  from sglang.srt.utils import is_cuda
14
14
 
15
15
  _is_cuda = is_cuda()
@@ -124,6 +124,7 @@ def cutlass_fused_experts_fp8(
124
124
 
125
125
  if is_cuda:
126
126
  from sglang.srt.layers.quantization.fp8_kernel import (
127
+ per_group_transpose,
127
128
  per_token_group_quant_fp8_hopper_moe_mn_major,
128
129
  sglang_per_token_group_quant_fp8,
129
130
  )
@@ -152,15 +153,12 @@ def cutlass_fused_experts_fp8(
152
153
  k,
153
154
  )
154
155
 
155
- if is_sm100_supported():
156
- a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
157
- rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
158
- rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
159
- else:
160
- rep_a = shuffle_rows(a, a_map, (m * topk, k))
161
- rep_a_q, rep_a1_scales = per_token_group_quant_fp8_hopper_moe_mn_major(
162
- rep_a, expert_offsets, problem_sizes1, 128
163
- )
156
+ a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
157
+ rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
158
+ rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
159
+
160
+ if not is_sm100_supported():
161
+ rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets)
164
162
  w1_scale = w1_scale.contiguous()
165
163
 
166
164
  c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
@@ -193,12 +191,9 @@ def cutlass_fused_experts_fp8(
193
191
  intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
194
192
  silu_and_mul(c1, intermediate)
195
193
 
196
- if is_sm100_supported():
197
- intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
198
- else:
199
- intemediate_q, a2_scale = per_token_group_quant_fp8_hopper_moe_mn_major(
200
- intermediate, expert_offsets, problem_sizes2, 128
201
- )
194
+ intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
195
+ if not is_sm100_supported():
196
+ a2_scale = per_group_transpose(a2_scale, expert_offsets)
202
197
  w2_scale = w2_scale.contiguous()
203
198
 
204
199
  fp8_blockwise_scaled_grouped_mm(
@@ -11,7 +11,7 @@ from sgl_kernel import (
11
11
  )
12
12
 
13
13
  from sglang.srt.layers.moe.ep_moe.kernels import (
14
- post_reorder_triton_kernel,
14
+ post_reorder_triton_kernel_for_cutlass_moe,
15
15
  pre_reorder_triton_kernel_for_cutlass_moe,
16
16
  run_cutlass_moe_ep_preproess,
17
17
  )
@@ -199,14 +199,13 @@ def cutlass_w4a8_moe(
199
199
  )
200
200
 
201
201
  output = torch.empty_like(a)
202
- post_reorder_triton_kernel[(m,)](
202
+ post_reorder_triton_kernel_for_cutlass_moe[(m,)](
203
203
  c2,
204
204
  output,
205
205
  src2dst,
206
- topk_ids_,
206
+ local_topk_ids,
207
207
  topk_weights,
208
- start_expert_id,
209
- end_expert_id,
208
+ num_experts,
210
209
  topk,
211
210
  k,
212
211
  0,
@@ -581,6 +581,49 @@ def post_reorder_triton_kernel(
581
581
  )
582
582
 
583
583
 
584
+ @triton.jit
585
+ def post_reorder_triton_kernel_for_cutlass_moe(
586
+ down_output_ptr,
587
+ output_ptr,
588
+ src2dst_ptr,
589
+ topk_ids_ptr,
590
+ topk_weights_ptr,
591
+ num_experts,
592
+ topk,
593
+ hidden_size,
594
+ dst_start,
595
+ BLOCK_SIZE: tl.constexpr,
596
+ ):
597
+ InDtype = down_output_ptr.dtype.element_ty
598
+
599
+ src_idx_int32 = tl.program_id(0)
600
+ src_idx = src_idx_int32.to(tl.int64)
601
+ src2dst_ptr = src2dst_ptr + src_idx * topk
602
+ topk_ids_ptr = topk_ids_ptr + src_idx * topk
603
+ topk_weights_ptr = topk_weights_ptr + src_idx * topk
604
+
605
+ store_ptr = output_ptr + src_idx * hidden_size
606
+
607
+ vec = tl.arange(0, BLOCK_SIZE)
608
+
609
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
610
+ offset = start_offset + vec
611
+ mask = offset < hidden_size
612
+
613
+ sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
614
+ for idx in range(topk):
615
+ expert_id = tl.load(topk_ids_ptr + idx)
616
+ if expert_id != num_experts:
617
+ dst_idx_int32 = tl.load(src2dst_ptr + idx)
618
+ dst_idx = dst_idx_int32.to(tl.int64)
619
+ dst_idx = dst_idx - dst_start
620
+ weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
621
+ load_ptr = down_output_ptr + dst_idx * hidden_size
622
+ in_data = tl.load(load_ptr + offset, mask=mask)
623
+ sum_vec += in_data * weigh_scale
624
+ tl.store(store_ptr + offset, sum_vec, mask=mask)
625
+
626
+
584
627
  @triton.jit
585
628
  def compute_m_range(
586
629
  pid,