sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,562 @@
1
+ # Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
2
+
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
5
+
6
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
7
+ # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py
8
+
9
+ # ruff: noqa: E501,SIM102
10
+
11
+ import torch
12
+ import triton
13
+ import triton.language as tl
14
+ from packaging import version
15
+
16
+ TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
17
+
18
+
19
+ @triton.jit
20
+ def _chunk_scan_fwd_kernel(
21
+ # Pointers to matrices
22
+ cb_ptr,
23
+ x_ptr,
24
+ z_ptr,
25
+ out_ptr,
26
+ out_x_ptr,
27
+ dt_ptr,
28
+ dA_cumsum_ptr,
29
+ seq_idx_ptr,
30
+ C_ptr,
31
+ states_ptr,
32
+ D_ptr,
33
+ initstates_ptr,
34
+ chunk_indices_ptr,
35
+ chunk_offsets_ptr,
36
+ chunk_meta_num,
37
+ # Matrix dimensions
38
+ chunk_size,
39
+ hdim,
40
+ dstate,
41
+ batch,
42
+ seqlen,
43
+ nheads_ngroups_ratio,
44
+ # Strides
45
+ stride_cb_batch,
46
+ stride_cb_chunk,
47
+ stride_cb_head,
48
+ stride_cb_csize_m,
49
+ stride_cb_csize_k,
50
+ stride_x_batch,
51
+ stride_x_seqlen,
52
+ stride_x_head,
53
+ stride_x_hdim,
54
+ stride_z_batch,
55
+ stride_z_seqlen,
56
+ stride_z_head,
57
+ stride_z_hdim,
58
+ stride_out_batch,
59
+ stride_out_seqlen,
60
+ stride_out_head,
61
+ stride_out_hdim,
62
+ stride_dt_batch,
63
+ stride_dt_chunk,
64
+ stride_dt_head,
65
+ stride_dt_csize,
66
+ stride_dA_cs_batch,
67
+ stride_dA_cs_chunk,
68
+ stride_dA_cs_head,
69
+ stride_dA_cs_csize,
70
+ stride_seq_idx_batch,
71
+ stride_seq_idx_seqlen,
72
+ stride_C_batch,
73
+ stride_C_seqlen,
74
+ stride_C_head,
75
+ stride_C_dstate,
76
+ stride_states_batch,
77
+ stride_states_chunk,
78
+ stride_states_head,
79
+ stride_states_hdim,
80
+ stride_states_dstate,
81
+ stride_init_states_batch,
82
+ stride_init_states_head,
83
+ stride_init_states_hdim,
84
+ stride_init_states_dstate,
85
+ stride_D_head,
86
+ # Meta-parameters
87
+ IS_CAUSAL: tl.constexpr,
88
+ HAS_D: tl.constexpr,
89
+ D_HAS_HDIM: tl.constexpr,
90
+ HAS_Z: tl.constexpr,
91
+ HAS_SEQ_IDX: tl.constexpr,
92
+ BLOCK_SIZE_DSTATE: tl.constexpr,
93
+ IS_TRITON_22: tl.constexpr,
94
+ HAS_INITSTATES: tl.constexpr,
95
+ BLOCK_SIZE_M: tl.constexpr = 16,
96
+ BLOCK_SIZE_N: tl.constexpr = 16,
97
+ BLOCK_SIZE_K: tl.constexpr = 16,
98
+ ):
99
+ pid_bc = tl.program_id(axis=1).to(tl.int64)
100
+ pid_c = pid_bc // batch
101
+ pid_b = pid_bc - pid_c * batch
102
+ if not HAS_INITSTATES:
103
+ c_idx = pid_c
104
+ c_off = 0
105
+ else:
106
+ c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0)
107
+ c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0)
108
+
109
+ pid_h = tl.program_id(axis=2)
110
+ num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
111
+ pid_m = tl.program_id(axis=0) // num_pid_n
112
+ pid_n = tl.program_id(axis=0) % num_pid_n
113
+ cb_ptr += (
114
+ pid_b * stride_cb_batch
115
+ + c_idx * stride_cb_chunk
116
+ + (pid_h // nheads_ngroups_ratio) * stride_cb_head
117
+ )
118
+ x_ptr += (
119
+ pid_b * stride_x_batch
120
+ + c_idx * chunk_size * stride_x_seqlen
121
+ + pid_h * stride_x_head
122
+ )
123
+ dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head
124
+ dA_cumsum_ptr += (
125
+ pid_b * stride_dA_cs_batch
126
+ + c_idx * stride_dA_cs_chunk
127
+ + pid_h * stride_dA_cs_head
128
+ )
129
+ C_ptr += (
130
+ pid_b * stride_C_batch
131
+ + c_idx * chunk_size * stride_C_seqlen
132
+ + (pid_h // nheads_ngroups_ratio) * stride_C_head
133
+ )
134
+
135
+ # M-block offsets and prev states
136
+ # - logic in next block may override these if there is an active offset
137
+ offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
138
+ prev_states_ptr = (
139
+ states_ptr
140
+ + pid_b * stride_states_batch
141
+ + c_idx * stride_states_chunk
142
+ + pid_h * stride_states_head
143
+ )
144
+ prev_states_hdim = stride_states_hdim
145
+ prev_states_dstate = stride_states_dstate
146
+
147
+ chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size)
148
+ if HAS_SEQ_IDX:
149
+ seq_idx_ptr += (
150
+ pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen
151
+ )
152
+
153
+ # - we only need seq_idx_prev to be aligned to chunk boundary
154
+ seq_idx_prev = tl.load(
155
+ seq_idx_ptr - stride_seq_idx_seqlen, mask=c_idx >= 1, other=0
156
+ )
157
+
158
+ if HAS_INITSTATES:
159
+ # if there are init states, we only need seq_idx_m to point
160
+ # what is the current seq_idx
161
+
162
+ # get current seq idx
163
+ if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit:
164
+ seq_idx_m = tl.load(
165
+ seq_idx_ptr
166
+ + (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen,
167
+ )
168
+
169
+ # - recall that in ssd_state_passing, for the case c_off == 0
170
+ # i.e., the very first sequence, we made states_ptr hold its initial state
171
+ # so this edge case is taken care of
172
+ if (
173
+ (c_off == 0)
174
+ and (
175
+ seq_idx_prev != seq_idx_m
176
+ ) # if a seq is changed exactly on boundary
177
+ or (c_off > 0) # implies a new example (pseudo chunk)
178
+ ):
179
+
180
+ # - replace prev_states_ptr with init_states
181
+ prev_states_ptr = (
182
+ initstates_ptr
183
+ + seq_idx_m * stride_init_states_batch
184
+ + pid_h * stride_init_states_head
185
+ )
186
+ prev_states_hdim = stride_init_states_hdim # override strides
187
+ prev_states_dstate = stride_init_states_dstate
188
+
189
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
190
+ dA_cs_m = tl.load(
191
+ dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
192
+ ).to(tl.float32)
193
+
194
+ # - handle chunk state limit
195
+ if HAS_INITSTATES:
196
+
197
+ # have to split this if otherwise compilation will have problems
198
+ dA_cs_m_boundary = 0.0
199
+
200
+ # get the c_idx for the next (logica) chunk
201
+ c_idx_n = tl.load(
202
+ chunk_indices_ptr + (pid_c + 1),
203
+ mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
204
+ other=-1, # to trigger different chunk
205
+ )
206
+
207
+ # - there are things to consider
208
+ # A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct
209
+ # contribution of past states
210
+ # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to
211
+ # encroach into the next sequence, where c_off_n is the offset of the next
212
+ # (logical) chunk.
213
+ # An equivalent check for B is c_idx == c_idx_n, where there is repetition in
214
+ # (logical) chunk indices.
215
+
216
+ if (c_idx == c_idx_n) or c_off > 0:
217
+
218
+ # get the next offset
219
+ c_off_n = tl.load(
220
+ chunk_offsets_ptr + (pid_c + 1),
221
+ mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
222
+ other=chunk_size,
223
+ )
224
+
225
+ # in this case, adjust down the chunk_size_limit
226
+ if c_idx == c_idx_n:
227
+ chunk_size_limit = min(c_off_n, chunk_size_limit)
228
+
229
+ # get the cs at the offset boundary
230
+ # - c_off == 0 is a passthrough
231
+ # - We need dA_cs at the boundary, defined by c_off - no need
232
+ # to increase pointer by pid_m (it is a constant offset,
233
+ # i.e. the same for all blocks)
234
+ dA_cs_m_boundary = tl.load(
235
+ dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
236
+ mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
237
+ other=0.0,
238
+ ).to(tl.float32)
239
+
240
+ if HAS_SEQ_IDX:
241
+ # - handle seq idx when HAS_INITSTATES==False
242
+ if not HAS_INITSTATES:
243
+ seq_idx_m = tl.load(
244
+ seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
245
+ mask=offs_m < chunk_size_limit,
246
+ other=-1,
247
+ )
248
+
249
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
250
+
251
+ # Without the if (pid_c > -1), with Triton 2.1.0, I get
252
+ # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed.
253
+ # With Triton 2.2.0, this works
254
+ if IS_TRITON_22 or c_idx > -1:
255
+ # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
256
+ offs_k_dstate = tl.arange(
257
+ 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
258
+ )
259
+ C_ptrs = C_ptr + (
260
+ offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate
261
+ )
262
+
263
+ prev_states_ptrs = prev_states_ptr + (
264
+ offs_n[None, :] * prev_states_hdim
265
+ + offs_k_dstate[:, None] * prev_states_dstate
266
+ )
267
+ if HAS_SEQ_IDX:
268
+
269
+ if not HAS_INITSTATES:
270
+ # - this is for continuous batching where there is no init states
271
+ scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
272
+ else:
273
+ # - if there is initstates, we will rely on prev_states, no zeroing
274
+ # required.
275
+ scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary)
276
+ else:
277
+ scale_m = tl.exp(dA_cs_m)
278
+ if BLOCK_SIZE_DSTATE <= 128:
279
+ C = tl.load(
280
+ C_ptrs,
281
+ mask=(offs_m[:, None] < chunk_size_limit)
282
+ & (offs_k_dstate[None, :] < dstate),
283
+ other=0.0,
284
+ )
285
+
286
+ prev_states = tl.load(
287
+ prev_states_ptrs,
288
+ mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
289
+ other=0.0,
290
+ )
291
+ prev_states = prev_states.to(C_ptr.dtype.element_ty)
292
+ acc = tl.dot(C, prev_states) * scale_m[:, None]
293
+ else:
294
+ for k in range(0, dstate, BLOCK_SIZE_K):
295
+ C = tl.load(
296
+ C_ptrs,
297
+ mask=(offs_m[:, None] < chunk_size_limit)
298
+ & (offs_k_dstate[None, :] < dstate - k),
299
+ other=0.0,
300
+ )
301
+ # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty)
302
+ prev_states = tl.load(
303
+ prev_states_ptrs,
304
+ mask=(offs_k_dstate[:, None] < dstate - k)
305
+ & (offs_n[None, :] < hdim),
306
+ other=0.0,
307
+ )
308
+ prev_states = prev_states.to(C_ptr.dtype.element_ty)
309
+ acc += tl.dot(C, prev_states)
310
+ C_ptrs += BLOCK_SIZE_K
311
+ prev_states_ptrs += BLOCK_SIZE_K
312
+ acc *= scale_m[:, None]
313
+
314
+ offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off
315
+ cb_ptrs = cb_ptr + (
316
+ offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
317
+ )
318
+ x_ptrs = x_ptr + (
319
+ offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
320
+ )
321
+ dt_ptrs = dt_ptr + offs_k * stride_dt_csize
322
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
323
+ K_MAX = (
324
+ chunk_size_limit
325
+ if not IS_CAUSAL
326
+ else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
327
+ )
328
+ for k in range(0, K_MAX, BLOCK_SIZE_K):
329
+ cb = tl.load(
330
+ cb_ptrs,
331
+ mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k),
332
+ other=0.0,
333
+ ).to(tl.float32)
334
+ dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(
335
+ tl.float32
336
+ )
337
+ # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
338
+ # So we don't need masking wrt seq_idx here.
339
+ cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
340
+ dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
341
+ cb *= dt_k
342
+ if IS_CAUSAL:
343
+ mask = offs_m[:, None] >= k + offs_k[None, :]
344
+ cb = tl.where(mask, cb, 0.0)
345
+ cb = cb.to(x_ptr.dtype.element_ty)
346
+ x = tl.load(
347
+ x_ptrs,
348
+ mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim),
349
+ other=0.0,
350
+ )
351
+ acc += tl.dot(cb, x)
352
+ cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
353
+ x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
354
+ dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
355
+ dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
356
+
357
+ offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
358
+ offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
359
+
360
+ if HAS_D:
361
+ if D_HAS_HDIM:
362
+ D = tl.load(
363
+ D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
364
+ ).to(tl.float32)
365
+ else:
366
+ D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
367
+ x_residual = tl.load(
368
+ x_ptr
369
+ + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),
370
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
371
+ other=0.0,
372
+ ).to(tl.float32)
373
+ acc += x_residual * D
374
+
375
+ if HAS_Z:
376
+ out_x_ptr += (
377
+ pid_b * stride_out_batch
378
+ + c_idx * chunk_size * stride_out_seqlen
379
+ + pid_h * stride_out_head
380
+ )
381
+ out_x_ptrs = out_x_ptr + (
382
+ stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]
383
+ )
384
+ tl.store(
385
+ out_x_ptrs,
386
+ acc,
387
+ mask=(offs_out_m[:, None] < chunk_size_limit)
388
+ & (offs_out_n[None, :] < hdim),
389
+ )
390
+
391
+ z_ptr += (
392
+ pid_b * stride_z_batch
393
+ + c_idx * chunk_size * stride_z_seqlen
394
+ + pid_h * stride_z_head
395
+ )
396
+ z_ptrs = z_ptr + (
397
+ stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]
398
+ )
399
+ z = tl.load(
400
+ z_ptrs,
401
+ mask=(offs_out_m[:, None] < chunk_size_limit)
402
+ & (offs_out_n[None, :] < hdim),
403
+ other=0.0,
404
+ ).to(tl.float32)
405
+ acc *= z * tl.sigmoid(z)
406
+
407
+ out_ptr += (
408
+ pid_b * stride_out_batch
409
+ + c_idx * chunk_size * stride_out_seqlen
410
+ + pid_h * stride_out_head
411
+ )
412
+ out_ptrs = out_ptr + (
413
+ stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim
414
+ )
415
+ tl.store(
416
+ out_ptrs,
417
+ acc,
418
+ mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim),
419
+ )
420
+
421
+
422
+ def _chunk_scan_fwd(
423
+ cb,
424
+ x,
425
+ dt,
426
+ dA_cumsum,
427
+ C,
428
+ states,
429
+ D=None,
430
+ z=None,
431
+ seq_idx=None,
432
+ chunk_indices=None,
433
+ chunk_offsets=None,
434
+ initial_states=None,
435
+ out=None,
436
+ ):
437
+ batch, seqlen, nheads, headdim = x.shape
438
+ _, _, nchunks, chunk_size = dt.shape
439
+ _, _, ngroups, dstate = C.shape
440
+ assert nheads % ngroups == 0
441
+ assert C.shape == (batch, seqlen, ngroups, dstate)
442
+ assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
443
+ if z is not None:
444
+ assert z.shape == x.shape
445
+ if D is not None:
446
+ assert D.shape == (nheads, headdim) or D.shape == (nheads,)
447
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
448
+ assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
449
+ assert states.shape == (batch, nchunks, nheads, headdim, dstate)
450
+
451
+ if seq_idx is not None:
452
+ assert seq_idx.shape == (batch, seqlen)
453
+
454
+ if initial_states is not None:
455
+ # with initial states, we need to take care of how
456
+ # seq_idx crosses the boundaries
457
+ assert batch == 1, "chunk scan only supports initial states with batch 1"
458
+ assert (
459
+ chunk_indices is not None and chunk_offsets is not None
460
+ ), "chunk_indices and chunk_offsets should have been set"
461
+ else:
462
+ chunk_indices, chunk_offsets = None, None
463
+ else:
464
+ chunk_indices, chunk_offsets = None, None
465
+
466
+ assert out.shape == x.shape
467
+
468
+ if z is not None:
469
+ out_x = torch.empty_like(x)
470
+ assert out_x.stride() == out.stride()
471
+ else:
472
+ out_x = None
473
+
474
+ grid = lambda META: (
475
+ triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
476
+ * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
477
+ batch * nchunks if chunk_offsets is None else len(chunk_offsets),
478
+ nheads,
479
+ )
480
+ z_strides = (
481
+ (z.stride(0), z.stride(1), z.stride(2), z.stride(3))
482
+ if z is not None
483
+ else (0, 0, 0, 0)
484
+ )
485
+ _chunk_scan_fwd_kernel[grid](
486
+ cb,
487
+ x,
488
+ z,
489
+ out,
490
+ out_x,
491
+ dt,
492
+ dA_cumsum,
493
+ seq_idx,
494
+ C,
495
+ states,
496
+ D,
497
+ initial_states,
498
+ chunk_indices,
499
+ chunk_offsets,
500
+ len(chunk_indices) if chunk_indices is not None else 0,
501
+ chunk_size,
502
+ headdim,
503
+ dstate,
504
+ batch,
505
+ seqlen,
506
+ nheads // ngroups,
507
+ cb.stride(0),
508
+ cb.stride(1),
509
+ cb.stride(2),
510
+ cb.stride(3),
511
+ cb.stride(4),
512
+ x.stride(0),
513
+ x.stride(1),
514
+ x.stride(2),
515
+ x.stride(3),
516
+ z_strides[0],
517
+ z_strides[1],
518
+ z_strides[2],
519
+ z_strides[3],
520
+ out.stride(0),
521
+ out.stride(1),
522
+ out.stride(2),
523
+ out.stride(3),
524
+ dt.stride(0),
525
+ dt.stride(2),
526
+ dt.stride(1),
527
+ dt.stride(3),
528
+ dA_cumsum.stride(0),
529
+ dA_cumsum.stride(2),
530
+ dA_cumsum.stride(1),
531
+ dA_cumsum.stride(3),
532
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
533
+ C.stride(0),
534
+ C.stride(1),
535
+ C.stride(2),
536
+ C.stride(3),
537
+ states.stride(0),
538
+ states.stride(1),
539
+ states.stride(2),
540
+ states.stride(3),
541
+ states.stride(4),
542
+ *(
543
+ (
544
+ initial_states.stride(0),
545
+ initial_states.stride(1),
546
+ initial_states.stride(2),
547
+ initial_states.stride(3),
548
+ )
549
+ if initial_states is not None
550
+ else (0, 0, 0, 0)
551
+ ),
552
+ D.stride(0) if D is not None else 0,
553
+ True,
554
+ D is not None,
555
+ D.dim() == 2 if D is not None else True,
556
+ BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
557
+ HAS_Z=z is not None,
558
+ HAS_SEQ_IDX=seq_idx is not None,
559
+ IS_TRITON_22=TRITON_22,
560
+ HAS_INITSTATES=initial_states is not None,
561
+ )
562
+ return out_x