sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1052 @@
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
3
+ # and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
4
+
5
+ from typing import Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ PAD_SLOT_ID = -1
11
+ import triton
12
+ import triton.language as tl
13
+
14
+
15
+ @triton.jit()
16
+ def _causal_conv1d_fwd_kernel( # continuous batching
17
+ # Pointers to matrices
18
+ x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences
19
+ w_ptr, # (dim, width)
20
+ bias_ptr,
21
+ initial_states_ptr, # conv_states_ptr
22
+ cache_indices_ptr, # conv_state_indices_ptr
23
+ has_initial_states_ptr,
24
+ query_start_loc_ptr,
25
+ batch_ptr,
26
+ token_chunk_offset_ptr,
27
+ o_ptr, # (dim, seqlen) - actually pointing to x_ptr
28
+ # Matrix dimensions
29
+ batch: tl.int32, # actually padded_batch
30
+ dim: tl.constexpr,
31
+ seqlen: tl.int32, # cu_seqlen
32
+ num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
33
+ # Strides
34
+ stride_x_seq: tl.constexpr, # stride to get to next sequence,
35
+ stride_x_dim: tl.constexpr, # stride to get to next feature-value,
36
+ stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index)
37
+ stride_w_dim: tl.constexpr, # stride to get to next dim-axis value
38
+ stride_w_width: tl.constexpr, # stride to get to next width-axis value
39
+ stride_istate_seq: tl.constexpr,
40
+ stride_istate_dim: tl.constexpr,
41
+ stride_istate_token: tl.constexpr,
42
+ stride_o_seq: tl.constexpr,
43
+ stride_o_dim: tl.constexpr,
44
+ stride_o_token: tl.constexpr,
45
+ # others
46
+ pad_slot_id: tl.constexpr,
47
+ # Meta-parameters
48
+ HAS_BIAS: tl.constexpr,
49
+ KERNEL_WIDTH: tl.constexpr,
50
+ SILU_ACTIVATION: tl.constexpr,
51
+ HAS_INITIAL_STATES: tl.constexpr,
52
+ HAS_CACHE: tl.constexpr,
53
+ IS_CONTINUOUS_BATCHING: tl.constexpr,
54
+ USE_PAD_SLOT: tl.constexpr,
55
+ NP2_STATELEN: tl.constexpr,
56
+ BLOCK_M: tl.constexpr,
57
+ BLOCK_N: tl.constexpr,
58
+ ):
59
+ conv_states_ptr = initial_states_ptr
60
+ conv_state_indices_ptr = cache_indices_ptr
61
+ stride_conv_state_seq = stride_istate_seq
62
+ stride_conv_state_dim = stride_istate_dim
63
+ stride_conv_state_tok = stride_istate_token
64
+ state_len = (
65
+ KERNEL_WIDTH - 1
66
+ ) # can be passed via argument if it's not the same as this value
67
+
68
+ # one program handles one chunk in a single sequence
69
+ # rather than mixing sequences - to make updating initial_states across sequences efficiently
70
+
71
+ # single-sequence id
72
+ idx_seq = tl.load(batch_ptr + tl.program_id(0))
73
+ chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
74
+
75
+ # BLOCK_N elements along the feature-dimension (channel)
76
+ idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
77
+
78
+ if idx_seq == pad_slot_id:
79
+ return
80
+
81
+ sequence_start_index = tl.load(query_start_loc_ptr + idx_seq)
82
+ sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1)
83
+ # find the actual sequence length
84
+ seqlen = sequence_end_index - sequence_start_index
85
+
86
+ token_offset = BLOCK_M * chunk_offset
87
+ segment_len = min(BLOCK_M, seqlen - token_offset)
88
+
89
+ # base of the sequence
90
+ x_base = (
91
+ x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim
92
+ ) # [BLOCK_N,]
93
+
94
+ if IS_CONTINUOUS_BATCHING:
95
+ # cache_idx
96
+ conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(tl.int64)
97
+ else:
98
+ # cache_idx
99
+ conv_state_batch_coord = idx_seq
100
+ if USE_PAD_SLOT: # noqa
101
+ if conv_state_batch_coord == pad_slot_id:
102
+ # not processing as this is not the actual sequence
103
+ return
104
+ conv_states_base = (
105
+ conv_states_ptr
106
+ + (conv_state_batch_coord * stride_conv_state_seq)
107
+ + (idx_feats * stride_conv_state_dim)
108
+ ) # [BLOCK_N,]
109
+
110
+ w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
111
+
112
+ # Does 2 things:
113
+ # 1. READ prior-block init-state data - [done by every Triton programs]
114
+ # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0]
115
+ if chunk_offset == 0:
116
+ # read from conv_states
117
+ load_init_state = False
118
+ if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES
119
+ load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1)
120
+ if load_init_state:
121
+ # load from conv_states
122
+ prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok
123
+ mask_w = idx_feats < dim
124
+ if KERNEL_WIDTH == 2:
125
+ conv_states_ptrs = prior_tokens # [BLOCK_N]
126
+ col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
127
+ if KERNEL_WIDTH == 3:
128
+ conv_states_ptrs = prior_tokens # [BLOCK_N]
129
+ col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
130
+ conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N]
131
+ col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
132
+ if KERNEL_WIDTH == 4:
133
+ conv_states_ptrs = prior_tokens # [BLOCK_N]
134
+ col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
135
+ conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N]
136
+ col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
137
+ conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N]
138
+ col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
139
+ if KERNEL_WIDTH == 5:
140
+ conv_states_ptrs = prior_tokens # [BLOCK_N]
141
+ col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
142
+ conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N]
143
+ col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
144
+ conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N]
145
+ col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
146
+ conv_states_ptrs = prior_tokens - 3 * stride_conv_state_tok # [BLOCK_N]
147
+ col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
148
+ else:
149
+ # prior-tokens are zeros
150
+ if KERNEL_WIDTH >= 2: # STRATEGY1
151
+ # first chunk and does not have prior-token, so just set to 0
152
+ col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
153
+ if KERNEL_WIDTH >= 3: # STRATEGY1
154
+ col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
155
+ if KERNEL_WIDTH >= 4: # STRATEGY1
156
+ col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
157
+ if KERNEL_WIDTH >= 5: # STRATEGY1
158
+ col3 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
159
+
160
+ # STEP 2:
161
+ # here prepare data for updating conv_state
162
+ if (
163
+ state_len <= seqlen
164
+ ): # SMALL_CACHE=True (only move part of 'x' into conv_state cache)
165
+ # just read from 'x'
166
+ # copy 'x' data to conv_state
167
+ # load only 'x' data (and set 0 before 'x' if seqlen < state_len)
168
+ idx_tokens_last = (seqlen - state_len) + tl.arange(
169
+ 0, NP2_STATELEN
170
+ ) # [BLOCK_M]
171
+ x_ptrs = (
172
+ x_ptr
173
+ + ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None]
174
+ + (idx_feats * stride_x_dim)[None, :]
175
+ ) # [BLOCK_M,BLOCK_N,]
176
+ mask_x = (
177
+ (idx_tokens_last >= 0)[:, None]
178
+ & (idx_tokens_last < seqlen)[:, None]
179
+ & (idx_feats < dim)[None, :]
180
+ ) # token-index # token-index # feature-index
181
+ loaded_x = tl.load(x_ptrs, mask_x, 0.0)
182
+ new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
183
+ idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
184
+ conv_states_ptrs_target = (
185
+ conv_states_base[None, :]
186
+ + (idx_tokens_conv * stride_conv_state_tok)[:, None]
187
+ )
188
+
189
+ mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :]
190
+ tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
191
+ tl.store(conv_states_ptrs_target, new_conv_state, mask)
192
+
193
+ else:
194
+ if load_init_state:
195
+ # update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x'
196
+ idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
197
+
198
+ conv_states_ptrs_source = (
199
+ conv_states_ptr
200
+ + (conv_state_batch_coord * stride_conv_state_seq)
201
+ + (idx_feats * stride_conv_state_dim)[None, :]
202
+ + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None]
203
+ ) # [BLOCK_M, BLOCK_N]
204
+ mask = (
205
+ (conv_state_batch_coord < num_cache_lines)
206
+ & ((idx_tokens_conv + seqlen) < state_len)[:, None]
207
+ & (idx_feats < dim)[None, :]
208
+ )
209
+ conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0)
210
+
211
+ VAL = state_len - seqlen
212
+
213
+ x_ptrs = (
214
+ x_base[None, :]
215
+ + ((idx_tokens_conv - VAL) * stride_x_token)[:, None]
216
+ ) # [BLOCK_M, BLOCK_N]
217
+
218
+ mask_x = (
219
+ (idx_tokens_conv - VAL >= 0)[:, None]
220
+ & (idx_tokens_conv - VAL < seqlen)[:, None]
221
+ & (idx_feats < dim)[None, :]
222
+ ) # token-index # token-index # feature-index
223
+ loaded_x = tl.load(x_ptrs, mask_x, 0.0)
224
+
225
+ tl.debug_barrier() # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load
226
+ new_conv_state = tl.where(
227
+ mask, conv_state, loaded_x
228
+ ) # BUG in 'tl.where' which requires a barrier before this
229
+ conv_states_ptrs_target = (
230
+ conv_states_base
231
+ + (idx_tokens_conv * stride_conv_state_tok)[:, None]
232
+ ) # [BLOCK_M, BLOCK_N]
233
+ mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[
234
+ None, :
235
+ ]
236
+ tl.store(conv_states_ptrs_target, new_conv_state, mask)
237
+ else: # load_init_state == False
238
+ # update conv_state by shifting left, BUT
239
+ # set cols prior to 'x' as zeros + cols from 'x'
240
+ idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
241
+
242
+ VAL = state_len - seqlen
243
+
244
+ x_ptrs = (
245
+ x_base[None, :]
246
+ + ((idx_tokens_conv - VAL) * stride_x_token)[:, None]
247
+ ) # [BLOCK_M, BLOCK_N]
248
+
249
+ mask_x = (
250
+ (idx_tokens_conv - VAL >= 0)[:, None]
251
+ & (idx_tokens_conv - VAL < seqlen)[:, None]
252
+ & (idx_feats < dim)[None, :]
253
+ ) # token-index # token-index # feature-index
254
+ new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
255
+
256
+ conv_states_ptrs_target = (
257
+ conv_states_base
258
+ + (idx_tokens_conv * stride_conv_state_tok)[:, None]
259
+ ) # [BLOCK_M, BLOCK_N]
260
+ mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[
261
+ None, :
262
+ ]
263
+ tl.store(conv_states_ptrs_target, new_conv_state, mask)
264
+
265
+ else: # chunk_offset > 0
266
+ # read prior-token data from `x`
267
+ load_init_state = True
268
+ prior_tokens = x_base + (token_offset - 1) * stride_x_token
269
+ mask_w = idx_feats < dim
270
+ if KERNEL_WIDTH == 2:
271
+ conv_states_ptrs = prior_tokens # [BLOCK_N]
272
+ col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
273
+ if KERNEL_WIDTH == 3:
274
+ conv_states_ptrs = prior_tokens # [BLOCK_N]
275
+ col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
276
+ conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
277
+ col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
278
+ if KERNEL_WIDTH == 4:
279
+ conv_states_ptrs = prior_tokens # [BLOCK_N]
280
+ col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
281
+ conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
282
+ col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
283
+ conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N]
284
+ col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
285
+ if KERNEL_WIDTH == 5:
286
+ # ruff: noqa: F841
287
+ conv_states_ptrs = prior_tokens # [BLOCK_N]
288
+ col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
289
+ conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
290
+ col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
291
+ conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N]
292
+ col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
293
+ conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N]
294
+ col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
295
+
296
+ if HAS_BIAS:
297
+ bias = bias_ptr + idx_feats
298
+ mask_bias = idx_feats < dim
299
+ acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(
300
+ tl.float32
301
+ ) # [BLOCK_N]
302
+ else:
303
+ acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32)
304
+
305
+ x_base_1d = x_base + token_offset * stride_x_token # starting of chunk
306
+
307
+ # PRE-LOAD WEIGHTS
308
+ mask_w = idx_feats < dim
309
+ if KERNEL_WIDTH >= 2:
310
+ w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
311
+ w_col0 = tl.load(w_ptrs, mask_w, other=0.0)
312
+ w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
313
+ w_col1 = tl.load(w_ptrs, mask_w, other=0.0)
314
+ if KERNEL_WIDTH >= 3:
315
+ w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
316
+ w_col2 = tl.load(w_ptrs, mask_w, other=0.0)
317
+ if KERNEL_WIDTH >= 4:
318
+ w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
319
+ w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
320
+ mask_x_1d = idx_feats < dim
321
+ for idx_token in range(segment_len):
322
+ acc = acc_preload
323
+
324
+ matrix_w = w_col0
325
+ matrix_x = col0
326
+ for j in tl.static_range(KERNEL_WIDTH):
327
+
328
+ if KERNEL_WIDTH == 2:
329
+ if j == 1: # KERNEL_WIDTH-1:
330
+ matrix_w = w_col1
331
+ x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
332
+ matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
333
+ elif KERNEL_WIDTH == 3:
334
+ if j == 1:
335
+ matrix_w = w_col1
336
+ matrix_x = col1
337
+ elif j == 2:
338
+ matrix_w = w_col2
339
+ x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
340
+ matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
341
+ elif KERNEL_WIDTH == 4:
342
+ if j == 1:
343
+ matrix_w = w_col1
344
+ matrix_x = col1
345
+ elif j == 2:
346
+ matrix_w = w_col2
347
+ matrix_x = col2
348
+ elif j == 3:
349
+ matrix_w = w_col3
350
+ x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
351
+ matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
352
+
353
+ acc += matrix_x * matrix_w # [BLOCK_N]
354
+
355
+ if KERNEL_WIDTH == 2:
356
+ col0 = matrix_x
357
+ elif KERNEL_WIDTH == 3:
358
+ col0 = col1
359
+ col1 = matrix_x
360
+ elif KERNEL_WIDTH == 4:
361
+ col0 = col1
362
+ col1 = col2
363
+ col2 = matrix_x
364
+
365
+ if SILU_ACTIVATION:
366
+ acc = acc / (1 + tl.exp(-acc))
367
+ mask_1d = (idx_token < segment_len) & (
368
+ idx_feats < dim
369
+ ) # token-index # feature-index
370
+ o_ptrs = (
371
+ o_ptr
372
+ + (sequence_start_index + token_offset + idx_token) * stride_o_token
373
+ + (idx_feats * stride_o_dim)
374
+ )
375
+
376
+ tl.store(o_ptrs, acc, mask=mask_1d)
377
+
378
+
379
+ def causal_conv1d_fn(
380
+ x: torch.Tensor,
381
+ weight: torch.Tensor,
382
+ bias: Union[torch.Tensor, None],
383
+ conv_states: torch.Tensor,
384
+ query_start_loc: torch.Tensor,
385
+ cache_indices: Optional[torch.Tensor] = None,
386
+ has_initial_state: Optional[torch.Tensor] = None,
387
+ activation: Optional[str] = "silu",
388
+ pad_slot_id: int = PAD_SLOT_ID,
389
+ metadata=None,
390
+ validate_data=False,
391
+ ):
392
+ """support varlen + continuous batching when x is 2D tensor
393
+
394
+ x: (dim,cu_seq_len)
395
+ cu_seq_len = total tokens of all seqs in that batch
396
+ sequences are concatenated from left to right for varlen
397
+ weight: (dim, width)
398
+ conv_states: (...,dim,width - 1) itype
399
+ updated inplace if provided
400
+ [it use `cache_indices` to get the index to the cache of conv_state for that sequence
401
+
402
+ conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True
403
+ and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x'
404
+ ]
405
+ query_start_loc: (batch + 1) int32
406
+ The cumulative sequence lengths of the sequences in
407
+ the batch, used to index into sequence. prepended by 0.
408
+ if
409
+ x = [5, 1, 1, 1] <- continuous batching (batch=4)
410
+ then
411
+ query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is
412
+ the ending index of the last sequence
413
+ [length(query_start_loc)-1 == batch]
414
+ for example: query_start_loc = torch.Tensor([0,10,16,17]),
415
+ x.shape=(dim,17)
416
+ cache_indices: (batch) int32
417
+ indicates the corresponding state index,
418
+ like so: conv_state = conv_states[cache_indices[batch_id]]
419
+ has_initial_state: (batch) bool
420
+ indicates whether should the kernel take the current state as initial
421
+ state for the calculations
422
+ [single boolean for each sequence in the batch: True or False]
423
+ bias: (dim,)
424
+ activation: either None or "silu" or "swish" or True
425
+ pad_slot_id: int
426
+ if cache_indices is passed, lets the kernel identify padded
427
+ entries that will not be processed,
428
+ for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
429
+ in this case, the kernel will not process entries at
430
+ indices 0 and 3
431
+
432
+ out: same shape as `x`
433
+ """
434
+ if isinstance(activation, bool) and activation:
435
+ activation = "silu"
436
+
437
+ args = None
438
+ out = torch.empty_like(x)
439
+ if metadata is not None:
440
+ cu_seqlen = metadata.cu_seqlen
441
+ nums_dict = metadata.nums_dict
442
+ # x = metadata.x
443
+ args = nums_dict
444
+ batch_ptr = metadata.batch_ptr
445
+ token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
446
+ else:
447
+ seqlens = np.diff(query_start_loc.to("cpu"))
448
+ args = seqlens
449
+ MAX_NUM_PROGRAMS = 1024
450
+
451
+ batch_ptr = torch.full(
452
+ (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
453
+ ) # tracking which seq-idx the Triton program is handling
454
+ token_chunk_offset_ptr = torch.full(
455
+ (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
456
+ ) # tracking BLOCK_M-based index in the sequence the Triton program is handling
457
+
458
+ is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)
459
+ dim, cu_seqlen = x.shape
460
+ _, width = weight.shape
461
+ state_len = width - 1
462
+ np2_statelen = triton.next_power_of_2(state_len)
463
+
464
+ padded_batch = query_start_loc.size(0) - 1
465
+ stride_x_seq = 0
466
+ stride_x_dim = x.stride(0)
467
+ stride_x_token = x.stride(1)
468
+ stride_w_dim = weight.stride(0)
469
+ stride_w_width = weight.stride(1)
470
+ stride_istate_seq = 0
471
+ stride_istate_dim = 0
472
+ stride_istate_token = 0
473
+ num_cache_lines = 0
474
+ if conv_states is not None:
475
+ # extensions to support vLLM:
476
+ # 1. conv_states is used to replaced initial_states
477
+ # 2. conv_states serve as a cache with num cache lines can be larger than batch size
478
+ # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
479
+ # 4. computation can be skipped if cache_indices[idx] == pad_slot_id
480
+ num_cache_lines = conv_states.size(0)
481
+ assert (
482
+ num_cache_lines == conv_states.shape[0]
483
+ and dim == conv_states.shape[1]
484
+ and width - 1 <= conv_states.shape[2]
485
+ )
486
+ stride_istate_seq = conv_states.stride(0)
487
+ stride_istate_dim = conv_states.stride(1)
488
+ stride_istate_token = conv_states.stride(2)
489
+ # assert stride_istate_dim == 1
490
+ if out.dim() == 2:
491
+ stride_o_seq = 0
492
+ stride_o_dim = out.stride(0)
493
+ stride_o_token = out.stride(1)
494
+ else:
495
+ stride_o_seq = out.stride(0)
496
+ stride_o_dim = out.stride(1)
497
+ stride_o_token = out.stride(2)
498
+
499
+ if validate_data:
500
+ assert x.dim() == 2
501
+ assert query_start_loc is not None
502
+ assert query_start_loc.dim() == 1
503
+ assert x.stride(0) == 1 or x.stride(1) == 1
504
+ if bias is not None:
505
+ assert bias.dim() == 1
506
+ assert dim == bias.size(0)
507
+ if cache_indices is not None:
508
+ assert cache_indices.dim() == 1
509
+ assert padded_batch == cache_indices.size(0)
510
+ if has_initial_state is not None:
511
+ assert has_initial_state.size() == (padded_batch,)
512
+ assert (
513
+ conv_states is not None
514
+ ), "ERROR: `has_initial_state` is used, which needs also `conv_states`"
515
+ assert weight.stride(1) == 1
516
+ assert (dim, width) == weight.shape
517
+ assert is_channel_last, "Need to run in channel-last layout"
518
+
519
+ if metadata is None:
520
+
521
+ def num_program(META, seqlens):
522
+ tot = 0
523
+
524
+ mlist = []
525
+ offsetlist = [] # type: ignore
526
+
527
+ nums = -(-seqlens // META["BLOCK_M"])
528
+
529
+ tot = nums.sum().item()
530
+ mlist = np.repeat(np.arange(len(nums)), nums)
531
+ for idx, num in enumerate(nums):
532
+ offsetlist.extend(
533
+ range(num)
534
+ ) # chunk-idx if a sequence is split into multiple chunks
535
+
536
+ if META["batch_ptr"].nelement() < len(mlist):
537
+ newlen = len(mlist) + 1
538
+ META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
539
+ META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
540
+
541
+ if META["batch_ptr"].nelement() >= len(mlist):
542
+ META["batch_ptr"][0 : len(mlist)].copy_(
543
+ torch.from_numpy(np.array(mlist))
544
+ )
545
+ META["token_chunk_offset_ptr"][0 : len(mlist)].copy_(
546
+ torch.from_numpy(np.array(offsetlist))
547
+ )
548
+
549
+ META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device)
550
+ META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to(
551
+ META["x_ptr"].device
552
+ )
553
+ return tot
554
+
555
+ else:
556
+
557
+ def num_program(META, nums_dict):
558
+ tot = nums_dict[META["BLOCK_M"]]["tot"]
559
+
560
+ mlist = nums_dict[META["BLOCK_M"]]["mlist"]
561
+ mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"]
562
+
563
+ offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"]
564
+
565
+ if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None:
566
+ META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"]
567
+ META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]][
568
+ "token_chunk_offset_ptr"
569
+ ]
570
+ else:
571
+ if META["batch_ptr"].nelement() < mlist_len:
572
+ newlen = mlist_len + 1
573
+ META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
574
+ META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
575
+
576
+ if META["batch_ptr"].nelement() >= mlist_len:
577
+ META["batch_ptr"][0:mlist_len].copy_(mlist)
578
+ META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist)
579
+ return tot
580
+
581
+ def grid(META):
582
+ return (
583
+ num_program(META, args),
584
+ triton.cdiv(dim, META["BLOCK_N"]),
585
+ )
586
+
587
+ if batch_ptr.device != x.device:
588
+ batch_ptr = batch_ptr.to(x.device)
589
+ token_chunk_offset_ptr = token_chunk_offset_ptr.to(x.device)
590
+
591
+ _causal_conv1d_fwd_kernel[grid](
592
+ # Pointers to matrices
593
+ x,
594
+ weight,
595
+ bias,
596
+ conv_states,
597
+ cache_indices,
598
+ has_initial_state,
599
+ query_start_loc,
600
+ batch_ptr,
601
+ token_chunk_offset_ptr,
602
+ out,
603
+ # Matrix dimensions
604
+ padded_batch,
605
+ dim,
606
+ cu_seqlen,
607
+ num_cache_lines,
608
+ # stride
609
+ stride_x_seq,
610
+ stride_x_dim,
611
+ stride_x_token,
612
+ stride_w_dim,
613
+ stride_w_width,
614
+ stride_istate_seq,
615
+ stride_istate_dim,
616
+ stride_istate_token,
617
+ stride_o_seq,
618
+ stride_o_dim,
619
+ stride_o_token,
620
+ # others
621
+ pad_slot_id,
622
+ # META
623
+ HAS_BIAS=bias is not None,
624
+ KERNEL_WIDTH=width,
625
+ SILU_ACTIVATION=activation in ["silu", "swish"],
626
+ HAS_INITIAL_STATES=has_initial_state is not None,
627
+ HAS_CACHE=conv_states is not None,
628
+ IS_CONTINUOUS_BATCHING=cache_indices is not None,
629
+ USE_PAD_SLOT=pad_slot_id is not None,
630
+ NP2_STATELEN=np2_statelen,
631
+ # launch_cooperative_grid=True
632
+ BLOCK_M=8,
633
+ BLOCK_N=256,
634
+ num_stages=2,
635
+ )
636
+ return out
637
+
638
+
639
+ @triton.jit()
640
+ def _causal_conv1d_update_kernel(
641
+ # Pointers to matrices
642
+ x_ptr, # (batch, dim, seqlen)
643
+ w_ptr, # (dim, width)
644
+ bias_ptr,
645
+ conv_state_ptr,
646
+ cache_seqlens_ptr, # circular buffer
647
+ conv_state_indices_ptr,
648
+ num_accepted_tokens_ptr,
649
+ intermediate_conv_window_ptr,
650
+ o_ptr, # (batch, dim, seqlen)
651
+ # Matrix dimensions
652
+ batch: int,
653
+ dim: tl.constexpr,
654
+ seqlen: tl.constexpr,
655
+ state_len: tl.constexpr,
656
+ num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
657
+ # Strides
658
+ stride_x_seq: tl.constexpr,
659
+ stride_x_dim: tl.constexpr,
660
+ stride_x_token: tl.constexpr,
661
+ stride_w_dim: tl.constexpr,
662
+ stride_w_width: tl.constexpr,
663
+ stride_conv_state_seq: tl.constexpr,
664
+ stride_conv_state_dim: tl.constexpr,
665
+ stride_conv_state_tok: tl.constexpr,
666
+ stride_state_indices: tl.constexpr,
667
+ stride_inter_seq: tl.constexpr,
668
+ stride_inter_step: tl.constexpr,
669
+ stride_inter_dim: tl.constexpr,
670
+ stride_inter_win: tl.constexpr,
671
+ stride_o_seq: tl.constexpr,
672
+ stride_o_dim: tl.constexpr,
673
+ stride_o_token: tl.constexpr,
674
+ # others
675
+ pad_slot_id: tl.constexpr,
676
+ # Meta-parameters
677
+ HAS_BIAS: tl.constexpr,
678
+ KERNEL_WIDTH: tl.constexpr,
679
+ SILU_ACTIVATION: tl.constexpr,
680
+ IS_CONTINUOUS_BATCHING: tl.constexpr,
681
+ IS_SPEC_DECODING: tl.constexpr,
682
+ NP2_STATELEN: tl.constexpr,
683
+ USE_PAD_SLOT: tl.constexpr,
684
+ BLOCK_N: tl.constexpr,
685
+ SAVE_INTERMEDIATE: tl.constexpr,
686
+ ):
687
+ # ruff: noqa: E501
688
+ idx_seq = tl.program_id(0)
689
+ if idx_seq >= batch:
690
+ return
691
+
692
+ # [BLOCK_N,] elements along the feature-dimension (channel)
693
+ idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
694
+
695
+ if IS_CONTINUOUS_BATCHING:
696
+ # mask = idx_seq < batch
697
+ conv_state_batch_coord = tl.load(
698
+ conv_state_indices_ptr + idx_seq * stride_state_indices
699
+ ).to(tl.int64)
700
+ else:
701
+ conv_state_batch_coord = idx_seq
702
+ if USE_PAD_SLOT: # noqa
703
+ if conv_state_batch_coord == pad_slot_id:
704
+ # not processing as this is not the actual sequence
705
+ return
706
+
707
+ if IS_SPEC_DECODING:
708
+ # The rolling of conv state:
709
+ #
710
+ # Before forward, the conv_state is:
711
+ # [history1, history2, ..., historyM].
712
+ #
713
+ # After forward, the conv_state becomes:
714
+ # [history2, ..., historyM, draft1, draft2, ..., draftN].
715
+ #
716
+ # After acceptance, it becomes:
717
+ #
718
+ # - accept 1 tokens: [history2, ..., historyM, draft1]
719
+ # - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
720
+ # - and so on.
721
+ conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq) - 1
722
+ else:
723
+ conv_state_token_offset = 0
724
+
725
+ # STEP 1: READ init_state data
726
+ conv_states_base = (
727
+ conv_state_ptr
728
+ + (conv_state_batch_coord * stride_conv_state_seq)
729
+ + (idx_feats * stride_conv_state_dim)
730
+ )
731
+ mask_w = idx_feats < dim
732
+
733
+ prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
734
+ if KERNEL_WIDTH >= 2:
735
+ conv_states_ptrs = prior_tokens # [BLOCK_N]
736
+ col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
737
+ if KERNEL_WIDTH >= 3:
738
+ conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N]
739
+ col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
740
+ if KERNEL_WIDTH >= 4:
741
+ conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
742
+ col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
743
+ if KERNEL_WIDTH == 5:
744
+ conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
745
+ col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
746
+
747
+ # STEP 2: assume state_len > seqlen
748
+ idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
749
+
750
+ # The conv_state updates works in a sliding window manner,
751
+ # at each forward pass, the tokens are shift by 1, so we
752
+ # load since idx_tokens + 1.
753
+ conv_state_ptrs_source = (
754
+ conv_state_ptr
755
+ + (conv_state_batch_coord * stride_conv_state_seq)
756
+ + conv_state_token_offset * stride_conv_state_tok
757
+ + (idx_feats * stride_conv_state_dim)[None, :]
758
+ + ((idx_tokens + 1) * stride_conv_state_tok)[:, None]
759
+ ) # [BLOCK_M, BLOCK_N]
760
+ mask = (
761
+ (conv_state_batch_coord < num_cache_lines)
762
+ & ((idx_tokens + seqlen) < state_len)[:, None]
763
+ & (idx_feats < dim)[None, :]
764
+ )
765
+ conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
766
+
767
+ VAL = state_len - seqlen
768
+ x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim) # [BLOCK_N]
769
+
770
+ x_ptrs = (
771
+ x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None]
772
+ ) # [BLOCK_M, BLOCK_N]
773
+
774
+ mask_x = (
775
+ (idx_tokens - VAL >= 0)[:, None]
776
+ & (idx_tokens - VAL < seqlen)[:, None]
777
+ & (idx_feats < dim)[None, :]
778
+ ) # token-index # token-index # feature-index
779
+ loaded_x = tl.load(x_ptrs, mask_x, 0.0)
780
+ tl.debug_barrier()
781
+
782
+ new_conv_state = tl.where(mask, conv_state, loaded_x)
783
+
784
+ conv_state_base = (
785
+ conv_state_ptr
786
+ + (conv_state_batch_coord * stride_conv_state_seq)
787
+ + (idx_feats * stride_conv_state_dim)
788
+ ) # [BLOCK_N,]
789
+ conv_state_ptrs_target = (
790
+ conv_state_base + (idx_tokens * stride_conv_state_tok)[:, None]
791
+ ) # [BLOCK_M, BLOCK_N]
792
+ mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
793
+ tl.store(conv_state_ptrs_target, new_conv_state, mask)
794
+
795
+ # STEP 3: init accumulator
796
+ if HAS_BIAS:
797
+ bias = bias_ptr + idx_feats
798
+ mask_bias = idx_feats < dim
799
+ acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(
800
+ tl.float32
801
+ ) # [BLOCK_N]
802
+ else:
803
+ acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32)
804
+
805
+ # STEP 4:
806
+ # PRE-LOAD WEIGHTS
807
+ # first kernel column, configured for weights to handle BLOCK_N features in range
808
+ w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
809
+ mask_w = idx_feats < dim
810
+ if KERNEL_WIDTH >= 2:
811
+ w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
812
+ w_col0 = tl.load(w_ptrs, mask_w, other=0.0)
813
+ w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
814
+ w_col1 = tl.load(w_ptrs, mask_w, other=0.0)
815
+ if KERNEL_WIDTH >= 3:
816
+ w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
817
+ w_col2 = tl.load(w_ptrs, mask_w, other=0.0)
818
+ if KERNEL_WIDTH >= 4:
819
+ w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
820
+ w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
821
+
822
+ x_base_1d = x_base # starting of chunk [BLOCK_N]
823
+ mask_x_1d = idx_feats < dim
824
+
825
+ # STEP 5: compute each token
826
+ for idx_token in tl.static_range(seqlen):
827
+ acc = acc_preload
828
+
829
+ matrix_w = w_col0
830
+ matrix_x = col0
831
+ for j in tl.static_range(KERNEL_WIDTH):
832
+ if KERNEL_WIDTH == 2:
833
+ if j == 1: # KERNEL_WIDTH-1:
834
+ matrix_w = w_col1
835
+ x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
836
+ matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
837
+ elif KERNEL_WIDTH == 3:
838
+ if j == 1:
839
+ matrix_w = w_col1
840
+ matrix_x = col1
841
+ elif j == 2:
842
+ matrix_w = w_col2
843
+ x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
844
+ matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
845
+ elif KERNEL_WIDTH == 4:
846
+ if j == 1:
847
+ matrix_w = w_col1
848
+ matrix_x = col1
849
+ elif j == 2:
850
+ matrix_w = w_col2
851
+ matrix_x = col2
852
+ elif j == 3:
853
+ matrix_w = w_col3
854
+ x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
855
+ matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
856
+
857
+ acc += matrix_x * matrix_w # [BLOCK_N]
858
+
859
+ if KERNEL_WIDTH == 2:
860
+ col0 = matrix_x
861
+ elif KERNEL_WIDTH == 3:
862
+ col0 = col1
863
+ col1 = matrix_x
864
+ elif KERNEL_WIDTH == 4:
865
+ col0 = col1
866
+ col1 = col2
867
+ col2 = matrix_x
868
+
869
+ if SILU_ACTIVATION:
870
+ acc = acc / (1 + tl.exp(-acc))
871
+ mask_1d = (idx_token < seqlen) & (
872
+ idx_feats < dim
873
+ ) # token-index # feature-index
874
+ o_ptrs = (
875
+ o_ptr
876
+ + (idx_seq) * stride_o_seq
877
+ + idx_token * stride_o_token
878
+ + (idx_feats * stride_o_dim)
879
+ )
880
+
881
+ tl.store(o_ptrs, acc, mask=mask_1d)
882
+
883
+ if SAVE_INTERMEDIATE:
884
+ # Save the window state after consuming this token
885
+ # Layout: [seq(cache line), step, dim, win(K-1)]
886
+ base_ptr = (
887
+ intermediate_conv_window_ptr
888
+ + conv_state_batch_coord * stride_inter_seq
889
+ + idx_token * stride_inter_step
890
+ + idx_feats * stride_inter_dim
891
+ )
892
+ if KERNEL_WIDTH >= 2:
893
+ tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w)
894
+ if KERNEL_WIDTH >= 3:
895
+ tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w)
896
+ if KERNEL_WIDTH >= 4:
897
+ tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w)
898
+
899
+
900
+ def causal_conv1d_update(
901
+ x: torch.Tensor,
902
+ conv_state: torch.Tensor,
903
+ weight: torch.Tensor,
904
+ bias: Optional[torch.Tensor] = None,
905
+ activation: Union[bool, str, None] = None,
906
+ cache_seqlens: Optional[torch.Tensor] = None,
907
+ conv_state_indices: Optional[torch.Tensor] = None,
908
+ num_accepted_tokens: Optional[torch.Tensor] = None,
909
+ intermediate_conv_window: Optional[torch.Tensor] = None,
910
+ pad_slot_id: int = PAD_SLOT_ID,
911
+ metadata=None,
912
+ validate_data=False,
913
+ ):
914
+ """
915
+ x: (batch, dim) or (batch, dim, seqlen)
916
+ [shape=2: single token prediction]
917
+ [shape=3: single or multiple tokens prediction]
918
+ conv_state: (..., dim, state_len), where state_len >= width - 1
919
+ weight: (dim, width)
920
+ bias: (dim,)
921
+ cache_seqlens: (batch,), dtype int32.
922
+ If not None, the conv_state is treated as a circular buffer.
923
+ The conv_state will be updated by copying x to the conv_state
924
+ starting at the index
925
+ @cache_seqlens % state_len.
926
+ conv_state_indices: (batch,), dtype int32
927
+ If not None, the conv_state is a larger tensor along the batch dim,
928
+ and we are selecting the batch coords specified by conv_state_indices.
929
+ Useful for a continuous batching scenario.
930
+ pad_slot_id: int
931
+ if cache_indices is passed, lets the kernel identify padded
932
+ entries that will not be processed,
933
+ for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
934
+ in this case, the kernel will not process entries at
935
+ indices 0 and 3
936
+ out: (batch, dim) or (batch, dim, seqlen)
937
+ """
938
+ if validate_data:
939
+ assert cache_seqlens is None # not implemented yet - ok for vLLM
940
+ assert pad_slot_id is not None
941
+ assert x.stride(1) == 1
942
+ if isinstance(activation, bool):
943
+ activation = "silu" if activation is True else None
944
+ elif activation is not None:
945
+ assert activation in ["silu", "swish"]
946
+ unsqueeze = x.dim() == 2
947
+ if unsqueeze:
948
+ # make it (batch, dim, seqlen) with seqlen == 1
949
+ x = x.unsqueeze(-1)
950
+ batch, dim, seqlen = x.shape
951
+ _, width = weight.shape
952
+ # conv_state: (..., dim, state_len), where state_len >= width - 1
953
+ num_cache_lines, _, state_len = conv_state.size()
954
+
955
+ if validate_data:
956
+ assert dim == weight.size(0)
957
+ assert (
958
+ conv_state.stride(-2) == 1
959
+ ), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})"
960
+ assert state_len >= width - 1
961
+ # when above happens, we don't shift-left to keep any records in conv_state
962
+ assert dim == conv_state.size(1)
963
+ if conv_state_indices is None:
964
+ assert conv_state.size(0) >= batch
965
+ else:
966
+ assert (batch,) == conv_state_indices.shape
967
+
968
+ assert num_cache_lines >= batch
969
+ assert weight.stride(1) == 1 # Need this
970
+ assert cache_seqlens is None # not needed for vLLM - circular buffer
971
+
972
+ # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
973
+ out = x
974
+ stride_w_dim, stride_w_width = weight.stride()
975
+
976
+ stride_x_seq, stride_x_dim, stride_x_token = x.stride() # X (batch, dim, seqlen)
977
+
978
+ stride_o_seq, stride_o_dim, stride_o_token = out.stride()
979
+ stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride()
980
+ stride_state_indices = (
981
+ conv_state_indices.stride(0) if conv_state_indices is not None else 0
982
+ )
983
+ state_len = width - 1 + (seqlen - 1) # effective state_len needed
984
+ np2_statelen = triton.next_power_of_2(state_len)
985
+
986
+ def grid(META):
987
+ return (
988
+ batch,
989
+ triton.cdiv(dim, META["BLOCK_N"]),
990
+ )
991
+
992
+ # prepare intermediate buffer strides if provided
993
+ if intermediate_conv_window is not None:
994
+ stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = (
995
+ intermediate_conv_window.stride(0),
996
+ intermediate_conv_window.stride(1),
997
+ intermediate_conv_window.stride(2),
998
+ intermediate_conv_window.stride(3),
999
+ )
1000
+ else:
1001
+ stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0
1002
+
1003
+ _causal_conv1d_update_kernel[grid](
1004
+ # Pointers to matrices
1005
+ x,
1006
+ weight,
1007
+ bias,
1008
+ conv_state,
1009
+ cache_seqlens,
1010
+ conv_state_indices,
1011
+ num_accepted_tokens,
1012
+ intermediate_conv_window if intermediate_conv_window is not None else x,
1013
+ out,
1014
+ # Matrix dimensions
1015
+ batch,
1016
+ dim,
1017
+ seqlen,
1018
+ state_len,
1019
+ num_cache_lines,
1020
+ # stride
1021
+ stride_x_seq,
1022
+ stride_x_dim,
1023
+ stride_x_token,
1024
+ stride_w_dim,
1025
+ stride_w_width,
1026
+ stride_istate_seq,
1027
+ stride_istate_dim,
1028
+ stride_istate_token,
1029
+ stride_state_indices,
1030
+ stride_inter_seq,
1031
+ stride_inter_step,
1032
+ stride_inter_dim,
1033
+ stride_inter_win,
1034
+ stride_o_seq,
1035
+ stride_o_dim,
1036
+ stride_o_token,
1037
+ # others
1038
+ pad_slot_id,
1039
+ # META
1040
+ HAS_BIAS=bias is not None,
1041
+ KERNEL_WIDTH=width,
1042
+ SILU_ACTIVATION=activation in ["silu", "swish"],
1043
+ IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
1044
+ IS_SPEC_DECODING=num_accepted_tokens is not None,
1045
+ NP2_STATELEN=np2_statelen,
1046
+ USE_PAD_SLOT=pad_slot_id is not None,
1047
+ BLOCK_N=256,
1048
+ SAVE_INTERMEDIATE=intermediate_conv_window is not None,
1049
+ )
1050
+ if unsqueeze:
1051
+ out = out.squeeze(-1)
1052
+ return out