sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,186 @@
|
|
1
|
+
"""
|
2
|
+
Memory-efficient attention for decoding.
|
3
|
+
It supports page size = 1.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import functools
|
7
|
+
import logging
|
8
|
+
|
9
|
+
from wave_lang.kernel.lang.global_symbols import *
|
10
|
+
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
|
11
|
+
from wave_lang.kernel.wave.constraints import GenericDot, MMAOperand, MMAType
|
12
|
+
from wave_lang.kernel.wave.templates.paged_decode_attention import (
|
13
|
+
get_paged_decode_attention_kernels,
|
14
|
+
get_paged_decode_intermediate_arrays_shapes,
|
15
|
+
paged_decode_attention_shape,
|
16
|
+
)
|
17
|
+
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
|
18
|
+
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
import os
|
22
|
+
|
23
|
+
dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
|
24
|
+
|
25
|
+
|
26
|
+
@functools.lru_cache(maxsize=4096)
|
27
|
+
def get_wave_kernel(
|
28
|
+
shape: paged_decode_attention_shape,
|
29
|
+
max_kv_splits,
|
30
|
+
input_dtype,
|
31
|
+
output_dtype,
|
32
|
+
logit_cap,
|
33
|
+
):
|
34
|
+
mha = (shape.num_query_heads // shape.num_kv_heads) == 1
|
35
|
+
|
36
|
+
# Get the kernels (either compile or load from cache).
|
37
|
+
if mha:
|
38
|
+
mfma_variant = (
|
39
|
+
GenericDot(along_dim=MMAOperand.M, k_vec_size=4, k_mult=1),
|
40
|
+
GenericDot(along_dim=MMAOperand.M, k_vec_size=1, k_mult=64),
|
41
|
+
)
|
42
|
+
else:
|
43
|
+
mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)
|
44
|
+
|
45
|
+
(
|
46
|
+
phase_0,
|
47
|
+
phase_1,
|
48
|
+
hyperparams_0,
|
49
|
+
hyperparams_1,
|
50
|
+
dynamic_symbols_0,
|
51
|
+
dynamic_symbols_1,
|
52
|
+
) = get_paged_decode_attention_kernels(
|
53
|
+
shape,
|
54
|
+
mfma_variant,
|
55
|
+
max_kv_splits,
|
56
|
+
input_dtype=input_dtype,
|
57
|
+
output_dtype=output_dtype,
|
58
|
+
logit_cap=logit_cap,
|
59
|
+
)
|
60
|
+
hyperparams_0.update(get_default_scheduling_params())
|
61
|
+
hyperparams_1.update(get_default_scheduling_params())
|
62
|
+
|
63
|
+
options = WaveCompileOptions(
|
64
|
+
subs=hyperparams_0,
|
65
|
+
canonicalize=True,
|
66
|
+
run_bench=False,
|
67
|
+
use_buffer_load_ops=True,
|
68
|
+
use_buffer_store_ops=True,
|
69
|
+
waves_per_eu=2,
|
70
|
+
dynamic_symbols=dynamic_symbols_0,
|
71
|
+
wave_runtime=True,
|
72
|
+
)
|
73
|
+
options = set_default_run_config(options)
|
74
|
+
phase_0 = wave_compile(options, phase_0)
|
75
|
+
|
76
|
+
options = WaveCompileOptions(
|
77
|
+
subs=hyperparams_1,
|
78
|
+
canonicalize=True,
|
79
|
+
run_bench=False,
|
80
|
+
use_buffer_load_ops=False,
|
81
|
+
use_buffer_store_ops=False,
|
82
|
+
waves_per_eu=4,
|
83
|
+
dynamic_symbols=dynamic_symbols_1,
|
84
|
+
wave_runtime=True,
|
85
|
+
)
|
86
|
+
options = set_default_run_config(options)
|
87
|
+
phase_1 = wave_compile(options, phase_1)
|
88
|
+
|
89
|
+
return phase_0, phase_1
|
90
|
+
|
91
|
+
|
92
|
+
def decode_attention_intermediate_arrays_shapes(
|
93
|
+
num_seqs, head_size_kv, num_query_heads, max_kv_splits
|
94
|
+
):
|
95
|
+
# Not all fields are used, but we need to pass them to the function
|
96
|
+
shape = paged_decode_attention_shape(
|
97
|
+
num_query_heads=num_query_heads,
|
98
|
+
num_kv_heads=0,
|
99
|
+
head_size=0,
|
100
|
+
head_size_kv=head_size_kv,
|
101
|
+
block_size=0,
|
102
|
+
num_seqs=num_seqs,
|
103
|
+
)
|
104
|
+
return get_paged_decode_intermediate_arrays_shapes(shape, max_kv_splits)
|
105
|
+
|
106
|
+
|
107
|
+
def decode_attention_wave(
|
108
|
+
q,
|
109
|
+
k_buffer,
|
110
|
+
v_buffer,
|
111
|
+
o,
|
112
|
+
b_req_idx,
|
113
|
+
req_to_token,
|
114
|
+
attn_logits,
|
115
|
+
attn_logits_max,
|
116
|
+
num_kv_splits,
|
117
|
+
max_kv_splits,
|
118
|
+
sm_scale,
|
119
|
+
logit_cap,
|
120
|
+
):
|
121
|
+
num_seqs, num_query_heads, head_size = q.shape
|
122
|
+
_, num_kv_heads, _ = k_buffer.shape
|
123
|
+
_, _, head_size_kv = v_buffer.shape
|
124
|
+
block_size = 32
|
125
|
+
shape = paged_decode_attention_shape(
|
126
|
+
num_query_heads,
|
127
|
+
num_kv_heads,
|
128
|
+
head_size,
|
129
|
+
head_size_kv,
|
130
|
+
block_size,
|
131
|
+
num_seqs,
|
132
|
+
)
|
133
|
+
|
134
|
+
phase_0, phase_1 = get_wave_kernel(
|
135
|
+
shape, max_kv_splits, q.dtype, o.dtype, logit_cap
|
136
|
+
)
|
137
|
+
|
138
|
+
mb_qk = phase_0(
|
139
|
+
q,
|
140
|
+
k_buffer,
|
141
|
+
v_buffer,
|
142
|
+
b_req_idx,
|
143
|
+
req_to_token,
|
144
|
+
attn_logits,
|
145
|
+
attn_logits_max,
|
146
|
+
)
|
147
|
+
if dump_generated_mlir:
|
148
|
+
filename = f"wave_decode_attention_phase0_{'x'.join(map(str, shape))}.mlir"
|
149
|
+
with open(filename, "w") as f:
|
150
|
+
f.write(mb_qk.module_op.get_asm())
|
151
|
+
|
152
|
+
mb_sv = phase_1(attn_logits, attn_logits_max, b_req_idx, o)
|
153
|
+
if dump_generated_mlir:
|
154
|
+
filename = f"wave_decode_attention_phase1_{'x'.join(map(str, shape))}.mlir"
|
155
|
+
with open(filename, "w") as f:
|
156
|
+
f.write(mb_sv.module_op.get_asm())
|
157
|
+
|
158
|
+
|
159
|
+
def decode_attention_fwd(
|
160
|
+
q,
|
161
|
+
k_buffer,
|
162
|
+
v_buffer,
|
163
|
+
o,
|
164
|
+
b_req_idx,
|
165
|
+
req_to_token,
|
166
|
+
attn_logits,
|
167
|
+
attn_logits_max,
|
168
|
+
num_kv_splits,
|
169
|
+
max_kv_splits,
|
170
|
+
sm_scale,
|
171
|
+
logit_cap=0.0,
|
172
|
+
):
|
173
|
+
decode_attention_wave(
|
174
|
+
q,
|
175
|
+
k_buffer,
|
176
|
+
v_buffer,
|
177
|
+
o,
|
178
|
+
b_req_idx,
|
179
|
+
req_to_token,
|
180
|
+
attn_logits,
|
181
|
+
attn_logits_max,
|
182
|
+
num_kv_splits,
|
183
|
+
max_kv_splits,
|
184
|
+
sm_scale,
|
185
|
+
logit_cap,
|
186
|
+
)
|
@@ -0,0 +1,149 @@
|
|
1
|
+
"""
|
2
|
+
Memory-efficient attention for prefill.
|
3
|
+
It support page size = 1.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import functools
|
7
|
+
import os
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from wave_lang.kernel.lang.global_symbols import *
|
11
|
+
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
|
12
|
+
from wave_lang.kernel.wave.constraints import MMAType
|
13
|
+
from wave_lang.kernel.wave.scheduling.schedule import SchedulingType
|
14
|
+
from wave_lang.kernel.wave.templates.attention_common import AttentionShape
|
15
|
+
from wave_lang.kernel.wave.templates.extend_attention import get_extend_attention_kernel
|
16
|
+
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
|
17
|
+
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
|
18
|
+
|
19
|
+
dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
|
20
|
+
|
21
|
+
|
22
|
+
@functools.lru_cache
|
23
|
+
def get_wave_kernel(
|
24
|
+
shape: AttentionShape,
|
25
|
+
q_shape: tuple[int],
|
26
|
+
k_shape: tuple[int],
|
27
|
+
v_shape: tuple[int],
|
28
|
+
k_cache_shape: tuple[int],
|
29
|
+
v_cache_shape: tuple[int],
|
30
|
+
o_shape: tuple[int],
|
31
|
+
input_dtype: torch.dtype,
|
32
|
+
output_dtype: torch.dtype,
|
33
|
+
size_dtype: torch.dtype,
|
34
|
+
is_causal: bool,
|
35
|
+
logit_cap: float,
|
36
|
+
layer_scaling: float,
|
37
|
+
):
|
38
|
+
assert shape.num_query_heads % shape.num_kv_heads == 0
|
39
|
+
|
40
|
+
mfma_variant = (MMAType.F32_16x16x32_K8_F16, MMAType.F32_16x16x16_F16)
|
41
|
+
(
|
42
|
+
extend_attention,
|
43
|
+
hyperparams,
|
44
|
+
dynamic_symbols,
|
45
|
+
) = get_extend_attention_kernel(
|
46
|
+
shape,
|
47
|
+
mfma_variant,
|
48
|
+
q_shape,
|
49
|
+
k_shape,
|
50
|
+
v_shape,
|
51
|
+
k_cache_shape,
|
52
|
+
v_cache_shape,
|
53
|
+
o_shape,
|
54
|
+
input_dtype=input_dtype,
|
55
|
+
output_dtype=output_dtype,
|
56
|
+
size_dtype=size_dtype,
|
57
|
+
is_causal=is_causal,
|
58
|
+
layer_scaling=layer_scaling,
|
59
|
+
logit_cap=logit_cap,
|
60
|
+
)
|
61
|
+
|
62
|
+
hyperparams.update(get_default_scheduling_params())
|
63
|
+
options = WaveCompileOptions(
|
64
|
+
subs=hyperparams,
|
65
|
+
canonicalize=True,
|
66
|
+
run_bench=False,
|
67
|
+
schedule=SchedulingType.NONE,
|
68
|
+
use_scheduling_barriers=False,
|
69
|
+
dynamic_symbols=dynamic_symbols,
|
70
|
+
use_buffer_load_ops=True,
|
71
|
+
use_buffer_store_ops=True,
|
72
|
+
waves_per_eu=2,
|
73
|
+
denorm_fp_math_f32="preserve-sign",
|
74
|
+
gpu_native_math_precision=True,
|
75
|
+
wave_runtime=True,
|
76
|
+
)
|
77
|
+
options = set_default_run_config(options)
|
78
|
+
extend_attention = wave_compile(options, extend_attention)
|
79
|
+
|
80
|
+
return extend_attention
|
81
|
+
|
82
|
+
|
83
|
+
def extend_attention_wave(
|
84
|
+
q_extend,
|
85
|
+
k_extend,
|
86
|
+
v_extend,
|
87
|
+
k_buffer,
|
88
|
+
v_buffer,
|
89
|
+
qo_indptr,
|
90
|
+
kv_indptr,
|
91
|
+
kv_indices,
|
92
|
+
custom_mask,
|
93
|
+
mask_indptr,
|
94
|
+
max_seq_len,
|
95
|
+
output,
|
96
|
+
is_causal=True,
|
97
|
+
layer_scaling=None,
|
98
|
+
logit_cap=0,
|
99
|
+
):
|
100
|
+
shape = AttentionShape(
|
101
|
+
num_query_heads=q_extend.shape[1],
|
102
|
+
num_kv_heads=k_extend.shape[1],
|
103
|
+
head_size=q_extend.shape[2],
|
104
|
+
head_size_kv=k_extend.shape[2],
|
105
|
+
num_seqs=kv_indptr.shape[0] - 1,
|
106
|
+
max_seq_len=max_seq_len,
|
107
|
+
)
|
108
|
+
|
109
|
+
# Run the wave kernel.
|
110
|
+
extend_attention = get_wave_kernel(
|
111
|
+
shape,
|
112
|
+
q_extend.shape,
|
113
|
+
k_extend.shape,
|
114
|
+
v_extend.shape,
|
115
|
+
k_buffer.shape,
|
116
|
+
v_buffer.shape,
|
117
|
+
output.shape,
|
118
|
+
input_dtype=q_extend.dtype,
|
119
|
+
output_dtype=output.dtype,
|
120
|
+
size_dtype=qo_indptr.dtype,
|
121
|
+
is_causal=is_causal,
|
122
|
+
layer_scaling=layer_scaling,
|
123
|
+
logit_cap=logit_cap,
|
124
|
+
)
|
125
|
+
|
126
|
+
mb = extend_attention(
|
127
|
+
q_extend,
|
128
|
+
k_extend,
|
129
|
+
v_extend,
|
130
|
+
k_buffer,
|
131
|
+
v_buffer,
|
132
|
+
qo_indptr,
|
133
|
+
kv_indptr,
|
134
|
+
kv_indices,
|
135
|
+
max_seq_len,
|
136
|
+
output,
|
137
|
+
)
|
138
|
+
|
139
|
+
if dump_generated_mlir:
|
140
|
+
shape_list = [
|
141
|
+
q_extend.shape[0],
|
142
|
+
q_extend.shape[1],
|
143
|
+
k_extend.shape[1],
|
144
|
+
q_extend.shape[2],
|
145
|
+
k_extend.shape[2],
|
146
|
+
]
|
147
|
+
filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir"
|
148
|
+
with open(filename, "w") as f:
|
149
|
+
f.write(mb.module_op.get_asm())
|
@@ -0,0 +1,79 @@
|
|
1
|
+
"""
|
2
|
+
Memory-efficient attention for prefill.
|
3
|
+
It support page size = 1.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import math
|
7
|
+
import os
|
8
|
+
|
9
|
+
from wave_lang.kernel.lang.global_symbols import *
|
10
|
+
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
|
11
|
+
from wave_lang.kernel.wave.constraints import MMAType
|
12
|
+
from wave_lang.kernel.wave.templates.attention_common import AttentionShape
|
13
|
+
from wave_lang.kernel.wave.templates.prefill_attention import (
|
14
|
+
get_prefill_attention_kernel,
|
15
|
+
)
|
16
|
+
from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params
|
17
|
+
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
|
18
|
+
|
19
|
+
dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0))
|
20
|
+
|
21
|
+
|
22
|
+
def prefill_attention_wave(
|
23
|
+
q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=True
|
24
|
+
):
|
25
|
+
|
26
|
+
shape = AttentionShape(
|
27
|
+
num_query_heads=q.shape[1],
|
28
|
+
num_kv_heads=k.shape[1],
|
29
|
+
head_size=q.shape[2],
|
30
|
+
head_size_kv=k.shape[2],
|
31
|
+
num_seqs=b_seq_len.shape[0],
|
32
|
+
max_seq_len=max_seq_len,
|
33
|
+
total_seq_len=q.shape[0],
|
34
|
+
)
|
35
|
+
|
36
|
+
assert shape.num_query_heads % shape.num_kv_heads == 0
|
37
|
+
|
38
|
+
output_shape = (shape.total_seq_len, shape.num_query_heads, shape.head_size_kv)
|
39
|
+
# Run the wave kernel.
|
40
|
+
mfma_variant = (MMAType.F32_16x16x16_F16, MMAType.F32_16x16x16_F16)
|
41
|
+
(prefill, hyperparams) = get_prefill_attention_kernel(
|
42
|
+
shape,
|
43
|
+
mfma_variant,
|
44
|
+
q.shape,
|
45
|
+
k.shape,
|
46
|
+
v.shape,
|
47
|
+
output_shape,
|
48
|
+
input_dtype=q.dtype,
|
49
|
+
output_dtype=o.dtype,
|
50
|
+
size_dtype=b_seq_len.dtype,
|
51
|
+
)
|
52
|
+
|
53
|
+
hyperparams.update(get_default_scheduling_params())
|
54
|
+
|
55
|
+
log2e = 1.44269504089
|
56
|
+
dk_sqrt = math.sqrt(1.0 / shape.head_size)
|
57
|
+
|
58
|
+
options = WaveCompileOptions(
|
59
|
+
subs=hyperparams,
|
60
|
+
canonicalize=True,
|
61
|
+
run_bench=False,
|
62
|
+
use_scheduling_barriers=False,
|
63
|
+
)
|
64
|
+
options = set_default_run_config(options)
|
65
|
+
prefill = wave_compile(options, prefill)
|
66
|
+
|
67
|
+
mb = prefill(
|
68
|
+
q * dk_sqrt * log2e,
|
69
|
+
k,
|
70
|
+
v,
|
71
|
+
b_start_loc,
|
72
|
+
b_seq_len,
|
73
|
+
o,
|
74
|
+
)
|
75
|
+
if dump_generated_mlir:
|
76
|
+
shape_list = [q.shape[0], q.shape[1], k.shape[1], q.shape[2], k.shape[2]]
|
77
|
+
filename = f"wave_prefill_attention_{'x'.join(map(str, shape_list))}.mlir"
|
78
|
+
with open(filename, "w") as f:
|
79
|
+
f.write(mb.module_op.get_asm())
|
@@ -27,6 +27,7 @@ from sglang.srt.layers.dp_attention import (
|
|
27
27
|
attn_tp_all_gather_into_tensor,
|
28
28
|
attn_tp_reduce_scatter_tensor,
|
29
29
|
dp_gather_partial,
|
30
|
+
dp_reduce_scatter_tensor,
|
30
31
|
dp_scatter,
|
31
32
|
get_attention_dp_size,
|
32
33
|
get_attention_tp_rank,
|
@@ -149,10 +150,13 @@ class LayerCommunicator:
|
|
149
150
|
layer_scatter_modes: LayerScatterModes,
|
150
151
|
input_layernorm: torch.nn.Module,
|
151
152
|
post_attention_layernorm: torch.nn.Module,
|
153
|
+
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
|
154
|
+
allow_reduce_scatter: bool = False,
|
152
155
|
):
|
153
156
|
self.layer_scatter_modes = layer_scatter_modes
|
154
157
|
self.input_layernorm = input_layernorm
|
155
158
|
self.post_attention_layernorm = post_attention_layernorm
|
159
|
+
self.allow_reduce_scatter = allow_reduce_scatter
|
156
160
|
|
157
161
|
self._context = CommunicateContext.init_new()
|
158
162
|
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
|
@@ -239,6 +243,15 @@ class LayerCommunicator:
|
|
239
243
|
residual=residual,
|
240
244
|
forward_batch=forward_batch,
|
241
245
|
context=self._context,
|
246
|
+
allow_reduce_scatter=self.allow_reduce_scatter,
|
247
|
+
)
|
248
|
+
|
249
|
+
def should_use_reduce_scatter(self, forward_batch: ForwardBatch):
|
250
|
+
return (
|
251
|
+
self.allow_reduce_scatter
|
252
|
+
and self._communicate_summable_tensor_pair_fn
|
253
|
+
is CommunicateSummableTensorPairFn._scatter_hidden_states
|
254
|
+
and forward_batch.dp_padding_mode.is_max_len()
|
242
255
|
)
|
243
256
|
|
244
257
|
|
@@ -395,9 +408,9 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
395
408
|
):
|
396
409
|
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
|
397
410
|
residual, local_residual = (
|
398
|
-
|
399
|
-
: forward_batch.input_ids.shape[0]
|
400
|
-
|
411
|
+
torch.empty_like(
|
412
|
+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]]
|
413
|
+
),
|
401
414
|
residual,
|
402
415
|
)
|
403
416
|
attn_tp_all_gather_into_tensor(residual, local_residual)
|
@@ -407,13 +420,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
407
420
|
|
408
421
|
# Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
|
409
422
|
use_layer_norm_before_gather = context.attn_tp_size == 1
|
410
|
-
if use_layer_norm_before_gather:
|
411
|
-
residual
|
412
|
-
|
413
|
-
hidden_states = layernorm(hidden_states)
|
414
|
-
|
423
|
+
if use_layer_norm_before_gather and hidden_states.shape[0] != 0:
|
424
|
+
residual = hidden_states
|
425
|
+
hidden_states = layernorm(hidden_states)
|
415
426
|
hidden_states, local_hidden_states = (
|
416
|
-
forward_batch.gathered_buffer,
|
427
|
+
torch.empty_like(forward_batch.gathered_buffer),
|
417
428
|
hidden_states,
|
418
429
|
)
|
419
430
|
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
@@ -430,7 +441,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
430
441
|
and _is_flashinfer_available
|
431
442
|
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
432
443
|
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
433
|
-
and hidden_states.shape[0] <=
|
444
|
+
and hidden_states.shape[0] <= 2048
|
434
445
|
):
|
435
446
|
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
436
447
|
hidden_states, residual
|
@@ -524,6 +535,7 @@ class CommunicateSummableTensorPairFn:
|
|
524
535
|
residual: torch.Tensor,
|
525
536
|
forward_batch: ForwardBatch,
|
526
537
|
context: CommunicateContext,
|
538
|
+
**kwargs,
|
527
539
|
):
|
528
540
|
return hidden_states, residual
|
529
541
|
|
@@ -533,15 +545,17 @@ class CommunicateSummableTensorPairFn:
|
|
533
545
|
residual: torch.Tensor,
|
534
546
|
forward_batch: ForwardBatch,
|
535
547
|
context: CommunicateContext,
|
548
|
+
allow_reduce_scatter: bool = False,
|
536
549
|
):
|
537
|
-
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
538
|
-
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
539
|
-
# be careful about this!
|
540
550
|
hidden_states, global_hidden_states = (
|
541
551
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
542
552
|
hidden_states,
|
543
553
|
)
|
544
|
-
|
554
|
+
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
|
555
|
+
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
|
556
|
+
dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
|
557
|
+
else:
|
558
|
+
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
545
559
|
return hidden_states, residual
|
546
560
|
|
547
561
|
@staticmethod
|
@@ -550,6 +564,7 @@ class CommunicateSummableTensorPairFn:
|
|
550
564
|
residual: torch.Tensor,
|
551
565
|
forward_batch: ForwardBatch,
|
552
566
|
context: CommunicateContext,
|
567
|
+
**kwargs,
|
553
568
|
):
|
554
569
|
hidden_states += residual
|
555
570
|
residual = None
|
@@ -12,6 +12,7 @@ import triton.language as tl
|
|
12
12
|
|
13
13
|
from sglang.srt.distributed import (
|
14
14
|
GroupCoordinator,
|
15
|
+
get_tensor_model_parallel_rank,
|
15
16
|
get_tensor_model_parallel_world_size,
|
16
17
|
get_tp_group,
|
17
18
|
tensor_model_parallel_all_reduce,
|
@@ -355,6 +356,17 @@ def dp_scatter(
|
|
355
356
|
)
|
356
357
|
|
357
358
|
|
359
|
+
def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
|
360
|
+
if get_tensor_model_parallel_world_size() == get_attention_dp_size():
|
361
|
+
get_tp_group().reduce_scatter_tensor(output, input)
|
362
|
+
else:
|
363
|
+
scattered_local_tokens = input.tensor_split(
|
364
|
+
get_tensor_model_parallel_world_size()
|
365
|
+
)[get_tensor_model_parallel_rank()]
|
366
|
+
get_tp_group().reduce_scatter_tensor(scattered_local_tokens, input)
|
367
|
+
get_attention_tp_group().all_gather_into_tensor(output, scattered_local_tokens)
|
368
|
+
|
369
|
+
|
358
370
|
def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
|
359
371
|
return get_attention_tp_group().reduce_scatter_tensor(output, input)
|
360
372
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import Tuple
|
2
|
+
from typing import Optional, Tuple
|
3
3
|
|
4
4
|
import torch
|
5
5
|
import torch.distributed as dist
|
@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
|
|
92
92
|
|
93
93
|
|
94
94
|
def ensure_workspace_initialized(
|
95
|
-
max_token_num: int =
|
95
|
+
max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False
|
96
96
|
):
|
97
97
|
"""Ensure workspace is initialized"""
|
98
98
|
if not is_flashinfer_available() or _flashinfer_comm is None:
|
@@ -124,8 +124,8 @@ def flashinfer_allreduce_residual_rmsnorm(
|
|
124
124
|
residual: torch.Tensor,
|
125
125
|
weight: torch.Tensor,
|
126
126
|
eps: float = 1e-6,
|
127
|
-
max_token_num: int =
|
128
|
-
use_oneshot: bool =
|
127
|
+
max_token_num: int = 2048,
|
128
|
+
use_oneshot: Optional[bool] = None,
|
129
129
|
trigger_completion_at_end: bool = False,
|
130
130
|
fp32_acc: bool = False,
|
131
131
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
sglang/srt/layers/linear.py
CHANGED
@@ -1191,11 +1191,6 @@ class RowParallelLinear(LinearBase):
|
|
1191
1191
|
else self.weight_loader
|
1192
1192
|
),
|
1193
1193
|
)
|
1194
|
-
if not reduce_results and (bias and not skip_bias_add):
|
1195
|
-
raise ValueError(
|
1196
|
-
"When not reduce the results, adding bias to the "
|
1197
|
-
"results can lead to incorrect results"
|
1198
|
-
)
|
1199
1194
|
|
1200
1195
|
if bias:
|
1201
1196
|
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
|
@@ -1282,7 +1277,7 @@ class RowParallelLinear(LinearBase):
|
|
1282
1277
|
# It does not support additional parameters.
|
1283
1278
|
param.load_row_parallel_weight(loaded_weight)
|
1284
1279
|
|
1285
|
-
def forward(self, input_,
|
1280
|
+
def forward(self, input_, skip_all_reduce=False):
|
1286
1281
|
if self.input_is_parallel:
|
1287
1282
|
input_parallel = input_
|
1288
1283
|
else:
|
@@ -1299,7 +1294,8 @@ class RowParallelLinear(LinearBase):
|
|
1299
1294
|
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
1300
1295
|
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
1301
1296
|
sm.tag(output_parallel)
|
1302
|
-
|
1297
|
+
|
1298
|
+
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
|
1303
1299
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
1304
1300
|
else:
|
1305
1301
|
output = output_parallel
|
@@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
|
12
|
+
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
|
12
13
|
from sglang.srt.utils import is_cuda
|
13
14
|
|
14
15
|
_is_cuda = is_cuda()
|
@@ -123,6 +124,8 @@ def cutlass_fused_experts_fp8(
|
|
123
124
|
|
124
125
|
if is_cuda:
|
125
126
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
127
|
+
per_group_transpose,
|
128
|
+
per_token_group_quant_fp8_hopper_moe_mn_major,
|
126
129
|
sglang_per_token_group_quant_fp8,
|
127
130
|
)
|
128
131
|
|
@@ -133,9 +136,7 @@ def cutlass_fused_experts_fp8(
|
|
133
136
|
n = w2_q.size(1)
|
134
137
|
|
135
138
|
topk = topk_ids.size(1)
|
136
|
-
|
137
|
-
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
|
138
|
-
device = a_q.device
|
139
|
+
device = a.device
|
139
140
|
|
140
141
|
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
141
142
|
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
@@ -152,9 +153,14 @@ def cutlass_fused_experts_fp8(
|
|
152
153
|
k,
|
153
154
|
)
|
154
155
|
|
156
|
+
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
|
155
157
|
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
|
156
158
|
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
|
157
159
|
|
160
|
+
if not is_sm100_supported():
|
161
|
+
rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets)
|
162
|
+
w1_scale = w1_scale.contiguous()
|
163
|
+
|
158
164
|
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
159
165
|
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
160
166
|
|
@@ -186,6 +192,9 @@ def cutlass_fused_experts_fp8(
|
|
186
192
|
silu_and_mul(c1, intermediate)
|
187
193
|
|
188
194
|
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
|
195
|
+
if not is_sm100_supported():
|
196
|
+
a2_scale = per_group_transpose(a2_scale, expert_offsets)
|
197
|
+
w2_scale = w2_scale.contiguous()
|
189
198
|
|
190
199
|
fp8_blockwise_scaled_grouped_mm(
|
191
200
|
c2,
|
@@ -11,7 +11,7 @@ from sgl_kernel import (
|
|
11
11
|
)
|
12
12
|
|
13
13
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
14
|
-
|
14
|
+
post_reorder_triton_kernel_for_cutlass_moe,
|
15
15
|
pre_reorder_triton_kernel_for_cutlass_moe,
|
16
16
|
run_cutlass_moe_ep_preproess,
|
17
17
|
)
|
@@ -199,14 +199,13 @@ def cutlass_w4a8_moe(
|
|
199
199
|
)
|
200
200
|
|
201
201
|
output = torch.empty_like(a)
|
202
|
-
|
202
|
+
post_reorder_triton_kernel_for_cutlass_moe[(m,)](
|
203
203
|
c2,
|
204
204
|
output,
|
205
205
|
src2dst,
|
206
|
-
|
206
|
+
local_topk_ids,
|
207
207
|
topk_weights,
|
208
|
-
|
209
|
-
end_expert_id,
|
208
|
+
num_experts,
|
210
209
|
topk,
|
211
210
|
k,
|
212
211
|
0,
|