sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +220 -378
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +208 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -74,6 +74,8 @@ def _fwd_kernel(
|
|
74
74
|
BLOCK_M: tl.constexpr,
|
75
75
|
BLOCK_N: tl.constexpr,
|
76
76
|
USE_CUSTOM_MASK: tl.constexpr,
|
77
|
+
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
|
78
|
+
STORE_TRANSPOSE: tl.constexpr,
|
77
79
|
):
|
78
80
|
cur_seq = tl.program_id(0)
|
79
81
|
cur_head = tl.program_id(1)
|
@@ -159,7 +161,7 @@ def _fwd_kernel(
|
|
159
161
|
if logit_cap > 0:
|
160
162
|
qk = logit_cap * tanh(qk / logit_cap)
|
161
163
|
|
162
|
-
if USE_CUSTOM_MASK:
|
164
|
+
if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
|
163
165
|
custom_mask = tl.load(
|
164
166
|
mask_ptr
|
165
167
|
+ cur_seq_mask_start_idx
|
@@ -272,9 +274,18 @@ def _fwd_kernel(
|
|
272
274
|
+ cur_head * stride_oh
|
273
275
|
+ offs_dv[None, :]
|
274
276
|
)
|
275
|
-
|
276
|
-
|
277
|
-
|
277
|
+
if STORE_TRANSPOSE:
|
278
|
+
tl.store(
|
279
|
+
O_Extend + offs_o.T,
|
280
|
+
(acc / deno[:, None]).T,
|
281
|
+
mask=(mask_m[:, None] & mask_dv[None, :]).T,
|
282
|
+
)
|
283
|
+
else:
|
284
|
+
tl.store(
|
285
|
+
O_Extend + offs_o,
|
286
|
+
acc / deno[:, None],
|
287
|
+
mask=mask_m[:, None] & mask_dv[None, :],
|
288
|
+
)
|
278
289
|
|
279
290
|
|
280
291
|
def extend_attention_fwd(
|
@@ -292,6 +303,7 @@ def extend_attention_fwd(
|
|
292
303
|
max_len_extend,
|
293
304
|
sm_scale=None,
|
294
305
|
logit_cap=0.0,
|
306
|
+
skip_prefix_custom_mask=True,
|
295
307
|
):
|
296
308
|
"""
|
297
309
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
@@ -345,6 +357,8 @@ def extend_attention_fwd(
|
|
345
357
|
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
346
358
|
|
347
359
|
USE_CUSTOM_MASK = custom_mask is not None
|
360
|
+
# Skip custom mask for prefix part
|
361
|
+
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
|
348
362
|
|
349
363
|
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
350
364
|
num_stages = 1
|
@@ -388,6 +402,8 @@ def extend_attention_fwd(
|
|
388
402
|
Lq=Lq,
|
389
403
|
Lv=Lv,
|
390
404
|
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
405
|
+
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
406
|
+
STORE_TRANSPOSE=is_hip_,
|
391
407
|
num_warps=num_warps,
|
392
408
|
num_stages=num_stages,
|
393
409
|
**extra_kargs,
|
@@ -0,0 +1,439 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""
|
15
|
+
Memory-efficient attention for decoding.
|
16
|
+
It supports page size = 1.
|
17
|
+
"""
|
18
|
+
|
19
|
+
# Adapted from
|
20
|
+
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
|
21
|
+
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
|
22
|
+
|
23
|
+
import triton
|
24
|
+
import triton.language as tl
|
25
|
+
|
26
|
+
from sglang.srt.layers.attention.triton_ops.decode_attention import (
|
27
|
+
_decode_softmax_reducev_fwd,
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
def is_hip():
|
32
|
+
return triton.runtime.driver.active.get_current_target().backend == "hip"
|
33
|
+
|
34
|
+
|
35
|
+
is_hip_ = is_hip()
|
36
|
+
|
37
|
+
|
38
|
+
@triton.jit
|
39
|
+
def tanh(x):
|
40
|
+
# Tanh is just a scaled sigmoid
|
41
|
+
return 2 * tl.sigmoid(2 * x) - 1
|
42
|
+
|
43
|
+
|
44
|
+
@triton.jit
|
45
|
+
def _fwd_grouped_kernel_stage1_rope(
|
46
|
+
Q, # Holds [Q_NOPE; Q_PE], b x h x (d+r)
|
47
|
+
K_Buffer, # Holds [KV; K_PE], b*s x (c+r)
|
48
|
+
V_buffer, # Holds [KV], b*s x (c)
|
49
|
+
cos_sin_cache, # max_seq_len x (rotary_dim * 2)
|
50
|
+
positions, # sequence positions
|
51
|
+
sm_scale,
|
52
|
+
kv_indptr,
|
53
|
+
kv_indices,
|
54
|
+
Att_Out, # b x h x NUM_KV_SPLITS x (kv_lora_rank + 1)
|
55
|
+
k_pe_t_out,
|
56
|
+
stride_qb,
|
57
|
+
stride_qh,
|
58
|
+
stride_buf_kbs,
|
59
|
+
stride_buf_vbs,
|
60
|
+
stride_mid_ob,
|
61
|
+
stride_mid_oh,
|
62
|
+
stride_mid_os,
|
63
|
+
stride_kpe_tokens_out_b,
|
64
|
+
stride_cos_sin_cache_s,
|
65
|
+
stride_positions_b,
|
66
|
+
rotary_dim: tl.constexpr,
|
67
|
+
kv_lora_rank: tl.constexpr,
|
68
|
+
qk_rope_head_dim: tl.constexpr,
|
69
|
+
kv_group_num: tl.constexpr,
|
70
|
+
q_head_num: tl.constexpr,
|
71
|
+
BLOCK_C: tl.constexpr,
|
72
|
+
BLOCK_R: tl.constexpr,
|
73
|
+
BLOCK_N: tl.constexpr,
|
74
|
+
BLOCK_H: tl.constexpr,
|
75
|
+
NUM_KV_SPLITS: tl.constexpr,
|
76
|
+
logit_cap: tl.constexpr,
|
77
|
+
USE_ROPE: tl.constexpr,
|
78
|
+
IS_NEOX_STYLE: tl.constexpr,
|
79
|
+
):
|
80
|
+
|
81
|
+
cur_batch = tl.program_id(0)
|
82
|
+
cur_head_id = tl.program_id(1)
|
83
|
+
split_kv_id = tl.program_id(2)
|
84
|
+
|
85
|
+
if BLOCK_H < kv_group_num:
|
86
|
+
VALID_BLOCK_H: tl.constexpr = BLOCK_H
|
87
|
+
else:
|
88
|
+
VALID_BLOCK_H: tl.constexpr = kv_group_num
|
89
|
+
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
|
90
|
+
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
|
91
|
+
mask_h = mask_h & (cur_head < q_head_num)
|
92
|
+
|
93
|
+
offs_c = tl.arange(0, BLOCK_C)
|
94
|
+
offs_qk_r = tl.arange(kv_lora_rank, kv_lora_rank + BLOCK_R) # to get the k_pe
|
95
|
+
|
96
|
+
off_q_pe = (
|
97
|
+
cur_batch * stride_qb + cur_head[:, None] * stride_qh + offs_qk_r[None, :]
|
98
|
+
)
|
99
|
+
offs_q = cur_batch * stride_qb + cur_head[:, None] * stride_qh + offs_c[None, :]
|
100
|
+
|
101
|
+
mask_c = offs_c < kv_lora_rank
|
102
|
+
mask_qk_r = offs_qk_r < (kv_lora_rank + qk_rope_head_dim)
|
103
|
+
|
104
|
+
cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
|
105
|
+
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
|
106
|
+
|
107
|
+
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_c[None, :]), other=0.0)
|
108
|
+
q_pe = tl.load(
|
109
|
+
Q + off_q_pe, mask=(mask_h[:, None]) & (mask_qk_r[None, :]), other=0.0
|
110
|
+
)
|
111
|
+
|
112
|
+
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
113
|
+
split_kv_start = kv_len_per_split * split_kv_id
|
114
|
+
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
|
115
|
+
|
116
|
+
# apply rotary embedding for q_pe, and k_pe (last token per batch of K_PE)
|
117
|
+
LAST_SPLIT = split_kv_end == cur_batch_seq_len
|
118
|
+
k_pe_last_token = tl.zeros([BLOCK_R], dtype=q.dtype)
|
119
|
+
|
120
|
+
if USE_ROPE:
|
121
|
+
if IS_NEOX_STYLE:
|
122
|
+
# [BLOCK_ROTARY // 2, BLOCK_ROTARY // 2 + 1, BLOCK_ROTARY // 2 + 2, ..., 0, 1, 2, ..., BLOCK_ROTARY // 2 - 1, pass:]
|
123
|
+
offs_qk_rot_r = kv_lora_rank + (
|
124
|
+
(tl.arange(0, BLOCK_R) + (rotary_dim // 2)) % rotary_dim
|
125
|
+
)
|
126
|
+
# Which elements to flip
|
127
|
+
mask_rotate = tl.arange(0, BLOCK_R) < (rotary_dim // 2)
|
128
|
+
# [0 , 1, 2, ..., rotary_dim // 2 - 1, 0 , 1, 2, ..., rotary_dim // 2 - 1]
|
129
|
+
offs_rotary = tl.arange(0, BLOCK_R) % (rotary_dim // 2)
|
130
|
+
else:
|
131
|
+
# [1, 0, 3, 2, 5, 4, ..., BLOCK_R, BLOCK_R - 1]
|
132
|
+
offs_qk_rot_r = (
|
133
|
+
kv_lora_rank
|
134
|
+
+ (((tl.arange(0, BLOCK_R) + 1) % 2) * 2)
|
135
|
+
- 1
|
136
|
+
+ tl.arange(0, BLOCK_R)
|
137
|
+
)
|
138
|
+
mask_rotate = tl.arange(0, BLOCK_R) % 2 < 1
|
139
|
+
# [0, 0, 1, 1, ..., rotary_dim // 2 - 1, rotary_dim // 2 - 1]
|
140
|
+
offs_rotary = tl.arange(0, BLOCK_R) // 2
|
141
|
+
|
142
|
+
if qk_rope_head_dim > rotary_dim:
|
143
|
+
offs_qk_rot_r = tl.where(
|
144
|
+
tl.arange(0, BLOCK_R) < rotary_dim, offs_qk_rot_r, tl.arange(0, BLOCK_R)
|
145
|
+
)
|
146
|
+
offs_rotary = tl.where(
|
147
|
+
tl.arange(0, BLOCK_R) < rotary_dim, offs_rotary, tl.arange(0, BLOCK_R)
|
148
|
+
)
|
149
|
+
|
150
|
+
mask_rotary = tl.arange(0, BLOCK_R) < rotary_dim
|
151
|
+
|
152
|
+
pos = tl.load(positions + cur_batch * stride_positions_b)
|
153
|
+
cos = tl.load(
|
154
|
+
cos_sin_cache + pos * stride_cos_sin_cache_s + offs_rotary,
|
155
|
+
mask=mask_rotary,
|
156
|
+
other=1.0,
|
157
|
+
)
|
158
|
+
sin = tl.load(
|
159
|
+
cos_sin_cache
|
160
|
+
+ pos * stride_cos_sin_cache_s
|
161
|
+
+ offs_rotary
|
162
|
+
+ rotary_dim // 2,
|
163
|
+
mask_rotary,
|
164
|
+
other=0.0,
|
165
|
+
)
|
166
|
+
|
167
|
+
off_q_pe_rot = (
|
168
|
+
cur_batch * stride_qb
|
169
|
+
+ cur_head[:, None] * stride_qh
|
170
|
+
+ offs_qk_rot_r[None, :]
|
171
|
+
)
|
172
|
+
mask_qk_rot_r = offs_qk_rot_r < (kv_lora_rank + qk_rope_head_dim)
|
173
|
+
|
174
|
+
# 0, 2, 4,.... 1, 3, 5...
|
175
|
+
q_pe_rot = tl.load(
|
176
|
+
Q + off_q_pe_rot,
|
177
|
+
mask=(mask_h[:, None]) & (mask_qk_rot_r[None, :]),
|
178
|
+
other=0.0,
|
179
|
+
)
|
180
|
+
q_pe_rot = tl.where(mask_rotate[None, :], -q_pe_rot, q_pe_rot)
|
181
|
+
|
182
|
+
q_pe = q_pe * cos + q_pe_rot * sin
|
183
|
+
|
184
|
+
# we only apply to the last token in the K_PE
|
185
|
+
if LAST_SPLIT:
|
186
|
+
# debug assert
|
187
|
+
if (cur_batch == 0 and cur_head == 0) and split_kv_id < NUM_KV_SPLITS - 1:
|
188
|
+
tl.device_assert(False, "Only last split should compute k_pe")
|
189
|
+
|
190
|
+
kv_loc = tl.load(
|
191
|
+
kv_indices + cur_batch_kv_start_idx + cur_batch_seq_len - 1
|
192
|
+
)
|
193
|
+
offs_buf_k_pe_last_token = kv_loc * stride_buf_kbs + offs_qk_r
|
194
|
+
offs_buf_k_pe_rot_last_token = kv_loc * stride_buf_kbs + offs_qk_rot_r
|
195
|
+
k_pe_last_token = tl.load(K_Buffer + offs_buf_k_pe_last_token)
|
196
|
+
|
197
|
+
k_pe_rot_last_token = tl.load(K_Buffer + offs_buf_k_pe_rot_last_token)
|
198
|
+
k_pe_rot_last_token = tl.where(
|
199
|
+
mask_rotate, -k_pe_rot_last_token, k_pe_rot_last_token
|
200
|
+
)
|
201
|
+
|
202
|
+
k_pe_last_token = k_pe_last_token * cos + k_pe_rot_last_token * sin
|
203
|
+
|
204
|
+
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
|
205
|
+
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
|
206
|
+
acc = tl.zeros([BLOCK_H, BLOCK_C], dtype=tl.float32)
|
207
|
+
|
208
|
+
if split_kv_end > split_kv_start:
|
209
|
+
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
210
|
+
offs_n = start_n + tl.arange(0, BLOCK_N)
|
211
|
+
kv_loc = tl.load(
|
212
|
+
kv_indices + cur_batch_kv_start_idx + offs_n,
|
213
|
+
mask=offs_n < split_kv_end,
|
214
|
+
other=0,
|
215
|
+
)
|
216
|
+
|
217
|
+
offs_buf_kv = kv_loc[None, :] * stride_buf_kbs + offs_c[:, None]
|
218
|
+
offs_buf_k_pe = kv_loc[None, :] * stride_buf_kbs + offs_qk_r[:, None]
|
219
|
+
|
220
|
+
k_pe = tl.load(
|
221
|
+
K_Buffer + offs_buf_k_pe,
|
222
|
+
mask=(offs_n[None, :] < split_kv_end) & (mask_qk_r[:, None]),
|
223
|
+
other=0.0,
|
224
|
+
) # positional embedding part of keys
|
225
|
+
|
226
|
+
if (USE_ROPE and LAST_SPLIT) and start_n >= cur_batch_seq_len - BLOCK_N:
|
227
|
+
k_pe = tl.where(
|
228
|
+
offs_n[None, :] != (split_kv_end - 1),
|
229
|
+
k_pe,
|
230
|
+
k_pe_last_token[:, None],
|
231
|
+
)
|
232
|
+
|
233
|
+
# (16, 64) x (64, 32)
|
234
|
+
# dot product of rope parts
|
235
|
+
qk = tl.dot(q_pe, k_pe.to(q_pe.dtype))
|
236
|
+
|
237
|
+
kv = tl.load(
|
238
|
+
K_Buffer + offs_buf_kv,
|
239
|
+
mask=(offs_n[None, :] < split_kv_end) & (mask_c[:, None]),
|
240
|
+
other=0.0,
|
241
|
+
) # the shared latent tensor for keys and values
|
242
|
+
|
243
|
+
# (16, 512) x (512, 32)
|
244
|
+
# dot product of nope parts
|
245
|
+
qk += tl.dot(q, kv)
|
246
|
+
|
247
|
+
qk *= sm_scale
|
248
|
+
|
249
|
+
if logit_cap > 0:
|
250
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
251
|
+
|
252
|
+
qk = tl.where(
|
253
|
+
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
|
254
|
+
)
|
255
|
+
|
256
|
+
offs_buf_v = kv_loc[:, None] * stride_buf_vbs + offs_c[None, :]
|
257
|
+
v = tl.load(
|
258
|
+
V_buffer + offs_buf_v,
|
259
|
+
mask=(offs_n[:, None] < split_kv_end) & (mask_c[None, :]),
|
260
|
+
other=0.0,
|
261
|
+
)
|
262
|
+
|
263
|
+
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
264
|
+
re_scale = tl.exp(e_max - n_e_max)
|
265
|
+
p = tl.exp(qk - n_e_max[:, None])
|
266
|
+
acc *= re_scale[:, None]
|
267
|
+
# (16, 32) x (32, 512)
|
268
|
+
acc += tl.dot(p.to(v.dtype), v)
|
269
|
+
|
270
|
+
e_sum = e_sum * re_scale + tl.sum(p, 1)
|
271
|
+
e_max = n_e_max
|
272
|
+
|
273
|
+
offs_mid_o = (
|
274
|
+
cur_batch * stride_mid_ob
|
275
|
+
+ cur_head[:, None] * stride_mid_oh
|
276
|
+
+ split_kv_id * stride_mid_os
|
277
|
+
+ offs_c[None, :]
|
278
|
+
)
|
279
|
+
|
280
|
+
if USE_ROPE:
|
281
|
+
if LAST_SPLIT:
|
282
|
+
k_pe_last_token_ptrs = (
|
283
|
+
k_pe_t_out
|
284
|
+
+ cur_batch * stride_kpe_tokens_out_b
|
285
|
+
+ tl.arange(0, BLOCK_R)
|
286
|
+
)
|
287
|
+
tl.store(k_pe_last_token_ptrs, k_pe_last_token, mask=mask_qk_r)
|
288
|
+
|
289
|
+
tl.store(
|
290
|
+
Att_Out + offs_mid_o,
|
291
|
+
acc / e_sum[:, None],
|
292
|
+
mask=(mask_h[:, None]) & (mask_c[None, :]),
|
293
|
+
)
|
294
|
+
|
295
|
+
offs_mid_o_1 = (
|
296
|
+
cur_batch * stride_mid_ob
|
297
|
+
+ cur_head * stride_mid_oh
|
298
|
+
+ split_kv_id * stride_mid_os
|
299
|
+
+ kv_lora_rank
|
300
|
+
)
|
301
|
+
|
302
|
+
tl.store(
|
303
|
+
Att_Out + offs_mid_o_1,
|
304
|
+
e_max + tl.log(e_sum),
|
305
|
+
mask=mask_h,
|
306
|
+
)
|
307
|
+
|
308
|
+
|
309
|
+
# TODO rope offset
|
310
|
+
def _decode_grouped_att_m_fwd_rope(
|
311
|
+
q,
|
312
|
+
k_buffer,
|
313
|
+
v_buffer,
|
314
|
+
att_out,
|
315
|
+
k_pe_tokens_out,
|
316
|
+
kv_lora_rank, # c
|
317
|
+
cos_sin_cache,
|
318
|
+
positions,
|
319
|
+
rotary_dim,
|
320
|
+
kv_indptr,
|
321
|
+
kv_indices,
|
322
|
+
num_kv_splits,
|
323
|
+
sm_scale,
|
324
|
+
logit_cap,
|
325
|
+
use_rope,
|
326
|
+
is_neox_style=True,
|
327
|
+
):
|
328
|
+
if use_rope:
|
329
|
+
assert (
|
330
|
+
k_pe_tokens_out is not None
|
331
|
+
), "We must output the k_pe tokens with rope applied if rope fusion enabled."
|
332
|
+
|
333
|
+
BLOCK = 32
|
334
|
+
|
335
|
+
# # [TODO] work around shmem limit on MI3xx
|
336
|
+
# if is_hip_ and kv_lora_rank >= 576:
|
337
|
+
# BLOCK = 16
|
338
|
+
|
339
|
+
qk_rope_head_dim = k_buffer.shape[-1] - kv_lora_rank
|
340
|
+
batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
|
341
|
+
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
342
|
+
|
343
|
+
BLOCK_C = triton.next_power_of_2(kv_lora_rank)
|
344
|
+
BLOCK_R = triton.next_power_of_2(qk_rope_head_dim)
|
345
|
+
|
346
|
+
BLOCK_H = 16
|
347
|
+
NUM_KV_SPLITS = num_kv_splits
|
348
|
+
grid = (
|
349
|
+
batch,
|
350
|
+
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
351
|
+
NUM_KV_SPLITS,
|
352
|
+
)
|
353
|
+
|
354
|
+
extra_kargs = {}
|
355
|
+
num_stages = 2
|
356
|
+
if is_hip_:
|
357
|
+
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
358
|
+
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
359
|
+
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
360
|
+
num_stages = 1
|
361
|
+
|
362
|
+
_fwd_grouped_kernel_stage1_rope[grid](
|
363
|
+
q,
|
364
|
+
k_buffer,
|
365
|
+
v_buffer,
|
366
|
+
cos_sin_cache,
|
367
|
+
positions,
|
368
|
+
sm_scale,
|
369
|
+
kv_indptr,
|
370
|
+
kv_indices,
|
371
|
+
att_out,
|
372
|
+
k_pe_tokens_out,
|
373
|
+
q.stride(0),
|
374
|
+
q.stride(1),
|
375
|
+
k_buffer.stride(0),
|
376
|
+
v_buffer.stride(0),
|
377
|
+
att_out.stride(0),
|
378
|
+
att_out.stride(1),
|
379
|
+
att_out.stride(2),
|
380
|
+
k_pe_tokens_out.stride(0) if use_rope else 0,
|
381
|
+
cos_sin_cache.stride(0) if use_rope else 0,
|
382
|
+
positions.stride(0) if use_rope else 0,
|
383
|
+
rotary_dim,
|
384
|
+
kv_lora_rank,
|
385
|
+
qk_rope_head_dim,
|
386
|
+
kv_group_num=kv_group_num,
|
387
|
+
q_head_num=head_num,
|
388
|
+
BLOCK_C=BLOCK_C,
|
389
|
+
BLOCK_R=BLOCK_R,
|
390
|
+
BLOCK_N=BLOCK,
|
391
|
+
BLOCK_H=BLOCK_H,
|
392
|
+
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
393
|
+
logit_cap=logit_cap,
|
394
|
+
USE_ROPE=use_rope,
|
395
|
+
IS_NEOX_STYLE=is_neox_style,
|
396
|
+
num_warps=4,
|
397
|
+
num_stages=num_stages,
|
398
|
+
**extra_kargs
|
399
|
+
)
|
400
|
+
|
401
|
+
|
402
|
+
def decode_attention_fwd_grouped_rope(
|
403
|
+
q,
|
404
|
+
k_buffer,
|
405
|
+
v_buffer,
|
406
|
+
o,
|
407
|
+
kv_indptr,
|
408
|
+
kv_indices,
|
409
|
+
k_pe_tokens,
|
410
|
+
kv_lora_rank,
|
411
|
+
rotary_dim,
|
412
|
+
cos_sin_cache,
|
413
|
+
positions,
|
414
|
+
attn_logits,
|
415
|
+
num_kv_splits,
|
416
|
+
sm_scale,
|
417
|
+
logit_cap=0.0,
|
418
|
+
use_rope=False,
|
419
|
+
is_neox_style=False,
|
420
|
+
):
|
421
|
+
_decode_grouped_att_m_fwd_rope(
|
422
|
+
q,
|
423
|
+
k_buffer,
|
424
|
+
v_buffer,
|
425
|
+
attn_logits,
|
426
|
+
k_pe_tokens,
|
427
|
+
kv_lora_rank,
|
428
|
+
cos_sin_cache,
|
429
|
+
positions,
|
430
|
+
rotary_dim,
|
431
|
+
kv_indptr,
|
432
|
+
kv_indices,
|
433
|
+
num_kv_splits,
|
434
|
+
sm_scale,
|
435
|
+
logit_cap,
|
436
|
+
use_rope,
|
437
|
+
is_neox_style,
|
438
|
+
)
|
439
|
+
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
|
@@ -0,0 +1,39 @@
|
|
1
|
+
import triton
|
2
|
+
import triton.language as tl
|
3
|
+
|
4
|
+
|
5
|
+
@triton.jit
|
6
|
+
def create_flashinfer_kv_indices_triton(
|
7
|
+
req_to_token_ptr, # [max_batch, max_context_len]
|
8
|
+
req_pool_indices_ptr,
|
9
|
+
page_kernel_lens_ptr,
|
10
|
+
kv_indptr,
|
11
|
+
kv_start_idx,
|
12
|
+
kv_indices_ptr,
|
13
|
+
req_to_token_ptr_stride: tl.constexpr,
|
14
|
+
):
|
15
|
+
BLOCK_SIZE: tl.constexpr = 512
|
16
|
+
pid = tl.program_id(axis=0)
|
17
|
+
|
18
|
+
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
19
|
+
kv_indices_offset = tl.load(kv_indptr + pid)
|
20
|
+
|
21
|
+
kv_start = 0
|
22
|
+
kv_end = 0
|
23
|
+
if kv_start_idx:
|
24
|
+
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
25
|
+
kv_end = kv_start
|
26
|
+
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
27
|
+
|
28
|
+
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
29
|
+
for i in range(num_loop):
|
30
|
+
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
31
|
+
mask = offset < kv_end - kv_start
|
32
|
+
data = tl.load(
|
33
|
+
req_to_token_ptr
|
34
|
+
+ req_pool_index * req_to_token_ptr_stride
|
35
|
+
+ kv_start
|
36
|
+
+ offset,
|
37
|
+
mask=mask,
|
38
|
+
)
|
39
|
+
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|