sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,442 @@
1
+ # Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/mamba_ssm.py
2
+
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
5
+
6
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
7
+ # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
8
+
9
+ import torch
10
+ import triton
11
+ import triton.language as tl
12
+ from packaging import version
13
+
14
+ from sglang.srt import _custom_ops as ops
15
+
16
+ PAD_SLOT_ID = -1
17
+
18
+ TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
19
+
20
+ if TRITON3:
21
+
22
+ @triton.jit
23
+ def softplus(dt):
24
+ dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
25
+ return dt
26
+
27
+ else:
28
+
29
+ @triton.jit
30
+ def softplus(dt):
31
+ dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
32
+ return dt
33
+
34
+
35
+ @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
36
+ @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
37
+ @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
38
+ @triton.heuristics(
39
+ {
40
+ "HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
41
+ is not None
42
+ }
43
+ )
44
+ @triton.heuristics(
45
+ {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
46
+ )
47
+ @triton.jit
48
+ def _selective_scan_update_kernel(
49
+ # Pointers to matrices
50
+ state_ptr,
51
+ x_ptr,
52
+ dt_ptr,
53
+ dt_bias_ptr,
54
+ A_ptr,
55
+ B_ptr,
56
+ C_ptr,
57
+ D_ptr,
58
+ z_ptr,
59
+ out_ptr,
60
+ state_batch_indices_ptr,
61
+ pad_slot_id,
62
+ # Matrix dimensions
63
+ batch,
64
+ nheads,
65
+ dim,
66
+ dstate,
67
+ nheads_ngroups_ratio,
68
+ # Strides
69
+ stride_state_batch,
70
+ stride_state_head,
71
+ stride_state_dim,
72
+ stride_state_dstate,
73
+ stride_x_batch,
74
+ stride_x_head,
75
+ stride_x_dim,
76
+ stride_dt_batch,
77
+ stride_dt_head,
78
+ stride_dt_dim,
79
+ stride_dt_bias_head,
80
+ stride_dt_bias_dim,
81
+ stride_A_head,
82
+ stride_A_dim,
83
+ stride_A_dstate,
84
+ stride_B_batch,
85
+ stride_B_group,
86
+ stride_B_dstate,
87
+ stride_C_batch,
88
+ stride_C_group,
89
+ stride_C_dstate,
90
+ stride_D_head,
91
+ stride_D_dim,
92
+ stride_z_batch,
93
+ stride_z_head,
94
+ stride_z_dim,
95
+ stride_out_batch,
96
+ stride_out_head,
97
+ stride_out_dim,
98
+ # Meta-parameters
99
+ DT_SOFTPLUS: tl.constexpr,
100
+ TIE_HDIM: tl.constexpr,
101
+ BLOCK_SIZE_M: tl.constexpr,
102
+ HAS_DT_BIAS: tl.constexpr,
103
+ HAS_D: tl.constexpr,
104
+ HAS_Z: tl.constexpr,
105
+ HAS_STATE_BATCH_INDICES: tl.constexpr,
106
+ BLOCK_SIZE_DSTATE: tl.constexpr,
107
+ ):
108
+ pid_m = tl.program_id(axis=0)
109
+ pid_b = tl.program_id(axis=1)
110
+ pid_h = tl.program_id(axis=2)
111
+
112
+ # If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
113
+ # is taken from the state_batch_indices_ptr Otherwise, the state coordinate
114
+ # is the same as the batch id.
115
+ if HAS_STATE_BATCH_INDICES:
116
+ state_batch_indices_ptr += pid_b
117
+ state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
118
+ state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
119
+ else:
120
+ state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
121
+
122
+ x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
123
+ dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
124
+ if HAS_DT_BIAS:
125
+ dt_bias_ptr += pid_h * stride_dt_bias_head
126
+ A_ptr += pid_h * stride_A_head
127
+ B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
128
+ C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
129
+ if HAS_Z:
130
+ z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
131
+ out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
132
+
133
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
134
+ offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
135
+ state_ptrs = state_ptr + (
136
+ offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
137
+ )
138
+ x_ptrs = x_ptr + offs_m * stride_x_dim
139
+ dt_ptrs = dt_ptr + offs_m * stride_dt_dim
140
+ if HAS_DT_BIAS:
141
+ dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
142
+ if HAS_D:
143
+ D_ptr += pid_h * stride_D_head
144
+ A_ptrs = A_ptr + (
145
+ offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
146
+ )
147
+ B_ptrs = B_ptr + offs_n * stride_B_dstate
148
+ C_ptrs = C_ptr + offs_n * stride_C_dstate
149
+ if HAS_D:
150
+ D_ptrs = D_ptr + offs_m * stride_D_dim
151
+ if HAS_Z:
152
+ z_ptrs = z_ptr + offs_m * stride_z_dim
153
+ out_ptrs = out_ptr + offs_m * stride_out_dim
154
+ mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
155
+ if HAS_STATE_BATCH_INDICES:
156
+ mask &= state_batch_idx != pad_slot_id
157
+ state = tl.load(state_ptrs, mask=mask, other=0.0)
158
+
159
+ x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
160
+ if not TIE_HDIM:
161
+ dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
162
+ if HAS_DT_BIAS:
163
+ dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
164
+ if DT_SOFTPLUS:
165
+ dt = softplus(dt)
166
+ A = tl.load(
167
+ A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
168
+ ).to(tl.float32)
169
+ dA = tl.exp(A * dt[:, None])
170
+ else:
171
+ dt = tl.load(dt_ptr).to(tl.float32)
172
+ if HAS_DT_BIAS:
173
+ dt += tl.load(dt_bias_ptr).to(tl.float32)
174
+ if DT_SOFTPLUS:
175
+ dt = softplus(dt)
176
+ A = tl.load(A_ptr).to(tl.float32)
177
+ dA = tl.exp(A * dt) # scalar, not a matrix
178
+
179
+ B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
180
+ C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
181
+ if HAS_D:
182
+ D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
183
+ if HAS_Z:
184
+ z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
185
+
186
+ dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
187
+ state = state * dA + dB * x[:, None]
188
+
189
+ mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
190
+ if HAS_STATE_BATCH_INDICES:
191
+ mask &= state_batch_idx != pad_slot_id
192
+ tl.store(state_ptrs, state, mask=mask)
193
+ out = tl.sum(state * C[None, :], axis=1)
194
+ if HAS_D:
195
+ out += x * D
196
+ if HAS_Z:
197
+ out *= z * tl.sigmoid(z)
198
+ tl.store(out_ptrs, out, mask=offs_m < dim)
199
+
200
+
201
+ def selective_state_update(
202
+ state,
203
+ x,
204
+ dt,
205
+ A,
206
+ B,
207
+ C,
208
+ D=None,
209
+ z=None,
210
+ dt_bias=None,
211
+ dt_softplus=False,
212
+ state_batch_indices=None,
213
+ pad_slot_id=PAD_SLOT_ID,
214
+ out=None,
215
+ ):
216
+ """
217
+ Argument:
218
+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
219
+ x: (batch, dim) or (batch, nheads, dim)
220
+ dt: (batch, dim) or (batch, nheads, dim)
221
+ A: (dim, dstate) or (nheads, dim, dstate)
222
+ B: (batch, dstate) or (batch, ngroups, dstate)
223
+ C: (batch, dstate) or (batch, ngroups, dstate)
224
+ D: (dim,) or (nheads, dim)
225
+ z: (batch, dim) or (batch, nheads, dim)
226
+ dt_bias: (dim,) or (nheads, dim)
227
+ pad_slot_id: int
228
+ if cache_indices is passed, lets the kernel identify padded
229
+ entries that will not be processed,
230
+ for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
231
+ in this case, the kernel will not process entries at
232
+ indices 0 and 3
233
+ out: Preallocated ssm output tensor. Assume same shape as x.
234
+ In-place updated.
235
+ """
236
+ if state.dim() == 3:
237
+ state = state.unsqueeze(1)
238
+ if x.dim() == 2:
239
+ x = x.unsqueeze(1)
240
+ if dt.dim() == 2:
241
+ dt = dt.unsqueeze(1)
242
+ if A.dim() == 2:
243
+ A = A.unsqueeze(0)
244
+ if B.dim() == 2:
245
+ B = B.unsqueeze(1)
246
+ if C.dim() == 2:
247
+ C = C.unsqueeze(1)
248
+ if D is not None and D.dim() == 1:
249
+ D = D.unsqueeze(0)
250
+ if z is not None and z.dim() == 2:
251
+ z = z.unsqueeze(1)
252
+ if dt_bias is not None and dt_bias.dim() == 1:
253
+ dt_bias = dt_bias.unsqueeze(0)
254
+ if out.dim() == 2:
255
+ out = out.unsqueeze(1)
256
+
257
+ _, nheads, dim, dstate = state.shape
258
+ batch = x.shape[0]
259
+
260
+ assert x.shape == (batch, nheads, dim)
261
+ assert dt.shape == x.shape
262
+ assert A.shape == (nheads, dim, dstate)
263
+ ngroups = B.shape[1]
264
+ assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
265
+ assert B.shape == (batch, ngroups, dstate)
266
+ assert C.shape == B.shape
267
+ if D is not None:
268
+ assert D.shape == (nheads, dim)
269
+ if z is not None:
270
+ assert z.shape == x.shape
271
+ if dt_bias is not None:
272
+ assert dt_bias.shape == (nheads, dim)
273
+ if state_batch_indices is not None:
274
+ assert state_batch_indices.shape == (batch,)
275
+ assert out.shape == x.shape
276
+
277
+ grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
278
+ z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
279
+ # We don't want autotune since it will overwrite the state
280
+ # We instead tune by hand.
281
+ BLOCK_SIZE_M, num_warps = (
282
+ (32, 4)
283
+ if dstate <= 16
284
+ else (
285
+ (16, 4)
286
+ if dstate <= 32
287
+ else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
288
+ )
289
+ )
290
+ tie_hdim = (
291
+ A.stride(-1) == 0
292
+ and A.stride(-2) == 0
293
+ and dt.stride(-1) == 0
294
+ and dt_bias.stride(-1) == 0
295
+ )
296
+ with torch.cuda.device(x.device.index):
297
+ _selective_scan_update_kernel[grid](
298
+ state,
299
+ x,
300
+ dt,
301
+ dt_bias,
302
+ A,
303
+ B,
304
+ C,
305
+ D,
306
+ z,
307
+ out,
308
+ state_batch_indices,
309
+ pad_slot_id,
310
+ batch,
311
+ nheads,
312
+ dim,
313
+ dstate,
314
+ nheads // ngroups,
315
+ state.stride(0),
316
+ state.stride(1),
317
+ state.stride(2),
318
+ state.stride(3),
319
+ x.stride(0),
320
+ x.stride(1),
321
+ x.stride(2),
322
+ dt.stride(0),
323
+ dt.stride(1),
324
+ dt.stride(2),
325
+ *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
326
+ A.stride(0),
327
+ A.stride(1),
328
+ A.stride(2),
329
+ B.stride(0),
330
+ B.stride(1),
331
+ B.stride(2),
332
+ C.stride(0),
333
+ C.stride(1),
334
+ C.stride(2),
335
+ *(D.stride(0), D.stride(1)) if D is not None else 0,
336
+ z_strides[0],
337
+ z_strides[1],
338
+ z_strides[2],
339
+ out.stride(0),
340
+ out.stride(1),
341
+ out.stride(2),
342
+ dt_softplus,
343
+ tie_hdim,
344
+ BLOCK_SIZE_M,
345
+ num_warps=num_warps,
346
+ )
347
+
348
+
349
+ def selective_scan_fn(
350
+ u,
351
+ ssm_states,
352
+ delta,
353
+ A,
354
+ B,
355
+ C,
356
+ D=None,
357
+ z=None,
358
+ delta_bias=None,
359
+ delta_softplus=False,
360
+ query_start_loc=None,
361
+ cache_indices=None,
362
+ has_initial_state=None,
363
+ pad_slot_id=PAD_SLOT_ID,
364
+ ) -> torch.Tensor:
365
+ """
366
+ u: (dim, total_length) for varlen or (batch, dim, seqlen)
367
+ applies changes in place.
368
+ ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
369
+ applies changes in place.
370
+ delta: (dim, total_length) for varlen or (batch, dim, seqlen)
371
+ A: (dim, dstate)
372
+ B: (ngroups, dstate, total_length) for varlen or
373
+ (batch,ngroups,dstate,seqlen)
374
+ C: (ngroups, dstate, total_length) for varlen or
375
+ (batch,ngroups,dstate,seqlen)
376
+ D: (dim,)
377
+ z: (dim, total_length) for varlen or (batch, dim, seqlen)
378
+ dt_bias: (dim,) or (dim)
379
+ query_start_loc: (batch + 1) int32
380
+ The cumulative sequence lengths of the sequences in
381
+ the batch, used to index into sequence. prepended with 0.
382
+ for example: query_start_loc = torch.Tensor([0,10,16,17]),
383
+ x.shape=(dim,17)
384
+ cache_indices: (batch) int32
385
+ A tensor with each cell is a correspondent
386
+ input and output ssm_state index
387
+ has_initial_state: (batch) bool
388
+ A tensor populated with ones and zeros,
389
+ indicate if the ssm_state at the corresponding index should be
390
+ used as initial state. Not providing argument assumes
391
+ there's no initial state
392
+ pad_slot_id: int
393
+ if cache_indices is passed, lets the kernel identify padding entries
394
+ that will not be processed,
395
+ for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
396
+ in this case, the kernel will not process entries at indices 0 and 3
397
+ returns
398
+ output: (dim, total_length) for varlen or (batch, dim, seqlen)
399
+ supports inplace replacement
400
+ """
401
+ if u.stride(-1) != 1:
402
+ u = u.contiguous()
403
+ if delta.stride(-1) != 1:
404
+ delta = delta.contiguous()
405
+ if D is not None:
406
+ D = D.contiguous()
407
+ if B.stride(-1) != 1:
408
+ B = B.contiguous()
409
+ if C.stride(-1) != 1:
410
+ C = C.contiguous()
411
+ if z is not None and z.stride(-1) != 1:
412
+ z = z.contiguous()
413
+ if B.dim() == 3 and query_start_loc is None:
414
+ B = B.unsqueeze(1)
415
+ if B.dim() == 2 and query_start_loc is not None:
416
+ B = B.unsqueeze(0)
417
+ if C.dim() == 3 and query_start_loc is None:
418
+ C = C.unsqueeze(1)
419
+ if C.dim() == 2 and query_start_loc is not None:
420
+ C = C.unsqueeze(0)
421
+
422
+ ops.selective_scan_fwd(
423
+ u,
424
+ delta,
425
+ A,
426
+ B,
427
+ C,
428
+ D,
429
+ z,
430
+ delta_bias,
431
+ delta_softplus,
432
+ query_start_loc,
433
+ cache_indices,
434
+ has_initial_state,
435
+ ssm_states,
436
+ pad_slot_id,
437
+ )
438
+
439
+ if z is None:
440
+ return delta # output written inplace to delta
441
+ else:
442
+ return z # output written inplace to z
@@ -0,0 +1,214 @@
1
+ # Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_bmm.py
2
+
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
5
+
6
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
7
+ # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py
8
+
9
+ # ruff: noqa: E501,SIM102
10
+
11
+ import math
12
+
13
+ import torch
14
+ import triton
15
+ import triton.language as tl
16
+
17
+
18
+ @triton.jit
19
+ def _bmm_chunk_fwd_kernel(
20
+ # Pointers to matrices
21
+ a_ptr,
22
+ b_ptr,
23
+ out_ptr,
24
+ seq_idx_ptr,
25
+ # Matrix dimensions
26
+ seqlen,
27
+ chunk_size,
28
+ K,
29
+ ngroups,
30
+ stride_a_batch,
31
+ stride_a_seqlen,
32
+ stride_a_head,
33
+ stride_ak,
34
+ stride_b_batch,
35
+ stride_b_seqlen,
36
+ stride_b_head,
37
+ stride_bk,
38
+ stride_out_batch,
39
+ stride_out_chunk,
40
+ stride_out_head,
41
+ stride_outm,
42
+ stride_outn,
43
+ stride_seq_idx_batch,
44
+ stride_seq_idx_seqlen,
45
+ # Meta-parameters
46
+ IS_CAUSAL: tl.constexpr,
47
+ dot_dtype: tl.constexpr,
48
+ HAS_SEQ_IDX: tl.constexpr,
49
+ BLOCK_SIZE_M: tl.constexpr = 16,
50
+ BLOCK_SIZE_N: tl.constexpr = 16,
51
+ BLOCK_SIZE_K: tl.constexpr = 16,
52
+ ):
53
+ pid_b = tl.program_id(axis=1)
54
+ pid_ch = tl.program_id(axis=2).to(tl.int64)
55
+ pid_c = pid_ch // ngroups
56
+ pid_h = pid_ch - pid_c * ngroups
57
+ num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
58
+ pid_m = tl.program_id(axis=0) // num_pid_n
59
+ pid_n = tl.program_id(axis=0) % num_pid_n
60
+ if IS_CAUSAL:
61
+ if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
62
+ return
63
+ a_ptr += (
64
+ pid_b * stride_a_batch
65
+ + pid_c * chunk_size * stride_a_seqlen
66
+ + pid_h * stride_a_head
67
+ )
68
+ b_ptr += (
69
+ pid_b * stride_b_batch
70
+ + pid_c * chunk_size * stride_b_seqlen
71
+ + pid_h * stride_b_head
72
+ )
73
+ if HAS_SEQ_IDX:
74
+ seq_idx_ptr += (
75
+ pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
76
+ )
77
+
78
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
79
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
80
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
81
+ a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
82
+ b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
83
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
84
+
85
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
86
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
87
+ a = tl.load(
88
+ a_ptrs,
89
+ mask=(offs_m[:, None] < chunk_size_limit)
90
+ & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
91
+ other=0.0,
92
+ ).to(dot_dtype)
93
+ b = tl.load(
94
+ b_ptrs,
95
+ mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K)
96
+ & (offs_n[None, :] < chunk_size_limit),
97
+ other=0.0,
98
+ ).to(dot_dtype)
99
+ acc += tl.dot(a, b)
100
+ a_ptrs += BLOCK_SIZE_K * stride_ak
101
+ b_ptrs += BLOCK_SIZE_K * stride_bk
102
+
103
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
104
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
105
+ if HAS_SEQ_IDX:
106
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
107
+ seq_idx_m = tl.load(
108
+ seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
109
+ mask=offs_m < chunk_size_limit,
110
+ other=-1,
111
+ )
112
+ seq_idx_n = tl.load(
113
+ seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
114
+ mask=offs_n < chunk_size_limit,
115
+ other=-2,
116
+ )
117
+ acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
118
+ out = acc.to(out_ptr.dtype.element_ty)
119
+
120
+ out_ptr += (
121
+ pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
122
+ )
123
+ out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
124
+ tl.store(
125
+ out_ptrs,
126
+ out,
127
+ mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size),
128
+ )
129
+
130
+
131
+ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):
132
+ """
133
+ Argument:
134
+ a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
135
+ b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
136
+ seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
137
+ causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
138
+ guaranteed to be correct.
139
+ Return:
140
+ out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
141
+ """
142
+ # Check constraints.
143
+ has_groups = a.dim() == 4
144
+ if not has_groups:
145
+ batch, seqlen, k = a.shape
146
+ else:
147
+ batch, seqlen, ngroups, k = a.shape
148
+ assert b.shape == a.shape
149
+ if seq_idx is not None:
150
+ assert seq_idx.shape == (batch, seqlen)
151
+ if a.stride(-1) != 1 and a.stride(1) != 1:
152
+ a = a.contiguous()
153
+ if b.stride(-1) != 1 and b.stride(1) != 1:
154
+ b = b.contiguous()
155
+ nchunks = math.ceil(seqlen / chunk_size)
156
+ # Allocates output.
157
+ out_dtype = a.dtype if output_dtype is None else output_dtype
158
+ out = torch.empty(
159
+ (
160
+ (batch, nchunks, chunk_size, chunk_size)
161
+ if not has_groups
162
+ else (batch, nchunks, ngroups, chunk_size, chunk_size)
163
+ ),
164
+ device=a.device,
165
+ dtype=out_dtype,
166
+ )
167
+ dot_dtype = (
168
+ tl.bfloat16
169
+ if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16
170
+ else (
171
+ tl.float16
172
+ if a.dtype == torch.float16 or b.dtype == torch.float16
173
+ else tl.float32
174
+ )
175
+ )
176
+ grid = lambda META: (
177
+ triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
178
+ * triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]),
179
+ batch,
180
+ nchunks if not has_groups else nchunks * ngroups,
181
+ )
182
+ with torch.cuda.device(a.device.index):
183
+ _bmm_chunk_fwd_kernel[grid](
184
+ a,
185
+ b,
186
+ out,
187
+ seq_idx,
188
+ seqlen,
189
+ chunk_size,
190
+ k,
191
+ ngroups if has_groups else 1,
192
+ a.stride(0),
193
+ a.stride(1),
194
+ 0 if not has_groups else a.stride(2),
195
+ a.stride(-1),
196
+ b.stride(0),
197
+ b.stride(1),
198
+ 0 if not has_groups else b.stride(2),
199
+ b.stride(-1),
200
+ out.stride(0),
201
+ out.stride(1),
202
+ 0 if not has_groups else out.stride(2),
203
+ out.stride(-2),
204
+ out.stride(-1),
205
+ *(
206
+ (seq_idx.stride(0), seq_idx.stride(1))
207
+ if seq_idx is not None
208
+ else (0, 0)
209
+ ),
210
+ causal,
211
+ dot_dtype,
212
+ HAS_SEQ_IDX=seq_idx is not None,
213
+ )
214
+ return out