sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,577 @@
1
+ from typing import Callable, List, Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from sglang.srt.configs.mamba_utils import (
7
+ Mamba2CacheParams,
8
+ extra_groups_for_head_shards,
9
+ )
10
+ from sglang.srt.distributed import (
11
+ divide,
12
+ get_tensor_model_parallel_rank,
13
+ get_tensor_model_parallel_world_size,
14
+ )
15
+ from sglang.srt.distributed.utils import divide
16
+ from sglang.srt.layers.attention.mamba.causal_conv1d import (
17
+ causal_conv1d_fn,
18
+ causal_conv1d_update,
19
+ )
20
+ from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
21
+ causal_conv1d_fn as causal_conv1d_fn_triton,
22
+ )
23
+ from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
24
+ causal_conv1d_update as causal_conv1d_update_triton,
25
+ )
26
+ from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata
27
+ from sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated import Mixer2RMSNormGated
28
+ from sglang.srt.layers.attention.mamba.ops import (
29
+ mamba_chunk_scan_combined,
30
+ selective_state_update,
31
+ )
32
+ from sglang.srt.layers.linear import (
33
+ ColumnParallelLinear,
34
+ MergedColumnParallelLinear,
35
+ RowParallelLinear,
36
+ )
37
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
+ from sglang.srt.mem_cache.memory_pool import MambaPool
39
+ from sglang.srt.model_loader.weight_utils import (
40
+ composed_weight_loader,
41
+ sharded_weight_loader,
42
+ )
43
+ from sglang.srt.utils import set_weight_attrs
44
+
45
+ LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None]
46
+
47
+
48
+ def mamba_v2_sharded_weight_loader(
49
+ shard_spec: List[Tuple[int, int, float]],
50
+ tp_size: int,
51
+ tp_rank: int,
52
+ ) -> LoaderFunction:
53
+ """Create a weight loader for mamba v2. This ensures that the projections
54
+ are correctly sharded so that they can be split into x, B, C. It also
55
+ ensures the the all the groups corresponding to a head shard is placed
56
+ together with it.
57
+ """
58
+
59
+ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
60
+
61
+ # - track boundary of (sharded) param, and loaded_weight, respectively
62
+ boundary, loaded_boundary = 0, 0
63
+
64
+ # - iterate over the shard specs
65
+ for full_dim, extra, duplicate_groups in shard_spec:
66
+ # - full dim is the model dim (before TP).
67
+ # - extra > 0, means there is expected overall increase
68
+ # of dimensions. This is so because of replication.
69
+ # - ratio is used map the tp_rank to the actual shard
70
+ # rank. This is useful when there is replication of
71
+ # groups to accompany head shards.
72
+
73
+ # - size of the loaded shard
74
+ shard_size = full_dim // tp_size
75
+
76
+ # - compute the rank into the loaded shard.
77
+ # - if there is replication, different TP shards will
78
+ # take from the same rank.
79
+ # NOTE: currently we only support duplication
80
+ # in the case where num_groups == 1
81
+ rank = 0 if duplicate_groups else tp_rank
82
+
83
+ # - leftmost boundary index into loaded weight.
84
+ loaded_skip = rank * shard_size
85
+ loaded_start_idx = loaded_boundary + loaded_skip
86
+
87
+ # - take these many dims from the loaded weight.
88
+ take = min(shard_size, full_dim - extra - loaded_skip)
89
+
90
+ # - always shard on dim 0
91
+ # - the ignore is for a mundane mypy error as it does not
92
+ # seem to handle slices well.
93
+ # https://github.com/python/mypy/issues/2410
94
+ param.data[
95
+ boundary : (boundary + take), ... # type: ignore[misc]
96
+ ] = loaded_weight[
97
+ loaded_start_idx : (loaded_start_idx + take) # type: ignore[misc]
98
+ ] # type: ignore[misc]
99
+
100
+ # move indexing boundaries
101
+ boundary += shard_size
102
+ loaded_boundary += full_dim - extra
103
+
104
+ return loader
105
+
106
+
107
+ class MambaMixer2(torch.nn.Module):
108
+ """
109
+ Compute ∆, A, B, C, and D the state space parameters and compute
110
+ the `contextualized_states`. A, D are input independent
111
+ (see Mamba paper [1] Section 3.5.2 "Interpretation of A"
112
+ for why A isn't selective) ∆, B, C are input-dependent
113
+ (this is a key difference between Mamba and the linear time
114
+ invariant S4, and is why Mamba is called
115
+ **selective** state spaces)
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ cache_params: Mamba2CacheParams,
121
+ hidden_size: int,
122
+ use_conv_bias: bool,
123
+ use_bias: bool,
124
+ n_groups: int = 1,
125
+ rms_norm_eps: float = 1e-5,
126
+ activation: str = "silu",
127
+ use_rms_norm: bool = True,
128
+ quant_config: Optional[QuantizationConfig] = None,
129
+ prefix: str = "",
130
+ ):
131
+ super().__init__()
132
+
133
+ # For TP, the sharding plan is as follows:
134
+ # - for the conv modules, since
135
+ # conv_dim = intermediate_size * 2 * n_groups * ssm_state_size,
136
+ # we shard intermediate_size and n_groups
137
+ # - since intermediate_size = n_heads * head_dim, sharding on
138
+ # intermediate_size is achieved by sharding on n_heads.
139
+ # - IF, world_size divides groups, then sharding
140
+ # (n_groups / world_size, n_heads / world_size)
141
+ # also maintains the invariant n_heads % n_groups == 0
142
+ # - HOWEVER IF, world_size DOES NOT divide groups, then we need
143
+ # to allocate extra space in the shard, such that groups
144
+ # may be replicated to follow the head shard.
145
+ # - NOTE: currently for the world size DOES NOT divide groups
146
+ # case, we only support the case when n_groups == 1
147
+ self.tp_size = get_tensor_model_parallel_world_size()
148
+ self.tp_rank = get_tensor_model_parallel_rank()
149
+
150
+ self.num_heads = num_heads = cache_params.shape.num_heads
151
+ self.head_dim = cache_params.shape.head_dim
152
+
153
+ assert (
154
+ num_heads % self.tp_size == 0
155
+ ), "Tensor parallel world size must divide num heads."
156
+
157
+ assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
158
+ "If tensor parallel world size does not divide num_groups, "
159
+ "then num_groups must equal 1."
160
+ )
161
+
162
+ assert (
163
+ (n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None
164
+ ), (
165
+ "Tensor parallel currently supported for quantized models only "
166
+ "if tensor parallel world size divides num groups."
167
+ )
168
+
169
+ self.ssm_state_size = cache_params.shape.ssm_state_size
170
+ self.activation = activation
171
+
172
+ conv_kernel_size = cache_params.shape.conv_kernel
173
+ self.intermediate_size = intermediate_size = (
174
+ cache_params.shape.intermediate_size
175
+ )
176
+ self.n_groups = n_groups
177
+ if n_groups % self.tp_size != 0:
178
+ # - for TP we shard conv_dim by sharding on n_groups,
179
+ # - but if n_groups cannot divide tp_size, we need to
180
+ # extend some extra groups
181
+ groups = extra_groups_for_head_shards(n_groups, self.tp_size)
182
+ self.n_groups = n_groups + groups
183
+ self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
184
+ self.conv_dim = cache_params.shape.conv_dim
185
+
186
+ if n_groups % self.tp_size == 0:
187
+ self.conv1d = MergedColumnParallelLinear(
188
+ input_size=conv_kernel_size,
189
+ output_sizes=[
190
+ intermediate_size,
191
+ self.groups_ssm_state_size,
192
+ self.groups_ssm_state_size,
193
+ ],
194
+ bias=use_conv_bias,
195
+ quant_config=None,
196
+ prefix=f"{prefix}.conv1d",
197
+ )
198
+
199
+ self.in_proj = MergedColumnParallelLinear(
200
+ input_size=hidden_size,
201
+ output_sizes=[
202
+ intermediate_size,
203
+ intermediate_size,
204
+ self.groups_ssm_state_size,
205
+ self.groups_ssm_state_size,
206
+ self.num_heads,
207
+ ],
208
+ bias=use_bias,
209
+ quant_config=quant_config,
210
+ prefix=f"{prefix}.in_proj",
211
+ )
212
+ else:
213
+ # This is the n_groups == 1 case,
214
+ # where we need to duplicate groups if TP>1.
215
+
216
+ self.conv1d = ColumnParallelLinear(
217
+ input_size=conv_kernel_size,
218
+ output_size=self.conv_dim,
219
+ bias=use_conv_bias,
220
+ quant_config=None,
221
+ prefix=f"{prefix}.conv1d",
222
+ )
223
+
224
+ self.in_proj = ColumnParallelLinear(
225
+ input_size=hidden_size,
226
+ output_size=intermediate_size + self.conv_dim + self.num_heads,
227
+ bias=use_bias,
228
+ quant_config=quant_config,
229
+ prefix=f"{prefix}.in_proj",
230
+ )
231
+
232
+ # - because in_proj is a concatenation of 3 weights, we
233
+ # need to interleave them before sharding
234
+ # - use the custom weight loader mamba_v2_sharded_weight_loader
235
+ # for conv1d.bias, covn1d.weight and in_proj.weight
236
+ # - need to set these settings, to assign the groups
237
+ # to the head shards
238
+ group_shard_settings = (
239
+ self.groups_ssm_state_size, # expected model size
240
+ (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
241
+ n_groups == 1, # if there was only one group
242
+ )
243
+ intermediate_settings = (intermediate_size, 0, False)
244
+ head_settings = (self.num_heads, 0, False)
245
+
246
+ # - the weight already has a "weight_loader" attribute
247
+ # which set_weight_attrs will raise if we do not
248
+ # delete before trying to override it
249
+ # - ditto for the other two weights below
250
+ delattr(self.conv1d.bias, "weight_loader")
251
+ set_weight_attrs(
252
+ self.conv1d.bias,
253
+ {
254
+ "weight_loader": mamba_v2_sharded_weight_loader(
255
+ [
256
+ intermediate_settings,
257
+ group_shard_settings,
258
+ group_shard_settings,
259
+ ],
260
+ self.tp_size,
261
+ self.tp_rank,
262
+ )
263
+ },
264
+ )
265
+
266
+ delattr(self.conv1d.weight, "weight_loader")
267
+ set_weight_attrs(
268
+ self.conv1d.weight,
269
+ {
270
+ "weight_loader": mamba_v2_sharded_weight_loader(
271
+ [
272
+ intermediate_settings,
273
+ group_shard_settings,
274
+ group_shard_settings,
275
+ ],
276
+ self.tp_size,
277
+ self.tp_rank,
278
+ )
279
+ },
280
+ )
281
+
282
+ if quant_config is None:
283
+ # - quant layers do not have a weight loader
284
+ delattr(self.in_proj.weight, "weight_loader")
285
+ set_weight_attrs(
286
+ self.in_proj.weight,
287
+ {
288
+ "weight_loader": mamba_v2_sharded_weight_loader(
289
+ [
290
+ intermediate_settings, # for gate
291
+ intermediate_settings,
292
+ group_shard_settings,
293
+ group_shard_settings,
294
+ head_settings, # for dt
295
+ ],
296
+ self.tp_size,
297
+ self.tp_rank,
298
+ )
299
+ },
300
+ )
301
+
302
+ # unsqueeze to fit conv1d weights shape into the linear weights shape.
303
+ # Can't do this in `weight_loader` since it already exists in
304
+ # `ColumnParallelLinear` and `MergedColumnParallelLinear`,
305
+ # and `set_weight_attrs` doesn't allow to override it
306
+ self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
307
+
308
+ # - these are TPed by heads to reduce the size of the
309
+ # temporal shape
310
+ self.A = nn.Parameter(
311
+ torch.empty(
312
+ divide(num_heads, self.tp_size),
313
+ dtype=torch.float32,
314
+ )
315
+ )
316
+ self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
317
+ self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
318
+ self.use_rms_norm = use_rms_norm
319
+
320
+ set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
321
+ a_weight_loader = composed_weight_loader(
322
+ sharded_weight_loader(0), lambda x: -torch.exp(x.float())
323
+ )
324
+ set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
325
+ set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
326
+
327
+ self.out_proj = RowParallelLinear(
328
+ intermediate_size,
329
+ hidden_size,
330
+ bias=use_bias,
331
+ input_is_parallel=True,
332
+ quant_config=quant_config,
333
+ prefix=f"{prefix}.out_proj",
334
+ reduce_results=False,
335
+ )
336
+
337
+ self.norm = Mixer2RMSNormGated(
338
+ intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
339
+ )
340
+
341
+ self.prefix = prefix
342
+
343
+ def forward(
344
+ self,
345
+ *,
346
+ hidden_states: torch.Tensor,
347
+ output: torch.Tensor,
348
+ layer_cache: MambaPool.State,
349
+ metadata: Mamba2Metadata,
350
+ mup_vector: Optional[torch.Tensor] = None,
351
+ use_triton_causal_conv: bool = False,
352
+ ):
353
+ # metadata contains metadata necessary for the mamba2 triton
354
+ # kernels to operate in continuous batching and in chunked prefill
355
+ # modes; they are computed at top-level model forward since they
356
+ # stay the same and reused for all mamba layers in the same iteration
357
+ state_indices_tensor = metadata.mamba_cache_indices
358
+ conv_state = layer_cache.conv
359
+ ssm_state = layer_cache.temporal
360
+
361
+ query_start_loc = metadata.query_start_loc
362
+
363
+ # 1. Gated MLP's linear projection
364
+ projected_states, _ = self.in_proj(hidden_states)
365
+
366
+ if mup_vector is not None:
367
+ projected_states = projected_states * mup_vector
368
+
369
+ gate, hidden_states_B_C, dt = torch.split(
370
+ projected_states,
371
+ [
372
+ self.intermediate_size // self.tp_size,
373
+ self.conv_dim // self.tp_size,
374
+ self.num_heads // self.tp_size,
375
+ ],
376
+ dim=-1,
377
+ )
378
+ conv_weights = self.conv1d.weight.view(
379
+ self.conv1d.weight.size(0), self.conv1d.weight.size(2)
380
+ )
381
+
382
+ # - get hidden_states, B and C after depthwise convolution.
383
+ split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
384
+ hidden_states_B_C,
385
+ [
386
+ self.intermediate_size // self.tp_size,
387
+ self.groups_ssm_state_size // self.tp_size,
388
+ self.groups_ssm_state_size // self.tp_size,
389
+ ],
390
+ dim=-1,
391
+ )
392
+
393
+ num_prefills = metadata.num_prefills # request count
394
+ num_decodes = metadata.num_decodes # token count (=request)
395
+ num_prefill_tokens = metadata.num_prefill_tokens # token count
396
+ has_prefill = num_prefills > 0
397
+ has_decode = num_decodes > 0
398
+ num_actual_tokens = num_prefill_tokens + num_decodes
399
+ assert num_actual_tokens == projected_states.shape[0]
400
+
401
+ # NOTE: V0 put prefill before decode
402
+ # Separate prefill and decode by splitting varlen input
403
+ # Split along token dimension
404
+ hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
405
+ hidden_states_B_C,
406
+ [num_prefill_tokens, num_decodes],
407
+ dim=0,
408
+ )
409
+ dt_p, dt_d = torch.split(
410
+ dt,
411
+ [num_prefill_tokens, num_decodes],
412
+ dim=0,
413
+ )
414
+ # Split along batch dimension
415
+ state_indices_tensor_p, state_indices_tensor_d = torch.split(
416
+ state_indices_tensor,
417
+ [num_prefills, num_decodes],
418
+ dim=0,
419
+ )
420
+ query_start_loc_p = query_start_loc[: num_prefills + 1] if has_prefill else None
421
+
422
+ # Preallocate output tensor to avoid memcpy cost for merging prefill
423
+ # and decode outputs
424
+
425
+ preallocated_ssm_out = torch.empty(
426
+ [
427
+ projected_states.shape[0],
428
+ (self.num_heads * self.head_dim) // self.tp_size,
429
+ ],
430
+ dtype=hidden_states.dtype,
431
+ device=hidden_states.device,
432
+ )
433
+ preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
434
+ preallocated_ssm_out,
435
+ [num_prefill_tokens, num_decodes],
436
+ dim=0,
437
+ )
438
+
439
+ # Process prefill requests
440
+ if has_prefill:
441
+ mixed_metadata = metadata.mixed_metadata
442
+ assert mixed_metadata is not None
443
+ # 2. Convolution sequence transformation
444
+ # - "cache_indices" updates the conv_state cache in positions
445
+ # pointed to by "state_indices_tensor"
446
+ has_initial_states_p = mixed_metadata.has_initial_states
447
+ prep_initial_states = mixed_metadata.prep_initial_states
448
+ cache_indices = state_indices_tensor_p
449
+ x = hidden_states_B_C_p.transpose(
450
+ 0, 1
451
+ ) # this is the form that causal-conv see
452
+ ccfn = (
453
+ causal_conv1d_fn
454
+ if not use_triton_causal_conv
455
+ else causal_conv1d_fn_triton
456
+ )
457
+ hidden_states_B_C_p = ccfn(
458
+ x,
459
+ conv_weights,
460
+ self.conv1d.bias,
461
+ activation=self.activation,
462
+ conv_states=conv_state,
463
+ has_initial_state=has_initial_states_p,
464
+ cache_indices=cache_indices,
465
+ query_start_loc=query_start_loc_p,
466
+ seq_lens_cpu=mixed_metadata.extend_seq_lens_cpu,
467
+ ).transpose(0, 1)[:num_prefill_tokens]
468
+
469
+ hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p)
470
+
471
+ # 3. State Space Model sequence transformation
472
+ initial_states = None
473
+ if has_initial_states_p is not None and prep_initial_states:
474
+ initial_states = torch.where(
475
+ has_initial_states_p[:, None, None, None],
476
+ ssm_state[state_indices_tensor_p],
477
+ 0,
478
+ )
479
+
480
+ # NOTE: final output is an in-place update of out tensor
481
+ varlen_state = mamba_chunk_scan_combined(
482
+ hidden_states_p.view(
483
+ 1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
484
+ ),
485
+ dt_p.unsqueeze(0),
486
+ self.A,
487
+ B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
488
+ C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
489
+ chunk_size=mixed_metadata.chunk_size,
490
+ D=self.D,
491
+ z=None,
492
+ dt_bias=self.dt_bias,
493
+ seq_idx=mixed_metadata.seq_idx,
494
+ chunk_indices=mixed_metadata.chunk_indices,
495
+ chunk_offsets=mixed_metadata.chunk_offsets,
496
+ cu_seqlens=query_start_loc_p,
497
+ initial_states=initial_states,
498
+ return_varlen_states=True,
499
+ return_final_states=False,
500
+ dt_softplus=True,
501
+ dt_limit=(0.0, float("inf")),
502
+ out=preallocated_ssm_out_p.view(
503
+ 1, num_prefill_tokens, -1, self.head_dim
504
+ ),
505
+ state_dtype=ssm_state.dtype,
506
+ )
507
+
508
+ # update ssm states
509
+ # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
510
+ ssm_state[state_indices_tensor_p] = varlen_state
511
+
512
+ # Process decode requests
513
+ if has_decode:
514
+ # 2. Convolution sequence transformation
515
+ ccu = (
516
+ causal_conv1d_update
517
+ if not use_triton_causal_conv
518
+ else causal_conv1d_update_triton
519
+ )
520
+ hidden_states_B_C_d = ccu(
521
+ hidden_states_B_C_d,
522
+ conv_state,
523
+ conv_weights,
524
+ self.conv1d.bias,
525
+ self.activation,
526
+ conv_state_indices=state_indices_tensor_d,
527
+ )
528
+
529
+ hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
530
+
531
+ # 3. State Space Model sequence transformation
532
+ n_groups = self.n_groups // self.tp_size
533
+ A_d = (
534
+ self.A[:, None, ...][:, :, None]
535
+ .expand(-1, self.head_dim, self.ssm_state_size)
536
+ .to(dtype=torch.float32)
537
+ )
538
+ dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
539
+ dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
540
+ D_d = self.D[:, None, ...].expand(-1, self.head_dim)
541
+ B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
542
+ C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
543
+ hidden_states_d = hidden_states_d.view(
544
+ -1, self.num_heads // self.tp_size, self.head_dim
545
+ )
546
+
547
+ # - the hidden is reshaped into (bs, num_heads, head_dim)
548
+ # - layer_state.ssm_state's slots will be selected
549
+ # using state_indices_tensor_d
550
+ # NOTE: final output is an in-place update of out tensor
551
+ selective_state_update(
552
+ ssm_state,
553
+ hidden_states_d,
554
+ dt_d,
555
+ A_d,
556
+ B_d,
557
+ C_d,
558
+ D_d,
559
+ z=None,
560
+ dt_bias=dt_bias,
561
+ dt_softplus=True,
562
+ state_batch_indices=state_indices_tensor_d,
563
+ out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
564
+ )
565
+
566
+ # 4. gated MLP
567
+ # GatedRMSNorm internally applying SiLU to the gate
568
+ # SiLU is applied internally before normalization, unlike standard
569
+ # norm usage
570
+ hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens])
571
+
572
+ # 5. Final linear projection
573
+ output[:num_actual_tokens], _ = self.out_proj(hidden_states)
574
+
575
+ @property
576
+ def mamba_type(self) -> str:
577
+ return "mamba2"