sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,969 @@
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
3
+ # and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
4
+
5
+ from typing import List, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ PAD_SLOT_ID = -1
11
+ import triton
12
+ import triton.language as tl
13
+
14
+
15
+ @triton.jit()
16
+ def _causal_conv1d_fwd_kernel( # continuous batching
17
+ # Pointers to matrices
18
+ x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences
19
+ w_ptr, # (dim, width)
20
+ bias_ptr,
21
+ initial_states_ptr, # conv_states_ptr
22
+ cache_indices_ptr, # conv_state_indices_ptr
23
+ has_initial_states_ptr,
24
+ query_start_loc_ptr,
25
+ o_ptr, # (dim, seqlen) - actually pointing to x_ptr
26
+ # Matrix dimensions
27
+ dim: tl.constexpr,
28
+ seqlen: tl.int32, # cu_seqlen
29
+ num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
30
+ # Strides
31
+ stride_x_seq: tl.constexpr, # stride to get to next sequence,
32
+ stride_x_dim: tl.constexpr, # stride to get to next feature-value,
33
+ stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index)
34
+ stride_w_dim: tl.constexpr, # stride to get to next dim-axis value
35
+ stride_w_width: tl.constexpr, # stride to get to next width-axis value
36
+ stride_istate_seq: tl.constexpr,
37
+ stride_istate_dim: tl.constexpr,
38
+ stride_istate_token: tl.constexpr,
39
+ stride_o_seq: tl.constexpr,
40
+ stride_o_dim: tl.constexpr,
41
+ stride_o_token: tl.constexpr,
42
+ # others
43
+ pad_slot_id: tl.constexpr,
44
+ # Meta-parameters
45
+ HAS_BIAS: tl.constexpr,
46
+ KERNEL_WIDTH: tl.constexpr,
47
+ SILU_ACTIVATION: tl.constexpr,
48
+ HAS_INITIAL_STATES: tl.constexpr,
49
+ HAS_CACHE: tl.constexpr,
50
+ IS_CONTINUOUS_BATCHING: tl.constexpr,
51
+ USE_PAD_SLOT: tl.constexpr,
52
+ NP2_STATELEN: tl.constexpr,
53
+ BLOCK_M: tl.constexpr,
54
+ BLOCK_N: tl.constexpr,
55
+ ):
56
+ conv_states_ptr = initial_states_ptr
57
+ conv_state_indices_ptr = cache_indices_ptr
58
+ stride_conv_state_seq = stride_istate_seq
59
+ stride_conv_state_dim = stride_istate_dim
60
+ stride_conv_state_tok = stride_istate_token
61
+ state_len = (
62
+ KERNEL_WIDTH - 1
63
+ ) # can be passed via argument if it's not the same as this value
64
+
65
+ # one program handles one chunk in a single sequence
66
+ # rather than mixing sequences - to make updating initial_states across sequences efficiently
67
+
68
+ # single-sequence id
69
+ idx_seq = tl.program_id(0)
70
+ chunk_offset = tl.program_id(1)
71
+
72
+ # BLOCK_N elements along the feature-dimension (channel)
73
+ idx_feats = tl.program_id(2) * BLOCK_N + tl.arange(0, BLOCK_N)
74
+
75
+ if idx_seq == pad_slot_id:
76
+ return
77
+
78
+ sequence_start_index = tl.load(query_start_loc_ptr + idx_seq)
79
+ sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1)
80
+ # find the actual sequence length
81
+ seqlen = sequence_end_index - sequence_start_index
82
+
83
+ token_offset = BLOCK_M * chunk_offset
84
+ segment_len = min(BLOCK_M, seqlen - token_offset)
85
+
86
+ if segment_len <= 0:
87
+ return
88
+
89
+ # base of the sequence
90
+ x_base = (
91
+ x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim
92
+ ) # [BLOCK_N,]
93
+
94
+ if IS_CONTINUOUS_BATCHING:
95
+ # cache_idx
96
+ conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(tl.int64)
97
+ else:
98
+ # cache_idx
99
+ conv_state_batch_coord = idx_seq
100
+ if USE_PAD_SLOT: # noqa
101
+ if conv_state_batch_coord == pad_slot_id:
102
+ # not processing as this is not the actual sequence
103
+ return
104
+ conv_states_base = (
105
+ conv_states_ptr
106
+ + (conv_state_batch_coord * stride_conv_state_seq)
107
+ + (idx_feats * stride_conv_state_dim)
108
+ ) # [BLOCK_N,]
109
+
110
+ w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
111
+
112
+ # Does 2 things:
113
+ # 1. READ prior-block init-state data - [done by every Triton programs]
114
+ # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0]
115
+ if chunk_offset == 0:
116
+ # read from conv_states
117
+ load_init_state = False
118
+ if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES
119
+ load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1)
120
+ if load_init_state:
121
+ # load from conv_states
122
+ prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok
123
+ mask_w = idx_feats < dim
124
+ if KERNEL_WIDTH == 2:
125
+ conv_states_ptrs = prior_tokens # [BLOCK_N]
126
+ col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
127
+ if KERNEL_WIDTH == 3:
128
+ conv_states_ptrs = prior_tokens # [BLOCK_N]
129
+ col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
130
+ conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N]
131
+ col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
132
+ if KERNEL_WIDTH == 4:
133
+ conv_states_ptrs = prior_tokens # [BLOCK_N]
134
+ col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
135
+ conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N]
136
+ col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
137
+ conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N]
138
+ col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
139
+ if KERNEL_WIDTH == 5:
140
+ conv_states_ptrs = prior_tokens # [BLOCK_N]
141
+ col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
142
+ conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N]
143
+ col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
144
+ conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N]
145
+ col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
146
+ conv_states_ptrs = prior_tokens - 3 * stride_conv_state_tok # [BLOCK_N]
147
+ col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
148
+ else:
149
+ # prior-tokens are zeros
150
+ if KERNEL_WIDTH >= 2: # STRATEGY1
151
+ # first chunk and does not have prior-token, so just set to 0
152
+ col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
153
+ if KERNEL_WIDTH >= 3: # STRATEGY1
154
+ col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
155
+ if KERNEL_WIDTH >= 4: # STRATEGY1
156
+ col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
157
+ if KERNEL_WIDTH >= 5: # STRATEGY1
158
+ col3 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
159
+
160
+ # STEP 2:
161
+ # here prepare data for updating conv_state
162
+ if (
163
+ state_len <= seqlen
164
+ ): # SMALL_CACHE=True (only move part of 'x' into conv_state cache)
165
+ # just read from 'x'
166
+ # copy 'x' data to conv_state
167
+ # load only 'x' data (and set 0 before 'x' if seqlen < state_len)
168
+ idx_tokens_last = (seqlen - state_len) + tl.arange(
169
+ 0, NP2_STATELEN
170
+ ) # [BLOCK_M]
171
+ x_ptrs = (
172
+ x_ptr
173
+ + ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None]
174
+ + (idx_feats * stride_x_dim)[None, :]
175
+ ) # [BLOCK_M,BLOCK_N,]
176
+ mask_x = (
177
+ (idx_tokens_last >= 0)[:, None]
178
+ & (idx_tokens_last < seqlen)[:, None]
179
+ & (idx_feats < dim)[None, :]
180
+ ) # token-index # token-index # feature-index
181
+ loaded_x = tl.load(x_ptrs, mask_x, 0.0)
182
+ new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
183
+ idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
184
+ conv_states_ptrs_target = (
185
+ conv_states_base[None, :]
186
+ + (idx_tokens_conv * stride_conv_state_tok)[:, None]
187
+ )
188
+
189
+ mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :]
190
+ tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
191
+ tl.store(conv_states_ptrs_target, new_conv_state, mask)
192
+
193
+ else:
194
+ if load_init_state:
195
+ # update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x'
196
+ idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
197
+
198
+ conv_states_ptrs_source = (
199
+ conv_states_ptr
200
+ + (conv_state_batch_coord * stride_conv_state_seq)
201
+ + (idx_feats * stride_conv_state_dim)[None, :]
202
+ + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None]
203
+ ) # [BLOCK_M, BLOCK_N]
204
+ mask = (
205
+ (conv_state_batch_coord < num_cache_lines)
206
+ & ((idx_tokens_conv + seqlen) < state_len)[:, None]
207
+ & (idx_feats < dim)[None, :]
208
+ )
209
+ conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0)
210
+
211
+ VAL = state_len - seqlen
212
+
213
+ x_ptrs = (
214
+ x_base[None, :]
215
+ + ((idx_tokens_conv - VAL) * stride_x_token)[:, None]
216
+ ) # [BLOCK_M, BLOCK_N]
217
+
218
+ mask_x = (
219
+ (idx_tokens_conv - VAL >= 0)[:, None]
220
+ & (idx_tokens_conv - VAL < seqlen)[:, None]
221
+ & (idx_feats < dim)[None, :]
222
+ ) # token-index # token-index # feature-index
223
+ loaded_x = tl.load(x_ptrs, mask_x, 0.0)
224
+
225
+ tl.debug_barrier() # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load
226
+ new_conv_state = tl.where(
227
+ mask, conv_state, loaded_x
228
+ ) # BUG in 'tl.where' which requires a barrier before this
229
+ conv_states_ptrs_target = (
230
+ conv_states_base
231
+ + (idx_tokens_conv * stride_conv_state_tok)[:, None]
232
+ ) # [BLOCK_M, BLOCK_N]
233
+ mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[
234
+ None, :
235
+ ]
236
+ tl.store(conv_states_ptrs_target, new_conv_state, mask)
237
+ else: # load_init_state == False
238
+ # update conv_state by shifting left, BUT
239
+ # set cols prior to 'x' as zeros + cols from 'x'
240
+ idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
241
+
242
+ VAL = state_len - seqlen
243
+
244
+ x_ptrs = (
245
+ x_base[None, :]
246
+ + ((idx_tokens_conv - VAL) * stride_x_token)[:, None]
247
+ ) # [BLOCK_M, BLOCK_N]
248
+
249
+ mask_x = (
250
+ (idx_tokens_conv - VAL >= 0)[:, None]
251
+ & (idx_tokens_conv - VAL < seqlen)[:, None]
252
+ & (idx_feats < dim)[None, :]
253
+ ) # token-index # token-index # feature-index
254
+ new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
255
+
256
+ conv_states_ptrs_target = (
257
+ conv_states_base
258
+ + (idx_tokens_conv * stride_conv_state_tok)[:, None]
259
+ ) # [BLOCK_M, BLOCK_N]
260
+ mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[
261
+ None, :
262
+ ]
263
+ tl.store(conv_states_ptrs_target, new_conv_state, mask)
264
+
265
+ else: # chunk_offset > 0
266
+ # read prior-token data from `x`
267
+ load_init_state = True
268
+ prior_tokens = x_base + (token_offset - 1) * stride_x_token
269
+ mask_w = idx_feats < dim
270
+ if KERNEL_WIDTH == 2:
271
+ conv_states_ptrs = prior_tokens # [BLOCK_N]
272
+ col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
273
+ if KERNEL_WIDTH == 3:
274
+ conv_states_ptrs = prior_tokens # [BLOCK_N]
275
+ col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
276
+ conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
277
+ col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
278
+ if KERNEL_WIDTH == 4:
279
+ conv_states_ptrs = prior_tokens # [BLOCK_N]
280
+ col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
281
+ conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
282
+ col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
283
+ conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N]
284
+ col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
285
+ if KERNEL_WIDTH == 5:
286
+ # ruff: noqa: F841
287
+ conv_states_ptrs = prior_tokens # [BLOCK_N]
288
+ col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
289
+ conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
290
+ col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
291
+ conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N]
292
+ col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
293
+ conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N]
294
+ col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
295
+
296
+ if HAS_BIAS:
297
+ bias = bias_ptr + idx_feats
298
+ mask_bias = idx_feats < dim
299
+ acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(
300
+ tl.float32
301
+ ) # [BLOCK_N]
302
+ else:
303
+ acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32)
304
+
305
+ x_base_1d = x_base + token_offset * stride_x_token # starting of chunk
306
+
307
+ # PRE-LOAD WEIGHTS
308
+ mask_w = idx_feats < dim
309
+ if KERNEL_WIDTH >= 2:
310
+ w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
311
+ w_col0 = tl.load(w_ptrs, mask_w, other=0.0)
312
+ w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
313
+ w_col1 = tl.load(w_ptrs, mask_w, other=0.0)
314
+ if KERNEL_WIDTH >= 3:
315
+ w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
316
+ w_col2 = tl.load(w_ptrs, mask_w, other=0.0)
317
+ if KERNEL_WIDTH >= 4:
318
+ w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
319
+ w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
320
+ mask_x_1d = idx_feats < dim
321
+ for idx_token in range(segment_len):
322
+ acc = acc_preload
323
+
324
+ matrix_w = w_col0
325
+ matrix_x = col0
326
+ for j in tl.static_range(KERNEL_WIDTH):
327
+
328
+ if KERNEL_WIDTH == 2:
329
+ if j == 1: # KERNEL_WIDTH-1:
330
+ matrix_w = w_col1
331
+ x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
332
+ matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
333
+ elif KERNEL_WIDTH == 3:
334
+ if j == 1:
335
+ matrix_w = w_col1
336
+ matrix_x = col1
337
+ elif j == 2:
338
+ matrix_w = w_col2
339
+ x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
340
+ matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
341
+ elif KERNEL_WIDTH == 4:
342
+ if j == 1:
343
+ matrix_w = w_col1
344
+ matrix_x = col1
345
+ elif j == 2:
346
+ matrix_w = w_col2
347
+ matrix_x = col2
348
+ elif j == 3:
349
+ matrix_w = w_col3
350
+ x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
351
+ matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
352
+
353
+ acc += matrix_x * matrix_w # [BLOCK_N]
354
+
355
+ if KERNEL_WIDTH == 2:
356
+ col0 = matrix_x
357
+ elif KERNEL_WIDTH == 3:
358
+ col0 = col1
359
+ col1 = matrix_x
360
+ elif KERNEL_WIDTH == 4:
361
+ col0 = col1
362
+ col1 = col2
363
+ col2 = matrix_x
364
+
365
+ if SILU_ACTIVATION:
366
+ acc = acc / (1 + tl.exp(-acc))
367
+ mask_1d = (idx_token < segment_len) & (
368
+ idx_feats < dim
369
+ ) # token-index # feature-index
370
+ o_ptrs = (
371
+ o_ptr
372
+ + (sequence_start_index + token_offset + idx_token) * stride_o_token
373
+ + (idx_feats * stride_o_dim)
374
+ )
375
+
376
+ tl.store(o_ptrs, acc, mask=mask_1d)
377
+
378
+
379
+ def causal_conv1d_fn(
380
+ x: torch.Tensor,
381
+ weight: torch.Tensor,
382
+ bias: Union[torch.Tensor, None],
383
+ conv_states: torch.Tensor,
384
+ query_start_loc: torch.Tensor,
385
+ seq_lens_cpu: List[int],
386
+ cache_indices: Optional[torch.Tensor] = None,
387
+ has_initial_state: Optional[torch.Tensor] = None,
388
+ activation: Optional[str] = "silu",
389
+ pad_slot_id: int = PAD_SLOT_ID,
390
+ validate_data=False,
391
+ **kwargs,
392
+ ):
393
+ """support varlen + continuous batching when x is 2D tensor
394
+
395
+ x: (dim,cu_seq_len)
396
+ cu_seq_len = total tokens of all seqs in that batch
397
+ sequences are concatenated from left to right for varlen
398
+ weight: (dim, width)
399
+ conv_states: (...,dim,width - 1) itype
400
+ updated inplace if provided
401
+ [it use `cache_indices` to get the index to the cache of conv_state for that sequence
402
+
403
+ conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True
404
+ and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x'
405
+ ]
406
+ query_start_loc: (batch + 1) int32
407
+ The cumulative sequence lengths of the sequences in
408
+ the batch, used to index into sequence. prepended by 0.
409
+ if
410
+ x = [5, 1, 1, 1] <- continuous batching (batch=4)
411
+ then
412
+ query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is
413
+ the ending index of the last sequence
414
+ [length(query_start_loc)-1 == batch]
415
+ for example: query_start_loc = torch.Tensor([0,10,16,17]),
416
+ x.shape=(dim,17)
417
+ seq_lens_cpu: (batch) int32
418
+ The sequence lengths of the sequences in the batch
419
+ cache_indices: (batch) int32
420
+ indicates the corresponding state index,
421
+ like so: conv_state = conv_states[cache_indices[batch_id]]
422
+ has_initial_state: (batch) bool
423
+ indicates whether should the kernel take the current state as initial
424
+ state for the calculations
425
+ [single boolean for each sequence in the batch: True or False]
426
+ bias: (dim,)
427
+ activation: either None or "silu" or "swish" or True
428
+ pad_slot_id: int
429
+ if cache_indices is passed, lets the kernel identify padded
430
+ entries that will not be processed,
431
+ for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
432
+ in this case, the kernel will not process entries at
433
+ indices 0 and 3
434
+
435
+ out: same shape as `x`
436
+ """
437
+ if isinstance(activation, bool) and activation:
438
+ activation = "silu"
439
+
440
+ out = torch.empty_like(x)
441
+
442
+ is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)
443
+ dim, cu_seqlen = x.shape
444
+ _, width = weight.shape
445
+ state_len = width - 1
446
+ np2_statelen = triton.next_power_of_2(state_len)
447
+
448
+ stride_x_seq = 0
449
+ stride_x_dim = x.stride(0)
450
+ stride_x_token = x.stride(1)
451
+ stride_w_dim = weight.stride(0)
452
+ stride_w_width = weight.stride(1)
453
+ stride_istate_seq = 0
454
+ stride_istate_dim = 0
455
+ stride_istate_token = 0
456
+ num_cache_lines = 0
457
+ if conv_states is not None:
458
+ # extensions to support vLLM:
459
+ # 1. conv_states is used to replaced initial_states
460
+ # 2. conv_states serve as a cache with num cache lines can be larger than batch size
461
+ # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
462
+ # 4. computation can be skipped if cache_indices[idx] == pad_slot_id
463
+ num_cache_lines = conv_states.size(0)
464
+ assert (
465
+ num_cache_lines == conv_states.shape[0]
466
+ and dim == conv_states.shape[1]
467
+ and width - 1 <= conv_states.shape[2]
468
+ )
469
+ stride_istate_seq = conv_states.stride(0)
470
+ stride_istate_dim = conv_states.stride(1)
471
+ stride_istate_token = conv_states.stride(2)
472
+ # assert stride_istate_dim == 1
473
+ if out.dim() == 2:
474
+ stride_o_seq = 0
475
+ stride_o_dim = out.stride(0)
476
+ stride_o_token = out.stride(1)
477
+ else:
478
+ stride_o_seq = out.stride(0)
479
+ stride_o_dim = out.stride(1)
480
+ stride_o_token = out.stride(2)
481
+
482
+ if validate_data:
483
+ assert x.dim() == 2
484
+ assert query_start_loc is not None
485
+ assert query_start_loc.dim() == 1
486
+ assert x.stride(0) == 1 or x.stride(1) == 1
487
+ padded_batch = query_start_loc.size(0) - 1
488
+ if bias is not None:
489
+ assert bias.dim() == 1
490
+ assert dim == bias.size(0)
491
+ if cache_indices is not None:
492
+ assert cache_indices.dim() == 1
493
+ assert padded_batch == cache_indices.size(0)
494
+ if has_initial_state is not None:
495
+ assert has_initial_state.size() == (padded_batch,)
496
+ assert (
497
+ conv_states is not None
498
+ ), "ERROR: `has_initial_state` is used, which needs also `conv_states`"
499
+ assert weight.stride(1) == 1
500
+ assert (dim, width) == weight.shape
501
+ assert is_channel_last, "Need to run in channel-last layout"
502
+
503
+ def grid(META):
504
+ max_seq_len = max(seq_lens_cpu)
505
+ return (
506
+ len(seq_lens_cpu), # batch_size
507
+ (max_seq_len + META["BLOCK_M"] - 1) // META["BLOCK_M"],
508
+ triton.cdiv(dim, META["BLOCK_N"]),
509
+ )
510
+
511
+ _causal_conv1d_fwd_kernel[grid](
512
+ # Pointers to matrices
513
+ x,
514
+ weight,
515
+ bias,
516
+ conv_states,
517
+ cache_indices,
518
+ has_initial_state,
519
+ query_start_loc,
520
+ out,
521
+ # Matrix dimensions
522
+ dim,
523
+ cu_seqlen,
524
+ num_cache_lines,
525
+ # stride
526
+ stride_x_seq,
527
+ stride_x_dim,
528
+ stride_x_token,
529
+ stride_w_dim,
530
+ stride_w_width,
531
+ stride_istate_seq,
532
+ stride_istate_dim,
533
+ stride_istate_token,
534
+ stride_o_seq,
535
+ stride_o_dim,
536
+ stride_o_token,
537
+ # others
538
+ pad_slot_id,
539
+ # META
540
+ HAS_BIAS=bias is not None,
541
+ KERNEL_WIDTH=width,
542
+ SILU_ACTIVATION=activation in ["silu", "swish"],
543
+ HAS_INITIAL_STATES=has_initial_state is not None,
544
+ HAS_CACHE=conv_states is not None,
545
+ IS_CONTINUOUS_BATCHING=cache_indices is not None,
546
+ USE_PAD_SLOT=pad_slot_id is not None,
547
+ NP2_STATELEN=np2_statelen,
548
+ # launch_cooperative_grid=True
549
+ BLOCK_M=8,
550
+ BLOCK_N=256,
551
+ num_stages=2,
552
+ )
553
+ return out
554
+
555
+
556
+ @triton.jit()
557
+ def _causal_conv1d_update_kernel(
558
+ # Pointers to matrices
559
+ x_ptr, # (batch, dim, seqlen)
560
+ w_ptr, # (dim, width)
561
+ bias_ptr,
562
+ conv_state_ptr,
563
+ cache_seqlens_ptr, # circular buffer
564
+ conv_state_indices_ptr,
565
+ num_accepted_tokens_ptr,
566
+ intermediate_conv_window_ptr,
567
+ o_ptr, # (batch, dim, seqlen)
568
+ # Matrix dimensions
569
+ batch: int,
570
+ dim: tl.constexpr,
571
+ seqlen: tl.constexpr,
572
+ state_len: tl.constexpr,
573
+ num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
574
+ # Strides
575
+ stride_x_seq: tl.constexpr,
576
+ stride_x_dim: tl.constexpr,
577
+ stride_x_token: tl.constexpr,
578
+ stride_w_dim: tl.constexpr,
579
+ stride_w_width: tl.constexpr,
580
+ stride_conv_state_seq: tl.constexpr,
581
+ stride_conv_state_dim: tl.constexpr,
582
+ stride_conv_state_tok: tl.constexpr,
583
+ stride_state_indices: tl.constexpr,
584
+ stride_inter_seq: tl.constexpr,
585
+ stride_inter_step: tl.constexpr,
586
+ stride_inter_dim: tl.constexpr,
587
+ stride_inter_win: tl.constexpr,
588
+ stride_o_seq: tl.constexpr,
589
+ stride_o_dim: tl.constexpr,
590
+ stride_o_token: tl.constexpr,
591
+ # others
592
+ pad_slot_id: tl.constexpr,
593
+ # Meta-parameters
594
+ HAS_BIAS: tl.constexpr,
595
+ KERNEL_WIDTH: tl.constexpr,
596
+ SILU_ACTIVATION: tl.constexpr,
597
+ IS_CONTINUOUS_BATCHING: tl.constexpr,
598
+ IS_SPEC_DECODING: tl.constexpr,
599
+ NP2_STATELEN: tl.constexpr,
600
+ USE_PAD_SLOT: tl.constexpr,
601
+ BLOCK_N: tl.constexpr,
602
+ SAVE_INTERMEDIATE: tl.constexpr,
603
+ ):
604
+ # ruff: noqa: E501
605
+ idx_seq = tl.program_id(0)
606
+ if idx_seq >= batch:
607
+ return
608
+
609
+ # [BLOCK_N,] elements along the feature-dimension (channel)
610
+ idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
611
+
612
+ if IS_CONTINUOUS_BATCHING:
613
+ # mask = idx_seq < batch
614
+ conv_state_batch_coord = tl.load(
615
+ conv_state_indices_ptr + idx_seq * stride_state_indices
616
+ ).to(tl.int64)
617
+ else:
618
+ conv_state_batch_coord = idx_seq
619
+ if USE_PAD_SLOT: # noqa
620
+ if conv_state_batch_coord == pad_slot_id:
621
+ # not processing as this is not the actual sequence
622
+ return
623
+
624
+ if IS_SPEC_DECODING:
625
+ # The rolling of conv state:
626
+ #
627
+ # Before forward, the conv_state is:
628
+ # [history1, history2, ..., historyM].
629
+ #
630
+ # After forward, the conv_state becomes:
631
+ # [history2, ..., historyM, draft1, draft2, ..., draftN].
632
+ #
633
+ # After acceptance, it becomes:
634
+ #
635
+ # - accept 1 tokens: [history2, ..., historyM, draft1]
636
+ # - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
637
+ # - and so on.
638
+ conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq) - 1
639
+ else:
640
+ conv_state_token_offset = 0
641
+
642
+ # STEP 1: READ init_state data
643
+ conv_states_base = (
644
+ conv_state_ptr
645
+ + (conv_state_batch_coord * stride_conv_state_seq)
646
+ + (idx_feats * stride_conv_state_dim)
647
+ )
648
+ mask_w = idx_feats < dim
649
+
650
+ prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
651
+ if KERNEL_WIDTH >= 2:
652
+ conv_states_ptrs = prior_tokens # [BLOCK_N]
653
+ col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
654
+ if KERNEL_WIDTH >= 3:
655
+ conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N]
656
+ col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
657
+ if KERNEL_WIDTH >= 4:
658
+ conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
659
+ col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
660
+ if KERNEL_WIDTH == 5:
661
+ conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
662
+ col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
663
+
664
+ # STEP 2: assume state_len > seqlen
665
+ idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
666
+
667
+ # The conv_state updates works in a sliding window manner,
668
+ # at each forward pass, the tokens are shift by 1, so we
669
+ # load since idx_tokens + 1.
670
+ conv_state_ptrs_source = (
671
+ conv_state_ptr
672
+ + (conv_state_batch_coord * stride_conv_state_seq)
673
+ + conv_state_token_offset * stride_conv_state_tok
674
+ + (idx_feats * stride_conv_state_dim)[None, :]
675
+ + ((idx_tokens + 1) * stride_conv_state_tok)[:, None]
676
+ ) # [BLOCK_M, BLOCK_N]
677
+ mask = (
678
+ (conv_state_batch_coord < num_cache_lines)
679
+ & ((idx_tokens + seqlen) < state_len)[:, None]
680
+ & (idx_feats < dim)[None, :]
681
+ )
682
+ conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
683
+
684
+ VAL = state_len - seqlen
685
+ x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim) # [BLOCK_N]
686
+
687
+ x_ptrs = (
688
+ x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None]
689
+ ) # [BLOCK_M, BLOCK_N]
690
+
691
+ mask_x = (
692
+ (idx_tokens - VAL >= 0)[:, None]
693
+ & (idx_tokens - VAL < seqlen)[:, None]
694
+ & (idx_feats < dim)[None, :]
695
+ ) # token-index # token-index # feature-index
696
+ loaded_x = tl.load(x_ptrs, mask_x, 0.0)
697
+ tl.debug_barrier()
698
+
699
+ new_conv_state = tl.where(mask, conv_state, loaded_x)
700
+
701
+ conv_state_base = (
702
+ conv_state_ptr
703
+ + (conv_state_batch_coord * stride_conv_state_seq)
704
+ + (idx_feats * stride_conv_state_dim)
705
+ ) # [BLOCK_N,]
706
+ conv_state_ptrs_target = (
707
+ conv_state_base + (idx_tokens * stride_conv_state_tok)[:, None]
708
+ ) # [BLOCK_M, BLOCK_N]
709
+ mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
710
+ tl.store(conv_state_ptrs_target, new_conv_state, mask)
711
+
712
+ # STEP 3: init accumulator
713
+ if HAS_BIAS:
714
+ bias = bias_ptr + idx_feats
715
+ mask_bias = idx_feats < dim
716
+ acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(
717
+ tl.float32
718
+ ) # [BLOCK_N]
719
+ else:
720
+ acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32)
721
+
722
+ # STEP 4:
723
+ # PRE-LOAD WEIGHTS
724
+ # first kernel column, configured for weights to handle BLOCK_N features in range
725
+ w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
726
+ mask_w = idx_feats < dim
727
+ if KERNEL_WIDTH >= 2:
728
+ w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
729
+ w_col0 = tl.load(w_ptrs, mask_w, other=0.0)
730
+ w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
731
+ w_col1 = tl.load(w_ptrs, mask_w, other=0.0)
732
+ if KERNEL_WIDTH >= 3:
733
+ w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
734
+ w_col2 = tl.load(w_ptrs, mask_w, other=0.0)
735
+ if KERNEL_WIDTH >= 4:
736
+ w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
737
+ w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
738
+
739
+ x_base_1d = x_base # starting of chunk [BLOCK_N]
740
+ mask_x_1d = idx_feats < dim
741
+
742
+ # STEP 5: compute each token
743
+ for idx_token in tl.static_range(seqlen):
744
+ acc = acc_preload
745
+
746
+ matrix_w = w_col0
747
+ matrix_x = col0
748
+ for j in tl.static_range(KERNEL_WIDTH):
749
+ if KERNEL_WIDTH == 2:
750
+ if j == 1: # KERNEL_WIDTH-1:
751
+ matrix_w = w_col1
752
+ x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
753
+ matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
754
+ elif KERNEL_WIDTH == 3:
755
+ if j == 1:
756
+ matrix_w = w_col1
757
+ matrix_x = col1
758
+ elif j == 2:
759
+ matrix_w = w_col2
760
+ x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
761
+ matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
762
+ elif KERNEL_WIDTH == 4:
763
+ if j == 1:
764
+ matrix_w = w_col1
765
+ matrix_x = col1
766
+ elif j == 2:
767
+ matrix_w = w_col2
768
+ matrix_x = col2
769
+ elif j == 3:
770
+ matrix_w = w_col3
771
+ x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
772
+ matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
773
+
774
+ acc += matrix_x * matrix_w # [BLOCK_N]
775
+
776
+ if KERNEL_WIDTH == 2:
777
+ col0 = matrix_x
778
+ elif KERNEL_WIDTH == 3:
779
+ col0 = col1
780
+ col1 = matrix_x
781
+ elif KERNEL_WIDTH == 4:
782
+ col0 = col1
783
+ col1 = col2
784
+ col2 = matrix_x
785
+
786
+ if SILU_ACTIVATION:
787
+ acc = acc / (1 + tl.exp(-acc))
788
+ mask_1d = (idx_token < seqlen) & (
789
+ idx_feats < dim
790
+ ) # token-index # feature-index
791
+ o_ptrs = (
792
+ o_ptr
793
+ + (idx_seq) * stride_o_seq
794
+ + idx_token * stride_o_token
795
+ + (idx_feats * stride_o_dim)
796
+ )
797
+
798
+ tl.store(o_ptrs, acc, mask=mask_1d)
799
+
800
+ if SAVE_INTERMEDIATE:
801
+ # Save the window state after consuming this token
802
+ # Layout: [seq(cache line), step, dim, win(K-1)]
803
+ base_ptr = (
804
+ intermediate_conv_window_ptr
805
+ + conv_state_batch_coord * stride_inter_seq
806
+ + idx_token * stride_inter_step
807
+ + idx_feats * stride_inter_dim
808
+ )
809
+ if KERNEL_WIDTH >= 2:
810
+ tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w)
811
+ if KERNEL_WIDTH >= 3:
812
+ tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w)
813
+ if KERNEL_WIDTH >= 4:
814
+ tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w)
815
+
816
+
817
+ def causal_conv1d_update(
818
+ x: torch.Tensor,
819
+ conv_state: torch.Tensor,
820
+ weight: torch.Tensor,
821
+ bias: Optional[torch.Tensor] = None,
822
+ activation: Union[bool, str, None] = None,
823
+ cache_seqlens: Optional[torch.Tensor] = None,
824
+ conv_state_indices: Optional[torch.Tensor] = None,
825
+ num_accepted_tokens: Optional[torch.Tensor] = None,
826
+ intermediate_conv_window: Optional[torch.Tensor] = None,
827
+ pad_slot_id: int = PAD_SLOT_ID,
828
+ metadata=None,
829
+ validate_data=False,
830
+ ):
831
+ """
832
+ x: (batch, dim) or (batch, dim, seqlen)
833
+ [shape=2: single token prediction]
834
+ [shape=3: single or multiple tokens prediction]
835
+ conv_state: (..., dim, state_len), where state_len >= width - 1
836
+ weight: (dim, width)
837
+ bias: (dim,)
838
+ cache_seqlens: (batch,), dtype int32.
839
+ If not None, the conv_state is treated as a circular buffer.
840
+ The conv_state will be updated by copying x to the conv_state
841
+ starting at the index
842
+ @cache_seqlens % state_len.
843
+ conv_state_indices: (batch,), dtype int32
844
+ If not None, the conv_state is a larger tensor along the batch dim,
845
+ and we are selecting the batch coords specified by conv_state_indices.
846
+ Useful for a continuous batching scenario.
847
+ pad_slot_id: int
848
+ if cache_indices is passed, lets the kernel identify padded
849
+ entries that will not be processed,
850
+ for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
851
+ in this case, the kernel will not process entries at
852
+ indices 0 and 3
853
+ out: (batch, dim) or (batch, dim, seqlen)
854
+ """
855
+ if validate_data:
856
+ assert cache_seqlens is None # not implemented yet - ok for vLLM
857
+ assert pad_slot_id is not None
858
+ assert x.stride(1) == 1
859
+ if isinstance(activation, bool):
860
+ activation = "silu" if activation is True else None
861
+ elif activation is not None:
862
+ assert activation in ["silu", "swish"]
863
+ unsqueeze = x.dim() == 2
864
+ if unsqueeze:
865
+ # make it (batch, dim, seqlen) with seqlen == 1
866
+ x = x.unsqueeze(-1)
867
+ batch, dim, seqlen = x.shape
868
+ _, width = weight.shape
869
+ # conv_state: (..., dim, state_len), where state_len >= width - 1
870
+ num_cache_lines, _, state_len = conv_state.size()
871
+
872
+ if validate_data:
873
+ assert dim == weight.size(0)
874
+ assert (
875
+ conv_state.stride(-2) == 1
876
+ ), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})"
877
+ assert state_len >= width - 1
878
+ # when above happens, we don't shift-left to keep any records in conv_state
879
+ assert dim == conv_state.size(1)
880
+ if conv_state_indices is None:
881
+ assert conv_state.size(0) >= batch
882
+ else:
883
+ assert (batch,) == conv_state_indices.shape
884
+
885
+ assert num_cache_lines >= batch
886
+ assert weight.stride(1) == 1 # Need this
887
+ assert cache_seqlens is None # not needed for vLLM - circular buffer
888
+
889
+ # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
890
+ out = x
891
+ stride_w_dim, stride_w_width = weight.stride()
892
+
893
+ stride_x_seq, stride_x_dim, stride_x_token = x.stride() # X (batch, dim, seqlen)
894
+
895
+ stride_o_seq, stride_o_dim, stride_o_token = out.stride()
896
+ stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride()
897
+ stride_state_indices = (
898
+ conv_state_indices.stride(0) if conv_state_indices is not None else 0
899
+ )
900
+ state_len = width - 1 + (seqlen - 1) # effective state_len needed
901
+ np2_statelen = triton.next_power_of_2(state_len)
902
+
903
+ def grid(META):
904
+ return (
905
+ batch,
906
+ triton.cdiv(dim, META["BLOCK_N"]),
907
+ )
908
+
909
+ # prepare intermediate buffer strides if provided
910
+ if intermediate_conv_window is not None:
911
+ stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = (
912
+ intermediate_conv_window.stride(0),
913
+ intermediate_conv_window.stride(1),
914
+ intermediate_conv_window.stride(2),
915
+ intermediate_conv_window.stride(3),
916
+ )
917
+ else:
918
+ stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0
919
+
920
+ _causal_conv1d_update_kernel[grid](
921
+ # Pointers to matrices
922
+ x,
923
+ weight,
924
+ bias,
925
+ conv_state,
926
+ cache_seqlens,
927
+ conv_state_indices,
928
+ num_accepted_tokens,
929
+ intermediate_conv_window if intermediate_conv_window is not None else x,
930
+ out,
931
+ # Matrix dimensions
932
+ batch,
933
+ dim,
934
+ seqlen,
935
+ state_len,
936
+ num_cache_lines,
937
+ # stride
938
+ stride_x_seq,
939
+ stride_x_dim,
940
+ stride_x_token,
941
+ stride_w_dim,
942
+ stride_w_width,
943
+ stride_istate_seq,
944
+ stride_istate_dim,
945
+ stride_istate_token,
946
+ stride_state_indices,
947
+ stride_inter_seq,
948
+ stride_inter_step,
949
+ stride_inter_dim,
950
+ stride_inter_win,
951
+ stride_o_seq,
952
+ stride_o_dim,
953
+ stride_o_token,
954
+ # others
955
+ pad_slot_id,
956
+ # META
957
+ HAS_BIAS=bias is not None,
958
+ KERNEL_WIDTH=width,
959
+ SILU_ACTIVATION=activation in ["silu", "swish"],
960
+ IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
961
+ IS_SPEC_DECODING=num_accepted_tokens is not None,
962
+ NP2_STATELEN=np2_statelen,
963
+ USE_PAD_SLOT=pad_slot_id is not None,
964
+ BLOCK_N=256,
965
+ SAVE_INTERMEDIATE=intermediate_conv_window is not None,
966
+ )
967
+ if unsqueeze:
968
+ out = out.squeeze(-1)
969
+ return out