sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,81 @@
1
+ # Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/mamba_utils.py
2
+ from sglang.srt.distributed.utils import divide
3
+
4
+
5
+ class MambaStateShapeCalculator:
6
+
7
+ @classmethod
8
+ def linear_attention_state_shape(
9
+ cls,
10
+ num_heads: int,
11
+ tp_size: int,
12
+ head_dim: int,
13
+ ) -> tuple[tuple[int, int, int], ...]:
14
+
15
+ state_shape = (num_heads // tp_size, head_dim, head_dim)
16
+ return (state_shape,)
17
+
18
+ @classmethod
19
+ def mamba1_state_shape(
20
+ cls,
21
+ tp_world_size: int,
22
+ intermediate_size: int,
23
+ state_size: int,
24
+ conv_kernel: int,
25
+ ) -> tuple[tuple[int, int], tuple[int, int]]:
26
+ conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
27
+
28
+ temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
29
+
30
+ conv_state_shape = conv_state_shape[1], conv_state_shape[0]
31
+
32
+ return conv_state_shape, temporal_state_shape
33
+
34
+ @classmethod
35
+ def mamba2_state_shape(
36
+ cls,
37
+ tp_world_size: int,
38
+ intermediate_size: int,
39
+ n_groups: int,
40
+ num_heads: int,
41
+ head_dim: int,
42
+ state_size: int,
43
+ conv_kernel: int,
44
+ ) -> tuple[tuple[int, int], tuple[int, int, int]]:
45
+ # if n_groups is not divisible by world_size, need to extend the shards
46
+ # to ensure all groups needed by a head is sharded along with it
47
+ n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
48
+ # heads and n_groups are TP-ed
49
+ conv_dim = intermediate_size + 2 * n_groups * state_size
50
+
51
+ # contiguous along 'dim' axis
52
+ conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
53
+
54
+ # These are not TP-ed as they depend on A, dt_bias, D
55
+ # - they are typically small
56
+ # e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
57
+ temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
58
+ return conv_state_shape, temporal_state_shape
59
+
60
+ @classmethod
61
+ def short_conv_state_shape(
62
+ cls,
63
+ tp_world_size: int,
64
+ intermediate_size: int,
65
+ conv_kernel: int,
66
+ ) -> tuple[tuple[int, int]]:
67
+ conv_dim = divide(intermediate_size, tp_world_size)
68
+ conv_state_shape = (conv_kernel - 1, conv_dim)
69
+ return (conv_state_shape,)
70
+
71
+ @classmethod
72
+ def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
73
+ """Compute the increase in group numbers to account for
74
+ replication in order to accompany the head shards."""
75
+
76
+ # in the case ngoups % tp_size == 0, this will be zero
77
+ if ngroups % tp_size == 0:
78
+ return 0
79
+
80
+ # for n_groups == 1, this is exactly tp_size - n_groups
81
+ return tp_size - ngroups
@@ -0,0 +1,2 @@
1
+ from .mamba_ssm import selective_state_update
2
+ from .ssd_combined import mamba_chunk_scan_combined
@@ -0,0 +1,172 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ # Copyright (c) 2024, Tri Dao.
4
+ # Adapted from https://github.com/state-spaces/mamba/blob/60dadf2e0ee730ac337035d5533de10bc26e4847/mamba_ssm/ops/triton/layernorm_gated.py
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+
11
+ @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
12
+ @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
13
+ @triton.jit
14
+ def _layer_norm_fwd_1pass_kernel(
15
+ X, # pointer to the input
16
+ Y, # pointer to the output
17
+ W, # pointer to the weights
18
+ B, # pointer to the biases
19
+ Z, # pointer to the other branch
20
+ Mean, # pointer to the mean
21
+ Rstd, # pointer to the 1/std
22
+ stride_x_row: tl.int64,
23
+ stride_y_row: tl.int64,
24
+ stride_z_row: tl.int64,
25
+ M: tl.int64, # number of rows in X
26
+ N: tl.int64, # number of columns in X
27
+ eps, # epsilon to avoid division by zero
28
+ BLOCK_N: tl.constexpr,
29
+ HAS_BIAS: tl.constexpr,
30
+ HAS_Z: tl.constexpr,
31
+ NORM_BEFORE_GATE: tl.constexpr,
32
+ IS_RMS_NORM: tl.constexpr,
33
+ ):
34
+ # Map the program id to the row of X and Y it should compute.
35
+ row = tl.program_id(0)
36
+ group = tl.program_id(1)
37
+ X += row * stride_x_row + group * N
38
+ Y += row * stride_y_row + group * N
39
+ if HAS_Z:
40
+ Z += row * stride_z_row + group * N
41
+ if not IS_RMS_NORM:
42
+ Mean += group * M
43
+ Rstd += group * M
44
+ W += group * N
45
+ if HAS_BIAS:
46
+ B += group * N
47
+ # Compute mean and variance
48
+ cols = tl.arange(0, BLOCK_N)
49
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
50
+ if HAS_Z and not NORM_BEFORE_GATE:
51
+ z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
52
+ x *= z * tl.sigmoid(z)
53
+ if not IS_RMS_NORM:
54
+ mean = tl.sum(x, axis=0) / N
55
+ tl.store(Mean + row, mean)
56
+ xbar = tl.where(cols < N, x - mean, 0.0)
57
+ var = tl.sum(xbar * xbar, axis=0) / N
58
+ else:
59
+ xbar = tl.where(cols < N, x, 0.0)
60
+ var = tl.sum(xbar * xbar, axis=0) / N
61
+ rstd = 1 / tl.sqrt(var + eps)
62
+ tl.store(Rstd + row, rstd)
63
+ # Normalize and apply linear transformation
64
+ mask = cols < N
65
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
66
+ if HAS_BIAS:
67
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
68
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
69
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
70
+ if HAS_Z and NORM_BEFORE_GATE:
71
+ z = tl.load(Z + cols, mask=mask).to(tl.float32)
72
+ y *= z * tl.sigmoid(z)
73
+ # Write output
74
+ tl.store(Y + cols, y, mask=mask)
75
+
76
+
77
+ def _layer_norm_fwd(
78
+ x,
79
+ weight,
80
+ bias,
81
+ eps,
82
+ z=None,
83
+ out=None,
84
+ group_size=None,
85
+ norm_before_gate=True,
86
+ is_rms_norm=False,
87
+ ):
88
+ M, N = x.shape
89
+ if group_size is None:
90
+ group_size = N
91
+ assert N % group_size == 0
92
+ ngroups = N // group_size
93
+ assert x.stride(-1) == 1
94
+ if z is not None:
95
+ assert z.stride(-1) == 1
96
+ assert z.shape == (M, N)
97
+ assert weight.shape == (N,)
98
+ assert weight.stride(-1) == 1
99
+ if bias is not None:
100
+ assert bias.stride(-1) == 1
101
+ assert bias.shape == (N,)
102
+ # allocate output
103
+ if out is not None:
104
+ assert out.shape == x.shape
105
+ else:
106
+ out = torch.empty_like(x)
107
+ assert out.stride(-1) == 1
108
+ mean = (
109
+ torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
110
+ if not is_rms_norm
111
+ else None
112
+ )
113
+ rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
114
+ # Less than 64KB per feature: enqueue fused kernel
115
+ MAX_FUSED_SIZE = 65536 // x.element_size()
116
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
117
+ if group_size > BLOCK_N:
118
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
119
+ # heuristics for number of warps
120
+ num_warps = min(max(BLOCK_N // 256, 1), 8)
121
+ grid = (M, ngroups)
122
+ with torch.cuda.device(x.device.index):
123
+ _layer_norm_fwd_1pass_kernel[grid](
124
+ x,
125
+ out,
126
+ weight,
127
+ bias,
128
+ z,
129
+ mean,
130
+ rstd,
131
+ x.stride(0),
132
+ out.stride(0),
133
+ z.stride(0) if z is not None else 0,
134
+ M,
135
+ group_size,
136
+ eps,
137
+ BLOCK_N=BLOCK_N,
138
+ NORM_BEFORE_GATE=norm_before_gate,
139
+ IS_RMS_NORM=is_rms_norm,
140
+ num_warps=num_warps,
141
+ )
142
+ return out, mean, rstd
143
+
144
+
145
+ def rms_norm_gated(
146
+ x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
147
+ ):
148
+ x_shape_og = x.shape
149
+ # reshape input data into 2D tensor
150
+ x = x.reshape(-1, x.shape[-1])
151
+ if x.stride(-1) != 1:
152
+ x = x.contiguous()
153
+ if z is not None:
154
+ assert z.shape == x_shape_og
155
+ z = z.reshape(-1, z.shape[-1])
156
+ if z.stride(-1) != 1:
157
+ z = z.contiguous()
158
+ weight = weight.contiguous()
159
+ if bias is not None:
160
+ bias = bias.contiguous()
161
+ y, _, _ = _layer_norm_fwd(
162
+ x,
163
+ weight,
164
+ bias,
165
+ eps,
166
+ z=z,
167
+ group_size=group_size,
168
+ norm_before_gate=norm_before_gate,
169
+ is_rms_norm=True,
170
+ )
171
+
172
+ return y.reshape(x_shape_og)
@@ -0,0 +1,442 @@
1
+ # Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/mamba_ssm.py
2
+
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
5
+
6
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
7
+ # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
8
+
9
+ import torch
10
+ import triton
11
+ import triton.language as tl
12
+ from packaging import version
13
+
14
+ from sglang.srt import _custom_ops as ops
15
+
16
+ PAD_SLOT_ID = -1
17
+
18
+ TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
19
+
20
+ if TRITON3:
21
+
22
+ @triton.jit
23
+ def softplus(dt):
24
+ dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
25
+ return dt
26
+
27
+ else:
28
+
29
+ @triton.jit
30
+ def softplus(dt):
31
+ dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
32
+ return dt
33
+
34
+
35
+ @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
36
+ @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
37
+ @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
38
+ @triton.heuristics(
39
+ {
40
+ "HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
41
+ is not None
42
+ }
43
+ )
44
+ @triton.heuristics(
45
+ {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
46
+ )
47
+ @triton.jit
48
+ def _selective_scan_update_kernel(
49
+ # Pointers to matrices
50
+ state_ptr,
51
+ x_ptr,
52
+ dt_ptr,
53
+ dt_bias_ptr,
54
+ A_ptr,
55
+ B_ptr,
56
+ C_ptr,
57
+ D_ptr,
58
+ z_ptr,
59
+ out_ptr,
60
+ state_batch_indices_ptr,
61
+ pad_slot_id,
62
+ # Matrix dimensions
63
+ batch,
64
+ nheads,
65
+ dim,
66
+ dstate,
67
+ nheads_ngroups_ratio,
68
+ # Strides
69
+ stride_state_batch,
70
+ stride_state_head,
71
+ stride_state_dim,
72
+ stride_state_dstate,
73
+ stride_x_batch,
74
+ stride_x_head,
75
+ stride_x_dim,
76
+ stride_dt_batch,
77
+ stride_dt_head,
78
+ stride_dt_dim,
79
+ stride_dt_bias_head,
80
+ stride_dt_bias_dim,
81
+ stride_A_head,
82
+ stride_A_dim,
83
+ stride_A_dstate,
84
+ stride_B_batch,
85
+ stride_B_group,
86
+ stride_B_dstate,
87
+ stride_C_batch,
88
+ stride_C_group,
89
+ stride_C_dstate,
90
+ stride_D_head,
91
+ stride_D_dim,
92
+ stride_z_batch,
93
+ stride_z_head,
94
+ stride_z_dim,
95
+ stride_out_batch,
96
+ stride_out_head,
97
+ stride_out_dim,
98
+ # Meta-parameters
99
+ DT_SOFTPLUS: tl.constexpr,
100
+ TIE_HDIM: tl.constexpr,
101
+ BLOCK_SIZE_M: tl.constexpr,
102
+ HAS_DT_BIAS: tl.constexpr,
103
+ HAS_D: tl.constexpr,
104
+ HAS_Z: tl.constexpr,
105
+ HAS_STATE_BATCH_INDICES: tl.constexpr,
106
+ BLOCK_SIZE_DSTATE: tl.constexpr,
107
+ ):
108
+ pid_m = tl.program_id(axis=0)
109
+ pid_b = tl.program_id(axis=1)
110
+ pid_h = tl.program_id(axis=2)
111
+
112
+ # If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
113
+ # is taken from the state_batch_indices_ptr Otherwise, the state coordinate
114
+ # is the same as the batch id.
115
+ if HAS_STATE_BATCH_INDICES:
116
+ state_batch_indices_ptr += pid_b
117
+ state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
118
+ state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
119
+ else:
120
+ state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
121
+
122
+ x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
123
+ dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
124
+ if HAS_DT_BIAS:
125
+ dt_bias_ptr += pid_h * stride_dt_bias_head
126
+ A_ptr += pid_h * stride_A_head
127
+ B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
128
+ C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
129
+ if HAS_Z:
130
+ z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
131
+ out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
132
+
133
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
134
+ offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
135
+ state_ptrs = state_ptr + (
136
+ offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
137
+ )
138
+ x_ptrs = x_ptr + offs_m * stride_x_dim
139
+ dt_ptrs = dt_ptr + offs_m * stride_dt_dim
140
+ if HAS_DT_BIAS:
141
+ dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
142
+ if HAS_D:
143
+ D_ptr += pid_h * stride_D_head
144
+ A_ptrs = A_ptr + (
145
+ offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
146
+ )
147
+ B_ptrs = B_ptr + offs_n * stride_B_dstate
148
+ C_ptrs = C_ptr + offs_n * stride_C_dstate
149
+ if HAS_D:
150
+ D_ptrs = D_ptr + offs_m * stride_D_dim
151
+ if HAS_Z:
152
+ z_ptrs = z_ptr + offs_m * stride_z_dim
153
+ out_ptrs = out_ptr + offs_m * stride_out_dim
154
+ mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
155
+ if HAS_STATE_BATCH_INDICES:
156
+ mask &= state_batch_idx != pad_slot_id
157
+ state = tl.load(state_ptrs, mask=mask, other=0.0)
158
+
159
+ x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
160
+ if not TIE_HDIM:
161
+ dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
162
+ if HAS_DT_BIAS:
163
+ dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
164
+ if DT_SOFTPLUS:
165
+ dt = softplus(dt)
166
+ A = tl.load(
167
+ A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
168
+ ).to(tl.float32)
169
+ dA = tl.exp(A * dt[:, None])
170
+ else:
171
+ dt = tl.load(dt_ptr).to(tl.float32)
172
+ if HAS_DT_BIAS:
173
+ dt += tl.load(dt_bias_ptr).to(tl.float32)
174
+ if DT_SOFTPLUS:
175
+ dt = softplus(dt)
176
+ A = tl.load(A_ptr).to(tl.float32)
177
+ dA = tl.exp(A * dt) # scalar, not a matrix
178
+
179
+ B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
180
+ C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
181
+ if HAS_D:
182
+ D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
183
+ if HAS_Z:
184
+ z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
185
+
186
+ dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
187
+ state = state * dA + dB * x[:, None]
188
+
189
+ mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
190
+ if HAS_STATE_BATCH_INDICES:
191
+ mask &= state_batch_idx != pad_slot_id
192
+ tl.store(state_ptrs, state, mask=mask)
193
+ out = tl.sum(state * C[None, :], axis=1)
194
+ if HAS_D:
195
+ out += x * D
196
+ if HAS_Z:
197
+ out *= z * tl.sigmoid(z)
198
+ tl.store(out_ptrs, out, mask=offs_m < dim)
199
+
200
+
201
+ def selective_state_update(
202
+ state,
203
+ x,
204
+ dt,
205
+ A,
206
+ B,
207
+ C,
208
+ D=None,
209
+ z=None,
210
+ dt_bias=None,
211
+ dt_softplus=False,
212
+ state_batch_indices=None,
213
+ pad_slot_id=PAD_SLOT_ID,
214
+ out=None,
215
+ ):
216
+ """
217
+ Argument:
218
+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
219
+ x: (batch, dim) or (batch, nheads, dim)
220
+ dt: (batch, dim) or (batch, nheads, dim)
221
+ A: (dim, dstate) or (nheads, dim, dstate)
222
+ B: (batch, dstate) or (batch, ngroups, dstate)
223
+ C: (batch, dstate) or (batch, ngroups, dstate)
224
+ D: (dim,) or (nheads, dim)
225
+ z: (batch, dim) or (batch, nheads, dim)
226
+ dt_bias: (dim,) or (nheads, dim)
227
+ pad_slot_id: int
228
+ if cache_indices is passed, lets the kernel identify padded
229
+ entries that will not be processed,
230
+ for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
231
+ in this case, the kernel will not process entries at
232
+ indices 0 and 3
233
+ out: Preallocated ssm output tensor. Assume same shape as x.
234
+ In-place updated.
235
+ """
236
+ if state.dim() == 3:
237
+ state = state.unsqueeze(1)
238
+ if x.dim() == 2:
239
+ x = x.unsqueeze(1)
240
+ if dt.dim() == 2:
241
+ dt = dt.unsqueeze(1)
242
+ if A.dim() == 2:
243
+ A = A.unsqueeze(0)
244
+ if B.dim() == 2:
245
+ B = B.unsqueeze(1)
246
+ if C.dim() == 2:
247
+ C = C.unsqueeze(1)
248
+ if D is not None and D.dim() == 1:
249
+ D = D.unsqueeze(0)
250
+ if z is not None and z.dim() == 2:
251
+ z = z.unsqueeze(1)
252
+ if dt_bias is not None and dt_bias.dim() == 1:
253
+ dt_bias = dt_bias.unsqueeze(0)
254
+ if out.dim() == 2:
255
+ out = out.unsqueeze(1)
256
+
257
+ _, nheads, dim, dstate = state.shape
258
+ batch = x.shape[0]
259
+
260
+ assert x.shape == (batch, nheads, dim)
261
+ assert dt.shape == x.shape
262
+ assert A.shape == (nheads, dim, dstate)
263
+ ngroups = B.shape[1]
264
+ assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
265
+ assert B.shape == (batch, ngroups, dstate)
266
+ assert C.shape == B.shape
267
+ if D is not None:
268
+ assert D.shape == (nheads, dim)
269
+ if z is not None:
270
+ assert z.shape == x.shape
271
+ if dt_bias is not None:
272
+ assert dt_bias.shape == (nheads, dim)
273
+ if state_batch_indices is not None:
274
+ assert state_batch_indices.shape == (batch,)
275
+ assert out.shape == x.shape
276
+
277
+ grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
278
+ z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
279
+ # We don't want autotune since it will overwrite the state
280
+ # We instead tune by hand.
281
+ BLOCK_SIZE_M, num_warps = (
282
+ (32, 4)
283
+ if dstate <= 16
284
+ else (
285
+ (16, 4)
286
+ if dstate <= 32
287
+ else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
288
+ )
289
+ )
290
+ tie_hdim = (
291
+ A.stride(-1) == 0
292
+ and A.stride(-2) == 0
293
+ and dt.stride(-1) == 0
294
+ and dt_bias.stride(-1) == 0
295
+ )
296
+ with torch.cuda.device(x.device.index):
297
+ _selective_scan_update_kernel[grid](
298
+ state,
299
+ x,
300
+ dt,
301
+ dt_bias,
302
+ A,
303
+ B,
304
+ C,
305
+ D,
306
+ z,
307
+ out,
308
+ state_batch_indices,
309
+ pad_slot_id,
310
+ batch,
311
+ nheads,
312
+ dim,
313
+ dstate,
314
+ nheads // ngroups,
315
+ state.stride(0),
316
+ state.stride(1),
317
+ state.stride(2),
318
+ state.stride(3),
319
+ x.stride(0),
320
+ x.stride(1),
321
+ x.stride(2),
322
+ dt.stride(0),
323
+ dt.stride(1),
324
+ dt.stride(2),
325
+ *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
326
+ A.stride(0),
327
+ A.stride(1),
328
+ A.stride(2),
329
+ B.stride(0),
330
+ B.stride(1),
331
+ B.stride(2),
332
+ C.stride(0),
333
+ C.stride(1),
334
+ C.stride(2),
335
+ *(D.stride(0), D.stride(1)) if D is not None else 0,
336
+ z_strides[0],
337
+ z_strides[1],
338
+ z_strides[2],
339
+ out.stride(0),
340
+ out.stride(1),
341
+ out.stride(2),
342
+ dt_softplus,
343
+ tie_hdim,
344
+ BLOCK_SIZE_M,
345
+ num_warps=num_warps,
346
+ )
347
+
348
+
349
+ def selective_scan_fn(
350
+ u,
351
+ ssm_states,
352
+ delta,
353
+ A,
354
+ B,
355
+ C,
356
+ D=None,
357
+ z=None,
358
+ delta_bias=None,
359
+ delta_softplus=False,
360
+ query_start_loc=None,
361
+ cache_indices=None,
362
+ has_initial_state=None,
363
+ pad_slot_id=PAD_SLOT_ID,
364
+ ) -> torch.Tensor:
365
+ """
366
+ u: (dim, total_length) for varlen or (batch, dim, seqlen)
367
+ applies changes in place.
368
+ ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
369
+ applies changes in place.
370
+ delta: (dim, total_length) for varlen or (batch, dim, seqlen)
371
+ A: (dim, dstate)
372
+ B: (ngroups, dstate, total_length) for varlen or
373
+ (batch,ngroups,dstate,seqlen)
374
+ C: (ngroups, dstate, total_length) for varlen or
375
+ (batch,ngroups,dstate,seqlen)
376
+ D: (dim,)
377
+ z: (dim, total_length) for varlen or (batch, dim, seqlen)
378
+ dt_bias: (dim,) or (dim)
379
+ query_start_loc: (batch + 1) int32
380
+ The cumulative sequence lengths of the sequences in
381
+ the batch, used to index into sequence. prepended with 0.
382
+ for example: query_start_loc = torch.Tensor([0,10,16,17]),
383
+ x.shape=(dim,17)
384
+ cache_indices: (batch) int32
385
+ A tensor with each cell is a correspondent
386
+ input and output ssm_state index
387
+ has_initial_state: (batch) bool
388
+ A tensor populated with ones and zeros,
389
+ indicate if the ssm_state at the corresponding index should be
390
+ used as initial state. Not providing argument assumes
391
+ there's no initial state
392
+ pad_slot_id: int
393
+ if cache_indices is passed, lets the kernel identify padding entries
394
+ that will not be processed,
395
+ for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
396
+ in this case, the kernel will not process entries at indices 0 and 3
397
+ returns
398
+ output: (dim, total_length) for varlen or (batch, dim, seqlen)
399
+ supports inplace replacement
400
+ """
401
+ if u.stride(-1) != 1:
402
+ u = u.contiguous()
403
+ if delta.stride(-1) != 1:
404
+ delta = delta.contiguous()
405
+ if D is not None:
406
+ D = D.contiguous()
407
+ if B.stride(-1) != 1:
408
+ B = B.contiguous()
409
+ if C.stride(-1) != 1:
410
+ C = C.contiguous()
411
+ if z is not None and z.stride(-1) != 1:
412
+ z = z.contiguous()
413
+ if B.dim() == 3 and query_start_loc is None:
414
+ B = B.unsqueeze(1)
415
+ if B.dim() == 2 and query_start_loc is not None:
416
+ B = B.unsqueeze(0)
417
+ if C.dim() == 3 and query_start_loc is None:
418
+ C = C.unsqueeze(1)
419
+ if C.dim() == 2 and query_start_loc is not None:
420
+ C = C.unsqueeze(0)
421
+
422
+ ops.selective_scan_fwd(
423
+ u,
424
+ delta,
425
+ A,
426
+ B,
427
+ C,
428
+ D,
429
+ z,
430
+ delta_bias,
431
+ delta_softplus,
432
+ query_start_loc,
433
+ cache_indices,
434
+ has_initial_state,
435
+ ssm_states,
436
+ pad_slot_id,
437
+ )
438
+
439
+ if z is None:
440
+ return delta # output written inplace to delta
441
+ else:
442
+ return z # output written inplace to z