sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,629 @@
1
+ from typing import Callable, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from sglang.srt.configs.model_config import ModelConfig
7
+ from sglang.srt.custom_op import CustomOp
8
+ from sglang.srt.distributed import (
9
+ get_tensor_model_parallel_rank,
10
+ get_tensor_model_parallel_world_size,
11
+ tensor_model_parallel_all_gather,
12
+ tensor_model_parallel_all_reduce,
13
+ )
14
+ from sglang.srt.distributed.utils import divide
15
+ from sglang.srt.layers.attention.fla.layernorm_gated import layernorm_fn
16
+ from sglang.srt.layers.attention.mamba.causal_conv1d import (
17
+ causal_conv1d_fn,
18
+ causal_conv1d_update,
19
+ )
20
+ from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator
21
+ from sglang.srt.layers.attention.mamba.ops import (
22
+ mamba_chunk_scan_combined,
23
+ selective_state_update,
24
+ )
25
+ from sglang.srt.layers.linear import (
26
+ ColumnParallelLinear,
27
+ MergedColumnParallelLinear,
28
+ RowParallelLinear,
29
+ )
30
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
31
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
32
+ from sglang.srt.model_loader.weight_utils import (
33
+ composed_weight_loader,
34
+ sharded_weight_loader,
35
+ )
36
+ from sglang.srt.utils import set_weight_attrs
37
+
38
+ LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None]
39
+
40
+
41
+ def mamba_v2_sharded_weight_loader(
42
+ shard_spec: List[Tuple[int, int, float]],
43
+ tp_size: int,
44
+ tp_rank: int,
45
+ ) -> LoaderFunction:
46
+ """Create a weight loader for mamba v2. This ensures that the projections
47
+ are correctly sharded so that they can be split into x, B, C. It also
48
+ ensures the the all the groups corresponding to a head shard is placed
49
+ together with it.
50
+ """
51
+
52
+ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
53
+
54
+ # - track boundary of (sharded) param, and loaded_weight, respectively
55
+ boundary, loaded_boundary = 0, 0
56
+
57
+ # - iterate over the shard specs
58
+ for full_dim, extra, duplicate_groups in shard_spec:
59
+ # - full dim is the model dim (before TP).
60
+ # - extra > 0, means there is expected overall increase
61
+ # of dimensions. This is so because of replication.
62
+ # - ratio is used map the tp_rank to the actual shard
63
+ # rank. This is useful when there is replication of
64
+ # groups to accompany head shards.
65
+
66
+ # - size of the loaded shard
67
+ shard_size = full_dim // tp_size
68
+
69
+ # - compute the rank into the loaded shard.
70
+ # - if there is replication, different TP shards will
71
+ # take from the same rank.
72
+ # NOTE: currently we only support duplication
73
+ # in the case where num_groups == 1
74
+ rank = 0 if duplicate_groups else tp_rank
75
+
76
+ # - leftmost boundary index into loaded weight.
77
+ loaded_skip = rank * shard_size
78
+ loaded_start_idx = loaded_boundary + loaded_skip
79
+
80
+ # - take these many dims from the loaded weight.
81
+ take = min(shard_size, full_dim - extra - loaded_skip)
82
+
83
+ # - always shard on dim 0
84
+ # - the ignore is for a mundane mypy error as it does not
85
+ # seem to handle slices well.
86
+ # https://github.com/python/mypy/issues/2410
87
+ param.data[
88
+ boundary : (boundary + take), ... # type: ignore[misc]
89
+ ] = loaded_weight[
90
+ loaded_start_idx : (loaded_start_idx + take) # type: ignore[misc]
91
+ ] # type: ignore[misc]
92
+
93
+ # move indexing boundaries
94
+ boundary += shard_size
95
+ loaded_boundary += full_dim - extra
96
+
97
+ return loader
98
+
99
+
100
+ class Mixer2RMSNormGated(CustomOp):
101
+
102
+ def __init__(
103
+ self,
104
+ full_hidden_size: int,
105
+ full_n_groups: int,
106
+ use_rms_norm: bool = True,
107
+ eps: float = 1e-6,
108
+ ):
109
+ super().__init__()
110
+ self.tp_size = get_tensor_model_parallel_world_size()
111
+ self.tp_rank = get_tensor_model_parallel_rank()
112
+ self.full_hidden_size = full_hidden_size
113
+ self.group_size = full_hidden_size // full_n_groups
114
+ self.per_rank_hidden_size = full_hidden_size // self.tp_size
115
+ self.n_groups = full_hidden_size // self.group_size
116
+
117
+ self.variance_epsilon = eps
118
+ self.use_rms_norm = use_rms_norm
119
+ if self.use_rms_norm:
120
+ # Register norm weight only if we're actually applying RMSNorm
121
+ self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
122
+ set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
123
+ else:
124
+ # Avoid checkpoint mismatch by skipping unused parameter
125
+ self.register_parameter("weight", None)
126
+ assert (
127
+ self.full_hidden_size % self.tp_size == 0
128
+ ), "Tensor parallel world size must divide hidden size."
129
+
130
+ def forward_native(
131
+ self,
132
+ x: torch.Tensor,
133
+ gate: torch.Tensor,
134
+ ):
135
+ # Three tensor-parallel cases:
136
+ # 1. n_groups is 1
137
+ # In this case we parallelize along the reduction dim.
138
+ # Each rank computes a local sum of squares followed by AllReduce
139
+ # 2. tp_size divides n_groups
140
+ # Each rank only reduces within its local group(s).
141
+ # No collective ops necessary.
142
+ # 3. The general case can be pretty complicated so we AllGather
143
+ # the input and then redundantly compute the RMSNorm.
144
+ input_dtype = x.dtype
145
+ x = x * nn.functional.silu(gate.to(torch.float32))
146
+ if not self.use_rms_norm:
147
+ return x.to(input_dtype)
148
+
149
+ if self.n_groups == 1:
150
+ if self.tp_size > 1:
151
+ # Compute local sum and then reduce to obtain global sum
152
+ local_sums = x.pow(2).sum(dim=-1, keepdim=True)
153
+ global_sums = tensor_model_parallel_all_reduce(local_sums)
154
+ # Calculate the variance
155
+ count = self.tp_size * x.shape[-1]
156
+ variance = global_sums / count
157
+
158
+ else:
159
+ variance = x.pow(2).mean(-1, keepdim=True)
160
+ x = x * torch.rsqrt(variance + self.variance_epsilon)
161
+ else:
162
+ redundant_tp: bool = self.n_groups % self.tp_size != 0
163
+ if redundant_tp:
164
+ # To handle the general case, redundantly apply the variance
165
+ x = tensor_model_parallel_all_gather(x, -1)
166
+
167
+ *prefix_dims, hidden_dim = x.shape
168
+ group_count = hidden_dim // self.group_size
169
+ x_grouped = x.view(*prefix_dims, group_count, self.group_size)
170
+ variance = x_grouped.pow(2).mean(-1, keepdim=True)
171
+ x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
172
+ x = x_grouped.view(*prefix_dims, hidden_dim)
173
+
174
+ if redundant_tp:
175
+ start = self.per_rank_hidden_size * self.tp_rank
176
+ end = start + self.per_rank_hidden_size
177
+ x = x[..., start:end]
178
+
179
+ return self.weight * x.to(input_dtype)
180
+
181
+ def forward_cuda(
182
+ self,
183
+ x: torch.Tensor,
184
+ gate: torch.Tensor,
185
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
186
+ input_dtype = x.dtype
187
+ if not self.use_rms_norm:
188
+ # Keep gate in float32 for numerical stability during silu
189
+ return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
190
+
191
+ if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
192
+ return self.forward_native(x, gate)
193
+
194
+ return layernorm_fn(
195
+ x,
196
+ self.weight.data,
197
+ bias=None,
198
+ z=gate,
199
+ eps=self.variance_epsilon,
200
+ norm_before_gate=False,
201
+ )
202
+
203
+
204
+ class MambaMixer2(torch.nn.Module):
205
+ """
206
+ Compute ∆, A, B, C, and D the state space parameters and compute
207
+ the `contextualized_states`. A, D are input independent
208
+ (see Mamba paper [1] Section 3.5.2 "Interpretation of A"
209
+ for why A isn't selective) ∆, B, C are input-dependent
210
+ (this is a key difference between Mamba and the linear time
211
+ invariant S4, and is why Mamba is called
212
+ **selective** state spaces)
213
+ """
214
+
215
+ def __init__(
216
+ self,
217
+ hidden_size: int,
218
+ ssm_state_size: int,
219
+ conv_kernel_size: int,
220
+ intermediate_size: int,
221
+ use_conv_bias: bool,
222
+ use_bias: bool,
223
+ chunk_size: int,
224
+ layer_id: int,
225
+ n_groups: int = 1,
226
+ num_heads: int = 128,
227
+ head_dim: int = 64,
228
+ rms_norm_eps: float = 1e-5,
229
+ activation: str = "silu",
230
+ use_rms_norm: bool = True,
231
+ model_config: Optional[ModelConfig] = None,
232
+ # cache_config: Optional[CacheConfig] = None,
233
+ quant_config: Optional[QuantizationConfig] = None,
234
+ prefix: str = "",
235
+ ):
236
+ super().__init__()
237
+
238
+ # For TP, the sharding plan is as follows:
239
+ # - for the conv modules, since
240
+ # conv_dim = intermediate_size * 2 * n_groups * ssm_state_size,
241
+ # we shard intermediate_size and n_groups
242
+ # - since intermediate_size = n_heads * head_dim, sharding on
243
+ # intermediate_size is achieved by sharding on n_heads.
244
+ # - IF, world_size divides groups, then sharding
245
+ # (n_groups / world_size, n_heads / world_size)
246
+ # also maintains the invariant n_heads % n_groups == 0
247
+ # - HOWEVER IF, world_size DOES NOT divide groups, then we need
248
+ # to allocate extra space in the shard, such that groups
249
+ # may be replicated to follow the head shard.
250
+ # - NOTE: currently for the world size DOES NOT divide groups
251
+ # case, we only support the case when n_groups == 1
252
+ self.tp_size = get_tensor_model_parallel_world_size()
253
+ self.tp_rank = get_tensor_model_parallel_rank()
254
+
255
+ assert (
256
+ num_heads % self.tp_size == 0
257
+ ), "Tensor parallel world size must divide num heads."
258
+
259
+ assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
260
+ "If tensor parallel world size does not divide num_groups, "
261
+ "then num_groups must equal 1."
262
+ )
263
+
264
+ self.ssm_state_size = ssm_state_size
265
+ self.conv_kernel_size = conv_kernel_size
266
+ self.activation = activation
267
+ self.layer_id = layer_id
268
+
269
+ self.intermediate_size = intermediate_size
270
+ self.head_dim = head_dim
271
+ self.num_heads = num_heads
272
+ self.chunk_size = chunk_size
273
+
274
+ self.n_groups = n_groups
275
+ if n_groups % self.tp_size != 0:
276
+ # - for TP we shard conv_dim by sharding on n_groups,
277
+ # - but if n_groups cannot divide tp_size, we need to
278
+ # extend some extra groups
279
+ groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
280
+ n_groups, self.tp_size
281
+ )
282
+ self.n_groups = n_groups + groups
283
+
284
+ self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
285
+ self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
286
+
287
+ self.conv1d = MergedColumnParallelLinear(
288
+ input_size=conv_kernel_size,
289
+ output_sizes=[
290
+ intermediate_size,
291
+ self.groups_ssm_state_size,
292
+ self.groups_ssm_state_size,
293
+ ],
294
+ bias=use_conv_bias,
295
+ quant_config=None,
296
+ prefix=f"{prefix}.conv1d",
297
+ )
298
+
299
+ self.in_proj = MergedColumnParallelLinear(
300
+ input_size=hidden_size,
301
+ output_sizes=[
302
+ intermediate_size,
303
+ intermediate_size,
304
+ self.groups_ssm_state_size,
305
+ self.groups_ssm_state_size,
306
+ self.num_heads,
307
+ ],
308
+ bias=use_bias,
309
+ prefix=f"{prefix}.in_proj",
310
+ )
311
+ if n_groups % self.tp_size != 0:
312
+ # This is the n_groups == 1 case,
313
+ # where we need to duplicate groups if TP>1.
314
+
315
+ # - because in_proj is a concatenation of 3 weights, we
316
+ # need to interleave them before sharding
317
+ # - use the custom weight loader mamba_v2_sharded_weight_loader
318
+ # for conv1d.bias, covn1d.weight and in_proj.weight
319
+ # - need to set these settings, to assign the groups
320
+ # to the head shards
321
+ group_shard_settings = (
322
+ self.groups_ssm_state_size, # expected model size
323
+ (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
324
+ n_groups == 1, # if there was only one group
325
+ )
326
+ intermediate_settings = (intermediate_size, 0, False)
327
+ head_settings = (self.num_heads, 0, False)
328
+
329
+ # - the weight already has a "weight_loader" attribute
330
+ # which set_weight_attrs will raise if we do not
331
+ # delete before trying to override it
332
+ # - ditto for the other two weights below
333
+ delattr(self.conv1d.bias, "weight_loader")
334
+ set_weight_attrs(
335
+ self.conv1d.bias,
336
+ {
337
+ "weight_loader": mamba_v2_sharded_weight_loader(
338
+ [
339
+ intermediate_settings,
340
+ group_shard_settings,
341
+ group_shard_settings,
342
+ ],
343
+ self.tp_size,
344
+ self.tp_rank,
345
+ )
346
+ },
347
+ )
348
+
349
+ delattr(self.conv1d.weight, "weight_loader")
350
+ set_weight_attrs(
351
+ self.conv1d.weight,
352
+ {
353
+ "weight_loader": mamba_v2_sharded_weight_loader(
354
+ [
355
+ intermediate_settings,
356
+ group_shard_settings,
357
+ group_shard_settings,
358
+ ],
359
+ self.tp_size,
360
+ self.tp_rank,
361
+ )
362
+ },
363
+ )
364
+
365
+ if quant_config is None:
366
+ # - quant layers do not have a weight loader
367
+ delattr(self.in_proj.weight, "weight_loader")
368
+ set_weight_attrs(
369
+ self.in_proj.weight,
370
+ {
371
+ "weight_loader": mamba_v2_sharded_weight_loader(
372
+ [
373
+ intermediate_settings, # for gate
374
+ intermediate_settings,
375
+ group_shard_settings,
376
+ group_shard_settings,
377
+ head_settings, # for dt
378
+ ],
379
+ self.tp_size,
380
+ self.tp_rank,
381
+ )
382
+ },
383
+ )
384
+
385
+ # unsqueeze to fit conv1d weights shape into the linear weights shape.
386
+ # Can't do this in `weight_loader` since it already exists in
387
+ # `ColumnParallelLinear` and `MergedColumnParallelLinear`,
388
+ # and `set_weight_attrs` doesn't allow to override it
389
+ self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
390
+
391
+ # - these are TPed by heads to reduce the size of the
392
+ # temporal shape
393
+ self.A = nn.Parameter(
394
+ torch.empty(
395
+ divide(num_heads, self.tp_size),
396
+ dtype=torch.float32,
397
+ )
398
+ )
399
+ self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
400
+ self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
401
+ self.use_rms_norm = use_rms_norm
402
+
403
+ set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
404
+ a_weight_loader = composed_weight_loader(
405
+ sharded_weight_loader(0), lambda x: -torch.exp(x.float())
406
+ )
407
+ set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
408
+ set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
409
+
410
+ self.out_proj = RowParallelLinear(
411
+ intermediate_size,
412
+ hidden_size,
413
+ bias=use_bias,
414
+ input_is_parallel=True,
415
+ quant_config=quant_config,
416
+ prefix=f"{prefix}.out_proj",
417
+ reduce_results=False,
418
+ )
419
+
420
+ self.norm = Mixer2RMSNormGated(
421
+ intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
422
+ )
423
+
424
+ # The tuple is (conv_state, ssm_state)
425
+ self.kv_cache = (torch.tensor([]), torch.tensor([]))
426
+
427
+ self.model_config = model_config
428
+ self.prefix = prefix
429
+
430
+ def forward_native(
431
+ self,
432
+ hidden_states: torch.Tensor,
433
+ output: torch.Tensor,
434
+ mup_vector: Optional[torch.Tensor] = None,
435
+ ):
436
+ pass
437
+
438
+ def forward(
439
+ self,
440
+ hidden_states: torch.Tensor,
441
+ output: torch.Tensor,
442
+ forward_batch: ForwardBatch,
443
+ mup_vector: Optional[torch.Tensor] = None,
444
+ ):
445
+ # attn_backend_list[-1] gives access to MambaAttnBackend
446
+ mamba_backend = forward_batch.attn_backend.attn_backend_list[-1]
447
+ attn_metadata = mamba_backend.forward_metadata
448
+ state_indices_tensor = attn_metadata.mamba_cache_indices
449
+ chunk_size = self.chunk_size
450
+
451
+ conv_state, ssm_state, *rest = mamba_backend.req_to_token_pool.get_mamba_params(
452
+ self.layer_id
453
+ )
454
+
455
+ assert (
456
+ ssm_state.size(1) == self.ssm_state_size
457
+ ), f"dstate must be {self.ssm_state_size}, got {ssm_state.size(1)}"
458
+
459
+ query_start_loc = attn_metadata.query_start_loc
460
+
461
+ chunk_size = self.chunk_size
462
+
463
+ # TODO: properly support this
464
+ prep_initial_states = False
465
+
466
+ # 1. Gated MLP's linear projection
467
+ projected_states, _ = self.in_proj(hidden_states)
468
+
469
+ if mup_vector is not None:
470
+ projected_states = projected_states * mup_vector
471
+
472
+ gate, hidden_states_B_C, dt = torch.split(
473
+ projected_states,
474
+ [
475
+ self.intermediate_size // self.tp_size,
476
+ self.conv_dim // self.tp_size,
477
+ self.num_heads // self.tp_size,
478
+ ],
479
+ dim=-1,
480
+ )
481
+ conv_weights = self.conv1d.weight.view(
482
+ self.conv1d.weight.size(0), self.conv1d.weight.size(2)
483
+ )
484
+
485
+ # - get hidden_states, B and C after depthwise convolution.
486
+ split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
487
+ hidden_states_B_C,
488
+ [
489
+ self.intermediate_size // self.tp_size,
490
+ self.groups_ssm_state_size // self.tp_size,
491
+ self.groups_ssm_state_size // self.tp_size,
492
+ ],
493
+ dim=-1,
494
+ )
495
+
496
+ preallocated_ssm_out = torch.empty(
497
+ [
498
+ projected_states.shape[0],
499
+ (self.num_heads * self.head_dim) // self.tp_size,
500
+ ],
501
+ dtype=hidden_states.dtype,
502
+ device=hidden_states.device,
503
+ )
504
+
505
+ # Process prefill requests
506
+ if forward_batch.forward_mode.is_extend():
507
+ # 2. Convolution sequence transformation
508
+ # - "cache_indices" updates the conv_state cache in positions
509
+ # pointed to by "state_indices_tensor"
510
+ num_prefill_tokens = forward_batch.extend_num_tokens or 0
511
+ has_initial_states = forward_batch.extend_prefix_lens > 0
512
+ cache_indices = attn_metadata.mamba_cache_indices
513
+
514
+ x = hidden_states_B_C.transpose(
515
+ 0, 1
516
+ ) # this is the form that causal-conv see
517
+ hidden_states_B_C = causal_conv1d_fn(
518
+ x,
519
+ conv_weights,
520
+ self.conv1d.bias,
521
+ activation=self.activation,
522
+ conv_states=conv_state,
523
+ has_initial_state=has_initial_states,
524
+ cache_indices=cache_indices,
525
+ query_start_loc=query_start_loc,
526
+ seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
527
+ ).transpose(0, 1)
528
+
529
+ hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C)
530
+
531
+ # 3. State Space Model sequence transformation
532
+ initial_states = None
533
+
534
+ if has_initial_states is not None and prep_initial_states:
535
+ initial_states = torch.where(
536
+ has_initial_states[:, None, None, None],
537
+ ssm_state[state_indices_tensor],
538
+ 0,
539
+ )
540
+
541
+ # NOTE: final output is an in-place update of out tensor
542
+ varlen_state = mamba_chunk_scan_combined(
543
+ hidden_states.view(
544
+ 1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
545
+ ),
546
+ dt.unsqueeze(0),
547
+ self.A,
548
+ B.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
549
+ C.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
550
+ chunk_size=chunk_size,
551
+ D=self.D,
552
+ z=None,
553
+ dt_bias=self.dt_bias,
554
+ cu_seqlens=query_start_loc,
555
+ initial_states=initial_states,
556
+ return_varlen_states=True,
557
+ return_final_states=False,
558
+ dt_softplus=True,
559
+ dt_limit=(0.0, float("inf")),
560
+ out=preallocated_ssm_out.view(1, num_prefill_tokens, -1, self.head_dim),
561
+ state_dtype=ssm_state.dtype,
562
+ )
563
+
564
+ # update ssm states
565
+ # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
566
+ ssm_state[state_indices_tensor] = varlen_state.permute(0, 3, 2, 1)
567
+ elif forward_batch.forward_mode.is_decode():
568
+ num_decodes = len(query_start_loc) - 1
569
+ # 2. Convolution sequence transformation
570
+ hidden_states_B_C = causal_conv1d_update(
571
+ hidden_states_B_C,
572
+ conv_state,
573
+ conv_weights,
574
+ self.conv1d.bias,
575
+ self.activation,
576
+ conv_state_indices=state_indices_tensor,
577
+ )
578
+
579
+ hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C)
580
+
581
+ # 3. State Space Model sequence transformation
582
+ n_groups = self.n_groups // self.tp_size
583
+ A = (
584
+ self.A[:, None, ...][:, :, None]
585
+ .expand(-1, self.head_dim, self.ssm_state_size)
586
+ .to(dtype=torch.float32)
587
+ )
588
+ dt = dt[:, :, None].expand(-1, -1, self.head_dim)
589
+ dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
590
+ D = self.D[:, None, ...].expand(-1, self.head_dim)
591
+ B = B.view(-1, n_groups, B.shape[1] // n_groups)
592
+ C = C.view(-1, n_groups, C.shape[1] // n_groups)
593
+ hidden_states = hidden_states.view(
594
+ -1, self.num_heads // self.tp_size, self.head_dim
595
+ )
596
+
597
+ # - the hidden is reshaped into (bs, num_heads, head_dim)
598
+ # - mamba_cache_params.ssm_state's slots will be selected
599
+ # using state_indices_tensor_d
600
+ # NOTE: final output is an in-place update of out tensor
601
+ selective_state_update(
602
+ ssm_state.permute(0, 3, 2, 1),
603
+ hidden_states,
604
+ dt,
605
+ A,
606
+ B,
607
+ C,
608
+ D,
609
+ z=None,
610
+ dt_bias=dt_bias,
611
+ dt_softplus=True,
612
+ state_batch_indices=state_indices_tensor,
613
+ out=preallocated_ssm_out.view(num_decodes, -1, self.head_dim),
614
+ )
615
+ elif forward_batch.forward_mode.is_idle():
616
+ preallocated_ssm_out = preallocated_ssm_out
617
+
618
+ # 4. gated MLP
619
+ # GatedRMSNorm internally applying SiLU to the gate
620
+ # SiLU is applied internally before normalization, unlike standard
621
+ # norm usage
622
+ hidden_states = self.norm(preallocated_ssm_out, gate)
623
+
624
+ # 5. Final linear projection
625
+ output[:], _ = self.out_proj(hidden_states)
626
+
627
+ @property
628
+ def mamba_type(self) -> str:
629
+ return "mamba2"