sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,262 @@
1
+ # Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_combined.py
2
+
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
5
+
6
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
7
+ # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py
8
+
9
+ # ruff: noqa: E501
10
+
11
+ import torch
12
+ import triton
13
+ import triton.language as tl
14
+ from einops import rearrange
15
+ from packaging import version
16
+
17
+ from .ssd_bmm import _bmm_chunk_fwd
18
+ from .ssd_chunk_scan import _chunk_scan_fwd
19
+ from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd, chunk_state_varlen
20
+ from .ssd_state_passing import _state_passing_fwd
21
+
22
+ TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
23
+
24
+
25
+ def is_int_pow_2(n):
26
+ return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
27
+
28
+
29
+ def _mamba_chunk_scan_combined_fwd(
30
+ x,
31
+ dt,
32
+ A,
33
+ B,
34
+ C,
35
+ chunk_size,
36
+ D=None,
37
+ z=None,
38
+ dt_bias=None,
39
+ initial_states=None,
40
+ seq_idx=None,
41
+ chunk_indices=None,
42
+ chunk_offsets=None,
43
+ cu_seqlens=None,
44
+ dt_softplus=False,
45
+ dt_limit=(0.0, float("inf")),
46
+ state_dtype=None,
47
+ out=None,
48
+ ):
49
+ assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
50
+ batch, seqlen, nheads, headdim = x.shape
51
+ _, _, ngroups, dstate = B.shape
52
+ assert nheads % ngroups == 0
53
+ assert B.shape == (batch, seqlen, ngroups, dstate)
54
+ assert dt.shape == (batch, seqlen, nheads)
55
+ assert A.shape == (nheads,)
56
+ assert C.shape == B.shape
57
+ if z is not None:
58
+ assert z.shape == x.shape
59
+ if D is not None:
60
+ assert D.shape == (nheads, headdim) or D.shape == (nheads,)
61
+ if seq_idx is not None:
62
+ assert seq_idx.shape == (batch, seqlen)
63
+ if B.stride(-1) != 1:
64
+ B = B.contiguous()
65
+ if C.stride(-1) != 1:
66
+ C = C.contiguous()
67
+ if (
68
+ x.stride(-1) != 1 and x.stride(1) != 1
69
+ ): # Either M or K dimension should be contiguous
70
+ x = x.contiguous()
71
+ if (
72
+ z is not None and z.stride(-1) != 1 and z.stride(1) != 1
73
+ ): # Either M or K dimension should be contiguous
74
+ z = z.contiguous()
75
+ if D is not None and D.stride(-1) != 1:
76
+ D = D.contiguous()
77
+ if initial_states is not None:
78
+ if cu_seqlens is None:
79
+ assert initial_states.shape == (batch, nheads, headdim, dstate)
80
+ else:
81
+ assert initial_states.shape == (
82
+ len(cu_seqlens) - 1,
83
+ nheads,
84
+ headdim,
85
+ dstate,
86
+ )
87
+
88
+ # This function executes 5 sub-functions for computing mamba
89
+ # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
90
+ # which has a minimal implementation to understand the below operations
91
+ # - as explained by the blog, mamba is a special case of causal attention
92
+ # - the idea is to chunk the attention matrix and compute each
93
+ # submatrix separately using different optimizations.
94
+ # - see the blog and paper for a visualization of the submatrices
95
+ # which we refer to in the comments below
96
+
97
+ # 1. Compute chunked cumsum of A * dt
98
+ # - here dt may go through a softplus activation
99
+ dA_cumsum, dt = _chunk_cumsum_fwd(
100
+ dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit
101
+ )
102
+
103
+ # 2. Compute the state for each intra-chunk
104
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
105
+ states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
106
+
107
+ # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
108
+ # (middle term of factorization of off-diag blocks; A terms)
109
+ # - for handling chunked prefill, this requires i) initial_states
110
+ # ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified.
111
+ # - When a new seq_idx is detected, we will stop passing the prev_state
112
+ # and switch accordingly to the init_state corresponding to the new seq_idx.
113
+ # - We will also make sure that the dA_cumsum is taken only from the start of the
114
+ # sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
115
+ # - this will ensure that states will be updated with the rightmost flushed seq_idx
116
+ # of the previous chunk. This implies that the first chunk of states is either 0
117
+ # or equal to init_states of the first example.
118
+ states, final_states = _state_passing_fwd(
119
+ rearrange(states, "... p n -> ... (p n)"),
120
+ dA_cumsum,
121
+ initial_states=(
122
+ rearrange(initial_states, "... p n -> ... (p n)")
123
+ if initial_states is not None
124
+ else None
125
+ ),
126
+ seq_idx=seq_idx,
127
+ chunk_size=chunk_size,
128
+ out_dtype=state_dtype if state_dtype is not None else C.dtype,
129
+ is_cont_batched=cu_seqlens is not None,
130
+ chunk_offsets=chunk_offsets,
131
+ )
132
+ states, final_states = (
133
+ rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]
134
+ )
135
+
136
+ # 4. Compute batched matrix multiply for C_j^T B_i terms
137
+ CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
138
+
139
+ # 5. Scan and compute the diagonal blocks, taking into
140
+ # account past causal states.
141
+ # - if initial states are provided, then states information will be
142
+ # augmented with initial_states.
143
+ # - to do this properly, we need to account for example changes in
144
+ # the continuous batch, therefore we introduce pseudo chunks, which is
145
+ # a chunk that is split up each time an example changes.
146
+ # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
147
+ # a seq_idx change, in which case we take states information from
148
+ # init_states.
149
+ out_x = _chunk_scan_fwd(
150
+ CB,
151
+ x,
152
+ dt,
153
+ dA_cumsum,
154
+ C,
155
+ states,
156
+ D=D,
157
+ z=z,
158
+ seq_idx=seq_idx,
159
+ chunk_indices=chunk_indices,
160
+ chunk_offsets=chunk_offsets,
161
+ initial_states=initial_states,
162
+ out=out,
163
+ )
164
+ if cu_seqlens is None:
165
+ return out_x, dt, dA_cumsum, states, final_states
166
+ else:
167
+ assert (
168
+ batch == 1
169
+ ), "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
170
+ varlen_states = chunk_state_varlen(
171
+ B.squeeze(0),
172
+ x.squeeze(0),
173
+ dt.squeeze(0),
174
+ dA_cumsum.squeeze(0),
175
+ cu_seqlens,
176
+ states.squeeze(0),
177
+ initial_states=initial_states,
178
+ )
179
+ return out_x, dt, dA_cumsum, states, final_states, varlen_states
180
+
181
+
182
+ def mamba_chunk_scan_combined(
183
+ x,
184
+ dt,
185
+ A,
186
+ B,
187
+ C,
188
+ chunk_size,
189
+ D=None,
190
+ z=None,
191
+ dt_bias=None,
192
+ initial_states=None,
193
+ seq_idx=None,
194
+ chunk_indices=None,
195
+ chunk_offsets=None,
196
+ cu_seqlens=None,
197
+ dt_softplus=False,
198
+ dt_limit=(0.0, float("inf")),
199
+ out=None,
200
+ return_final_states=False,
201
+ return_varlen_states=False,
202
+ state_dtype=None,
203
+ ):
204
+ """
205
+ Argument:
206
+ x: (batch, seqlen, nheads, headdim)
207
+ dt: (batch, seqlen, nheads)
208
+ A: (nheads)
209
+ B: (batch, seqlen, ngroups, dstate)
210
+ C: (batch, seqlen, ngroups, dstate)
211
+ chunk_size: int
212
+ D: (nheads, headdim) or (nheads,)
213
+ z: (batch, seqlen, nheads, headdim)
214
+ dt_bias: (nheads,)
215
+ initial_states: (batch, nheads, headdim, dstate)
216
+ seq_idx: (batch, seqlen)
217
+ cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
218
+ dt_softplus: Whether to apply softplus to dt
219
+ out: Preallocated output tensor
220
+ state_dtype: The data type of the ssm state
221
+ """
222
+
223
+ if not return_varlen_states:
224
+ cu_seqlens = None
225
+ else:
226
+ assert (
227
+ cu_seqlens is not None
228
+ ), "cu_seqlens must be provided if return_varlen_states is True"
229
+ out_x, dt_out, dA_cumsum, states, final_states, *rest = (
230
+ _mamba_chunk_scan_combined_fwd(
231
+ x,
232
+ dt,
233
+ A,
234
+ B,
235
+ C,
236
+ chunk_size,
237
+ D=D,
238
+ z=z,
239
+ dt_bias=dt_bias,
240
+ initial_states=initial_states,
241
+ seq_idx=seq_idx,
242
+ chunk_indices=chunk_indices,
243
+ chunk_offsets=chunk_offsets,
244
+ cu_seqlens=cu_seqlens,
245
+ dt_softplus=dt_softplus,
246
+ dt_limit=dt_limit,
247
+ out=out,
248
+ state_dtype=state_dtype,
249
+ )
250
+ )
251
+ if not return_varlen_states:
252
+ if not return_final_states:
253
+ return
254
+ else:
255
+ return final_states
256
+ else:
257
+ varlen_states = rest[0]
258
+ return (
259
+ (varlen_states)
260
+ if not return_final_states
261
+ else (final_states, varlen_states)
262
+ )
@@ -0,0 +1,275 @@
1
+ # Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
2
+
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
5
+
6
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
7
+ # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py
8
+
9
+ # ruff: noqa: E501
10
+
11
+ import torch
12
+ import triton
13
+ import triton.language as tl
14
+
15
+
16
+ # @triton.autotune(
17
+ # configs=[
18
+ # triton.Config({"BLOCK_SIZE": 64}),
19
+ # triton.Config({"BLOCK_SIZE": 128}),
20
+ # triton.Config({"BLOCK_SIZE": 256}),
21
+ # triton.Config({"BLOCK_SIZE": 512}),
22
+ # triton.Config({"BLOCK_SIZE": 1024}),
23
+ # triton.Config({"BLOCK_SIZE": 2048}),
24
+ # ],
25
+ # key=["dim"],
26
+ # )
27
+ @triton.jit
28
+ def _state_passing_fwd_kernel(
29
+ # Pointers to matrices
30
+ states_ptr,
31
+ out_ptr,
32
+ final_states_ptr,
33
+ dA_cs_ptr,
34
+ initstates_ptr,
35
+ seq_idx_ptr,
36
+ chunk_offsets_ptr,
37
+ chunk_meta_num,
38
+ # Matrix dimensions
39
+ dim,
40
+ nchunks,
41
+ seqlen,
42
+ chunk_size,
43
+ # Strides
44
+ stride_states_batch,
45
+ stride_states_chunk,
46
+ stride_states_head,
47
+ stride_states_dim,
48
+ stride_out_batch,
49
+ stride_out_chunk,
50
+ stride_out_head,
51
+ stride_out_dim,
52
+ stride_final_states_batch,
53
+ stride_final_states_head,
54
+ stride_final_states_dim,
55
+ stride_dA_cs_batch,
56
+ stride_dA_cs_chunk,
57
+ stride_dA_cs_head,
58
+ stride_dA_cs_csize,
59
+ stride_initstates_batch,
60
+ stride_initstates_head,
61
+ stride_initstates_dim,
62
+ stride_seq_idx_batch,
63
+ stride_seq_idx_seqlen,
64
+ # Meta-parameters
65
+ HAS_INITSTATES: tl.constexpr,
66
+ HAS_SEQ_IDX: tl.constexpr,
67
+ IS_CONT_BATCHED: tl.constexpr,
68
+ BLOCK_SIZE: tl.constexpr = 16,
69
+ ):
70
+ pid_b = tl.program_id(axis=1)
71
+ pid_h = tl.program_id(axis=2)
72
+ pid_m = tl.program_id(axis=0)
73
+ states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
74
+ dA_cs_ptr += (
75
+ pid_b * stride_dA_cs_batch
76
+ + pid_h * stride_dA_cs_head
77
+ + (chunk_size - 1) * stride_dA_cs_csize
78
+ )
79
+ out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
80
+ final_states_ptr += (
81
+ pid_b * stride_final_states_batch + pid_h * stride_final_states_head
82
+ )
83
+ if HAS_INITSTATES:
84
+ initstates_ptr += pid_h * stride_initstates_head
85
+ if not IS_CONT_BATCHED:
86
+ initstates_ptr += pid_b * stride_initstates_batch
87
+
88
+ if HAS_SEQ_IDX:
89
+ seq_idx_ptr += pid_b * stride_seq_idx_batch
90
+
91
+ offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
92
+ states_ptrs = states_ptr + offs_m * stride_states_dim
93
+ out_ptrs = out_ptr + offs_m * stride_out_dim
94
+ final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
95
+
96
+ # - states will be the past state of the sequence that continues on the current check
97
+ if not HAS_INITSTATES:
98
+ states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
99
+ else:
100
+ initstates_ptr += offs_m * stride_initstates_dim
101
+ initstates_ptrs = initstates_ptr
102
+ # - for cont batches, for the first chunk mean it will be the first batch's
103
+ # init state
104
+ states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
105
+
106
+ tl.store(out_ptrs, states, mask=offs_m < dim)
107
+ out_ptrs += stride_out_chunk
108
+ prev_seq_idx_chunk_end = 0
109
+ logical_chunk_idx = 0
110
+ for c in range(nchunks):
111
+ new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
112
+ dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
113
+ scale_mask = True
114
+ if HAS_SEQ_IDX:
115
+ # - the seq to pass forward is the one that is flushed to the right
116
+ # boundary.
117
+ # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
118
+ seq_idx_chunk_end = tl.load(
119
+ seq_idx_ptr
120
+ + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen
121
+ )
122
+ if HAS_INITSTATES:
123
+ if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end:
124
+ # this means in the current chunk the rightmost flushed seq
125
+ # has changed.
126
+ # - so we do not propagate the state from previous chunk
127
+ # - but rather we load that sequence's init state
128
+ initstates_ptrs = (
129
+ initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
130
+ )
131
+
132
+ # - update state with seq_idx_new's init state
133
+ states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(
134
+ tl.float32
135
+ )
136
+
137
+ # - we need to consider the cumsum only of the last sequence in the chunk
138
+ # - find its starting position (given by c_off of the logical chunk index)
139
+ # - and subtract the cumsum just before that position from the total cumsum
140
+ # - first, update the logical chunk index (add the number of sequences in the current physical chunk):
141
+ # sequence index at the start of the current chunk
142
+ seq_idx_chunk_start = tl.load(
143
+ seq_idx_ptr
144
+ + min(c * chunk_size, seqlen) * stride_seq_idx_seqlen
145
+ )
146
+ logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
147
+ # - load the chunk offset:
148
+ c_off = tl.load(
149
+ chunk_offsets_ptr + logical_chunk_idx,
150
+ mask=logical_chunk_idx < chunk_meta_num,
151
+ other=0,
152
+ )
153
+ # - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
154
+ if c_off > 0:
155
+ # - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
156
+ dA_cs_boundary = tl.load(
157
+ dA_cs_ptr
158
+ - (chunk_size - 1) * stride_dA_cs_csize
159
+ + (c_off - 1) * stride_dA_cs_csize,
160
+ mask=(c_off - 1) > -1 and c_off < chunk_size,
161
+ other=0.0,
162
+ )
163
+ dA_cs -= dA_cs_boundary
164
+
165
+ # - increment logical chunk index for every physical chunk
166
+ logical_chunk_idx += 1
167
+ else:
168
+ scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
169
+ prev_seq_idx_chunk_end = seq_idx_chunk_end
170
+
171
+ scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
172
+ states = scale * states + new_states
173
+ if c < nchunks - 1:
174
+ tl.store(out_ptrs, states, mask=offs_m < dim)
175
+ else:
176
+ tl.store(final_states_ptrs, states, mask=offs_m < dim)
177
+ states_ptrs += stride_states_chunk
178
+ dA_cs_ptr += stride_dA_cs_chunk
179
+ out_ptrs += stride_out_chunk
180
+
181
+
182
+ def _state_passing_fwd(
183
+ states,
184
+ dA_cumsum,
185
+ initial_states=None,
186
+ seq_idx=None,
187
+ chunk_size=None,
188
+ out_dtype=None,
189
+ is_cont_batched=False,
190
+ chunk_offsets=None,
191
+ ):
192
+ batch, nchunks, nheads, dim = states.shape
193
+ if chunk_size is None:
194
+ chunk_size = dA_cumsum.shape[-1]
195
+ else:
196
+ assert chunk_size == dA_cumsum.shape[-1]
197
+ assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
198
+ if initial_states is not None:
199
+ if is_cont_batched:
200
+ # - if cu_seqlens is provided, then the initial states
201
+ # are used for continuous batching. In which case we
202
+ # require seq_idx to be provided
203
+ assert (
204
+ seq_idx is not None
205
+ ), "seq_idx must be provided for continuous batching"
206
+ # - we also need chunk_offsets to be provided, to account
207
+ # for computation of dA_cumsum from the start of the
208
+ # sequence
209
+ assert (
210
+ chunk_offsets is not None
211
+ ), "chunk_offsets must be provided for continuous batching"
212
+ else:
213
+ # - this is the regular batching case, where initial
214
+ # states are used are for each example of the batch.
215
+ assert initial_states.shape == (batch, nheads, dim)
216
+
217
+ if seq_idx is not None:
218
+ seqlen = seq_idx.shape[-1]
219
+ assert seq_idx.shape == (batch, seqlen)
220
+ out_dtype = states.dtype if out_dtype is None else out_dtype
221
+ out = torch.empty(
222
+ (batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype
223
+ )
224
+ final_states = torch.empty(
225
+ (batch, nheads, dim), device=states.device, dtype=torch.float32
226
+ )
227
+ grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads)
228
+ with torch.cuda.device(states.device.index):
229
+ _state_passing_fwd_kernel[grid](
230
+ states,
231
+ out,
232
+ final_states,
233
+ dA_cumsum,
234
+ initial_states,
235
+ seq_idx,
236
+ chunk_offsets,
237
+ len(chunk_offsets) if chunk_offsets is not None else 0,
238
+ dim,
239
+ nchunks,
240
+ seqlen if seq_idx is not None else 0,
241
+ chunk_size,
242
+ states.stride(0),
243
+ states.stride(1),
244
+ states.stride(2),
245
+ states.stride(3),
246
+ out.stride(0),
247
+ out.stride(1),
248
+ out.stride(2),
249
+ out.stride(3),
250
+ final_states.stride(0),
251
+ final_states.stride(1),
252
+ final_states.stride(2),
253
+ dA_cumsum.stride(0),
254
+ dA_cumsum.stride(2),
255
+ dA_cumsum.stride(1),
256
+ dA_cumsum.stride(3),
257
+ *(
258
+ (
259
+ initial_states.stride(0),
260
+ initial_states.stride(1),
261
+ initial_states.stride(2),
262
+ )
263
+ if initial_states is not None
264
+ else (0, 0, 0)
265
+ ),
266
+ *(
267
+ (seq_idx.stride(0), seq_idx.stride(1))
268
+ if seq_idx is not None
269
+ else (0, 0)
270
+ ),
271
+ HAS_INITSTATES=initial_states is not None,
272
+ HAS_SEQ_IDX=seq_idx is not None,
273
+ IS_CONT_BATCHED=is_cont_batched,
274
+ )
275
+ return out, final_states