sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,206 @@
1
+ import logging
2
+
3
+ logger = logging.getLogger(__name__)
4
+
5
+ ATTENTION_BACKENDS = {}
6
+
7
+
8
+ def register_attention_backend(name):
9
+ def decorator(fn):
10
+ ATTENTION_BACKENDS[name] = fn
11
+ return fn
12
+
13
+ return decorator
14
+
15
+
16
+ @register_attention_backend("flashinfer")
17
+ def create_flashinfer_backend(runner):
18
+ import torch
19
+
20
+ if not runner.use_mla_backend:
21
+ from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
22
+
23
+ # Init streams
24
+ if runner.server_args.speculative_algorithm == "EAGLE":
25
+ if (
26
+ not hasattr(runner, "plan_stream_for_flashinfer")
27
+ or not runner.plan_stream_for_flashinfer
28
+ ):
29
+ runner.plan_stream_for_flashinfer = torch.cuda.Stream()
30
+ return FlashInferAttnBackend(runner)
31
+ else:
32
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
33
+ FlashInferMLAAttnBackend,
34
+ )
35
+
36
+ return FlashInferMLAAttnBackend(runner)
37
+
38
+
39
+ @register_attention_backend("trtllm_mla")
40
+ def create_trtllm_mla_backend(runner):
41
+ if not runner.use_mla_backend:
42
+ raise ValueError("trtllm_mla backend can only be used with MLA models.")
43
+ from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
44
+
45
+ return TRTLLMMLABackend(runner)
46
+
47
+
48
+ @register_attention_backend("aiter")
49
+ def create_aiter_backend(runner):
50
+ from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
51
+
52
+ return AiterAttnBackend(runner)
53
+
54
+
55
+ @register_attention_backend("wave")
56
+ def create_wave_backend(runner):
57
+ from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
58
+
59
+ return WaveAttnBackend(runner)
60
+
61
+
62
+ @register_attention_backend("ascend")
63
+ def create_ascend_backend(runner):
64
+ from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
65
+
66
+ return AscendAttnBackend(runner)
67
+
68
+
69
+ @register_attention_backend("nsa")
70
+ def create_nsa_backend(runner):
71
+ from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
72
+
73
+ return NativeSparseAttnBackend(runner)
74
+
75
+
76
+ @register_attention_backend("triton")
77
+ def create_triton_backend(runner):
78
+ assert not runner.model_config.is_encoder_decoder, (
79
+ "Cross attention is not supported in the triton attention backend. "
80
+ "Please use `--attention-backend flashinfer`."
81
+ )
82
+ if runner.server_args.enable_double_sparsity:
83
+ from sglang.srt.layers.attention.double_sparsity_backend import (
84
+ DoubleSparseAttnBackend,
85
+ )
86
+
87
+ return DoubleSparseAttnBackend(runner)
88
+ else:
89
+ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
90
+
91
+ return TritonAttnBackend(runner)
92
+
93
+
94
+ @register_attention_backend("torch_native")
95
+ def create_torch_native_backend(runner):
96
+ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
97
+
98
+ return TorchNativeAttnBackend(runner)
99
+
100
+
101
+ @register_attention_backend("flex_attention")
102
+ def create_flex_attention_backend(runner):
103
+ from sglang.srt.layers.attention.torch_flex_backend import TorchFlexAttnBackend
104
+
105
+ return TorchFlexAttnBackend(runner)
106
+
107
+
108
+ @register_attention_backend("flashmla")
109
+ def create_flashmla_backend(runner):
110
+ from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
111
+
112
+ return FlashMLABackend(runner)
113
+
114
+
115
+ @register_attention_backend("fa3")
116
+ def create_flashattention_v3_backend(runner):
117
+ import torch
118
+
119
+ assert (
120
+ torch.cuda.get_device_capability()[0] == 8 and not runner.use_mla_backend
121
+ ) or torch.cuda.get_device_capability()[0] == 9, (
122
+ "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
123
+ "Please use `--attention-backend flashinfer`."
124
+ )
125
+ from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
126
+
127
+ return FlashAttentionBackend(runner)
128
+
129
+
130
+ @register_attention_backend("fa4")
131
+ def create_flashattention_v4_backend(runner):
132
+ assert (
133
+ runner.use_mla_backend
134
+ ), "FlashAttention v4 Support is at an early stage, only MLA model supported now"
135
+ from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
136
+
137
+ return FlashAttentionBackend(runner, fa_impl_ver=4)
138
+
139
+
140
+ @register_attention_backend("cutlass_mla")
141
+ def create_cutlass_mla_backend(runner):
142
+ from sglang.srt.layers.attention.cutlass_mla_backend import CutlassMLABackend
143
+
144
+ return CutlassMLABackend(runner)
145
+
146
+
147
+ @register_attention_backend("trtllm_mha")
148
+ def create_trtllm_mha_backend(runner):
149
+ if runner.use_mla_backend:
150
+ raise ValueError("trtllm_mha backend can only be used with non-MLA models.")
151
+ from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
152
+
153
+ return TRTLLMHAAttnBackend(runner)
154
+
155
+
156
+ @register_attention_backend("intel_amx")
157
+ def create_intel_amx_backend(runner):
158
+ from sglang.srt.layers.attention.intel_amx_backend import IntelAMXAttnBackend
159
+
160
+ return IntelAMXAttnBackend(runner)
161
+
162
+
163
+ @register_attention_backend("dual_chunk_flash_attn")
164
+ def create_dual_chunk_flash_attn_backend(runner):
165
+ from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
166
+ DualChunkFlashAttentionBackend,
167
+ )
168
+
169
+ return DualChunkFlashAttentionBackend(runner)
170
+
171
+
172
+ def attn_backend_wrapper(runner, full_attn_backend):
173
+ """
174
+ Wrapper for special models like hybrid GDN, so we don't
175
+ need to change the code of the original attention backend.
176
+ """
177
+ assert not (
178
+ runner.is_hybrid_gdn and runner.use_mla_backend
179
+ ), "hybrid_gdn can only be used with non-MLA models."
180
+
181
+ # wrap for hybrid GDN models
182
+ if runner.is_hybrid_gdn:
183
+ from sglang.srt.utils import is_blackwell, is_npu
184
+
185
+ if is_blackwell():
186
+ assert (
187
+ runner.server_args.attention_backend == "triton"
188
+ or runner.server_args.attention_backend == "trtllm_mha"
189
+ ), "triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend."
190
+ if is_npu():
191
+ assert (
192
+ runner.server_args.attention_backend == "ascend"
193
+ ), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
194
+ logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
195
+ from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
196
+ HybridLinearAttnBackend,
197
+ MambaAttnBackend,
198
+ )
199
+
200
+ linear_attn_backend = MambaAttnBackend(runner)
201
+ full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
202
+ return HybridLinearAttnBackend(
203
+ full_attn_backend, linear_attn_backend, full_attn_layers
204
+ )
205
+
206
+ return full_attn_backend
@@ -6,9 +6,10 @@ from typing import TYPE_CHECKING, Optional, Union
6
6
  import torch
