sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,186 @@
|
|
1
|
+
"""
|
2
|
+
Memory-efficient attention for decoding.
|
3
|
+
It supports page size = 1.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import functools
|
7
|
+
import logging
|
8
|
+
|
9
|
+
from wave_lang.kernel.lang.global_symbols import *
|
10
|
+
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
|
11
|
+
from wave_lang.kernel.wave.constraints import GenericDot, MMAOperand, MMAType
|
12
|
+
from wave_lang.kernel.wave.templates.paged_decode_attention import (
|
13
|
+
get_paged_decode_attention_kernels,
|
14
|
+
get_paged_decode_intermediate_arrays_shapes,
|
15
|
+
paged_decode_attention_shape,
|
16
|
+
)
|
17
|
+
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
|
18
|
+
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
import os
|
22
|
+
|
23
|
+
dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
|
24
|
+
|
25
|
+
|
26
|
+
@functools.lru_cache(maxsize=4096)
|
27
|
+
def get_wave_kernel(
|
28
|
+
shape: paged_decode_attention_shape,
|
29
|
+
max_kv_splits,
|
30
|
+
input_dtype,
|
31
|
+
output_dtype,
|
32
|
+
logit_cap,
|
33
|
+
):
|
34
|
+
mha = (shape.num_query_heads // shape.num_kv_heads) == 1
|
35
|
+
|
36
|
+
# Get the kernels (either compile or load from cache).
|
37
|
+
if mha:
|
38
|
+
mfma_variant = (
|
39
|
+
GenericDot(along_dim=MMAOperand.M, k_vec_size=4, k_mult=1),
|
40
|
+
GenericDot(along_dim=MMAOperand.M, k_vec_size=1, k_mult=64),
|
41
|
+
)
|
42
|
+
else:
|
43
|
+
mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)
|
44
|
+
|
45
|
+
(
|
46
|
+
phase_0,
|
47
|
+
phase_1,
|
48
|
+
hyperparams_0,
|
49
|
+
hyperparams_1,
|
50
|
+
dynamic_symbols_0,
|
51
|
+
dynamic_symbols_1,
|
52
|
+
) = get_paged_decode_attention_kernels(
|
53
|
+
shape,
|
54
|
+
mfma_variant,
|
55
|
+
max_kv_splits,
|
56
|
+
input_dtype=input_dtype,
|
57
|
+
output_dtype=output_dtype,
|
58
|
+
logit_cap=logit_cap,
|
59
|
+
)
|
60
|
+
hyperparams_0.update(get_default_scheduling_params())
|
61
|
+
hyperparams_1.update(get_default_scheduling_params())
|
62
|
+
|
63
|
+
options = WaveCompileOptions(
|
64
|
+
subs=hyperparams_0,
|
65
|
+
canonicalize=True,
|
66
|
+
run_bench=False,
|
67
|
+
use_buffer_load_ops=True,
|
68
|
+
use_buffer_store_ops=True,
|
69
|
+
waves_per_eu=2,
|
70
|
+
dynamic_symbols=dynamic_symbols_0,
|
71
|
+
wave_runtime=True,
|
72
|
+
)
|
73
|
+
options = set_default_run_config(options)
|
74
|
+
phase_0 = wave_compile(options, phase_0)
|
75
|
+
|
76
|
+
options = WaveCompileOptions(
|
77
|
+
subs=hyperparams_1,
|
78
|
+
canonicalize=True,
|
79
|
+
run_bench=False,
|
80
|
+
use_buffer_load_ops=False,
|
81
|
+
use_buffer_store_ops=False,
|
82
|
+
waves_per_eu=4,
|
83
|
+
dynamic_symbols=dynamic_symbols_1,
|
84
|
+
wave_runtime=True,
|
85
|
+
)
|
86
|
+
options = set_default_run_config(options)
|
87
|
+
phase_1 = wave_compile(options, phase_1)
|
88
|
+
|
89
|
+
return phase_0, phase_1
|
90
|
+
|
91
|
+
|
92
|
+
def decode_attention_intermediate_arrays_shapes(
|
93
|
+
num_seqs, head_size_kv, num_query_heads, max_kv_splits
|
94
|
+
):
|
95
|
+
# Not all fields are used, but we need to pass them to the function
|
96
|
+
shape = paged_decode_attention_shape(
|
97
|
+
num_query_heads=num_query_heads,
|
98
|
+
num_kv_heads=0,
|
99
|
+
head_size=0,
|
100
|
+
head_size_kv=head_size_kv,
|
101
|
+
block_size=0,
|
102
|
+
num_seqs=num_seqs,
|
103
|
+
)
|
104
|
+
return get_paged_decode_intermediate_arrays_shapes(shape, max_kv_splits)
|
105
|
+
|
106
|
+
|
107
|
+
def decode_attention_wave(
|
108
|
+
q,
|
109
|
+
k_buffer,
|
110
|
+
v_buffer,
|
111
|
+
o,
|
112
|
+
b_req_idx,
|
113
|
+
req_to_token,
|
114
|
+
attn_logits,
|
115
|
+
attn_logits_max,
|
116
|
+
num_kv_splits,
|
117
|
+
max_kv_splits,
|
118
|
+
sm_scale,
|
119
|
+
logit_cap,
|
120
|
+
):
|
121
|
+
num_seqs, num_query_heads, head_size = q.shape
|
122
|
+
_, num_kv_heads, _ = k_buffer.shape
|
123
|
+
_, _, head_size_kv = v_buffer.shape
|
124
|
+
block_size = 32
|
125
|
+
shape = paged_decode_attention_shape(
|
126
|
+
num_query_heads,
|
127
|
+
num_kv_heads,
|
128
|
+
head_size,
|
129
|
+
head_size_kv,
|
130
|
+
block_size,
|
131
|
+
num_seqs,
|
132
|
+
)
|
133
|
+
|
134
|
+
phase_0, phase_1 = get_wave_kernel(
|
135
|
+
shape, max_kv_splits, q.dtype, o.dtype, logit_cap
|
136
|
+
)
|
137
|
+
|
138
|
+
mb_qk = phase_0(
|
139
|
+
q,
|
140
|
+
k_buffer,
|
141
|
+
v_buffer,
|
142
|
+
b_req_idx,
|
143
|
+
req_to_token,
|
144
|
+
attn_logits,
|
145
|
+
attn_logits_max,
|
146
|
+
)
|
147
|
+
if dump_generated_mlir:
|
148
|
+
filename = f"wave_decode_attention_phase0_{'x'.join(map(str, shape))}.mlir"
|
149
|
+
with open(filename, "w") as f:
|
150
|
+
f.write(mb_qk.module_op.get_asm())
|
151
|
+
|
152
|
+
mb_sv = phase_1(attn_logits, attn_logits_max, b_req_idx, o)
|
153
|
+
if dump_generated_mlir:
|
154
|
+
filename = f"wave_decode_attention_phase1_{'x'.join(map(str, shape))}.mlir"
|
155
|
+
with open(filename, "w") as f:
|
156
|
+
f.write(mb_sv.module_op.get_asm())
|
157
|
+
|
158
|
+
|
159
|
+
def decode_attention_fwd(
|
160
|
+
q,
|
161
|
+
k_buffer,
|
162
|
+
v_buffer,
|
163
|
+
o,
|
164
|
+
b_req_idx,
|
165
|
+
req_to_token,
|
166
|
+
attn_logits,
|
167
|
+
attn_logits_max,
|
168
|
+
num_kv_splits,
|
169
|
+
max_kv_splits,
|
170
|
+
sm_scale,
|
171
|
+
logit_cap=0.0,
|
172
|
+
):
|
173
|
+
decode_attention_wave(
|
174
|
+
q,
|
175
|
+
k_buffer,
|
176
|
+
v_buffer,
|
177
|
+
o,
|
178
|
+
b_req_idx,
|
179
|
+
req_to_token,
|
180
|
+
attn_logits,
|
181
|
+
attn_logits_max,
|
182
|
+
num_kv_splits,
|
183
|
+
max_kv_splits,
|
184
|
+
sm_scale,
|
185
|
+
logit_cap,
|
186
|
+
)
|
@@ -0,0 +1,149 @@
|
|
1
|
+
"""
|
2
|
+
Memory-efficient attention for prefill.
|
3
|
+
It support page size = 1.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import functools
|
7
|
+
import os
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from wave_lang.kernel.lang.global_symbols import *
|
11
|
+
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
|
12
|
+
from wave_lang.kernel.wave.constraints import MMAType
|
13
|
+
from wave_lang.kernel.wave.scheduling.schedule import SchedulingType
|
14
|
+
from wave_lang.kernel.wave.templates.attention_common import AttentionShape
|
15
|
+
from wave_lang.kernel.wave.templates.extend_attention import get_extend_attention_kernel
|
16
|
+
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
|
17
|
+
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
|
18
|
+
|
19
|
+
dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
|
20
|
+
|
21
|
+
|
22
|
+
@functools.lru_cache
|
23
|
+
def get_wave_kernel(
|
24
|
+
shape: AttentionShape,
|
25
|
+
q_shape: tuple[int],
|
26
|
+
k_shape: tuple[int],
|
27
|
+
v_shape: tuple[int],
|
28
|
+
k_cache_shape: tuple[int],
|
29
|
+
v_cache_shape: tuple[int],
|
30
|
+
o_shape: tuple[int],
|
31
|
+
input_dtype: torch.dtype,
|
32
|
+
output_dtype: torch.dtype,
|
33
|
+
size_dtype: torch.dtype,
|
34
|
+
is_causal: bool,
|
35
|
+
logit_cap: float,
|
36
|
+
layer_scaling: float,
|
37
|
+
):
|
38
|
+
assert shape.num_query_heads % shape.num_kv_heads == 0
|
39
|
+
|
40
|
+
mfma_variant = (MMAType.F32_16x16x32_K8_F16, MMAType.F32_16x16x16_F16)
|
41
|
+
(
|
42
|
+
extend_attention,
|
43
|
+
hyperparams,
|
44
|
+
dynamic_symbols,
|
45
|
+
) = get_extend_attention_kernel(
|
46
|
+
shape,
|
47
|
+
mfma_variant,
|
48
|
+
q_shape,
|
49
|
+
k_shape,
|
50
|
+
v_shape,
|
51
|
+
k_cache_shape,
|
52
|
+
v_cache_shape,
|
53
|
+
o_shape,
|
54
|
+
input_dtype=input_dtype,
|
55
|
+
output_dtype=output_dtype,
|
56
|
+
size_dtype=size_dtype,
|
57
|
+
is_causal=is_causal,
|
58
|
+
layer_scaling=layer_scaling,
|
59
|
+
logit_cap=logit_cap,
|
60
|
+
)
|
61
|
+
|
62
|
+
hyperparams.update(get_default_scheduling_params())
|
63
|
+
options = WaveCompileOptions(
|
64
|
+
subs=hyperparams,
|
65
|
+
canonicalize=True,
|
66
|
+
run_bench=False,
|
67
|
+
schedule=SchedulingType.NONE,
|
68
|
+
use_scheduling_barriers=False,
|
69
|
+
dynamic_symbols=dynamic_symbols,
|
70
|
+
use_buffer_load_ops=True,
|
71
|
+
use_buffer_store_ops=True,
|
72
|
+
waves_per_eu=2,
|
73
|
+
denorm_fp_math_f32="preserve-sign",
|
74
|
+
gpu_native_math_precision=True,
|
75
|
+
wave_runtime=True,
|
76
|
+
)
|
77
|
+
options = set_default_run_config(options)
|
78
|
+
extend_attention = wave_compile(options, extend_attention)
|
79
|
+
|
80
|
+
return extend_attention
|
81
|
+
|
82
|
+
|
83
|
+
def extend_attention_wave(
|
84
|
+
q_extend,
|
85
|
+
k_extend,
|
86
|
+
v_extend,
|
87
|
+
k_buffer,
|
88
|
+
v_buffer,
|
89
|
+
qo_indptr,
|
90
|
+
kv_indptr,
|
91
|
+
kv_indices,
|
92
|
+
custom_mask,
|
93
|
+
mask_indptr,
|
94
|
+
max_seq_len,
|
95
|
+
output,
|
96
|
+
is_causal=True,
|
97
|
+
layer_scaling=None,
|
98
|
+
logit_cap=0,
|
99
|
+
):
|
100
|
+
shape = AttentionShape(
|
101
|
+
num_query_heads=q_extend.shape[1],
|
102
|
+
num_kv_heads=k_extend.shape[1],
|
103
|
+
head_size=q_extend.shape[2],
|
104
|
+
head_size_kv=k_extend.shape[2],
|
105
|
+
num_seqs=kv_indptr.shape[0] - 1,
|
106
|
+
max_seq_len=max_seq_len,
|
107
|
+
)
|
108
|
+
|
109
|
+
# Run the wave kernel.
|
110
|
+
extend_attention = get_wave_kernel(
|
111
|
+
shape,
|
112
|
+
q_extend.shape,
|
113
|
+
k_extend.shape,
|
114
|
+
v_extend.shape,
|
115
|
+
k_buffer.shape,
|
116
|
+
v_buffer.shape,
|
117
|
+
output.shape,
|
118
|
+
input_dtype=q_extend.dtype,
|
119
|
+
output_dtype=output.dtype,
|
120
|
+
size_dtype=qo_indptr.dtype,
|
121
|
+
is_causal=is_causal,
|
122
|
+
layer_scaling=layer_scaling,
|
123
|
+
logit_cap=logit_cap,
|
124
|
+
)
|
125
|
+
|
126
|
+
mb = extend_attention(
|
127
|
+
q_extend,
|
128
|
+
k_extend,
|
129
|
+
v_extend,
|
130
|
+
k_buffer,
|
131
|
+
v_buffer,
|
132
|
+
qo_indptr,
|
133
|
+
kv_indptr,
|
134
|
+
kv_indices,
|
135
|
+
max_seq_len,
|
136
|
+
output,
|
137
|
+
)
|
138
|
+
|
139
|
+
if dump_generated_mlir:
|
140
|
+
shape_list = [
|
141
|
+
q_extend.shape[0],
|
142
|
+
q_extend.shape[1],
|
143
|
+
k_extend.shape[1],
|
144
|
+
q_extend.shape[2],
|
145
|
+
k_extend.shape[2],
|
146
|
+
]
|
147
|
+
filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir"
|
148
|
+
with open(filename, "w") as f:
|
149
|
+
f.write(mb.module_op.get_asm())
|
@@ -0,0 +1,79 @@
|
|
1
|
+
"""
|
2
|
+
Memory-efficient attention for prefill.
|
3
|
+
It support page size = 1.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import math
|
7
|
+
import os
|
8
|
+
|
9
|
+
from wave_lang.kernel.lang.global_symbols import *
|
10
|
+
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
|
11
|
+
from wave_lang.kernel.wave.constraints import MMAType
|
12
|
+
from wave_lang.kernel.wave.templates.attention_common import AttentionShape
|
13
|
+
from wave_lang.kernel.wave.templates.prefill_attention import (
|
14
|
+
get_prefill_attention_kernel,
|
15
|
+
)
|
16
|
+
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
|
17
|
+
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
|
18
|
+
|
19
|
+
dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
|
20
|
+
|
21
|
+
|
22
|
+
def prefill_attention_wave(
|
23
|
+
q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=True
|
24
|
+
):
|
25
|
+
|
26
|
+
shape = AttentionShape(
|
27
|
+
num_query_heads=q.shape[1],
|
28
|
+
num_kv_heads=k.shape[1],
|
29
|
+
head_size=q.shape[2],
|
30
|
+
head_size_kv=k.shape[2],
|
31
|
+
num_seqs=b_seq_len.shape[0],
|
32
|
+
max_seq_len=max_seq_len,
|
33
|
+
total_seq_len=q.shape[0],
|
34
|
+
)
|
35
|
+
|
36
|
+
assert shape.num_query_heads % shape.num_kv_heads == 0
|
37
|
+
|
38
|
+
output_shape = (shape.total_seq_len, shape.num_query_heads, shape.head_size_kv)
|
39
|
+
# Run the wave kernel.
|
40
|
+
mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)
|
41
|
+
(prefill, hyperparams) = get_prefill_attention_kernel(
|
42
|
+
shape,
|
43
|
+
mfma_variant,
|
44
|
+
q.shape,
|
45
|
+
k.shape,
|
46
|
+
v.shape,
|
47
|
+
output_shape,
|
48
|
+
input_dtype=q.dtype,
|
49
|
+
output_dtype=o.dtype,
|
50
|
+
size_dtype=b_seq_len.dtype,
|
51
|
+
)
|
52
|
+
|
53
|
+
hyperparams.update(get_default_scheduling_params())
|
54
|
+
|
55
|
+
log2e = 1.44269504089
|
56
|
+
dk_sqrt = math.sqrt(1.0 / shape.head_size)
|
57
|
+
|
58
|
+
options = WaveCompileOptions(
|
59
|
+
subs=hyperparams,
|
60
|
+
canonicalize=True,
|
61
|
+
run_bench=False,
|
62
|
+
use_scheduling_barriers=False,
|
63
|
+
)
|
64
|
+
options = set_default_run_config(options)
|
65
|
+
prefill = wave_compile(options, prefill)
|
66
|
+
|
67
|
+
mb = prefill(
|
68
|
+
q * dk_sqrt * log2e,
|
69
|
+
k,
|
70
|
+
v,
|
71
|
+
b_start_loc,
|
72
|
+
b_seq_len,
|
73
|
+
o,
|
74
|
+
)
|
75
|
+
if dump_generated_mlir:
|
76
|
+
shape_list = [q.shape[0], q.shape[1], k.shape[1], q.shape[2], k.shape[2]]
|
77
|
+
filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir"
|
78
|
+
with open(filename, "w") as f:
|
79
|
+
f.write(mb.module_op.get_asm())
|
@@ -32,6 +32,8 @@ from sglang.srt.layers.dp_attention import (
|
|
32
32
|
get_attention_dp_size,
|
33
33
|
get_attention_tp_rank,
|
34
34
|
get_attention_tp_size,
|
35
|
+
get_global_dp_buffer,
|
36
|
+
get_local_dp_buffer,
|
35
37
|
)
|
36
38
|
from sglang.srt.layers.utils import is_sm100_supported
|
37
39
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -319,7 +321,7 @@ class CommunicateSimpleFn:
|
|
319
321
|
context: CommunicateContext,
|
320
322
|
) -> torch.Tensor:
|
321
323
|
hidden_states, local_hidden_states = (
|
322
|
-
|
324
|
+
get_local_dp_buffer(),
|
323
325
|
hidden_states,
|
324
326
|
)
|
325
327
|
attn_tp_all_gather_into_tensor(
|
@@ -408,9 +410,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
408
410
|
):
|
409
411
|
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
|
410
412
|
residual, local_residual = (
|
411
|
-
|
412
|
-
: forward_batch.input_ids.shape[0]
|
413
|
-
].clone(),
|
413
|
+
get_local_dp_buffer(),
|
414
414
|
residual,
|
415
415
|
)
|
416
416
|
attn_tp_all_gather_into_tensor(residual, local_residual)
|
@@ -420,13 +420,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
420
420
|
|
421
421
|
# Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
|
422
422
|
use_layer_norm_before_gather = context.attn_tp_size == 1
|
423
|
-
if use_layer_norm_before_gather:
|
424
|
-
residual
|
425
|
-
|
426
|
-
hidden_states = layernorm(hidden_states)
|
427
|
-
|
423
|
+
if use_layer_norm_before_gather and hidden_states.shape[0] != 0:
|
424
|
+
residual = hidden_states
|
425
|
+
hidden_states = layernorm(hidden_states)
|
428
426
|
hidden_states, local_hidden_states = (
|
429
|
-
|
427
|
+
get_global_dp_buffer(),
|
430
428
|
hidden_states,
|
431
429
|
)
|
432
430
|
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
@@ -443,7 +441,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
443
441
|
and _is_flashinfer_available
|
444
442
|
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
445
443
|
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
446
|
-
and hidden_states.shape[0] <=
|
444
|
+
and hidden_states.shape[0] <= 2048
|
447
445
|
):
|
448
446
|
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
449
447
|
hidden_states, residual
|
@@ -550,7 +548,7 @@ class CommunicateSummableTensorPairFn:
|
|
550
548
|
allow_reduce_scatter: bool = False,
|
551
549
|
):
|
552
550
|
hidden_states, global_hidden_states = (
|
553
|
-
|
551
|
+
get_local_dp_buffer(),
|
554
552
|
hidden_states,
|
555
553
|
)
|
556
554
|
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
|
@@ -571,7 +569,7 @@ class CommunicateSummableTensorPairFn:
|
|
571
569
|
hidden_states += residual
|
572
570
|
residual = None
|
573
571
|
hidden_states, local_hidden_states = (
|
574
|
-
|
572
|
+
get_local_dp_buffer(),
|
575
573
|
hidden_states,
|
576
574
|
)
|
577
575
|
attn_tp_all_gather_into_tensor(
|