sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1052 @@
|
|
1
|
+
# Copyright (c) 2024, Tri Dao.
|
2
|
+
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
|
3
|
+
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
4
|
+
|
5
|
+
from typing import Optional, Union
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import torch
|
9
|
+
|
10
|
+
PAD_SLOT_ID = -1
|
11
|
+
import triton
|
12
|
+
import triton.language as tl
|
13
|
+
|
14
|
+
|
15
|
+
@triton.jit()
|
16
|
+
def _causal_conv1d_fwd_kernel( # continuous batching
|
17
|
+
# Pointers to matrices
|
18
|
+
x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences
|
19
|
+
w_ptr, # (dim, width)
|
20
|
+
bias_ptr,
|
21
|
+
initial_states_ptr, # conv_states_ptr
|
22
|
+
cache_indices_ptr, # conv_state_indices_ptr
|
23
|
+
has_initial_states_ptr,
|
24
|
+
query_start_loc_ptr,
|
25
|
+
batch_ptr,
|
26
|
+
token_chunk_offset_ptr,
|
27
|
+
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
|
28
|
+
# Matrix dimensions
|
29
|
+
batch: tl.int32, # actually padded_batch
|
30
|
+
dim: tl.constexpr,
|
31
|
+
seqlen: tl.int32, # cu_seqlen
|
32
|
+
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
|
33
|
+
# Strides
|
34
|
+
stride_x_seq: tl.constexpr, # stride to get to next sequence,
|
35
|
+
stride_x_dim: tl.constexpr, # stride to get to next feature-value,
|
36
|
+
stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index)
|
37
|
+
stride_w_dim: tl.constexpr, # stride to get to next dim-axis value
|
38
|
+
stride_w_width: tl.constexpr, # stride to get to next width-axis value
|
39
|
+
stride_istate_seq: tl.constexpr,
|
40
|
+
stride_istate_dim: tl.constexpr,
|
41
|
+
stride_istate_token: tl.constexpr,
|
42
|
+
stride_o_seq: tl.constexpr,
|
43
|
+
stride_o_dim: tl.constexpr,
|
44
|
+
stride_o_token: tl.constexpr,
|
45
|
+
# others
|
46
|
+
pad_slot_id: tl.constexpr,
|
47
|
+
# Meta-parameters
|
48
|
+
HAS_BIAS: tl.constexpr,
|
49
|
+
KERNEL_WIDTH: tl.constexpr,
|
50
|
+
SILU_ACTIVATION: tl.constexpr,
|
51
|
+
HAS_INITIAL_STATES: tl.constexpr,
|
52
|
+
HAS_CACHE: tl.constexpr,
|
53
|
+
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
54
|
+
USE_PAD_SLOT: tl.constexpr,
|
55
|
+
NP2_STATELEN: tl.constexpr,
|
56
|
+
BLOCK_M: tl.constexpr,
|
57
|
+
BLOCK_N: tl.constexpr,
|
58
|
+
):
|
59
|
+
conv_states_ptr = initial_states_ptr
|
60
|
+
conv_state_indices_ptr = cache_indices_ptr
|
61
|
+
stride_conv_state_seq = stride_istate_seq
|
62
|
+
stride_conv_state_dim = stride_istate_dim
|
63
|
+
stride_conv_state_tok = stride_istate_token
|
64
|
+
state_len = (
|
65
|
+
KERNEL_WIDTH - 1
|
66
|
+
) # can be passed via argument if it's not the same as this value
|
67
|
+
|
68
|
+
# one program handles one chunk in a single sequence
|
69
|
+
# rather than mixing sequences - to make updating initial_states across sequences efficiently
|
70
|
+
|
71
|
+
# single-sequence id
|
72
|
+
idx_seq = tl.load(batch_ptr + tl.program_id(0))
|
73
|
+
chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
|
74
|
+
|
75
|
+
# BLOCK_N elements along the feature-dimension (channel)
|
76
|
+
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
|
77
|
+
|
78
|
+
if idx_seq == pad_slot_id:
|
79
|
+
return
|
80
|
+
|
81
|
+
sequence_start_index = tl.load(query_start_loc_ptr + idx_seq)
|
82
|
+
sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1)
|
83
|
+
# find the actual sequence length
|
84
|
+
seqlen = sequence_end_index - sequence_start_index
|
85
|
+
|
86
|
+
token_offset = BLOCK_M * chunk_offset
|
87
|
+
segment_len = min(BLOCK_M, seqlen - token_offset)
|
88
|
+
|
89
|
+
# base of the sequence
|
90
|
+
x_base = (
|
91
|
+
x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim
|
92
|
+
) # [BLOCK_N,]
|
93
|
+
|
94
|
+
if IS_CONTINUOUS_BATCHING:
|
95
|
+
# cache_idx
|
96
|
+
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(tl.int64)
|
97
|
+
else:
|
98
|
+
# cache_idx
|
99
|
+
conv_state_batch_coord = idx_seq
|
100
|
+
if USE_PAD_SLOT: # noqa
|
101
|
+
if conv_state_batch_coord == pad_slot_id:
|
102
|
+
# not processing as this is not the actual sequence
|
103
|
+
return
|
104
|
+
conv_states_base = (
|
105
|
+
conv_states_ptr
|
106
|
+
+ (conv_state_batch_coord * stride_conv_state_seq)
|
107
|
+
+ (idx_feats * stride_conv_state_dim)
|
108
|
+
) # [BLOCK_N,]
|
109
|
+
|
110
|
+
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
|
111
|
+
|
112
|
+
# Does 2 things:
|
113
|
+
# 1. READ prior-block init-state data - [done by every Triton programs]
|
114
|
+
# 2. update conv_state with new data [only by the Triton program handles chunk_offset=0]
|
115
|
+
if chunk_offset == 0:
|
116
|
+
# read from conv_states
|
117
|
+
load_init_state = False
|
118
|
+
if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES
|
119
|
+
load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1)
|
120
|
+
if load_init_state:
|
121
|
+
# load from conv_states
|
122
|
+
prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok
|
123
|
+
mask_w = idx_feats < dim
|
124
|
+
if KERNEL_WIDTH == 2:
|
125
|
+
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
126
|
+
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
127
|
+
if KERNEL_WIDTH == 3:
|
128
|
+
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
129
|
+
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
130
|
+
conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N]
|
131
|
+
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
132
|
+
if KERNEL_WIDTH == 4:
|
133
|
+
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
134
|
+
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
135
|
+
conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N]
|
136
|
+
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
137
|
+
conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N]
|
138
|
+
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
139
|
+
if KERNEL_WIDTH == 5:
|
140
|
+
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
141
|
+
col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
142
|
+
conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N]
|
143
|
+
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
144
|
+
conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N]
|
145
|
+
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
146
|
+
conv_states_ptrs = prior_tokens - 3 * stride_conv_state_tok # [BLOCK_N]
|
147
|
+
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
148
|
+
else:
|
149
|
+
# prior-tokens are zeros
|
150
|
+
if KERNEL_WIDTH >= 2: # STRATEGY1
|
151
|
+
# first chunk and does not have prior-token, so just set to 0
|
152
|
+
col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
|
153
|
+
if KERNEL_WIDTH >= 3: # STRATEGY1
|
154
|
+
col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
|
155
|
+
if KERNEL_WIDTH >= 4: # STRATEGY1
|
156
|
+
col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
|
157
|
+
if KERNEL_WIDTH >= 5: # STRATEGY1
|
158
|
+
col3 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
|
159
|
+
|
160
|
+
# STEP 2:
|
161
|
+
# here prepare data for updating conv_state
|
162
|
+
if (
|
163
|
+
state_len <= seqlen
|
164
|
+
): # SMALL_CACHE=True (only move part of 'x' into conv_state cache)
|
165
|
+
# just read from 'x'
|
166
|
+
# copy 'x' data to conv_state
|
167
|
+
# load only 'x' data (and set 0 before 'x' if seqlen < state_len)
|
168
|
+
idx_tokens_last = (seqlen - state_len) + tl.arange(
|
169
|
+
0, NP2_STATELEN
|
170
|
+
) # [BLOCK_M]
|
171
|
+
x_ptrs = (
|
172
|
+
x_ptr
|
173
|
+
+ ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None]
|
174
|
+
+ (idx_feats * stride_x_dim)[None, :]
|
175
|
+
) # [BLOCK_M,BLOCK_N,]
|
176
|
+
mask_x = (
|
177
|
+
(idx_tokens_last >= 0)[:, None]
|
178
|
+
& (idx_tokens_last < seqlen)[:, None]
|
179
|
+
& (idx_feats < dim)[None, :]
|
180
|
+
) # token-index # token-index # feature-index
|
181
|
+
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
|
182
|
+
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
|
183
|
+
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
184
|
+
conv_states_ptrs_target = (
|
185
|
+
conv_states_base[None, :]
|
186
|
+
+ (idx_tokens_conv * stride_conv_state_tok)[:, None]
|
187
|
+
)
|
188
|
+
|
189
|
+
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :]
|
190
|
+
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
|
191
|
+
tl.store(conv_states_ptrs_target, new_conv_state, mask)
|
192
|
+
|
193
|
+
else:
|
194
|
+
if load_init_state:
|
195
|
+
# update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x'
|
196
|
+
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
197
|
+
|
198
|
+
conv_states_ptrs_source = (
|
199
|
+
conv_states_ptr
|
200
|
+
+ (conv_state_batch_coord * stride_conv_state_seq)
|
201
|
+
+ (idx_feats * stride_conv_state_dim)[None, :]
|
202
|
+
+ ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None]
|
203
|
+
) # [BLOCK_M, BLOCK_N]
|
204
|
+
mask = (
|
205
|
+
(conv_state_batch_coord < num_cache_lines)
|
206
|
+
& ((idx_tokens_conv + seqlen) < state_len)[:, None]
|
207
|
+
& (idx_feats < dim)[None, :]
|
208
|
+
)
|
209
|
+
conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0)
|
210
|
+
|
211
|
+
VAL = state_len - seqlen
|
212
|
+
|
213
|
+
x_ptrs = (
|
214
|
+
x_base[None, :]
|
215
|
+
+ ((idx_tokens_conv - VAL) * stride_x_token)[:, None]
|
216
|
+
) # [BLOCK_M, BLOCK_N]
|
217
|
+
|
218
|
+
mask_x = (
|
219
|
+
(idx_tokens_conv - VAL >= 0)[:, None]
|
220
|
+
& (idx_tokens_conv - VAL < seqlen)[:, None]
|
221
|
+
& (idx_feats < dim)[None, :]
|
222
|
+
) # token-index # token-index # feature-index
|
223
|
+
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
|
224
|
+
|
225
|
+
tl.debug_barrier() # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load
|
226
|
+
new_conv_state = tl.where(
|
227
|
+
mask, conv_state, loaded_x
|
228
|
+
) # BUG in 'tl.where' which requires a barrier before this
|
229
|
+
conv_states_ptrs_target = (
|
230
|
+
conv_states_base
|
231
|
+
+ (idx_tokens_conv * stride_conv_state_tok)[:, None]
|
232
|
+
) # [BLOCK_M, BLOCK_N]
|
233
|
+
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[
|
234
|
+
None, :
|
235
|
+
]
|
236
|
+
tl.store(conv_states_ptrs_target, new_conv_state, mask)
|
237
|
+
else: # load_init_state == False
|
238
|
+
# update conv_state by shifting left, BUT
|
239
|
+
# set cols prior to 'x' as zeros + cols from 'x'
|
240
|
+
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
241
|
+
|
242
|
+
VAL = state_len - seqlen
|
243
|
+
|
244
|
+
x_ptrs = (
|
245
|
+
x_base[None, :]
|
246
|
+
+ ((idx_tokens_conv - VAL) * stride_x_token)[:, None]
|
247
|
+
) # [BLOCK_M, BLOCK_N]
|
248
|
+
|
249
|
+
mask_x = (
|
250
|
+
(idx_tokens_conv - VAL >= 0)[:, None]
|
251
|
+
& (idx_tokens_conv - VAL < seqlen)[:, None]
|
252
|
+
& (idx_feats < dim)[None, :]
|
253
|
+
) # token-index # token-index # feature-index
|
254
|
+
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
|
255
|
+
|
256
|
+
conv_states_ptrs_target = (
|
257
|
+
conv_states_base
|
258
|
+
+ (idx_tokens_conv * stride_conv_state_tok)[:, None]
|
259
|
+
) # [BLOCK_M, BLOCK_N]
|
260
|
+
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[
|
261
|
+
None, :
|
262
|
+
]
|
263
|
+
tl.store(conv_states_ptrs_target, new_conv_state, mask)
|
264
|
+
|
265
|
+
else: # chunk_offset > 0
|
266
|
+
# read prior-token data from `x`
|
267
|
+
load_init_state = True
|
268
|
+
prior_tokens = x_base + (token_offset - 1) * stride_x_token
|
269
|
+
mask_w = idx_feats < dim
|
270
|
+
if KERNEL_WIDTH == 2:
|
271
|
+
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
272
|
+
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
273
|
+
if KERNEL_WIDTH == 3:
|
274
|
+
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
275
|
+
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
276
|
+
conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
|
277
|
+
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
278
|
+
if KERNEL_WIDTH == 4:
|
279
|
+
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
280
|
+
col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
281
|
+
conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
|
282
|
+
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
283
|
+
conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N]
|
284
|
+
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
285
|
+
if KERNEL_WIDTH == 5:
|
286
|
+
# ruff: noqa: F841
|
287
|
+
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
288
|
+
col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
289
|
+
conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
|
290
|
+
col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
291
|
+
conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N]
|
292
|
+
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
293
|
+
conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N]
|
294
|
+
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
295
|
+
|
296
|
+
if HAS_BIAS:
|
297
|
+
bias = bias_ptr + idx_feats
|
298
|
+
mask_bias = idx_feats < dim
|
299
|
+
acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(
|
300
|
+
tl.float32
|
301
|
+
) # [BLOCK_N]
|
302
|
+
else:
|
303
|
+
acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
304
|
+
|
305
|
+
x_base_1d = x_base + token_offset * stride_x_token # starting of chunk
|
306
|
+
|
307
|
+
# PRE-LOAD WEIGHTS
|
308
|
+
mask_w = idx_feats < dim
|
309
|
+
if KERNEL_WIDTH >= 2:
|
310
|
+
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
|
311
|
+
w_col0 = tl.load(w_ptrs, mask_w, other=0.0)
|
312
|
+
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
|
313
|
+
w_col1 = tl.load(w_ptrs, mask_w, other=0.0)
|
314
|
+
if KERNEL_WIDTH >= 3:
|
315
|
+
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
|
316
|
+
w_col2 = tl.load(w_ptrs, mask_w, other=0.0)
|
317
|
+
if KERNEL_WIDTH >= 4:
|
318
|
+
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
|
319
|
+
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
|
320
|
+
mask_x_1d = idx_feats < dim
|
321
|
+
for idx_token in range(segment_len):
|
322
|
+
acc = acc_preload
|
323
|
+
|
324
|
+
matrix_w = w_col0
|
325
|
+
matrix_x = col0
|
326
|
+
for j in tl.static_range(KERNEL_WIDTH):
|
327
|
+
|
328
|
+
if KERNEL_WIDTH == 2:
|
329
|
+
if j == 1: # KERNEL_WIDTH-1:
|
330
|
+
matrix_w = w_col1
|
331
|
+
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
332
|
+
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
333
|
+
elif KERNEL_WIDTH == 3:
|
334
|
+
if j == 1:
|
335
|
+
matrix_w = w_col1
|
336
|
+
matrix_x = col1
|
337
|
+
elif j == 2:
|
338
|
+
matrix_w = w_col2
|
339
|
+
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
340
|
+
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
341
|
+
elif KERNEL_WIDTH == 4:
|
342
|
+
if j == 1:
|
343
|
+
matrix_w = w_col1
|
344
|
+
matrix_x = col1
|
345
|
+
elif j == 2:
|
346
|
+
matrix_w = w_col2
|
347
|
+
matrix_x = col2
|
348
|
+
elif j == 3:
|
349
|
+
matrix_w = w_col3
|
350
|
+
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
351
|
+
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
352
|
+
|
353
|
+
acc += matrix_x * matrix_w # [BLOCK_N]
|
354
|
+
|
355
|
+
if KERNEL_WIDTH == 2:
|
356
|
+
col0 = matrix_x
|
357
|
+
elif KERNEL_WIDTH == 3:
|
358
|
+
col0 = col1
|
359
|
+
col1 = matrix_x
|
360
|
+
elif KERNEL_WIDTH == 4:
|
361
|
+
col0 = col1
|
362
|
+
col1 = col2
|
363
|
+
col2 = matrix_x
|
364
|
+
|
365
|
+
if SILU_ACTIVATION:
|
366
|
+
acc = acc / (1 + tl.exp(-acc))
|
367
|
+
mask_1d = (idx_token < segment_len) & (
|
368
|
+
idx_feats < dim
|
369
|
+
) # token-index # feature-index
|
370
|
+
o_ptrs = (
|
371
|
+
o_ptr
|
372
|
+
+ (sequence_start_index + token_offset + idx_token) * stride_o_token
|
373
|
+
+ (idx_feats * stride_o_dim)
|
374
|
+
)
|
375
|
+
|
376
|
+
tl.store(o_ptrs, acc, mask=mask_1d)
|
377
|
+
|
378
|
+
|
379
|
+
def causal_conv1d_fn(
|
380
|
+
x: torch.Tensor,
|
381
|
+
weight: torch.Tensor,
|
382
|
+
bias: Union[torch.Tensor, None],
|
383
|
+
conv_states: torch.Tensor,
|
384
|
+
query_start_loc: torch.Tensor,
|
385
|
+
cache_indices: Optional[torch.Tensor] = None,
|
386
|
+
has_initial_state: Optional[torch.Tensor] = None,
|
387
|
+
activation: Optional[str] = "silu",
|
388
|
+
pad_slot_id: int = PAD_SLOT_ID,
|
389
|
+
metadata=None,
|
390
|
+
validate_data=False,
|
391
|
+
):
|
392
|
+
"""support varlen + continuous batching when x is 2D tensor
|
393
|
+
|
394
|
+
x: (dim,cu_seq_len)
|
395
|
+
cu_seq_len = total tokens of all seqs in that batch
|
396
|
+
sequences are concatenated from left to right for varlen
|
397
|
+
weight: (dim, width)
|
398
|
+
conv_states: (...,dim,width - 1) itype
|
399
|
+
updated inplace if provided
|
400
|
+
[it use `cache_indices` to get the index to the cache of conv_state for that sequence
|
401
|
+
|
402
|
+
conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True
|
403
|
+
and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x'
|
404
|
+
]
|
405
|
+
query_start_loc: (batch + 1) int32
|
406
|
+
The cumulative sequence lengths of the sequences in
|
407
|
+
the batch, used to index into sequence. prepended by 0.
|
408
|
+
if
|
409
|
+
x = [5, 1, 1, 1] <- continuous batching (batch=4)
|
410
|
+
then
|
411
|
+
query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is
|
412
|
+
the ending index of the last sequence
|
413
|
+
[length(query_start_loc)-1 == batch]
|
414
|
+
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
415
|
+
x.shape=(dim,17)
|
416
|
+
cache_indices: (batch) int32
|
417
|
+
indicates the corresponding state index,
|
418
|
+
like so: conv_state = conv_states[cache_indices[batch_id]]
|
419
|
+
has_initial_state: (batch) bool
|
420
|
+
indicates whether should the kernel take the current state as initial
|
421
|
+
state for the calculations
|
422
|
+
[single boolean for each sequence in the batch: True or False]
|
423
|
+
bias: (dim,)
|
424
|
+
activation: either None or "silu" or "swish" or True
|
425
|
+
pad_slot_id: int
|
426
|
+
if cache_indices is passed, lets the kernel identify padded
|
427
|
+
entries that will not be processed,
|
428
|
+
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
429
|
+
in this case, the kernel will not process entries at
|
430
|
+
indices 0 and 3
|
431
|
+
|
432
|
+
out: same shape as `x`
|
433
|
+
"""
|
434
|
+
if isinstance(activation, bool) and activation:
|
435
|
+
activation = "silu"
|
436
|
+
|
437
|
+
args = None
|
438
|
+
out = torch.empty_like(x)
|
439
|
+
if metadata is not None:
|
440
|
+
cu_seqlen = metadata.cu_seqlen
|
441
|
+
nums_dict = metadata.nums_dict
|
442
|
+
# x = metadata.x
|
443
|
+
args = nums_dict
|
444
|
+
batch_ptr = metadata.batch_ptr
|
445
|
+
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
|
446
|
+
else:
|
447
|
+
seqlens = np.diff(query_start_loc.to("cpu"))
|
448
|
+
args = seqlens
|
449
|
+
MAX_NUM_PROGRAMS = 1024
|
450
|
+
|
451
|
+
batch_ptr = torch.full(
|
452
|
+
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
|
453
|
+
) # tracking which seq-idx the Triton program is handling
|
454
|
+
token_chunk_offset_ptr = torch.full(
|
455
|
+
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
|
456
|
+
) # tracking BLOCK_M-based index in the sequence the Triton program is handling
|
457
|
+
|
458
|
+
is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)
|
459
|
+
dim, cu_seqlen = x.shape
|
460
|
+
_, width = weight.shape
|
461
|
+
state_len = width - 1
|
462
|
+
np2_statelen = triton.next_power_of_2(state_len)
|
463
|
+
|
464
|
+
padded_batch = query_start_loc.size(0) - 1
|
465
|
+
stride_x_seq = 0
|
466
|
+
stride_x_dim = x.stride(0)
|
467
|
+
stride_x_token = x.stride(1)
|
468
|
+
stride_w_dim = weight.stride(0)
|
469
|
+
stride_w_width = weight.stride(1)
|
470
|
+
stride_istate_seq = 0
|
471
|
+
stride_istate_dim = 0
|
472
|
+
stride_istate_token = 0
|
473
|
+
num_cache_lines = 0
|
474
|
+
if conv_states is not None:
|
475
|
+
# extensions to support vLLM:
|
476
|
+
# 1. conv_states is used to replaced initial_states
|
477
|
+
# 2. conv_states serve as a cache with num cache lines can be larger than batch size
|
478
|
+
# 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
|
479
|
+
# 4. computation can be skipped if cache_indices[idx] == pad_slot_id
|
480
|
+
num_cache_lines = conv_states.size(0)
|
481
|
+
assert (
|
482
|
+
num_cache_lines == conv_states.shape[0]
|
483
|
+
and dim == conv_states.shape[1]
|
484
|
+
and width - 1 <= conv_states.shape[2]
|
485
|
+
)
|
486
|
+
stride_istate_seq = conv_states.stride(0)
|
487
|
+
stride_istate_dim = conv_states.stride(1)
|
488
|
+
stride_istate_token = conv_states.stride(2)
|
489
|
+
# assert stride_istate_dim == 1
|
490
|
+
if out.dim() == 2:
|
491
|
+
stride_o_seq = 0
|
492
|
+
stride_o_dim = out.stride(0)
|
493
|
+
stride_o_token = out.stride(1)
|
494
|
+
else:
|
495
|
+
stride_o_seq = out.stride(0)
|
496
|
+
stride_o_dim = out.stride(1)
|
497
|
+
stride_o_token = out.stride(2)
|
498
|
+
|
499
|
+
if validate_data:
|
500
|
+
assert x.dim() == 2
|
501
|
+
assert query_start_loc is not None
|
502
|
+
assert query_start_loc.dim() == 1
|
503
|
+
assert x.stride(0) == 1 or x.stride(1) == 1
|
504
|
+
if bias is not None:
|
505
|
+
assert bias.dim() == 1
|
506
|
+
assert dim == bias.size(0)
|
507
|
+
if cache_indices is not None:
|
508
|
+
assert cache_indices.dim() == 1
|
509
|
+
assert padded_batch == cache_indices.size(0)
|
510
|
+
if has_initial_state is not None:
|
511
|
+
assert has_initial_state.size() == (padded_batch,)
|
512
|
+
assert (
|
513
|
+
conv_states is not None
|
514
|
+
), "ERROR: `has_initial_state` is used, which needs also `conv_states`"
|
515
|
+
assert weight.stride(1) == 1
|
516
|
+
assert (dim, width) == weight.shape
|
517
|
+
assert is_channel_last, "Need to run in channel-last layout"
|
518
|
+
|
519
|
+
if metadata is None:
|
520
|
+
|
521
|
+
def num_program(META, seqlens):
|
522
|
+
tot = 0
|
523
|
+
|
524
|
+
mlist = []
|
525
|
+
offsetlist = [] # type: ignore
|
526
|
+
|
527
|
+
nums = -(-seqlens // META["BLOCK_M"])
|
528
|
+
|
529
|
+
tot = nums.sum().item()
|
530
|
+
mlist = np.repeat(np.arange(len(nums)), nums)
|
531
|
+
for idx, num in enumerate(nums):
|
532
|
+
offsetlist.extend(
|
533
|
+
range(num)
|
534
|
+
) # chunk-idx if a sequence is split into multiple chunks
|
535
|
+
|
536
|
+
if META["batch_ptr"].nelement() < len(mlist):
|
537
|
+
newlen = len(mlist) + 1
|
538
|
+
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
539
|
+
META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
540
|
+
|
541
|
+
if META["batch_ptr"].nelement() >= len(mlist):
|
542
|
+
META["batch_ptr"][0 : len(mlist)].copy_(
|
543
|
+
torch.from_numpy(np.array(mlist))
|
544
|
+
)
|
545
|
+
META["token_chunk_offset_ptr"][0 : len(mlist)].copy_(
|
546
|
+
torch.from_numpy(np.array(offsetlist))
|
547
|
+
)
|
548
|
+
|
549
|
+
META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device)
|
550
|
+
META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to(
|
551
|
+
META["x_ptr"].device
|
552
|
+
)
|
553
|
+
return tot
|
554
|
+
|
555
|
+
else:
|
556
|
+
|
557
|
+
def num_program(META, nums_dict):
|
558
|
+
tot = nums_dict[META["BLOCK_M"]]["tot"]
|
559
|
+
|
560
|
+
mlist = nums_dict[META["BLOCK_M"]]["mlist"]
|
561
|
+
mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"]
|
562
|
+
|
563
|
+
offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"]
|
564
|
+
|
565
|
+
if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None:
|
566
|
+
META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"]
|
567
|
+
META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]][
|
568
|
+
"token_chunk_offset_ptr"
|
569
|
+
]
|
570
|
+
else:
|
571
|
+
if META["batch_ptr"].nelement() < mlist_len:
|
572
|
+
newlen = mlist_len + 1
|
573
|
+
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
574
|
+
META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
575
|
+
|
576
|
+
if META["batch_ptr"].nelement() >= mlist_len:
|
577
|
+
META["batch_ptr"][0:mlist_len].copy_(mlist)
|
578
|
+
META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist)
|
579
|
+
return tot
|
580
|
+
|
581
|
+
def grid(META):
|
582
|
+
return (
|
583
|
+
num_program(META, args),
|
584
|
+
triton.cdiv(dim, META["BLOCK_N"]),
|
585
|
+
)
|
586
|
+
|
587
|
+
if batch_ptr.device != x.device:
|
588
|
+
batch_ptr = batch_ptr.to(x.device)
|
589
|
+
token_chunk_offset_ptr = token_chunk_offset_ptr.to(x.device)
|
590
|
+
|
591
|
+
_causal_conv1d_fwd_kernel[grid](
|
592
|
+
# Pointers to matrices
|
593
|
+
x,
|
594
|
+
weight,
|
595
|
+
bias,
|
596
|
+
conv_states,
|
597
|
+
cache_indices,
|
598
|
+
has_initial_state,
|
599
|
+
query_start_loc,
|
600
|
+
batch_ptr,
|
601
|
+
token_chunk_offset_ptr,
|
602
|
+
out,
|
603
|
+
# Matrix dimensions
|
604
|
+
padded_batch,
|
605
|
+
dim,
|
606
|
+
cu_seqlen,
|
607
|
+
num_cache_lines,
|
608
|
+
# stride
|
609
|
+
stride_x_seq,
|
610
|
+
stride_x_dim,
|
611
|
+
stride_x_token,
|
612
|
+
stride_w_dim,
|
613
|
+
stride_w_width,
|
614
|
+
stride_istate_seq,
|
615
|
+
stride_istate_dim,
|
616
|
+
stride_istate_token,
|
617
|
+
stride_o_seq,
|
618
|
+
stride_o_dim,
|
619
|
+
stride_o_token,
|
620
|
+
# others
|
621
|
+
pad_slot_id,
|
622
|
+
# META
|
623
|
+
HAS_BIAS=bias is not None,
|
624
|
+
KERNEL_WIDTH=width,
|
625
|
+
SILU_ACTIVATION=activation in ["silu", "swish"],
|
626
|
+
HAS_INITIAL_STATES=has_initial_state is not None,
|
627
|
+
HAS_CACHE=conv_states is not None,
|
628
|
+
IS_CONTINUOUS_BATCHING=cache_indices is not None,
|
629
|
+
USE_PAD_SLOT=pad_slot_id is not None,
|
630
|
+
NP2_STATELEN=np2_statelen,
|
631
|
+
# launch_cooperative_grid=True
|
632
|
+
BLOCK_M=8,
|
633
|
+
BLOCK_N=256,
|
634
|
+
num_stages=2,
|
635
|
+
)
|
636
|
+
return out
|
637
|
+
|
638
|
+
|
639
|
+
@triton.jit()
|
640
|
+
def _causal_conv1d_update_kernel(
|
641
|
+
# Pointers to matrices
|
642
|
+
x_ptr, # (batch, dim, seqlen)
|
643
|
+
w_ptr, # (dim, width)
|
644
|
+
bias_ptr,
|
645
|
+
conv_state_ptr,
|
646
|
+
cache_seqlens_ptr, # circular buffer
|
647
|
+
conv_state_indices_ptr,
|
648
|
+
num_accepted_tokens_ptr,
|
649
|
+
intermediate_conv_window_ptr,
|
650
|
+
o_ptr, # (batch, dim, seqlen)
|
651
|
+
# Matrix dimensions
|
652
|
+
batch: int,
|
653
|
+
dim: tl.constexpr,
|
654
|
+
seqlen: tl.constexpr,
|
655
|
+
state_len: tl.constexpr,
|
656
|
+
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
|
657
|
+
# Strides
|
658
|
+
stride_x_seq: tl.constexpr,
|
659
|
+
stride_x_dim: tl.constexpr,
|
660
|
+
stride_x_token: tl.constexpr,
|
661
|
+
stride_w_dim: tl.constexpr,
|
662
|
+
stride_w_width: tl.constexpr,
|
663
|
+
stride_conv_state_seq: tl.constexpr,
|
664
|
+
stride_conv_state_dim: tl.constexpr,
|
665
|
+
stride_conv_state_tok: tl.constexpr,
|
666
|
+
stride_state_indices: tl.constexpr,
|
667
|
+
stride_inter_seq: tl.constexpr,
|
668
|
+
stride_inter_step: tl.constexpr,
|
669
|
+
stride_inter_dim: tl.constexpr,
|
670
|
+
stride_inter_win: tl.constexpr,
|
671
|
+
stride_o_seq: tl.constexpr,
|
672
|
+
stride_o_dim: tl.constexpr,
|
673
|
+
stride_o_token: tl.constexpr,
|
674
|
+
# others
|
675
|
+
pad_slot_id: tl.constexpr,
|
676
|
+
# Meta-parameters
|
677
|
+
HAS_BIAS: tl.constexpr,
|
678
|
+
KERNEL_WIDTH: tl.constexpr,
|
679
|
+
SILU_ACTIVATION: tl.constexpr,
|
680
|
+
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
681
|
+
IS_SPEC_DECODING: tl.constexpr,
|
682
|
+
NP2_STATELEN: tl.constexpr,
|
683
|
+
USE_PAD_SLOT: tl.constexpr,
|
684
|
+
BLOCK_N: tl.constexpr,
|
685
|
+
SAVE_INTERMEDIATE: tl.constexpr,
|
686
|
+
):
|
687
|
+
# ruff: noqa: E501
|
688
|
+
idx_seq = tl.program_id(0)
|
689
|
+
if idx_seq >= batch:
|
690
|
+
return
|
691
|
+
|
692
|
+
# [BLOCK_N,] elements along the feature-dimension (channel)
|
693
|
+
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
|
694
|
+
|
695
|
+
if IS_CONTINUOUS_BATCHING:
|
696
|
+
# mask = idx_seq < batch
|
697
|
+
conv_state_batch_coord = tl.load(
|
698
|
+
conv_state_indices_ptr + idx_seq * stride_state_indices
|
699
|
+
).to(tl.int64)
|
700
|
+
else:
|
701
|
+
conv_state_batch_coord = idx_seq
|
702
|
+
if USE_PAD_SLOT: # noqa
|
703
|
+
if conv_state_batch_coord == pad_slot_id:
|
704
|
+
# not processing as this is not the actual sequence
|
705
|
+
return
|
706
|
+
|
707
|
+
if IS_SPEC_DECODING:
|
708
|
+
# The rolling of conv state:
|
709
|
+
#
|
710
|
+
# Before forward, the conv_state is:
|
711
|
+
# [history1, history2, ..., historyM].
|
712
|
+
#
|
713
|
+
# After forward, the conv_state becomes:
|
714
|
+
# [history2, ..., historyM, draft1, draft2, ..., draftN].
|
715
|
+
#
|
716
|
+
# After acceptance, it becomes:
|
717
|
+
#
|
718
|
+
# - accept 1 tokens: [history2, ..., historyM, draft1]
|
719
|
+
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
|
720
|
+
# - and so on.
|
721
|
+
conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq) - 1
|
722
|
+
else:
|
723
|
+
conv_state_token_offset = 0
|
724
|
+
|
725
|
+
# STEP 1: READ init_state data
|
726
|
+
conv_states_base = (
|
727
|
+
conv_state_ptr
|
728
|
+
+ (conv_state_batch_coord * stride_conv_state_seq)
|
729
|
+
+ (idx_feats * stride_conv_state_dim)
|
730
|
+
)
|
731
|
+
mask_w = idx_feats < dim
|
732
|
+
|
733
|
+
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
|
734
|
+
if KERNEL_WIDTH >= 2:
|
735
|
+
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
736
|
+
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
737
|
+
if KERNEL_WIDTH >= 3:
|
738
|
+
conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N]
|
739
|
+
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
740
|
+
if KERNEL_WIDTH >= 4:
|
741
|
+
conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
|
742
|
+
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
743
|
+
if KERNEL_WIDTH == 5:
|
744
|
+
conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
|
745
|
+
col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
|
746
|
+
|
747
|
+
# STEP 2: assume state_len > seqlen
|
748
|
+
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
749
|
+
|
750
|
+
# The conv_state updates works in a sliding window manner,
|
751
|
+
# at each forward pass, the tokens are shift by 1, so we
|
752
|
+
# load since idx_tokens + 1.
|
753
|
+
conv_state_ptrs_source = (
|
754
|
+
conv_state_ptr
|
755
|
+
+ (conv_state_batch_coord * stride_conv_state_seq)
|
756
|
+
+ conv_state_token_offset * stride_conv_state_tok
|
757
|
+
+ (idx_feats * stride_conv_state_dim)[None, :]
|
758
|
+
+ ((idx_tokens + 1) * stride_conv_state_tok)[:, None]
|
759
|
+
) # [BLOCK_M, BLOCK_N]
|
760
|
+
mask = (
|
761
|
+
(conv_state_batch_coord < num_cache_lines)
|
762
|
+
& ((idx_tokens + seqlen) < state_len)[:, None]
|
763
|
+
& (idx_feats < dim)[None, :]
|
764
|
+
)
|
765
|
+
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
|
766
|
+
|
767
|
+
VAL = state_len - seqlen
|
768
|
+
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim) # [BLOCK_N]
|
769
|
+
|
770
|
+
x_ptrs = (
|
771
|
+
x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None]
|
772
|
+
) # [BLOCK_M, BLOCK_N]
|
773
|
+
|
774
|
+
mask_x = (
|
775
|
+
(idx_tokens - VAL >= 0)[:, None]
|
776
|
+
& (idx_tokens - VAL < seqlen)[:, None]
|
777
|
+
& (idx_feats < dim)[None, :]
|
778
|
+
) # token-index # token-index # feature-index
|
779
|
+
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
|
780
|
+
tl.debug_barrier()
|
781
|
+
|
782
|
+
new_conv_state = tl.where(mask, conv_state, loaded_x)
|
783
|
+
|
784
|
+
conv_state_base = (
|
785
|
+
conv_state_ptr
|
786
|
+
+ (conv_state_batch_coord * stride_conv_state_seq)
|
787
|
+
+ (idx_feats * stride_conv_state_dim)
|
788
|
+
) # [BLOCK_N,]
|
789
|
+
conv_state_ptrs_target = (
|
790
|
+
conv_state_base + (idx_tokens * stride_conv_state_tok)[:, None]
|
791
|
+
) # [BLOCK_M, BLOCK_N]
|
792
|
+
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
|
793
|
+
tl.store(conv_state_ptrs_target, new_conv_state, mask)
|
794
|
+
|
795
|
+
# STEP 3: init accumulator
|
796
|
+
if HAS_BIAS:
|
797
|
+
bias = bias_ptr + idx_feats
|
798
|
+
mask_bias = idx_feats < dim
|
799
|
+
acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(
|
800
|
+
tl.float32
|
801
|
+
) # [BLOCK_N]
|
802
|
+
else:
|
803
|
+
acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
804
|
+
|
805
|
+
# STEP 4:
|
806
|
+
# PRE-LOAD WEIGHTS
|
807
|
+
# first kernel column, configured for weights to handle BLOCK_N features in range
|
808
|
+
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
|
809
|
+
mask_w = idx_feats < dim
|
810
|
+
if KERNEL_WIDTH >= 2:
|
811
|
+
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
|
812
|
+
w_col0 = tl.load(w_ptrs, mask_w, other=0.0)
|
813
|
+
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
|
814
|
+
w_col1 = tl.load(w_ptrs, mask_w, other=0.0)
|
815
|
+
if KERNEL_WIDTH >= 3:
|
816
|
+
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
|
817
|
+
w_col2 = tl.load(w_ptrs, mask_w, other=0.0)
|
818
|
+
if KERNEL_WIDTH >= 4:
|
819
|
+
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
|
820
|
+
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
|
821
|
+
|
822
|
+
x_base_1d = x_base # starting of chunk [BLOCK_N]
|
823
|
+
mask_x_1d = idx_feats < dim
|
824
|
+
|
825
|
+
# STEP 5: compute each token
|
826
|
+
for idx_token in tl.static_range(seqlen):
|
827
|
+
acc = acc_preload
|
828
|
+
|
829
|
+
matrix_w = w_col0
|
830
|
+
matrix_x = col0
|
831
|
+
for j in tl.static_range(KERNEL_WIDTH):
|
832
|
+
if KERNEL_WIDTH == 2:
|
833
|
+
if j == 1: # KERNEL_WIDTH-1:
|
834
|
+
matrix_w = w_col1
|
835
|
+
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
836
|
+
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
837
|
+
elif KERNEL_WIDTH == 3:
|
838
|
+
if j == 1:
|
839
|
+
matrix_w = w_col1
|
840
|
+
matrix_x = col1
|
841
|
+
elif j == 2:
|
842
|
+
matrix_w = w_col2
|
843
|
+
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
844
|
+
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
845
|
+
elif KERNEL_WIDTH == 4:
|
846
|
+
if j == 1:
|
847
|
+
matrix_w = w_col1
|
848
|
+
matrix_x = col1
|
849
|
+
elif j == 2:
|
850
|
+
matrix_w = w_col2
|
851
|
+
matrix_x = col2
|
852
|
+
elif j == 3:
|
853
|
+
matrix_w = w_col3
|
854
|
+
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
|
855
|
+
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
|
856
|
+
|
857
|
+
acc += matrix_x * matrix_w # [BLOCK_N]
|
858
|
+
|
859
|
+
if KERNEL_WIDTH == 2:
|
860
|
+
col0 = matrix_x
|
861
|
+
elif KERNEL_WIDTH == 3:
|
862
|
+
col0 = col1
|
863
|
+
col1 = matrix_x
|
864
|
+
elif KERNEL_WIDTH == 4:
|
865
|
+
col0 = col1
|
866
|
+
col1 = col2
|
867
|
+
col2 = matrix_x
|
868
|
+
|
869
|
+
if SILU_ACTIVATION:
|
870
|
+
acc = acc / (1 + tl.exp(-acc))
|
871
|
+
mask_1d = (idx_token < seqlen) & (
|
872
|
+
idx_feats < dim
|
873
|
+
) # token-index # feature-index
|
874
|
+
o_ptrs = (
|
875
|
+
o_ptr
|
876
|
+
+ (idx_seq) * stride_o_seq
|
877
|
+
+ idx_token * stride_o_token
|
878
|
+
+ (idx_feats * stride_o_dim)
|
879
|
+
)
|
880
|
+
|
881
|
+
tl.store(o_ptrs, acc, mask=mask_1d)
|
882
|
+
|
883
|
+
if SAVE_INTERMEDIATE:
|
884
|
+
# Save the window state after consuming this token
|
885
|
+
# Layout: [seq(cache line), step, dim, win(K-1)]
|
886
|
+
base_ptr = (
|
887
|
+
intermediate_conv_window_ptr
|
888
|
+
+ conv_state_batch_coord * stride_inter_seq
|
889
|
+
+ idx_token * stride_inter_step
|
890
|
+
+ idx_feats * stride_inter_dim
|
891
|
+
)
|
892
|
+
if KERNEL_WIDTH >= 2:
|
893
|
+
tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w)
|
894
|
+
if KERNEL_WIDTH >= 3:
|
895
|
+
tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w)
|
896
|
+
if KERNEL_WIDTH >= 4:
|
897
|
+
tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w)
|
898
|
+
|
899
|
+
|
900
|
+
def causal_conv1d_update(
|
901
|
+
x: torch.Tensor,
|
902
|
+
conv_state: torch.Tensor,
|
903
|
+
weight: torch.Tensor,
|
904
|
+
bias: Optional[torch.Tensor] = None,
|
905
|
+
activation: Union[bool, str, None] = None,
|
906
|
+
cache_seqlens: Optional[torch.Tensor] = None,
|
907
|
+
conv_state_indices: Optional[torch.Tensor] = None,
|
908
|
+
num_accepted_tokens: Optional[torch.Tensor] = None,
|
909
|
+
intermediate_conv_window: Optional[torch.Tensor] = None,
|
910
|
+
pad_slot_id: int = PAD_SLOT_ID,
|
911
|
+
metadata=None,
|
912
|
+
validate_data=False,
|
913
|
+
):
|
914
|
+
"""
|
915
|
+
x: (batch, dim) or (batch, dim, seqlen)
|
916
|
+
[shape=2: single token prediction]
|
917
|
+
[shape=3: single or multiple tokens prediction]
|
918
|
+
conv_state: (..., dim, state_len), where state_len >= width - 1
|
919
|
+
weight: (dim, width)
|
920
|
+
bias: (dim,)
|
921
|
+
cache_seqlens: (batch,), dtype int32.
|
922
|
+
If not None, the conv_state is treated as a circular buffer.
|
923
|
+
The conv_state will be updated by copying x to the conv_state
|
924
|
+
starting at the index
|
925
|
+
@cache_seqlens % state_len.
|
926
|
+
conv_state_indices: (batch,), dtype int32
|
927
|
+
If not None, the conv_state is a larger tensor along the batch dim,
|
928
|
+
and we are selecting the batch coords specified by conv_state_indices.
|
929
|
+
Useful for a continuous batching scenario.
|
930
|
+
pad_slot_id: int
|
931
|
+
if cache_indices is passed, lets the kernel identify padded
|
932
|
+
entries that will not be processed,
|
933
|
+
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
934
|
+
in this case, the kernel will not process entries at
|
935
|
+
indices 0 and 3
|
936
|
+
out: (batch, dim) or (batch, dim, seqlen)
|
937
|
+
"""
|
938
|
+
if validate_data:
|
939
|
+
assert cache_seqlens is None # not implemented yet - ok for vLLM
|
940
|
+
assert pad_slot_id is not None
|
941
|
+
assert x.stride(1) == 1
|
942
|
+
if isinstance(activation, bool):
|
943
|
+
activation = "silu" if activation is True else None
|
944
|
+
elif activation is not None:
|
945
|
+
assert activation in ["silu", "swish"]
|
946
|
+
unsqueeze = x.dim() == 2
|
947
|
+
if unsqueeze:
|
948
|
+
# make it (batch, dim, seqlen) with seqlen == 1
|
949
|
+
x = x.unsqueeze(-1)
|
950
|
+
batch, dim, seqlen = x.shape
|
951
|
+
_, width = weight.shape
|
952
|
+
# conv_state: (..., dim, state_len), where state_len >= width - 1
|
953
|
+
num_cache_lines, _, state_len = conv_state.size()
|
954
|
+
|
955
|
+
if validate_data:
|
956
|
+
assert dim == weight.size(0)
|
957
|
+
assert (
|
958
|
+
conv_state.stride(-2) == 1
|
959
|
+
), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})"
|
960
|
+
assert state_len >= width - 1
|
961
|
+
# when above happens, we don't shift-left to keep any records in conv_state
|
962
|
+
assert dim == conv_state.size(1)
|
963
|
+
if conv_state_indices is None:
|
964
|
+
assert conv_state.size(0) >= batch
|
965
|
+
else:
|
966
|
+
assert (batch,) == conv_state_indices.shape
|
967
|
+
|
968
|
+
assert num_cache_lines >= batch
|
969
|
+
assert weight.stride(1) == 1 # Need this
|
970
|
+
assert cache_seqlens is None # not needed for vLLM - circular buffer
|
971
|
+
|
972
|
+
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
|
973
|
+
out = x
|
974
|
+
stride_w_dim, stride_w_width = weight.stride()
|
975
|
+
|
976
|
+
stride_x_seq, stride_x_dim, stride_x_token = x.stride() # X (batch, dim, seqlen)
|
977
|
+
|
978
|
+
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
|
979
|
+
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride()
|
980
|
+
stride_state_indices = (
|
981
|
+
conv_state_indices.stride(0) if conv_state_indices is not None else 0
|
982
|
+
)
|
983
|
+
state_len = width - 1 + (seqlen - 1) # effective state_len needed
|
984
|
+
np2_statelen = triton.next_power_of_2(state_len)
|
985
|
+
|
986
|
+
def grid(META):
|
987
|
+
return (
|
988
|
+
batch,
|
989
|
+
triton.cdiv(dim, META["BLOCK_N"]),
|
990
|
+
)
|
991
|
+
|
992
|
+
# prepare intermediate buffer strides if provided
|
993
|
+
if intermediate_conv_window is not None:
|
994
|
+
stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = (
|
995
|
+
intermediate_conv_window.stride(0),
|
996
|
+
intermediate_conv_window.stride(1),
|
997
|
+
intermediate_conv_window.stride(2),
|
998
|
+
intermediate_conv_window.stride(3),
|
999
|
+
)
|
1000
|
+
else:
|
1001
|
+
stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0
|
1002
|
+
|
1003
|
+
_causal_conv1d_update_kernel[grid](
|
1004
|
+
# Pointers to matrices
|
1005
|
+
x,
|
1006
|
+
weight,
|
1007
|
+
bias,
|
1008
|
+
conv_state,
|
1009
|
+
cache_seqlens,
|
1010
|
+
conv_state_indices,
|
1011
|
+
num_accepted_tokens,
|
1012
|
+
intermediate_conv_window if intermediate_conv_window is not None else x,
|
1013
|
+
out,
|
1014
|
+
# Matrix dimensions
|
1015
|
+
batch,
|
1016
|
+
dim,
|
1017
|
+
seqlen,
|
1018
|
+
state_len,
|
1019
|
+
num_cache_lines,
|
1020
|
+
# stride
|
1021
|
+
stride_x_seq,
|
1022
|
+
stride_x_dim,
|
1023
|
+
stride_x_token,
|
1024
|
+
stride_w_dim,
|
1025
|
+
stride_w_width,
|
1026
|
+
stride_istate_seq,
|
1027
|
+
stride_istate_dim,
|
1028
|
+
stride_istate_token,
|
1029
|
+
stride_state_indices,
|
1030
|
+
stride_inter_seq,
|
1031
|
+
stride_inter_step,
|
1032
|
+
stride_inter_dim,
|
1033
|
+
stride_inter_win,
|
1034
|
+
stride_o_seq,
|
1035
|
+
stride_o_dim,
|
1036
|
+
stride_o_token,
|
1037
|
+
# others
|
1038
|
+
pad_slot_id,
|
1039
|
+
# META
|
1040
|
+
HAS_BIAS=bias is not None,
|
1041
|
+
KERNEL_WIDTH=width,
|
1042
|
+
SILU_ACTIVATION=activation in ["silu", "swish"],
|
1043
|
+
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
|
1044
|
+
IS_SPEC_DECODING=num_accepted_tokens is not None,
|
1045
|
+
NP2_STATELEN=np2_statelen,
|
1046
|
+
USE_PAD_SLOT=pad_slot_id is not None,
|
1047
|
+
BLOCK_N=256,
|
1048
|
+
SAVE_INTERMEDIATE=intermediate_conv_window is not None,
|
1049
|
+
)
|
1050
|
+
if unsqueeze:
|
1051
|
+
out = out.squeeze(-1)
|
1052
|
+
return out
|