7
7
 
8
8
  if TYPE_CHECKING:
9
+ from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
9
10
  from sglang.srt.layers.radix_attention import RadixAttention
10
11
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
11
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
12
+ from sglang.srt.speculative.spec_info import SpecInput
12
13
 
13
14
 
14
15
  class AttentionBackend(ABC):
@@ -31,7 +32,7 @@ class AttentionBackend(ABC):
31
32
  seq_lens: torch.Tensor,
32
33
  encoder_lens: Optional[torch.Tensor],
33
34
  forward_mode: ForwardMode,
34
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
35
+ spec_info: Optional[SpecInput],
35
36
  ):
36
37
  """Init the metadata for a forward pass for capturing a cuda graph."""
37
38
  raise NotImplementedError()
@@ -44,7 +45,7 @@ class AttentionBackend(ABC):
44
45
  seq_lens_sum: int,
45
46
  encoder_lens: Optional[torch.Tensor],
46
47
  forward_mode: ForwardMode,
47
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
48
+ spec_info: Optional[SpecInput],
48
49
  seq_lens_cpu: Optional[torch.Tensor],
49
50
  ):
50
51
  """Init the metadata for a forward pass for replaying a cuda graph."""
@@ -115,3 +116,11 @@ class AttentionBackend(ABC):
115
116
  def support_triton(self):
116
117
  """Check if the current backend supports triton."""
117
118
  return True
119
+
120
+ def get_indexer_metadata(
121
+ self,
122
+ layer_id: int,
123
+ forward_batch: ForwardBatch,
124
+ ) -> Optional[BaseIndexerMetadata]:
125
+ """Get the indexer metadata. None means don't support indexer."""
126
+ return None
@@ -20,7 +20,7 @@ from sglang.srt.utils import is_cuda
20
20
  if TYPE_CHECKING:
21
21
  from sglang.srt.layers.radix_attention import RadixAttention
22
22
  from sglang.srt.model_executor.model_runner import ModelRunner
23
- from sglang.srt.speculative.spec_info import SpecInfo
23
+ from sglang.srt.speculative.spec_info import SpecInput
24
24
 
