sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,562 @@
1
+ # Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
2
+
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
5
+
6
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
7
+ # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py
8
+
9
+ # ruff: noqa: E501,SIM102
10
+
11
+ import torch
12
+ import triton
13
+ import triton.language as tl
14
+ from packaging import version
15
+
16
+ TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
17
+
18
+
19
+ @triton.jit
20
+ def _chunk_scan_fwd_kernel(
21
+ # Pointers to matrices
22
+ cb_ptr,
23
+ x_ptr,
24
+ z_ptr,
25
+ out_ptr,
26
+ out_x_ptr,
27
+ dt_ptr,
28
+ dA_cumsum_ptr,
29
+ seq_idx_ptr,
30
+ C_ptr,
31
+ states_ptr,
32
+ D_ptr,
33
+ initstates_ptr,
34
+ chunk_indices_ptr,
35
+ chunk_offsets_ptr,
36
+ chunk_meta_num,
37
+ # Matrix dimensions
38
+ chunk_size,
39
+ hdim,
40
+ dstate,
41
+ batch,
42
+ seqlen,
43
+ nheads_ngroups_ratio,
44
+ # Strides
45
+ stride_cb_batch,
46
+ stride_cb_chunk,
47
+ stride_cb_head,
48
+ stride_cb_csize_m,
49
+ stride_cb_csize_k,
50
+ stride_x_batch,
51
+ stride_x_seqlen,
52
+ stride_x_head,
53
+ stride_x_hdim,
54
+ stride_z_batch,
55
+ stride_z_seqlen,
56
+ stride_z_head,
57
+ stride_z_hdim,
58
+ stride_out_batch,
59
+ stride_out_seqlen,
60
+ stride_out_head,
61
+ stride_out_hdim,
62
+ stride_dt_batch,
63
+ stride_dt_chunk,
64
+ stride_dt_head,
65
+ stride_dt_csize,
66
+ stride_dA_cs_batch,
67
+ stride_dA_cs_chunk,
68
+ stride_dA_cs_head,
69
+ stride_dA_cs_csize,
70
+ stride_seq_idx_batch,
71
+ stride_seq_idx_seqlen,
72
+ stride_C_batch,
73
+ stride_C_seqlen,
74
+ stride_C_head,
75
+ stride_C_dstate,
76
+ stride_states_batch,
77
+ stride_states_chunk,
78
+ stride_states_head,
79
+ stride_states_hdim,
80
+ stride_states_dstate,
81
+ stride_init_states_batch,
82
+ stride_init_states_head,
83
+ stride_init_states_hdim,
84
+ stride_init_states_dstate,
85
+ stride_D_head,
86
+ # Meta-parameters
87
+ IS_CAUSAL: tl.constexpr,
88
+ HAS_D: tl.constexpr,
89
+ D_HAS_HDIM: tl.constexpr,
90
+ HAS_Z: tl.constexpr,
91
+ HAS_SEQ_IDX: tl.constexpr,
92
+ BLOCK_SIZE_DSTATE: tl.constexpr,
93
+ IS_TRITON_22: tl.constexpr,
94
+ HAS_INITSTATES: tl.constexpr,
95
+ BLOCK_SIZE_M: tl.constexpr = 16,
96
+ BLOCK_SIZE_N: tl.constexpr = 16,
97
+ BLOCK_SIZE_K: tl.constexpr = 16,
98
+ ):
99
+ pid_bc = tl.program_id(axis=1).to(tl.int64)
100
+ pid_c = pid_bc // batch
101
+ pid_b = pid_bc - pid_c * batch
102
+ if not HAS_INITSTATES:
103
+ c_idx = pid_c
104
+ c_off = 0
105
+ else:
106
+ c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0)
107
+ c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0)
108
+
109
+ pid_h = tl.program_id(axis=2)
110
+ num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
111
+ pid_m = tl.program_id(axis=0) // num_pid_n
112
+ pid_n = tl.program_id(axis=0) % num_pid_n
113
+ cb_ptr += (
114
+ pid_b * stride_cb_batch
115
+ + c_idx * stride_cb_chunk
116
+ + (pid_h // nheads_ngroups_ratio) * stride_cb_head
117
+ )
118
+ x_ptr += (
119
+ pid_b * stride_x_batch
120
+ + c_idx * chunk_size * stride_x_seqlen
121
+ + pid_h * stride_x_head
122
+ )
123
+ dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head
124
+ dA_cumsum_ptr += (
125
+ pid_b * stride_dA_cs_batch
126
+ + c_idx * stride_dA_cs_chunk
127
+ + pid_h * stride_dA_cs_head
128
+ )
129
+ C_ptr += (
130
+ pid_b * stride_C_batch
131
+ + c_idx * chunk_size * stride_C_seqlen
132
+ + (pid_h // nheads_ngroups_ratio) * stride_C_head
133
+ )
134
+
135
+ # M-block offsets and prev states
136
+ # - logic in next block may override these if there is an active offset
137
+ offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
138
+ prev_states_ptr = (
139
+ states_ptr
140
+ + pid_b * stride_states_batch
141
+ + c_idx * stride_states_chunk
142
+ + pid_h * stride_states_head
143
+ )
144
+ prev_states_hdim = stride_states_hdim
145
+ prev_states_dstate = stride_states_dstate
146
+
147
+ chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size)
148
+ if HAS_SEQ_IDX:
149
+ seq_idx_ptr += (
150
+ pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen
151
+ )
152
+
153
+ # - we only need seq_idx_prev to be aligned to chunk boundary
154
+ seq_idx_prev = tl.load(
155
+ seq_idx_ptr - stride_seq_idx_seqlen, mask=c_idx >= 1, other=0
156
+ )
157
+
158
+ if HAS_INITSTATES:
159
+ # if there are init states, we only need seq_idx_m to point
160
+ # what is the current seq_idx
161
+
162
+ # get current seq idx
163
+ if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit:
164
+ seq_idx_m = tl.load(
165
+ seq_idx_ptr
166
+ + (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen,
167
+ )
168
+
169
+ # - recall that in ssd_state_passing, for the case c_off == 0
170
+ # i.e., the very first sequence, we made states_ptr hold its initial state
171
+ # so this edge case is taken care of
172
+ if (
173
+ (c_off == 0)
174
+ and (
175
+ seq_idx_prev != seq_idx_m
176
+ ) # if a seq is changed exactly on boundary
177
+ or (c_off > 0) # implies a new example (pseudo chunk)
178
+ ):
179
+
180
+ # - replace prev_states_ptr with init_states
181
+ prev_states_ptr = (
182
+ initstates_ptr
183
+ + seq_idx_m * stride_init_states_batch
184
+ + pid_h * stride_init_states_head
185
+ )
186
+ prev_states_hdim = stride_init_states_hdim # override strides
187
+ prev_states_dstate = stride_init_states_dstate
188
+
189
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
190
+ dA_cs_m = tl.load(
191
+ dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
192
+ ).to(tl.float32)
193
+
194
+ # - handle chunk state limit
195
+ if HAS_INITSTATES:
196
+
197
+ # have to split this if otherwise compilation will have problems
198
+ dA_cs_m_boundary = 0.0
199
+
200
+ # get the c_idx for the next (logica) chunk
201
+ c_idx_n = tl.load(
202
+ chunk_indices_ptr + (pid_c + 1),
203
+ mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
204
+ other=-1, # to trigger different chunk
205
+ )
206
+
207
+ # - there are things to consider
208
+ # A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct
209
+ # contribution of past states
210
+ # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to
211
+ # encroach into the next sequence, where c_off_n is the offset of the next
212
+ # (logical) chunk.
213
+ # An equivalent check for B is c_idx == c_idx_n, where there is repetition in
214
+ # (logical) chunk indices.
215
+
216
+ if (c_idx == c_idx_n) or c_off > 0:
217
+
218
+ # get the next offset
219
+ c_off_n = tl.load(
220
+ chunk_offsets_ptr + (pid_c + 1),
221
+ mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
222
+ other=chunk_size,
223
+ )
224
+
225
+ # in this case, adjust down the chunk_size_limit
226
+ if c_idx == c_idx_n:
227
+ chunk_size_limit = min(c_off_n, chunk_size_limit)
228
+
229
+ # get the cs at the offset boundary
230
+ # - c_off == 0 is a passthrough
231
+ # - We need dA_cs at the boundary, defined by c_off - no need
232
+ # to increase pointer by pid_m (it is a constant offset,
233
+ # i.e. the same for all blocks)
234
+ dA_cs_m_boundary = tl.load(
235
+ dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
236
+ mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
237
+ other=0.0,
238
+ ).to(tl.float32)
239
+
240
+ if HAS_SEQ_IDX:
241
+ # - handle seq idx when HAS_INITSTATES==False
242
+ if not HAS_INITSTATES:
243
+ seq_idx_m = tl.load(
244
+ seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
245
+ mask=offs_m < chunk_size_limit,
246
+ other=-1,
247
+ )
248
+
249
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
250
+
251
+ # Without the if (pid_c > -1), with Triton 2.1.0, I get
252
+ # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed.
253
+ # With Triton 2.2.0, this works
254
+ if IS_TRITON_22 or c_idx > -1:
255
+ # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
256
+ offs_k_dstate = tl.arange(
257
+ 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
258
+ )
259
+ C_ptrs = C_ptr + (
260
+ offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate
261
+ )
262
+
263
+ prev_states_ptrs = prev_states_ptr + (
264
+ offs_n[None, :] * prev_states_hdim
265
+ + offs_k_dstate[:, None] * prev_states_dstate
266
+ )
267
+ if HAS_SEQ_IDX:
268
+
269
+ if not HAS_INITSTATES:
270
+ # - this is for continuous batching where there is no init states
271
+ scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
272
+ else:
273
+ # - if there is initstates, we will rely on prev_states, no zeroing
274
+ # required.
275
+ scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary)
276
+ else:
277
+ scale_m = tl.exp(dA_cs_m)
278
+ if BLOCK_SIZE_DSTATE <= 128:
279
+ C = tl.load(
280
+ C_ptrs,
281
+ mask=(offs_m[:, None] < chunk_size_limit)
282
+ & (offs_k_dstate[None, :] < dstate),
283
+ other=0.0,
284
+ )
285
+
286
+ prev_states = tl.load(
287
+ prev_states_ptrs,
288
+ mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
289
+ other=0.0,
290
+ )
291
+ prev_states = prev_states.to(C_ptr.dtype.element_ty)
292
+ acc = tl.dot(C, prev_states) * scale_m[:, None]
293
+ else:
294
+ for k in range(0, dstate, BLOCK_SIZE_K):
295
+ C = tl.load(
296
+ C_ptrs,
297
+ mask=(offs_m[:, None] < chunk_size_limit)
298
+ & (offs_k_dstate[None, :] < dstate - k),
299
+ other=0.0,
300
+ )
301
+ # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty)
302
+ prev_states = tl.load(
303
+ prev_states_ptrs,
304
+ mask=(offs_k_dstate[:, None] < dstate - k)
305
+ & (offs_n[None, :] < hdim),
306
+ other=0.0,
307
+ )
308
+ prev_states = prev_states.to(C_ptr.dtype.element_ty)
309
+ acc += tl.dot(C, prev_states)
310
+ C_ptrs += BLOCK_SIZE_K
311
+ prev_states_ptrs += BLOCK_SIZE_K
312
+ acc *= scale_m[:, None]
313
+
314
+ offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off
315
+ cb_ptrs = cb_ptr + (
316
+ offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
317
+ )
318
+ x_ptrs = x_ptr + (
319
+ offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
320
+ )
321
+ dt_ptrs = dt_ptr + offs_k * stride_dt_csize
322
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
323
+ K_MAX = (
324
+ chunk_size_limit
325
+ if not IS_CAUSAL
326
+ else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
327
+ )
328
+ for k in range(0, K_MAX, BLOCK_SIZE_K):
329
+ cb = tl.load(
330
+ cb_ptrs,
331
+ mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k),
332
+ other=0.0,
333
+ ).to(tl.float32)
334
+ dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(
335
+ tl.float32
336
+ )
337
+ # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
338
+ # So we don't need masking wrt seq_idx here.
339
+ cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
340
+ dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
341
+ cb *= dt_k
342
+ if IS_CAUSAL:
343
+ mask = offs_m[:, None] >= k + offs_k[None, :]
344
+ cb = tl.where(mask, cb, 0.0)
345
+ cb = cb.to(x_ptr.dtype.element_ty)
346
+ x = tl.load(
347
+ x_ptrs,
348
+ mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim),
349
+ other=0.0,
350
+ )
351
+ acc += tl.dot(cb, x)
352
+ cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
353
+ x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
354
+ dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
355
+ dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
356
+
357
+ offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
358
+ offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
359
+
360
+ if HAS_D:
361
+ if D_HAS_HDIM:
362
+ D = tl.load(
363
+ D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
364
+ ).to(tl.float32)
365
+ else:
366
+ D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
367
+ x_residual = tl.load(
368
+ x_ptr
369
+ + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),
370
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
371
+ other=0.0,
372
+ ).to(tl.float32)
373
+ acc += x_residual * D
374
+
375
+ if HAS_Z:
376
+ out_x_ptr += (
377
+ pid_b * stride_out_batch
378
+ + c_idx * chunk_size * stride_out_seqlen
379
+ + pid_h * stride_out_head
380
+ )
381
+ out_x_ptrs = out_x_ptr + (
382
+ stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]
383
+ )
384
+ tl.store(
385
+ out_x_ptrs,
386
+ acc,
387
+ mask=(offs_out_m[:, None] < chunk_size_limit)
388
+ & (offs_out_n[None, :] < hdim),
389
+ )
390
+
391
+ z_ptr += (
392
+ pid_b * stride_z_batch
393
+ + c_idx * chunk_size * stride_z_seqlen
394
+ + pid_h * stride_z_head
395
+ )
396
+ z_ptrs = z_ptr + (
397
+ stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]
398
+ )
399
+ z = tl.load(
400
+ z_ptrs,
401
+ mask=(offs_out_m[:, None] < chunk_size_limit)
402
+ & (offs_out_n[None, :] < hdim),
403
+ other=0.0,
404
+ ).to(tl.float32)
405
+ acc *= z * tl.sigmoid(z)
406
+
407
+ out_ptr += (
408
+ pid_b * stride_out_batch
409
+ + c_idx * chunk_size * stride_out_seqlen
410
+ + pid_h * stride_out_head
411
+ )
412
+ out_ptrs = out_ptr + (
413
+ stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim
414
+ )
415
+ tl.store(
416
+ out_ptrs,
417
+ acc,
418
+ mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim),
419
+ )
420
+
421
+
422
+ def _chunk_scan_fwd(
423
+ cb,
424
+ x,
425
+ dt,
426
+ dA_cumsum,
427
+ C,
428
+ states,
429
+ D=None,
430
+ z=None,
431
+ seq_idx=None,
432
+ chunk_indices=None,
433
+ chunk_offsets=None,
434
+ initial_states=None,
435
+ out=None,
436
+ ):
437
+ batch, seqlen, nheads, headdim = x.shape
438
+ _, _, nchunks, chunk_size = dt.shape
439
+ _, _, ngroups, dstate = C.shape
440
+ assert nheads % ngroups == 0
441
+ assert C.shape == (batch, seqlen, ngroups, dstate)
442
+ assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
443
+ if z is not None:
444
+ assert z.shape == x.shape
445
+ if D is not None:
446
+ assert D.shape == (nheads, headdim) or D.shape == (nheads,)
447
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
448
+ assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
449
+ assert states.shape == (batch, nchunks, nheads, headdim, dstate)
450
+
451
+ if seq_idx is not None:
452
+ assert seq_idx.shape == (batch, seqlen)
453
+
454
+ if initial_states is not None:
455
+ # with initial states, we need to take care of how
456
+ # seq_idx crosses the boundaries
457
+ assert batch == 1, "chunk scan only supports initial states with batch 1"
458
+ assert (
459
+ chunk_indices is not None and chunk_offsets is not None
460
+ ), "chunk_indices and chunk_offsets should have been set"
461
+ else:
462
+ chunk_indices, chunk_offsets = None, None
463
+ else:
464
+ chunk_indices, chunk_offsets = None, None
465
+
466
+ assert out.shape == x.shape
467
+
468
+ if z is not None:
469
+ out_x = torch.empty_like(x)
470
+ assert out_x.stride() == out.stride()
471
+ else:
472
+ out_x = None
473
+
474
+ grid = lambda META: (
475
+ triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
476
+ * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
477
+ batch * nchunks if chunk_offsets is None else len(chunk_offsets),
478
+ nheads,
479
+ )
480
+ z_strides = (
481
+ (z.stride(0), z.stride(1), z.stride(2), z.stride(3))
482
+ if z is not None
483
+ else (0, 0, 0, 0)
484
+ )
485
+ _chunk_scan_fwd_kernel[grid](
486
+ cb,
487
+ x,
488
+ z,
489
+ out,
490
+ out_x,
491
+ dt,
492
+ dA_cumsum,
493
+ seq_idx,
494
+ C,
495
+ states,
496
+ D,
497
+ initial_states,
498
+ chunk_indices,
499
+ chunk_offsets,
500
+ len(chunk_indices) if chunk_indices is not None else 0,
501
+ chunk_size,
502
+ headdim,
503
+ dstate,
504
+ batch,
505
+ seqlen,
506
+ nheads // ngroups,
507
+ cb.stride(0),
508
+ cb.stride(1),
509
+ cb.stride(2),
510
+ cb.stride(3),
511
+ cb.stride(4),
512
+ x.stride(0),
513
+ x.stride(1),
514
+ x.stride(2),
515
+ x.stride(3),
516
+ z_strides[0],
517
+ z_strides[1],
518
+ z_strides[2],
519
+ z_strides[3],
520
+ out.stride(0),
521
+ out.stride(1),
522
+ out.stride(2),
523
+ out.stride(3),
524
+ dt.stride(0),
525
+ dt.stride(2),
526
+ dt.stride(1),
527
+ dt.stride(3),
528
+ dA_cumsum.stride(0),
529
+ dA_cumsum.stride(2),
530
+ dA_cumsum.stride(1),
531
+ dA_cumsum.stride(3),
532
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
533
+ C.stride(0),
534
+ C.stride(1),
535
+ C.stride(2),
536
+ C.stride(3),
537
+ states.stride(0),
538
+ states.stride(1),
539
+ states.stride(2),
540
+ states.stride(3),
541
+ states.stride(4),
542
+ *(
543
+ (
544
+ initial_states.stride(0),
545
+ initial_states.stride(1),
546
+ initial_states.stride(2),
547
+ initial_states.stride(3),
548
+ )
549
+ if initial_states is not None
550
+ else (0, 0, 0, 0)
551
+ ),
552
+ D.stride(0) if D is not None else 0,
553
+ True,
554
+ D is not None,
555
+ D.dim() == 2 if D is not None else True,
556
+ BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
557
+ HAS_Z=z is not None,
558
+ HAS_SEQ_IDX=seq_idx is not None,
559
+ IS_TRITON_22=TRITON_22,
560
+ HAS_INITSTATES=initial_states is not None,
561
+ )
562
+ return out_x