sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,261 @@
1
+ # Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_combined.py
2
+
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
5
+
6
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
7
+ # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py
8
+
9
+ # ruff: noqa: E501
10
+
11
+ import torch
12
+ import triton
13
+ from einops import rearrange
14
+ from packaging import version
15
+
16
+ from .ssd_bmm import _bmm_chunk_fwd
17
+ from .ssd_chunk_scan import _chunk_scan_fwd
18
+ from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd, chunk_state_varlen
19
+ from .ssd_state_passing import _state_passing_fwd
20
+
21
+ TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
22
+
23
+
24
+ def is_int_pow_2(n):
25
+ return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
26
+
27
+
28
+ def _mamba_chunk_scan_combined_fwd(
29
+ x,
30
+ dt,
31
+ A,
32
+ B,
33
+ C,
34
+ chunk_size,
35
+ D=None,
36
+ z=None,
37
+ dt_bias=None,
38
+ initial_states=None,
39
+ seq_idx=None,
40
+ chunk_indices=None,
41
+ chunk_offsets=None,
42
+ cu_seqlens=None,
43
+ dt_softplus=False,
44
+ dt_limit=(0.0, float("inf")),
45
+ state_dtype=None,
46
+ out=None,
47
+ ):
48
+ assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
49
+ batch, seqlen, nheads, headdim = x.shape
50
+ _, _, ngroups, dstate = B.shape
51
+ assert nheads % ngroups == 0
52
+ assert B.shape == (batch, seqlen, ngroups, dstate)
53
+ assert dt.shape == (batch, seqlen, nheads)
54
+ assert A.shape == (nheads,)
55
+ assert C.shape == B.shape
56
+ if z is not None:
57
+ assert z.shape == x.shape
58
+ if D is not None:
59
+ assert D.shape == (nheads, headdim) or D.shape == (nheads,)
60
+ if seq_idx is not None:
61
+ assert seq_idx.shape == (batch, seqlen)
62
+ if B.stride(-1) != 1:
63
+ B = B.contiguous()
64
+ if C.stride(-1) != 1:
65
+ C = C.contiguous()
66
+ if (
67
+ x.stride(-1) != 1 and x.stride(1) != 1
68
+ ): # Either M or K dimension should be contiguous
69
+ x = x.contiguous()
70
+ if (
71
+ z is not None and z.stride(-1) != 1 and z.stride(1) != 1
72
+ ): # Either M or K dimension should be contiguous
73
+ z = z.contiguous()
74
+ if D is not None and D.stride(-1) != 1:
75
+ D = D.contiguous()
76
+ if initial_states is not None:
77
+ if cu_seqlens is None:
78
+ assert initial_states.shape == (batch, nheads, headdim, dstate)
79
+ else:
80
+ assert initial_states.shape == (
81
+ len(cu_seqlens) - 1,
82
+ nheads,
83
+ headdim,
84
+ dstate,
85
+ )
86
+
87
+ # This function executes 5 sub-functions for computing mamba
88
+ # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
89
+ # which has a minimal implementation to understand the below operations
90
+ # - as explained by the blog, mamba is a special case of causal attention
91
+ # - the idea is to chunk the attention matrix and compute each
92
+ # submatrix separately using different optimizations.
93
+ # - see the blog and paper for a visualization of the submatrices
94
+ # which we refer to in the comments below
95
+
96
+ # 1. Compute chunked cumsum of A * dt
97
+ # - here dt may go through a softplus activation
98
+ dA_cumsum, dt = _chunk_cumsum_fwd(
99
+ dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit
100
+ )
101
+
102
+ # 2. Compute the state for each intra-chunk
103
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
104
+ states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
105
+
106
+ # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
107
+ # (middle term of factorization of off-diag blocks; A terms)
108
+ # - for handling chunked prefill, this requires i) initial_states
109
+ # ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified.
110
+ # - When a new seq_idx is detected, we will stop passing the prev_state
111
+ # and switch accordingly to the init_state corresponding to the new seq_idx.
112
+ # - We will also make sure that the dA_cumsum is taken only from the start of the
113
+ # sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
114
+ # - this will ensure that states will be updated with the rightmost flushed seq_idx
115
+ # of the previous chunk. This implies that the first chunk of states is either 0
116
+ # or equal to init_states of the first example.
117
+ states, final_states = _state_passing_fwd(
118
+ rearrange(states, "... p n -> ... (p n)"),
119
+ dA_cumsum,
120
+ initial_states=(
121
+ rearrange(initial_states, "... p n -> ... (p n)")
122
+ if initial_states is not None
123
+ else None
124
+ ),
125
+ seq_idx=seq_idx,
126
+ chunk_size=chunk_size,
127
+ out_dtype=state_dtype if state_dtype is not None else C.dtype,
128
+ is_cont_batched=cu_seqlens is not None,
129
+ chunk_offsets=chunk_offsets,
130
+ )
131
+ states, final_states = (
132
+ rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]
133
+ )
134
+
135
+ # 4. Compute batched matrix multiply for C_j^T B_i terms
136
+ CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
137
+
138
+ # 5. Scan and compute the diagonal blocks, taking into
139
+ # account past causal states.
140
+ # - if initial states are provided, then states information will be
141
+ # augmented with initial_states.
142
+ # - to do this properly, we need to account for example changes in
143
+ # the continuous batch, therefore we introduce pseudo chunks, which is
144
+ # a chunk that is split up each time an example changes.
145
+ # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
146
+ # a seq_idx change, in which case we take states information from
147
+ # init_states.
148
+ out_x = _chunk_scan_fwd(
149
+ CB,
150
+ x,
151
+ dt,
152
+ dA_cumsum,
153
+ C,
154
+ states,
155
+ D=D,
156
+ z=z,
157
+ seq_idx=seq_idx,
158
+ chunk_indices=chunk_indices,
159
+ chunk_offsets=chunk_offsets,
160
+ initial_states=initial_states,
161
+ out=out,
162
+ )
163
+ if cu_seqlens is None:
164
+ return out_x, dt, dA_cumsum, states, final_states
165
+ else:
166
+ assert (
167
+ batch == 1
168
+ ), "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
169
+ varlen_states = chunk_state_varlen(
170
+ B.squeeze(0),
171
+ x.squeeze(0),
172
+ dt.squeeze(0),
173
+ dA_cumsum.squeeze(0),
174
+ cu_seqlens,
175
+ states.squeeze(0),
176
+ initial_states=initial_states,
177
+ )
178
+ return out_x, dt, dA_cumsum, states, final_states, varlen_states
179
+
180
+
181
+ def mamba_chunk_scan_combined(
182
+ x,
183
+ dt,
184
+ A,
185
+ B,
186
+ C,
187
+ chunk_size,
188
+ D=None,
189
+ z=None,
190
+ dt_bias=None,
191
+ initial_states=None,
192
+ seq_idx=None,
193
+ chunk_indices=None,
194
+ chunk_offsets=None,
195
+ cu_seqlens=None,
196
+ dt_softplus=False,
197
+ dt_limit=(0.0, float("inf")),
198
+ out=None,
199
+ return_final_states=False,
200
+ return_varlen_states=False,
201
+ state_dtype=None,
202
+ ):
203
+ """
204
+ Argument:
205
+ x: (batch, seqlen, nheads, headdim)
206
+ dt: (batch, seqlen, nheads)
207
+ A: (nheads)
208
+ B: (batch, seqlen, ngroups, dstate)
209
+ C: (batch, seqlen, ngroups, dstate)
210
+ chunk_size: int
211
+ D: (nheads, headdim) or (nheads,)
212
+ z: (batch, seqlen, nheads, headdim)
213
+ dt_bias: (nheads,)
214
+ initial_states: (batch, nheads, headdim, dstate)
215
+ seq_idx: (batch, seqlen)
216
+ cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
217
+ dt_softplus: Whether to apply softplus to dt
218
+ out: Preallocated output tensor
219
+ state_dtype: The data type of the ssm state
220
+ """
221
+
222
+ if not return_varlen_states:
223
+ cu_seqlens = None
224
+ else:
225
+ assert (
226
+ cu_seqlens is not None
227
+ ), "cu_seqlens must be provided if return_varlen_states is True"
228
+ out_x, dt_out, dA_cumsum, states, final_states, *rest = (
229
+ _mamba_chunk_scan_combined_fwd(
230
+ x,
231
+ dt,
232
+ A,
233
+ B,
234
+ C,
235
+ chunk_size,
236
+ D=D,
237
+ z=z,
238
+ dt_bias=dt_bias,
239
+ initial_states=initial_states,
240
+ seq_idx=seq_idx,
241
+ chunk_indices=chunk_indices,
242
+ chunk_offsets=chunk_offsets,
243
+ cu_seqlens=cu_seqlens,
244
+ dt_softplus=dt_softplus,
245
+ dt_limit=dt_limit,
246
+ out=out,
247
+ state_dtype=state_dtype,
248
+ )
249
+ )
250
+ if not return_varlen_states:
251
+ if not return_final_states:
252
+ return
253
+ else:
254
+ return final_states
255
+ else:
256
+ varlen_states = rest[0]
257
+ return (
258
+ (varlen_states)
259
+ if not return_final_states
260
+ else (final_states, varlen_states)
261
+ )
@@ -0,0 +1,264 @@
1
+ # Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
2
+
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
5
+
6
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
7
+ # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py
8
+
9
+ # ruff: noqa: E501
10
+
11
+ import torch
12
+ import triton
13
+ import triton.language as tl
14
+
15
+
16
+ @triton.jit
17
+ def _state_passing_fwd_kernel(
18
+ # Pointers to matrices
19
+ states_ptr,
20
+ out_ptr,
21
+ final_states_ptr,
22
+ dA_cs_ptr,
23
+ initstates_ptr,
24
+ seq_idx_ptr,
25
+ chunk_offsets_ptr,
26
+ chunk_meta_num,
27
+ # Matrix dimensions
28
+ dim,
29
+ nchunks,
30
+ seqlen,
31
+ chunk_size,
32
+ # Strides
33
+ stride_states_batch,
34
+ stride_states_chunk,
35
+ stride_states_head,
36
+ stride_states_dim,
37
+ stride_out_batch,
38
+ stride_out_chunk,
39
+ stride_out_head,
40
+ stride_out_dim,
41
+ stride_final_states_batch,
42
+ stride_final_states_head,
43
+ stride_final_states_dim,
44
+ stride_dA_cs_batch,
45
+ stride_dA_cs_chunk,
46
+ stride_dA_cs_head,
47
+ stride_dA_cs_csize,
48
+ stride_initstates_batch,
49
+ stride_initstates_head,
50
+ stride_initstates_dim,
51
+ stride_seq_idx_batch,
52
+ stride_seq_idx_seqlen,
53
+ # Meta-parameters
54
+ HAS_INITSTATES: tl.constexpr,
55
+ HAS_SEQ_IDX: tl.constexpr,
56
+ IS_CONT_BATCHED: tl.constexpr,
57
+ BLOCK_SIZE: tl.constexpr = 16,
58
+ ):
59
+ pid_b = tl.program_id(axis=1)
60
+ pid_h = tl.program_id(axis=2)
61
+ pid_m = tl.program_id(axis=0)
62
+ states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
63
+ dA_cs_ptr += (
64
+ pid_b * stride_dA_cs_batch
65
+ + pid_h * stride_dA_cs_head
66
+ + (chunk_size - 1) * stride_dA_cs_csize
67
+ )
68
+ out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
69
+ final_states_ptr += (
70
+ pid_b * stride_final_states_batch + pid_h * stride_final_states_head
71
+ )
72
+ if HAS_INITSTATES:
73
+ initstates_ptr += pid_h * stride_initstates_head
74
+ if not IS_CONT_BATCHED:
75
+ initstates_ptr += pid_b * stride_initstates_batch
76
+
77
+ if HAS_SEQ_IDX:
78
+ seq_idx_ptr += pid_b * stride_seq_idx_batch
79
+
80
+ offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
81
+ states_ptrs = states_ptr + offs_m * stride_states_dim
82
+ out_ptrs = out_ptr + offs_m * stride_out_dim
83
+ final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
84
+
85
+ # - states will be the past state of the sequence that continues on the current check
86
+ if not HAS_INITSTATES:
87
+ states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
88
+ else:
89
+ initstates_ptr += offs_m * stride_initstates_dim
90
+ initstates_ptrs = initstates_ptr
91
+ # - for cont batches, for the first chunk mean it will be the first batch's
92
+ # init state
93
+ states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
94
+
95
+ tl.store(out_ptrs, states, mask=offs_m < dim)
96
+ out_ptrs += stride_out_chunk
97
+ prev_seq_idx_chunk_end = 0
98
+ logical_chunk_idx = 0
99
+ for c in range(nchunks):
100
+ new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
101
+ dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
102
+ scale_mask = True
103
+ if HAS_SEQ_IDX:
104
+ # - the seq to pass forward is the one that is flushed to the right
105
+ # boundary.
106
+ # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
107
+ seq_idx_chunk_end = tl.load(
108
+ seq_idx_ptr
109
+ + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen
110
+ )
111
+ if HAS_INITSTATES:
112
+ if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end:
113
+ # this means in the current chunk the rightmost flushed seq
114
+ # has changed.
115
+ # - so we do not propagate the state from previous chunk
116
+ # - but rather we load that sequence's init state
117
+ initstates_ptrs = (
118
+ initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
119
+ )
120
+
121
+ # - update state with seq_idx_new's init state
122
+ states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(
123
+ tl.float32
124
+ )
125
+
126
+ # - we need to consider the cumsum only of the last sequence in the chunk
127
+ # - find its starting position (given by c_off of the logical chunk index)
128
+ # - and subtract the cumsum just before that position from the total cumsum
129
+ # - first, update the logical chunk index (add the number of sequences in the current physical chunk):
130
+ # sequence index at the start of the current chunk
131
+ seq_idx_chunk_start = tl.load(
132
+ seq_idx_ptr
133
+ + min(c * chunk_size, seqlen) * stride_seq_idx_seqlen
134
+ )
135
+ logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
136
+ # - load the chunk offset:
137
+ c_off = tl.load(
138
+ chunk_offsets_ptr + logical_chunk_idx,
139
+ mask=logical_chunk_idx < chunk_meta_num,
140
+ other=0,
141
+ )
142
+ # - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
143
+ if c_off > 0:
144
+ # - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
145
+ dA_cs_boundary = tl.load(
146
+ dA_cs_ptr
147
+ - (chunk_size - 1) * stride_dA_cs_csize
148
+ + (c_off - 1) * stride_dA_cs_csize,
149
+ mask=(c_off - 1) > -1 and c_off < chunk_size,
150
+ other=0.0,
151
+ )
152
+ dA_cs -= dA_cs_boundary
153
+
154
+ # - increment logical chunk index for every physical chunk
155
+ logical_chunk_idx += 1
156
+ else:
157
+ scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
158
+ prev_seq_idx_chunk_end = seq_idx_chunk_end
159
+
160
+ scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
161
+ states = scale * states + new_states
162
+ if c < nchunks - 1:
163
+ tl.store(out_ptrs, states, mask=offs_m < dim)
164
+ else:
165
+ tl.store(final_states_ptrs, states, mask=offs_m < dim)
166
+ states_ptrs += stride_states_chunk
167
+ dA_cs_ptr += stride_dA_cs_chunk
168
+ out_ptrs += stride_out_chunk
169
+
170
+
171
+ def _state_passing_fwd(
172
+ states,
173
+ dA_cumsum,
174
+ initial_states=None,
175
+ seq_idx=None,
176
+ chunk_size=None,
177
+ out_dtype=None,
178
+ is_cont_batched=False,
179
+ chunk_offsets=None,
180
+ ):
181
+ batch, nchunks, nheads, dim = states.shape
182
+ if chunk_size is None:
183
+ chunk_size = dA_cumsum.shape[-1]
184
+ else:
185
+ assert chunk_size == dA_cumsum.shape[-1]
186
+ assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
187
+ if initial_states is not None:
188
+ if is_cont_batched:
189
+ # - if cu_seqlens is provided, then the initial states
190
+ # are used for continuous batching. In which case we
191
+ # require seq_idx to be provided
192
+ assert (
193
+ seq_idx is not None
194
+ ), "seq_idx must be provided for continuous batching"
195
+ # - we also need chunk_offsets to be provided, to account
196
+ # for computation of dA_cumsum from the start of the
197
+ # sequence
198
+ assert (
199
+ chunk_offsets is not None
200
+ ), "chunk_offsets must be provided for continuous batching"
201
+ else:
202
+ # - this is the regular batching case, where initial
203
+ # states are used are for each example of the batch.
204
+ assert initial_states.shape == (batch, nheads, dim)
205
+
206
+ if seq_idx is not None:
207
+ seqlen = seq_idx.shape[-1]
208
+ assert seq_idx.shape == (batch, seqlen)
209
+ out_dtype = states.dtype if out_dtype is None else out_dtype
210
+ out = torch.empty(
211
+ (batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype
212
+ )
213
+ final_states = torch.empty(
214
+ (batch, nheads, dim), device=states.device, dtype=torch.float32
215
+ )
216
+ grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads)
217
+ with torch.cuda.device(states.device.index):
218
+ _state_passing_fwd_kernel[grid](
219
+ states,
220
+ out,
221
+ final_states,
222
+ dA_cumsum,
223
+ initial_states,
224
+ seq_idx,
225
+ chunk_offsets,
226
+ len(chunk_offsets) if chunk_offsets is not None else 0,
227
+ dim,
228
+ nchunks,
229
+ seqlen if seq_idx is not None else 0,
230
+ chunk_size,
231
+ states.stride(0),
232
+ states.stride(1),
233
+ states.stride(2),
234
+ states.stride(3),
235
+ out.stride(0),
236
+ out.stride(1),
237
+ out.stride(2),
238
+ out.stride(3),
239
+ final_states.stride(0),
240
+ final_states.stride(1),
241
+ final_states.stride(2),
242
+ dA_cumsum.stride(0),
243
+ dA_cumsum.stride(2),
244
+ dA_cumsum.stride(1),
245
+ dA_cumsum.stride(3),
246
+ *(
247
+ (
248
+ initial_states.stride(0),
249
+ initial_states.stride(1),
250
+ initial_states.stride(2),
251
+ )
252
+ if initial_states is not None
253
+ else (0, 0, 0)
254
+ ),
255
+ *(
256
+ (seq_idx.stride(0), seq_idx.stride(1))
257
+ if seq_idx is not None
258
+ else (0, 0)
259
+ ),
260
+ HAS_INITSTATES=initial_states is not None,
261
+ HAS_SEQ_IDX=seq_idx is not None,
262
+ IS_CONT_BATCHED=is_cont_batched,
263
+ )
264
+ return out, final_states