25
25
  _is_cuda = is_cuda()
26
26
  if _is_cuda:
@@ -151,7 +151,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
151
151
  seq_lens: torch.Tensor,
152
152
  encoder_lens: Optional[torch.Tensor],
153
153
  forward_mode: ForwardMode,
154
- spec_info: Optional[SpecInfo],
154
+ spec_info: Optional[SpecInput],
155
155
  ):
156
156
  if forward_mode.is_decode_or_idle():
157
157
  if spec_info is None:
@@ -190,7 +190,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
190
190
  seq_lens_sum: int,
191
191
  encoder_lens: Optional[torch.Tensor],
192
192
  forward_mode: ForwardMode,
193
- spec_info: Optional[SpecInfo],
193
+ spec_info: Optional[SpecInput],
194
194
  seq_lens_cpu: Optional[torch.Tensor],
195
195
  ):
196
196
 
@@ -1537,7 +1537,7 @@ class DualChunkFlashAttentionBackend(AttentionBackend):
1537
1537
  query_inter,
1538
1538
  key_cache,
1539
1539
  value_cache,
1540
- block_table[:, : decode_meta.max_seq_len_inter],
1540
+ block_table,
1541
1541
  decode_meta.seq_lens_inter,
1542
1542
  softmax_scale,
1543
1543
  causal=False,
