sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -0
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +7 -7
- sglang/srt/disaggregation/decode.py +8 -3
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +4 -5
- sglang/srt/entrypoints/openai/protocol.py +0 -9
- sglang/srt/entrypoints/openai/serving_chat.py +59 -265
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +8 -10
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/quantization/__init__.py +5 -3
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/modelopt_quant.py +6 -11
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +21 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +6 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +35 -20
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +15 -7
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +25 -26
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +22 -3
- sglang/srt/model_executor/forward_batch_info.py +26 -5
- sglang/srt/model_executor/model_runner.py +129 -35
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_v2.py +74 -35
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +9 -9
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +136 -19
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/server_args.py +115 -139
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +12 -4
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,186 @@
|
|
1
|
+
"""
|
2
|
+
Memory-efficient attention for decoding.
|
3
|
+
It supports page size = 1.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import functools
|
7
|
+
import logging
|
8
|
+
|
9
|
+
from wave_lang.kernel.lang.global_symbols import *
|
10
|
+
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
|
11
|
+
from wave_lang.kernel.wave.constraints import GenericDot, MMAOperand, MMAType
|
12
|
+
from wave_lang.kernel.wave.templates.paged_decode_attention import (
|
13
|
+
get_paged_decode_attention_kernels,
|
14
|
+
get_paged_decode_intermediate_arrays_shapes,
|
15
|
+
paged_decode_attention_shape,
|
16
|
+
)
|
17
|
+
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
|
18
|
+
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
import os
|
22
|
+
|
23
|
+
dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
|
24
|
+
|
25
|
+
|
26
|
+
@functools.lru_cache(maxsize=4096)
|
27
|
+
def get_wave_kernel(
|
28
|
+
shape: paged_decode_attention_shape,
|
29
|
+
max_kv_splits,
|
30
|
+
input_dtype,
|
31
|
+
output_dtype,
|
32
|
+
logit_cap,
|
33
|
+
):
|
34
|
+
mha = (shape.num_query_heads // shape.num_kv_heads) == 1
|
35
|
+
|
36
|
+
# Get the kernels (either compile or load from cache).
|
37
|
+
if mha:
|
38
|
+
mfma_variant = (
|
39
|
+
GenericDot(along_dim=MMAOperand.M, k_vec_size=4, k_mult=1),
|
40
|
+
GenericDot(along_dim=MMAOperand.M, k_vec_size=1, k_mult=64),
|
41
|
+
)
|
42
|
+
else:
|
43
|
+
mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)
|
44
|
+
|
45
|
+
(
|
46
|
+
phase_0,
|
47
|
+
phase_1,
|
48
|
+
hyperparams_0,
|
49
|
+
hyperparams_1,
|
50
|
+
dynamic_symbols_0,
|
51
|
+
dynamic_symbols_1,
|
52
|
+
) = get_paged_decode_attention_kernels(
|
53
|
+
shape,
|
54
|
+
mfma_variant,
|
55
|
+
max_kv_splits,
|
56
|
+
input_dtype=input_dtype,
|
57
|
+
output_dtype=output_dtype,
|
58
|
+
logit_cap=logit_cap,
|
59
|
+
)
|
60
|
+
hyperparams_0.update(get_default_scheduling_params())
|
61
|
+
hyperparams_1.update(get_default_scheduling_params())
|
62
|
+
|
63
|
+
options = WaveCompileOptions(
|
64
|
+
subs=hyperparams_0,
|
65
|
+
canonicalize=True,
|
66
|
+
run_bench=False,
|
67
|
+
use_buffer_load_ops=True,
|
68
|
+
use_buffer_store_ops=True,
|
69
|
+
waves_per_eu=2,
|
70
|
+
dynamic_symbols=dynamic_symbols_0,
|
71
|
+
wave_runtime=True,
|
72
|
+
)
|
73
|
+
options = set_default_run_config(options)
|
74
|
+
phase_0 = wave_compile(options, phase_0)
|
75
|
+
|
76
|
+
options = WaveCompileOptions(
|
77
|
+
subs=hyperparams_1,
|
78
|
+
canonicalize=True,
|
79
|
+
run_bench=False,
|
80
|
+
use_buffer_load_ops=False,
|
81
|
+
use_buffer_store_ops=False,
|
82
|
+
waves_per_eu=4,
|
83
|
+
dynamic_symbols=dynamic_symbols_1,
|
84
|
+
wave_runtime=True,
|
85
|
+
)
|
86
|
+
options = set_default_run_config(options)
|
87
|
+
phase_1 = wave_compile(options, phase_1)
|
88
|
+
|
89
|
+
return phase_0, phase_1
|
90
|
+
|
91
|
+
|
92
|
+
def decode_attention_intermediate_arrays_shapes(
|
93
|
+
num_seqs, head_size_kv, num_query_heads, max_kv_splits
|
94
|
+
):
|
95
|
+
# Not all fields are used, but we need to pass them to the function
|
96
|
+
shape = paged_decode_attention_shape(
|
97
|
+
num_query_heads=num_query_heads,
|
98
|
+
num_kv_heads=0,
|
99
|
+
head_size=0,
|
100
|
+
head_size_kv=head_size_kv,
|
101
|
+
block_size=0,
|
102
|
+
num_seqs=num_seqs,
|
103
|
+
)
|
104
|
+
return get_paged_decode_intermediate_arrays_shapes(shape, max_kv_splits)
|
105
|
+
|
106
|
+
|
107
|
+
def decode_attention_wave(
|
108
|
+
q,
|
109
|
+
k_buffer,
|
110
|
+
v_buffer,
|
111
|
+
o,
|
112
|
+
b_req_idx,
|
113
|
+
req_to_token,
|
114
|
+
attn_logits,
|
115
|
+
attn_logits_max,
|
116
|
+
num_kv_splits,
|
117
|
+
max_kv_splits,
|
118
|
+
sm_scale,
|
119
|
+
logit_cap,
|
120
|
+
):
|
121
|
+
num_seqs, num_query_heads, head_size = q.shape
|
122
|
+
_, num_kv_heads, _ = k_buffer.shape
|
123
|
+
_, _, head_size_kv = v_buffer.shape
|
124
|
+
block_size = 32
|
125
|
+
shape = paged_decode_attention_shape(
|
126
|
+
num_query_heads,
|
127
|
+
num_kv_heads,
|
128
|
+
head_size,
|
129
|
+
head_size_kv,
|
130
|
+
block_size,
|
131
|
+
num_seqs,
|
132
|
+
)
|
133
|
+
|
134
|
+
phase_0, phase_1 = get_wave_kernel(
|
135
|
+
shape, max_kv_splits, q.dtype, o.dtype, logit_cap
|
136
|
+
)
|
137
|
+
|
138
|
+
mb_qk = phase_0(
|
139
|
+
q,
|
140
|
+
k_buffer,
|
141
|
+
v_buffer,
|
142
|
+
b_req_idx,
|
143
|
+
req_to_token,
|
144
|
+
attn_logits,
|
145
|
+
attn_logits_max,
|
146
|
+
)
|
147
|
+
if dump_generated_mlir:
|
148
|
+
filename = f"wave_decode_attention_phase0_{'x'.join(map(str, shape))}.mlir"
|
149
|
+
with open(filename, "w") as f:
|
150
|
+
f.write(mb_qk.module_op.get_asm())
|
151
|
+
|
152
|
+
mb_sv = phase_1(attn_logits, attn_logits_max, b_req_idx, o)
|
153
|
+
if dump_generated_mlir:
|
154
|
+
filename = f"wave_decode_attention_phase1_{'x'.join(map(str, shape))}.mlir"
|
155
|
+
with open(filename, "w") as f:
|
156
|
+
f.write(mb_sv.module_op.get_asm())
|
157
|
+
|
158
|
+
|
159
|
+
def decode_attention_fwd(
|
160
|
+
q,
|
161
|
+
k_buffer,
|
162
|
+
v_buffer,
|
163
|
+
o,
|
164
|
+
b_req_idx,
|
165
|
+
req_to_token,
|
166
|
+
attn_logits,
|
167
|
+
attn_logits_max,
|
168
|
+
num_kv_splits,
|
169
|
+
max_kv_splits,
|
170
|
+
sm_scale,
|
171
|
+
logit_cap=0.0,
|
172
|
+
):
|
173
|
+
decode_attention_wave(
|
174
|
+
q,
|
175
|
+
k_buffer,
|
176
|
+
v_buffer,
|
177
|
+
o,
|
178
|
+
b_req_idx,
|
179
|
+
req_to_token,
|
180
|
+
attn_logits,
|
181
|
+
attn_logits_max,
|
182
|
+
num_kv_splits,
|
183
|
+
max_kv_splits,
|
184
|
+
sm_scale,
|
185
|
+
logit_cap,
|
186
|
+
)
|
@@ -0,0 +1,149 @@
|
|
1
|
+
"""
|
2
|
+
Memory-efficient attention for prefill.
|
3
|
+
It support page size = 1.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import functools
|
7
|
+
import os
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from wave_lang.kernel.lang.global_symbols import *
|
11
|
+
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
|
12
|
+
from wave_lang.kernel.wave.constraints import MMAType
|
13
|
+
from wave_lang.kernel.wave.scheduling.schedule import SchedulingType
|
14
|
+
from wave_lang.kernel.wave.templates.attention_common import AttentionShape
|
15
|
+
from wave_lang.kernel.wave.templates.extend_attention import get_extend_attention_kernel
|
16
|
+
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
|
17
|
+
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
|
18
|
+
|
19
|
+
dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
|
20
|
+
|
21
|
+
|
22
|
+
@functools.lru_cache
|
23
|
+
def get_wave_kernel(
|
24
|
+
shape: AttentionShape,
|
25
|
+
q_shape: tuple[int],
|
26
|
+
k_shape: tuple[int],
|
27
|
+
v_shape: tuple[int],
|
28
|
+
k_cache_shape: tuple[int],
|
29
|
+
v_cache_shape: tuple[int],
|
30
|
+
o_shape: tuple[int],
|
31
|
+
input_dtype: torch.dtype,
|
32
|
+
output_dtype: torch.dtype,
|
33
|
+
size_dtype: torch.dtype,
|
34
|
+
is_causal: bool,
|
35
|
+
logit_cap: float,
|
36
|
+
layer_scaling: float,
|
37
|
+
):
|
38
|
+
assert shape.num_query_heads % shape.num_kv_heads == 0
|
39
|
+
|
40
|
+
mfma_variant = (MMAType.F32_16x16x32_K8_F16, MMAType.F32_16x16x16_F16)
|
41
|
+
(
|
42
|
+
extend_attention,
|
43
|
+
hyperparams,
|
44
|
+
dynamic_symbols,
|
45
|
+
) = get_extend_attention_kernel(
|
46
|
+
shape,
|
47
|
+
mfma_variant,
|
48
|
+
q_shape,
|
49
|
+
k_shape,
|
50
|
+
v_shape,
|
51
|
+
k_cache_shape,
|
52
|
+
v_cache_shape,
|
53
|
+
o_shape,
|
54
|
+
input_dtype=input_dtype,
|
55
|
+
output_dtype=output_dtype,
|
56
|
+
size_dtype=size_dtype,
|
57
|
+
is_causal=is_causal,
|
58
|
+
layer_scaling=layer_scaling,
|
59
|
+
logit_cap=logit_cap,
|
60
|
+
)
|
61
|
+
|
62
|
+
hyperparams.update(get_default_scheduling_params())
|
63
|
+
options = WaveCompileOptions(
|
64
|
+
subs=hyperparams,
|
65
|
+
canonicalize=True,
|
66
|
+
run_bench=False,
|
67
|
+
schedule=SchedulingType.NONE,
|
68
|
+
use_scheduling_barriers=False,
|
69
|
+
dynamic_symbols=dynamic_symbols,
|
70
|
+
use_buffer_load_ops=True,
|
71
|
+
use_buffer_store_ops=True,
|
72
|
+
waves_per_eu=2,
|
73
|
+
denorm_fp_math_f32="preserve-sign",
|
74
|
+
gpu_native_math_precision=True,
|
75
|
+
wave_runtime=True,
|
76
|
+
)
|
77
|
+
options = set_default_run_config(options)
|
78
|
+
extend_attention = wave_compile(options, extend_attention)
|
79
|
+
|
80
|
+
return extend_attention
|
81
|
+
|
82
|
+
|
83
|
+
def extend_attention_wave(
|
84
|
+
q_extend,
|
85
|
+
k_extend,
|
86
|
+
v_extend,
|
87
|
+
k_buffer,
|
88
|
+
v_buffer,
|
89
|
+
qo_indptr,
|
90
|
+
kv_indptr,
|
91
|
+
kv_indices,
|
92
|
+
custom_mask,
|
93
|
+
mask_indptr,
|
94
|
+
max_seq_len,
|
95
|
+
output,
|
96
|
+
is_causal=True,
|
97
|
+
layer_scaling=None,
|
98
|
+
logit_cap=0,
|
99
|
+
):
|
100
|
+
shape = AttentionShape(
|
101
|
+
num_query_heads=q_extend.shape[1],
|
102
|
+
num_kv_heads=k_extend.shape[1],
|
103
|
+
head_size=q_extend.shape[2],
|
104
|
+
head_size_kv=k_extend.shape[2],
|
105
|
+
num_seqs=kv_indptr.shape[0] - 1,
|
106
|
+
max_seq_len=max_seq_len,
|
107
|
+
)
|
108
|
+
|
109
|
+
# Run the wave kernel.
|
110
|
+
extend_attention = get_wave_kernel(
|
111
|
+
shape,
|
112
|
+
q_extend.shape,
|
113
|
+
k_extend.shape,
|
114
|
+
v_extend.shape,
|
115
|
+
k_buffer.shape,
|
116
|
+
v_buffer.shape,
|
117
|
+
output.shape,
|
118
|
+
input_dtype=q_extend.dtype,
|
119
|
+
output_dtype=output.dtype,
|
120
|
+
size_dtype=qo_indptr.dtype,
|
121
|
+
is_causal=is_causal,
|
122
|
+
layer_scaling=layer_scaling,
|
123
|
+
logit_cap=logit_cap,
|
124
|
+
)
|
125
|
+
|
126
|
+
mb = extend_attention(
|
127
|
+
q_extend,
|
128
|
+
k_extend,
|
129
|
+
v_extend,
|
130
|
+
k_buffer,
|
131
|
+
v_buffer,
|
132
|
+
qo_indptr,
|
133
|
+
kv_indptr,
|
134
|
+
kv_indices,
|
135
|
+
max_seq_len,
|
136
|
+
output,
|
137
|
+
)
|
138
|
+
|
139
|
+
if dump_generated_mlir:
|
140
|
+
shape_list = [
|
141
|
+
q_extend.shape[0],
|
142
|
+
q_extend.shape[1],
|
143
|
+
k_extend.shape[1],
|
144
|
+
q_extend.shape[2],
|
145
|
+
k_extend.shape[2],
|
146
|
+
]
|
147
|
+
filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir"
|
148
|
+
with open(filename, "w") as f:
|
149
|
+
f.write(mb.module_op.get_asm())
|
@@ -0,0 +1,79 @@
|
|
1
|
+
"""
|
2
|
+
Memory-efficient attention for prefill.
|
3
|
+
It support page size = 1.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import math
|
7
|
+
import os
|
8
|
+
|
9
|
+
from wave_lang.kernel.lang.global_symbols import *
|
10
|
+
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
|
11
|
+
from wave_lang.kernel.wave.constraints import MMAType
|
12
|
+
from wave_lang.kernel.wave.templates.attention_common import AttentionShape
|
13
|
+
from wave_lang.kernel.wave.templates.prefill_attention import (
|
14
|
+
get_prefill_attention_kernel,
|
15
|
+
)
|
16
|
+
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
|
17
|
+
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
|
18
|
+
|
19
|
+
dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
|
20
|
+
|
21
|
+
|
22
|
+
def prefill_attention_wave(
|
23
|
+
q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=True
|
24
|
+
):
|
25
|
+
|
26
|
+
shape = AttentionShape(
|
27
|
+
num_query_heads=q.shape[1],
|
28
|
+
num_kv_heads=k.shape[1],
|
29
|
+
head_size=q.shape[2],
|
30
|
+
head_size_kv=k.shape[2],
|
31
|
+
num_seqs=b_seq_len.shape[0],
|
32
|
+
max_seq_len=max_seq_len,
|
33
|
+
total_seq_len=q.shape[0],
|
34
|
+
)
|
35
|
+
|
36
|
+
assert shape.num_query_heads % shape.num_kv_heads == 0
|
37
|
+
|
38
|
+
output_shape = (shape.total_seq_len, shape.num_query_heads, shape.head_size_kv)
|
39
|
+
# Run the wave kernel.
|
40
|
+
mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)
|
41
|
+
(prefill, hyperparams) = get_prefill_attention_kernel(
|
42
|
+
shape,
|
43
|
+
mfma_variant,
|
44
|
+
q.shape,
|
45
|
+
k.shape,
|
46
|
+
v.shape,
|
47
|
+
output_shape,
|
48
|
+
input_dtype=q.dtype,
|
49
|
+
output_dtype=o.dtype,
|
50
|
+
size_dtype=b_seq_len.dtype,
|
51
|
+
)
|
52
|
+
|
53
|
+
hyperparams.update(get_default_scheduling_params())
|
54
|
+
|
55
|
+
log2e = 1.44269504089
|
56
|
+
dk_sqrt = math.sqrt(1.0 / shape.head_size)
|
57
|
+
|
58
|
+
options = WaveCompileOptions(
|
59
|
+
subs=hyperparams,
|
60
|
+
canonicalize=True,
|
61
|
+
run_bench=False,
|
62
|
+
use_scheduling_barriers=False,
|
63
|
+
)
|
64
|
+
options = set_default_run_config(options)
|
65
|
+
prefill = wave_compile(options, prefill)
|
66
|
+
|
67
|
+
mb = prefill(
|
68
|
+
q * dk_sqrt * log2e,
|
69
|
+
k,
|
70
|
+
v,
|
71
|
+
b_start_loc,
|
72
|
+
b_seq_len,
|
73
|
+
o,
|
74
|
+
)
|
75
|
+
if dump_generated_mlir:
|
76
|
+
shape_list = [q.shape[0], q.shape[1], k.shape[1], q.shape[2], k.shape[2]]
|
77
|
+
filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir"
|
78
|
+
with open(filename, "w") as f:
|
79
|
+
f.write(mb.module_op.get_asm())
|
@@ -408,9 +408,9 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
408
408
|
):
|
409
409
|
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
|
410
410
|
residual, local_residual = (
|
411
|
-
|
412
|
-
: forward_batch.input_ids.shape[0]
|
413
|
-
|
411
|
+
torch.empty_like(
|
412
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]]
|
413
|
+
),
|
414
414
|
residual,
|
415
415
|
)
|
416
416
|
attn_tp_all_gather_into_tensor(residual, local_residual)
|
@@ -420,13 +420,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
420
420
|
|
421
421
|
# Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
|
422
422
|
use_layer_norm_before_gather = context.attn_tp_size == 1
|
423
|
-
if use_layer_norm_before_gather:
|
424
|
-
residual
|
425
|
-
|
426
|
-
hidden_states = layernorm(hidden_states)
|
427
|
-
|
423
|
+
if use_layer_norm_before_gather and hidden_states.shape[0] != 0:
|
424
|
+
residual = hidden_states
|
425
|
+
hidden_states = layernorm(hidden_states)
|
428
426
|
hidden_states, local_hidden_states = (
|
429
|
-
forward_batch.gathered_buffer,
|
427
|
+
torch.empty_like(forward_batch.gathered_buffer),
|
430
428
|
hidden_states,
|
431
429
|
)
|
432
430
|
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
@@ -443,7 +441,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
443
441
|
and _is_flashinfer_available
|
444
442
|
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
445
443
|
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
446
|
-
and hidden_states.shape[0] <=
|
444
|
+
and hidden_states.shape[0] <= 2048
|
447
445
|
):
|
448
446
|
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
449
447
|
hidden_states, residual
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import Tuple
|
2
|
+
from typing import Optional, Tuple
|
3
3
|
|
4
4
|
import torch
|
5
5
|
import torch.distributed as dist
|
@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
|
|
92
92
|
|
93
93
|
|
94
94
|
def ensure_workspace_initialized(
|
95
|
-
max_token_num: int =
|
95
|
+
max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False
|
96
96
|
):
|
97
97
|
"""Ensure workspace is initialized"""
|
98
98
|
if not is_flashinfer_available() or _flashinfer_comm is None:
|
@@ -124,8 +124,8 @@ def flashinfer_allreduce_residual_rmsnorm(
|
|
124
124
|
residual: torch.Tensor,
|
125
125
|
weight: torch.Tensor,
|
126
126
|
eps: float = 1e-6,
|
127
|
-
max_token_num: int =
|
128
|
-
use_oneshot: bool =
|
127
|
+
max_token_num: int = 2048,
|
128
|
+
use_oneshot: Optional[bool] = None,
|
129
129
|
trigger_completion_at_end: bool = False,
|
130
130
|
fp32_acc: bool = False,
|
131
131
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
sglang/srt/layers/linear.py
CHANGED
@@ -1294,6 +1294,7 @@ class RowParallelLinear(LinearBase):
|
|
1294
1294
|
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
1295
1295
|
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
1296
1296
|
sm.tag(output_parallel)
|
1297
|
+
|
1297
1298
|
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
|
1298
1299
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
1299
1300
|
else:
|
@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
|
12
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
12
|
+
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
|
13
13
|
from sglang.srt.utils import is_cuda
|
14
14
|
|
15
15
|
_is_cuda = is_cuda()
|
@@ -124,6 +124,7 @@ def cutlass_fused_experts_fp8(
|
|
124
124
|
|
125
125
|
if is_cuda:
|
126
126
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
127
|
+
per_group_transpose,
|
127
128
|
per_token_group_quant_fp8_hopper_moe_mn_major,
|
128
129
|
sglang_per_token_group_quant_fp8,
|
129
130
|
)
|
@@ -152,15 +153,12 @@ def cutlass_fused_experts_fp8(
|
|
152
153
|
k,
|
153
154
|
)
|
154
155
|
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
rep_a_q, rep_a1_scales = per_token_group_quant_fp8_hopper_moe_mn_major(
|
162
|
-
rep_a, expert_offsets, problem_sizes1, 128
|
163
|
-
)
|
156
|
+
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
|
157
|
+
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
|
158
|
+
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
|
159
|
+
|
160
|
+
if not is_sm100_supported():
|
161
|
+
rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets)
|
164
162
|
w1_scale = w1_scale.contiguous()
|
165
163
|
|
166
164
|
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
@@ -193,12 +191,9 @@ def cutlass_fused_experts_fp8(
|
|
193
191
|
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
|
194
192
|
silu_and_mul(c1, intermediate)
|
195
193
|
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
intemediate_q, a2_scale = per_token_group_quant_fp8_hopper_moe_mn_major(
|
200
|
-
intermediate, expert_offsets, problem_sizes2, 128
|
201
|
-
)
|
194
|
+
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
|
195
|
+
if not is_sm100_supported():
|
196
|
+
a2_scale = per_group_transpose(a2_scale, expert_offsets)
|
202
197
|
w2_scale = w2_scale.contiguous()
|
203
198
|
|
204
199
|
fp8_blockwise_scaled_grouped_mm(
|
@@ -11,7 +11,7 @@ from sgl_kernel import (
|
|
11
11
|
)
|
12
12
|
|
13
13
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
14
|
-
|
14
|
+
post_reorder_triton_kernel_for_cutlass_moe,
|
15
15
|
pre_reorder_triton_kernel_for_cutlass_moe,
|
16
16
|
run_cutlass_moe_ep_preproess,
|
17
17
|
)
|
@@ -199,14 +199,13 @@ def cutlass_w4a8_moe(
|
|
199
199
|
)
|
200
200
|
|
201
201
|
output = torch.empty_like(a)
|
202
|
-
|
202
|
+
post_reorder_triton_kernel_for_cutlass_moe[(m,)](
|
203
203
|
c2,
|
204
204
|
output,
|
205
205
|
src2dst,
|
206
|
-
|
206
|
+
local_topk_ids,
|
207
207
|
topk_weights,
|
208
|
-
|
209
|
-
end_expert_id,
|
208
|
+
num_experts,
|
210
209
|
topk,
|
211
210
|
k,
|
212
211
|
0,
|
@@ -581,6 +581,49 @@ def post_reorder_triton_kernel(
|
|
581
581
|
)
|
582
582
|
|
583
583
|
|
584
|
+
@triton.jit
|
585
|
+
def post_reorder_triton_kernel_for_cutlass_moe(
|
586
|
+
down_output_ptr,
|
587
|
+
output_ptr,
|
588
|
+
src2dst_ptr,
|
589
|
+
topk_ids_ptr,
|
590
|
+
topk_weights_ptr,
|
591
|
+
num_experts,
|
592
|
+
topk,
|
593
|
+
hidden_size,
|
594
|
+
dst_start,
|
595
|
+
BLOCK_SIZE: tl.constexpr,
|
596
|
+
):
|
597
|
+
InDtype = down_output_ptr.dtype.element_ty
|
598
|
+
|
599
|
+
src_idx_int32 = tl.program_id(0)
|
600
|
+
src_idx = src_idx_int32.to(tl.int64)
|
601
|
+
src2dst_ptr = src2dst_ptr + src_idx * topk
|
602
|
+
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
603
|
+
topk_weights_ptr = topk_weights_ptr + src_idx * topk
|
604
|
+
|
605
|
+
store_ptr = output_ptr + src_idx * hidden_size
|
606
|
+
|
607
|
+
vec = tl.arange(0, BLOCK_SIZE)
|
608
|
+
|
609
|
+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
610
|
+
offset = start_offset + vec
|
611
|
+
mask = offset < hidden_size
|
612
|
+
|
613
|
+
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
614
|
+
for idx in range(topk):
|
615
|
+
expert_id = tl.load(topk_ids_ptr + idx)
|
616
|
+
if expert_id != num_experts:
|
617
|
+
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
618
|
+
dst_idx = dst_idx_int32.to(tl.int64)
|
619
|
+
dst_idx = dst_idx - dst_start
|
620
|
+
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
621
|
+
load_ptr = down_output_ptr + dst_idx * hidden_size
|
622
|
+
in_data = tl.load(load_ptr + offset, mask=mask)
|
623
|
+
sum_vec += in_data * weigh_scale
|
624
|
+
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
625
|
+
|
626
|
+
|
584
627
|
@triton.jit
|
585
628
|
def compute_m_range(
|
586
629
|
pid,
|