@@ -0,0 +1,242 @@
1
+ # Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/chunk.py
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ import warnings
6
+ from typing import Optional
7
+
8
+ import torch
9
+ from einops import rearrange
10
+
11
+ from sglang.srt.layers.attention.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h
12
+ from sglang.srt.layers.attention.fla.chunk_o import chunk_fwd_o
13
+ from sglang.srt.layers.attention.fla.chunk_scaled_dot_kkt import (
14
+ chunk_scaled_dot_kkt_fwd,
15
+ )
16
+ from sglang.srt.layers.attention.fla.cumsum import chunk_local_cumsum
17
+ from sglang.srt.layers.attention.fla.l2norm import l2norm_fwd
18
+ from sglang.srt.layers.attention.fla.solve_tril import solve_tril
19
+ from sglang.srt.layers.attention.fla.utils import (
20
+ SUPPRESS_LEVEL,
21
+ autocast_custom_fwd,
22
+ input_guard,
23
+ )
24
+ from sglang.srt.layers.attention.fla.wy_fast import recompute_w_u_fwd
25
+
26
+
27
+ def chunk_gated_delta_rule_fwd(
28
+ q: torch.Tensor,
29
+ k: torch.Tensor,
30
+ v: torch.Tensor,
31
+ g: torch.Tensor,
32
+ beta: torch.Tensor,
33
+ scale: float,
34
+ initial_state: torch.Tensor,
35
+ output_final_state: bool,
36
+ cu_seqlens: Optional[torch.LongTensor] = None,
37
+ ):
38
+ g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
39
+ # obtain WY representation. u is actually the new v.
40
+ A = chunk_scaled_dot_kkt_fwd(
41
+ k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
42
+ )
43
+ A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
44
+ w, u = recompute_w_u_fwd(
45
+ k=k,
46
+ v=v,
47
+ beta=beta,
48
+ A=A,
49
+ g_cumsum=g,
50
+ cu_seqlens=cu_seqlens,
51
+ )
52
+ h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
53
+ k=k,
54
+ w=w,
55
+ u=u,
56
+ g=g,
57
+ initial_state=initial_state,
58
+ output_final_state=output_final_state,
59
+ cu_seqlens=cu_seqlens,
60
+ )
61
+ o = chunk_fwd_o(
62
+ q=q,
63
+ k=k,
64
+ v=v_new,
65
+ h=h,
66
+ g=g,
67
+ scale=scale,
68
+ cu_seqlens=cu_seqlens,
69
+ )
70
+ if SUPPRESS_LEVEL < 3:
71
+ return g, o, A, final_state, None, None, None
72
+ elif SUPPRESS_LEVEL >= 3:
73
+ return g, o, A, final_state, w, h, v_new
74
+
75
+
76
+ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
77
+
78
+ @staticmethod
79
+ @input_guard
80
+ @autocast_custom_fwd
81
+ def forward(
82
+ ctx,
83
+ q: torch.Tensor,
84
+ k: torch.Tensor,
85
+ v: torch.Tensor,
86
+ g: torch.Tensor,
87
+ beta: torch.Tensor,
88
+ scale: float,
89
+ initial_state: torch.Tensor,
90
+ output_final_state: bool,
91
+ cu_seqlens: Optional[torch.LongTensor] = None,
92
+ use_qk_l2norm_in_kernel: bool = False,
93
+ ):
94
+ q_orig = q
95
+ k_orig = k
96
+
97
+ if use_qk_l2norm_in_kernel:
98
+ q = l2norm_fwd(q)
99
+ k = l2norm_fwd(k)
100
+
101
+ g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
102
+ q=q,
103
+ k=k,
104
+ v=v,
105
+ g=g,
106
+ beta=beta,
107
+ scale=scale,
108
+ initial_state=initial_state,
109
+ output_final_state=output_final_state,
110
+ cu_seqlens=cu_seqlens,
111
+ )
112
+ return o.to(q.dtype), final_state
113
+
114
+
115
+ @torch.compiler.disable
116
+ def chunk_gated_delta_rule(
117
+ q: torch.Tensor,
118
+ k: torch.Tensor,
119
+ v: torch.Tensor,
120
+ g: torch.Tensor,
121
+ beta: torch.Tensor,
122
+ scale: float = None,
123
+ initial_state: torch.Tensor = None,
124
+ output_final_state: bool = False,
125
+ cu_seqlens: Optional[torch.LongTensor] = None,
126
+ head_first: bool = False,
127
+ use_qk_l2norm_in_kernel: bool = False,
128
+ ):
129
+ r"""
130
+ Args:
131
+ q (torch.Tensor):
132
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
133
+ k (torch.Tensor):
134
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
135
+ v (torch.Tensor):
136
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
137
+ g (torch.Tensor):
138
+ (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
139
+ beta (torch.Tensor):
140
+ betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
141
+ scale (Optional[int]):
142
+ Scale factor for the RetNet attention scores.
143
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
144
+ initial_state (Optional[torch.Tensor]):
145
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
146
+ For equal-length input sequences, `N` equals the batch size `B`.
147
+ Default: `None`.
148
+ output_final_state (Optional[bool]):
149
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
150
+ cu_seqlens (torch.LongTensor):
151
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
152
+ consistent with the FlashAttention API.
153
+ head_first (Optional[bool]):
154
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
155
+ Default: `False`.
156
+
157
+ Returns:
158
+ o (torch.Tensor):
159
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
160
+ final_state (torch.Tensor):
161
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
162
+
163
+ Examples::
164
+ >>> import torch
165
+ >>> import torch.nn.functional as F
166
+ >>> from einops import rearrange
167
+ >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
168
+ # inputs with equal lengths
169
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
170
+ >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
171
+ >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
172
+ >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
173
+ >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
174
+ >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
175
+ >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
176
+ >>> o, ht = chunk_gated_delta_rule(
177
+ q, k, v, g, beta,
178
+ initial_state=h0,
179
+ output_final_state=True
180
+ )
181
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
182
+ >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
183
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
184
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
185
+ >>> o_var, ht_var = chunk_gated_delta_rule(
186
+ q, k, v, g, beta,
187
+ initial_state=h0,
188
+ output_final_state=True,
189
+ cu_seqlens=cu_seqlens
190
+ )
191
+ """
192
+ assert q.dtype == k.dtype == v.dtype
193
+ assert (
194
+ q.dtype != torch.float32
195
+ ), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
196
+ assert (
197
+ len(beta.shape) == 3
198
+ ), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
199
+
200
+ if head_first:
201
+ raise DeprecationWarning(
202
+ "head_first is deprecated and will be removed in a future version. "
203
+ "Please use head_first=False for now instead."
204
+ )
205
+ q, k, v, beta, g = map(
206
+ lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
207
+ )
208
+ # if not head_first and q.shape[1] < q.shape[2]:
209
+ # warnings.warn(
210
+ # f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
211
+ # "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
212
+ # "when head_first=False was specified. "
213
+ # "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
214
+ # )
215
+ if cu_seqlens is not None:
216
+ if q.shape[0] != 1:
217
+ raise ValueError(
218
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
219
+ f"Please flatten variable-length inputs before processing."
220
+ )
221
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
222
+ raise ValueError(
223
+ f"The number of initial states is expected to be equal to the number of input sequences, "
224
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
225
+ )
226
+ if scale is None:
227
+ scale = k.shape[-1] ** -0.5
228
+ o, final_state = ChunkGatedDeltaRuleFunction.apply(
229
+ q,
230
+ k,
231
+ v,
232
+ g,
233
+ beta,
234
+ scale,
235
+ initial_state,
236
+ output_final_state,
237
+ cu_seqlens,
238
+ use_qk_l2norm_in_kernel,
239
+ )
240
+ if head_first:
241
+ o = rearrange(o, "b t h ... -> b h t ...")
242
+ return o, final_state