sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -7,15 +7,19 @@ from typing import Any, Dict, List, Optional, Tuple, Union
7
7
 
8
8
  import torch
9
9
  import torch.nn as nn
10
+ import triton
11
+ import triton.language as tl
10
12
 
11
13
  from sglang.srt.custom_op import CustomOp
12
14
  from sglang.srt.utils import (
13
15
  cpu_has_amx_support,
14
16
  get_bool_env_var,
17
+ get_compiler_backend,
15
18
  is_cpu,
16
19
  is_cuda,
17
20
  is_hip,
18
21
  is_npu,
22
+ is_xpu,
19
23
  )
20
24
 
21
25
  _is_cuda = is_cuda()
@@ -24,15 +28,22 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
24
28
  _is_npu = is_npu()
25
29
  _is_cpu_amx_available = cpu_has_amx_support()
26
30
  _is_cpu = is_cpu()
31
+ _is_xpu = is_xpu()
27
32
 
28
33
  if _is_cuda:
29
- from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
34
+ from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
35
+ else:
36
+ FusedSetKVBufferArg = None
37
+
30
38
  if _use_aiter:
31
39
  from aiter.rotary_embedding import get_rope as aiter_get_rope
32
40
 
33
41
  if is_npu():
34
42
  import torch_npu
35
43
 
44
+ NPU_ROTARY_MUL_MAX_NUM_HEADS = 1000
45
+ NPU_ROTARY_MUL_MAX_HEAD_SIZE = 896
46
+
36
47
 
37
48
  def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
38
49
  x1 = x[..., : x.shape[-1] // 2]
@@ -101,9 +112,11 @@ class RotaryEmbedding(CustomOp):
101
112
  if not _is_cuda:
102
113
  cache = cache.to(dtype)
103
114
 
104
- if (
105
- not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]
106
- ) and not (_is_cpu and _is_cpu_amx_available):
115
+ if dtype == torch.float32 or (
116
+ (not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512])
117
+ and not (_is_cpu and _is_cpu_amx_available)
118
+ and not (_is_xpu)
119
+ ):
107
120
  from vllm._custom_ops import rotary_embedding
108
121
 
109
122
  self.vllm_rotary_embedding = rotary_embedding
@@ -142,8 +155,13 @@ class RotaryEmbedding(CustomOp):
142
155
  query: torch.Tensor,
143
156
  key: torch.Tensor,
144
157
  offsets: Optional[torch.Tensor] = None,
158
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
145
159
  ) -> Tuple[torch.Tensor, torch.Tensor]:
146
160
  """A PyTorch-native implementation of forward()."""
161
+ assert (
162
+ fused_set_kv_buffer_arg is None
163
+ ), "fused_set_kv_buffer_arg is not supported for native implementation"
164
+
147
165
  if offsets is not None:
148
166
  positions = positions + offsets
149
167
  positions = positions.flatten()
@@ -172,12 +190,17 @@ class RotaryEmbedding(CustomOp):
172
190
  query: torch.Tensor,
173
191
  key: torch.Tensor,
174
192
  offsets: Optional[torch.Tensor] = None,
193
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
175
194
  ) -> Tuple[torch.Tensor, torch.Tensor]:
176
195
  """A PyTorch-npu implementation of forward()."""
177
- import os
196
+ assert (
197
+ fused_set_kv_buffer_arg is None
198
+ ), "fused_set_kv_buffer_arg is not supported for npu implementation"
178
199
 
179
200
  if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
180
- return self.forward_native(positions, query, key, offsets)
201
+ return self.forward_native(
202
+ positions, query, key, offsets, fused_set_kv_buffer_arg
203
+ )
181
204
  else:
182
205
  rotary_mode = "half"
183
206
  if self.is_neox_style:
@@ -202,7 +225,12 @@ class RotaryEmbedding(CustomOp):
202
225
  query: torch.Tensor,
203
226
  key: torch.Tensor,
204
227
  offsets: Optional[torch.Tensor] = None,
228
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
205
229
  ) -> Tuple[torch.Tensor, torch.Tensor]:
230
+ assert (
231
+ fused_set_kv_buffer_arg is None
232
+ ), "fused_set_kv_buffer_arg is not supported for cpu implementation"
233
+
206
234
  positions = torch.add(positions, offsets) if offsets is not None else positions
207
235
  if _is_cpu_amx_available:
208
236
  return torch.ops.sgl_kernel.rotary_embedding_cpu(
@@ -214,7 +242,9 @@ class RotaryEmbedding(CustomOp):
214
242
  self.is_neox_style,
215
243
  )
216
244
  else:
217
- return self.forward_native(positions, query, key, offsets)
245
+ return self.forward_native(
246
+ positions, query, key, offsets, fused_set_kv_buffer_arg
247
+ )
218
248
 
219
249
  def forward_cuda(
220
250
  self,
@@ -222,9 +252,13 @@ class RotaryEmbedding(CustomOp):
222
252
  query: torch.Tensor,
223
253
  key: torch.Tensor,
224
254
  offsets: Optional[torch.Tensor] = None,
225
- fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg]
255
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
226
256
  ) -> Tuple[torch.Tensor, torch.Tensor]:
227
- if _is_cuda and (self.head_size in [64, 128, 256, 512]):
257
+ if (
258
+ _is_cuda
259
+ and (self.head_size in [64, 128, 256, 512])
260
+ and self.dtype != torch.float32
261
+ ):
228
262
  apply_rope_with_cos_sin_cache_inplace(
229
263
  positions=positions,
230
264
  query=query,
@@ -260,6 +294,17 @@ class RotaryEmbedding(CustomOp):
260
294
  s += f", base={self.base}, is_neox_style={self.is_neox_style}"
261
295
  return s
262
296
 
297
+ def forward_xpu(
298
+ self,
299
+ positions: torch.Tensor,
300
+ query: torch.Tensor,
301
+ key: torch.Tensor,
302
+ offsets: Optional[torch.Tensor] = None,
303
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
304
+ # TODO: make a wrapper, and XPU will implement this kernel later.
305
+ self.cos_sin_cache = self.cos_sin_cache.to(query.device)
306
+ return self.forward_native(positions, query, key, offsets)
307
+
263
308
 
264
309
  class LinearScalingRotaryEmbedding(RotaryEmbedding):
265
310
  """RotaryEmbedding extended with linear scaling.
@@ -782,27 +827,33 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
782
827
  key: torch.Tensor,
783
828
  offsets: Optional[torch.Tensor] = None,
784
829
  ) -> Tuple[torch.Tensor, torch.Tensor]:
785
- # NOTE: now npu_mrope can only support `numQHeads*headSize <= 4096` pattern,
786
- # and generalization to more scenarios will be supported in the future.
787
- if query.shape[1] * query.shape[2] > 4096:
788
- return self.forward_native(positions, query, key, offsets)
789
- num_tokens = query.shape[0]
790
- rotary_mode = "half" if self.is_neox_style else "interleave"
830
+ num_tokens, num_q_heads, _ = query.shape
831
+ num_k_heads = key.shape[1]
832
+
791
833
  self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
834
+ cos_sin = self.cos_sin_cache[
835
+ torch.add(positions, offsets) if offsets is not None else positions
836
+ ]
837
+ cos, sin = cos_sin.chunk(2, dim=-1)
838
+ # Reshape to [batchsize, head_dim, seq, rotary_dim]
839
+ cos = cos.repeat(1, 2).unsqueeze(-2).unsqueeze(-2)
840
+ sin = sin.repeat(1, 2).unsqueeze(-2).unsqueeze(-2)
841
+
792
842
  query_rot = query[..., : self.rotary_dim]
793
843
  key_rot = key[..., : self.rotary_dim]
794
844
  if self.rotary_dim < self.head_size:
795
845
  query_pass = query[..., self.rotary_dim :]
796
846
  key_pass = key[..., self.rotary_dim :]
797
847
 
798
- query_rot, key_rot = torch_npu.npu_mrope(
799
- torch.add(positions, offsets) if offsets is not None else positions,
800
- query_rot.reshape(num_tokens, -1),
801
- key_rot.reshape(num_tokens, -1),
802
- self.cos_sin_cache,
803
- self.rotary_dim,
804
- mrope_section=[0, 0, 0],
805
- rotary_mode=rotary_mode,
848
+ query_rot = torch_npu.npu_interleave_rope(
849
+ query_rot.reshape(num_tokens, num_q_heads, 1, self.rotary_dim),
850
+ cos,
851
+ sin,
852
+ )
853
+ key_rot = torch_npu.npu_interleave_rope(
854
+ key_rot.reshape(num_tokens, num_k_heads, 1, self.rotary_dim),
855
+ cos,
856
+ sin,
806
857
  )
807
858
  query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim)
808
859
  key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim)
@@ -978,6 +1029,199 @@ class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
978
1029
  return cache
979
1030
 
980
1031
 
1032
+ def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor:
1033
+ """Apply interleaved MRoPE to 3D rotary embeddings.
1034
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
1035
+ interleaved [THTHWHTHW...TT], preserving frequency continuity.
1036
+ """
1037
+ x_t = x[0].clone()
1038
+ x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3]
1039
+ x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3]
1040
+ return x_t
1041
+
1042
+
1043
+ @triton.jit
1044
+ def _triton_mrope_forward(
1045
+ q_ptr,
1046
+ k_ptr,
1047
+ cos,
1048
+ sin,
1049
+ num_tokens,
1050
+ n_qh: tl.constexpr,
1051
+ n_kh: tl.constexpr,
1052
+ hd: tl.constexpr,
1053
+ rd: tl.constexpr,
1054
+ pad_n_qh: tl.constexpr,
1055
+ pad_n_kh: tl.constexpr,
1056
+ pad_hd: tl.constexpr,
1057
+ mrope_section_t: tl.constexpr,
1058
+ mrope_section_h: tl.constexpr,
1059
+ mrope_section_w: tl.constexpr,
1060
+ is_interleaved: tl.constexpr,
1061
+ ):
1062
+ # Adapted from
1063
+ # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
1064
+ # This version supports flatten input tensors from vllm
1065
+ # and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
1066
+ # instead of (3, bsz, seq_len, head_dim), also supports interleaved rotary
1067
+ pid = tl.program_id(0)
1068
+ # locate start address
1069
+ q_ptr = q_ptr + pid * (n_qh * hd)
1070
+ k_ptr = k_ptr + pid * (n_kh * hd)
1071
+
1072
+ # ####################################################################
1073
+ # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
1074
+ # m of this program instance
1075
+ # ####################################################################
1076
+ # Note: cos and sin now have shape (3, num_tokens, head_dim // 2)
1077
+
1078
+ # Updated stride calculation for half head_dim
1079
+ half_rd = rd // 2
1080
+ t_cos = cos + pid * half_rd
1081
+ h_cos = t_cos + num_tokens * half_rd
1082
+ w_cos = h_cos + num_tokens * half_rd
1083
+ t_sin = sin + pid * half_rd
1084
+ h_sin = t_sin + num_tokens * half_rd
1085
+ w_sin = h_sin + num_tokens * half_rd
1086
+
1087
+ # Updated offsets for half head_dim
1088
+ cos_offsets = tl.arange(0, pad_hd // 2)
1089
+ if is_interleaved:
1090
+ h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h)
1091
+ w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w)
1092
+ t_mask = ~(h_mask | w_mask)
1093
+ else:
1094
+ t_end = mrope_section_t
1095
+ h_end = t_end + mrope_section_h
1096
+ t_mask = cos_offsets < mrope_section_t
1097
+ h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
1098
+ w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
1099
+
1100
+ t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
1101
+ h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
1102
+ w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
1103
+ t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
1104
+ h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
1105
+ w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
1106
+
1107
+ cos_row = t_cos_row + h_cos_row + w_cos_row
1108
+ sin_row = t_sin_row + h_sin_row + w_sin_row
1109
+
1110
+ # ####################################################################
1111
+ # Load the left and right half of q and k for the current
1112
+ # program instance (i.e. for the current token) separately
1113
+ # ####################################################################
1114
+ # left half of the head
1115
+ first_half_q_offsets = (
1116
+ tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
1117
+ )
1118
+ first_half_k_offsets = (
1119
+ tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
1120
+ )
1121
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
1122
+ tl.arange(0, pad_hd // 2)[None, :] < rd // 2
1123
+ )
1124
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
1125
+ tl.arange(0, pad_hd // 2)[None, :] < rd // 2
1126
+ )
1127
+
1128
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
1129
+ sin_row.dtype
1130
+ )
1131
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
1132
+ sin_row.dtype
1133
+ )
1134
+
1135
+ # right half of the head
1136
+ second_half_q_offsets = first_half_q_offsets + (rd // 2)
1137
+ second_half_k_offsets = first_half_k_offsets + (rd // 2)
1138
+ second_q_mask = first_q_mask
1139
+ second_k_mask = first_k_mask
1140
+
1141
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
1142
+ sin_row.dtype
1143
+ )
1144
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
1145
+ sin_row.dtype
1146
+ )
1147
+
1148
+ # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
1149
+ # Since cos and sin are now half-size,
1150
+ # we use the same cos_row and sin_row for both halves
1151
+ new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
1152
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
1153
+ new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
1154
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
1155
+
1156
+ new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
1157
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
1158
+ new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
1159
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
1160
+
1161
+
1162
+ def triton_mrope(
1163
+ q: torch.Tensor,
1164
+ k: torch.Tensor,
1165
+ cos: torch.Tensor,
1166
+ sin: torch.Tensor,
1167
+ mrope_section: list[int],
1168
+ head_size: int,
1169
+ rotary_dim: int,
1170
+ mrope_interleaved: bool,
1171
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1172
+ """The mrope triton kernel.
1173
+
1174
+ Args:
1175
+ q: [num_tokens, num_heads * head_size]
1176
+ k: [num_tokens, num_kv_heads * head_size]
1177
+ cos: [3, num_tokens, head_size //2 ]
1178
+ (T/H/W positions with multimodal inputs)
1179
+ sin: [3, num_tokens, head_size //2 ]
1180
+ (T/H/W positions with multimodal inputs)
1181
+ mrope_section: [t, h, w]
1182
+ head_size: int
1183
+ """
1184
+ n_row, n_q_head_head_dim = q.shape
1185
+ assert (
1186
+ n_q_head_head_dim % head_size == 0
1187
+ ), f"q shape {n_q_head_head_dim} must be divisible by head_size {head_size}"
1188
+ n_q_head = n_q_head_head_dim // head_size
1189
+ assert (
1190
+ k.shape[1] % head_size == 0
1191
+ ), f"k shape {k.shape[1]} must be divisible by head_size {head_size}"
1192
+ n_kv_head = k.shape[1] // head_size
1193
+ pad_hd = triton.next_power_of_2(head_size)
1194
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
1195
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
1196
+
1197
+ # ensure tensors passed into the kernel are contiguous.
1198
+ # It will be no-op if they are already contiguous
1199
+ q = q.contiguous()
1200
+ k = k.contiguous()
1201
+ cos = cos.contiguous()
1202
+ sin = sin.contiguous()
1203
+
1204
+ _triton_mrope_forward[(n_row,)](
1205
+ q,
1206
+ k,
1207
+ cos,
1208
+ sin,
1209
+ n_row,
1210
+ n_q_head,
1211
+ n_kv_head,
1212
+ head_size,
1213
+ rotary_dim,
1214
+ pad_n_q_head,
1215
+ pad_n_kv_head,
1216
+ pad_hd,
1217
+ mrope_section[0],
1218
+ mrope_section[1],
1219
+ mrope_section[2],
1220
+ mrope_interleaved,
1221
+ )
1222
+ return q, k
1223
+
1224
+
981
1225
  class MRotaryEmbedding(RotaryEmbedding):
982
1226
  """Rotary Embedding with Multimodal Sections."""
983
1227
 
@@ -990,12 +1234,14 @@ class MRotaryEmbedding(RotaryEmbedding):
990
1234
  is_neox_style: bool,
991
1235
  dtype: torch.dtype,
992
1236
  mrope_section: Optional[List[int]] = None,
1237
+ mrope_interleaved: bool = False,
993
1238
  ) -> None:
994
1239
  super().__init__(
995
1240
  head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
996
1241
  )
997
1242
 
998
1243
  self.mrope_section = mrope_section
1244
+ self.mrope_interleaved = mrope_interleaved
999
1245
  if self.mrope_section:
1000
1246
  expected_sum = rotary_dim // 2
1001
1247
  actual_sum = sum(self.mrope_section)
@@ -1029,12 +1275,22 @@ class MRotaryEmbedding(RotaryEmbedding):
1029
1275
  f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
1030
1276
  )
1031
1277
 
1032
- @torch.compile(dynamic=True)
1033
- def forward(
1278
+ def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
1279
+ # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
1280
+ # is expensive, so avoid calling it if possible
1281
+ if (
1282
+ self.cos_sin_cache.device != query.device
1283
+ or self.cos_sin_cache.dtype != query.dtype
1284
+ ):
1285
+ self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
1286
+
1287
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
1288
+ def _forward_native(
1034
1289
  self,
1035
1290
  positions: torch.Tensor,
1036
1291
  query: torch.Tensor,
1037
1292
  key: torch.Tensor,
1293
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
1038
1294
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1039
1295
  """PyTorch-native implementation equivalent to forward().
1040
1296
 
@@ -1045,6 +1301,9 @@ class MRotaryEmbedding(RotaryEmbedding):
1045
1301
  query: [num_tokens, num_heads * head_size]
1046
1302
  key: [num_tokens, num_kv_heads * head_size]
1047
1303
  """
1304
+ assert (
1305
+ fused_set_kv_buffer_arg is None
1306
+ ), "save kv cache is not supported for MRotaryEmbedding."
1048
1307
  assert positions.ndim == 1 or positions.ndim == 2
1049
1308
 
1050
1309
  num_tokens = positions.shape[-1]
@@ -1052,15 +1311,18 @@ class MRotaryEmbedding(RotaryEmbedding):
1052
1311
  cos, sin = cos_sin.chunk(2, dim=-1)
1053
1312
  if positions.ndim == 2:
1054
1313
  assert self.mrope_section
1055
-
1056
- cos = torch.cat(
1057
- [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
1058
- dim=-1,
1059
- )
1060
- sin = torch.cat(
1061
- [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
1062
- dim=-1,
1063
- )
1314
+ if self.mrope_interleaved:
1315
+ cos = apply_interleaved_rope(cos, self.mrope_section)
1316
+ sin = apply_interleaved_rope(sin, self.mrope_section)
1317
+ else:
1318
+ cos = torch.cat(
1319
+ [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
1320
+ dim=-1,
1321
+ )
1322
+ sin = torch.cat(
1323
+ [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
1324
+ dim=-1,
1325
+ )
1064
1326
 
1065
1327
  query_shape = query.shape
1066
1328
  query = query.view(num_tokens, -1, self.head_size)
@@ -1077,6 +1339,72 @@ class MRotaryEmbedding(RotaryEmbedding):
1077
1339
  key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
1078
1340
  return query, key
1079
1341
 
1342
+ def forward(
1343
+ self,
1344
+ positions: torch.Tensor,
1345
+ query: torch.Tensor,
1346
+ key: torch.Tensor,
1347
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
1348
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1349
+ """Forward pass with optional Triton kernel acceleration.
1350
+ Args:
1351
+ positions:
1352
+ [num_tokens,] (text only) or
1353
+ [3, num_tokens] (T/H/W positions with multimodal inputs)
1354
+ query: [num_tokens, num_heads * head_size]
1355
+ key: [num_tokens, num_kv_heads * head_size]
1356
+ """
1357
+ assert positions.ndim == 1 or positions.ndim == 2
1358
+
1359
+ if positions.ndim == 2 and self.mrope_section and _is_cuda:
1360
+ return self._forward_triton(positions, query, key)
1361
+ else:
1362
+ return self._forward_native(positions, query, key)
1363
+
1364
+ def _forward_triton(
1365
+ self,
1366
+ positions: torch.Tensor,
1367
+ query: torch.Tensor,
1368
+ key: torch.Tensor,
1369
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1370
+ assert positions.ndim == 1 or positions.ndim == 2
1371
+ assert key is not None
1372
+
1373
+ self._match_cos_sin_cache_dtype(query)
1374
+ num_tokens = positions.shape[-1]
1375
+ cos_sin = self.cos_sin_cache[positions]
1376
+ cos, sin = cos_sin.chunk(2, dim=-1)
1377
+ query_shape = query.shape
1378
+ key_shape = key.shape
1379
+ if positions.ndim == 2:
1380
+ assert self.mrope_section
1381
+
1382
+ q, k = triton_mrope(
1383
+ query,
1384
+ key,
1385
+ cos,
1386
+ sin,
1387
+ self.mrope_section,
1388
+ self.head_size,
1389
+ self.rotary_dim,
1390
+ self.mrope_interleaved,
1391
+ )
1392
+
1393
+ return q.reshape(query_shape), k.reshape(key_shape)
1394
+
1395
+ query = query.view(num_tokens, -1, self.head_size)
1396
+ query_rot = query[..., : self.rotary_dim]
1397
+ query_pass = query[..., self.rotary_dim :]
1398
+ query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
1399
+ query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
1400
+
1401
+ key = key.view(num_tokens, -1, self.head_size)
1402
+ key_rot = key[..., : self.rotary_dim]
1403
+ key_pass = key[..., self.rotary_dim :]
1404
+ key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
1405
+ key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
1406
+ return query, key
1407
+
1080
1408
  # Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439
1081
1409
  @staticmethod
1082
1410
  def get_rope_index(
@@ -1092,6 +1420,28 @@ class MRotaryEmbedding(RotaryEmbedding):
1092
1420
  second_per_grid_ts: Optional[torch.Tensor] = None,
1093
1421
  **kwargs,
1094
1422
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1423
+ if model_type == "qwen3_omni_moe":
1424
+ # For qwen3-omni
1425
+ return MRotaryEmbedding.get_rope_index_qwen3_omni(
1426
+ spatial_merge_size,
1427
+ image_token_id,
1428
+ video_token_id,
1429
+ vision_start_token_id,
1430
+ tokens_per_second,
1431
+ input_ids,
1432
+ image_grid_thw,
1433
+ video_grid_thw,
1434
+ second_per_grid_ts,
1435
+ **kwargs,
1436
+ )
1437
+ if (
1438
+ model_type.startswith("qwen3_vl") or model_type.startswith("qwen3_vl_moe")
1439
+ ) and video_grid_thw is not None:
1440
+ video_grid_thw = torch.repeat_interleave(
1441
+ video_grid_thw, video_grid_thw[:, 0], dim=0
1442
+ )
1443
+ video_grid_thw[:, 0] = 1
1444
+
1095
1445
  mrope_position_deltas = []
1096
1446
  if input_ids is not None and (
1097
1447
  image_grid_thw is not None or video_grid_thw is not None
@@ -1177,7 +1527,11 @@ class MRotaryEmbedding(RotaryEmbedding):
1177
1527
 
1178
1528
  time_tensor_long = time_tensor.long()
1179
1529
  t_index = time_tensor_long.flatten()
1180
- elif model_type == "qwen2_vl":
1530
+ elif model_type in (
1531
+ "qwen2_vl",
1532
+ "qwen3_vl",
1533
+ "qwen3_vl_moe",
1534
+ ):
1181
1535
  t_index = (
1182
1536
  torch.arange(llm_grid_t)
1183
1537
  .view(-1, 1)
@@ -1185,7 +1539,7 @@ class MRotaryEmbedding(RotaryEmbedding):
1185
1539
  .flatten()
1186
1540
  )
1187
1541
  else:
1188
- raise RuntimeError("Unimplemented")
1542
+ raise RuntimeError(f"Unimplemented model type: {model_type}")
1189
1543
  h_index = (
1190
1544
  torch.arange(llm_grid_h)
1191
1545
  .view(1, -1, 1)
@@ -1235,6 +1589,304 @@ class MRotaryEmbedding(RotaryEmbedding):
1235
1589
  mrope_position_deltas = max_position_ids + 1 - s
1236
1590
  return position_ids, mrope_position_deltas
1237
1591
 
1592
+ @staticmethod
1593
+ def get_rope_index_qwen3_omni(
1594
+ spatial_merge_size: int,
1595
+ image_token_id: int,
1596
+ video_token_id: int,
1597
+ vision_start_token_id: int,
1598
+ tokens_per_second: Optional[int] = None,
1599
+ input_ids: Optional[torch.LongTensor] = None,
1600
+ image_grid_thw: Optional[torch.LongTensor] = None,
1601
+ video_grid_thw: Optional[torch.LongTensor] = None,
1602
+ second_per_grid_ts: Optional[torch.Tensor] = None,
1603
+ **kwargs,
1604
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1605
+ # For qwen3-omni
1606
+ audio_token_id = kwargs["audio_token_id"]
1607
+ audio_start_token_id = kwargs["audio_start_token_id"]
1608
+ position_id_per_seconds = kwargs["position_id_per_seconds"]
1609
+ use_audio_in_video = kwargs.get("use_audio_in_video", False)
1610
+ audio_seqlens = kwargs.get("audio_seqlens", None)
1611
+ second_per_grids = second_per_grid_ts
1612
+
1613
+ mrope_position_deltas = []
1614
+ if input_ids is not None and (
1615
+ image_grid_thw is not None or video_grid_thw is not None
1616
+ ):
1617
+ total_input_ids = input_ids
1618
+ position_ids = torch.zeros(
1619
+ 3,
1620
+ input_ids.shape[0],
1621
+ input_ids.shape[1],
1622
+ dtype=torch.float,
1623
+ device=input_ids.device,
1624
+ )
1625
+ image_idx, video_idx, audio_idx = 0, 0, 0
1626
+ for i, current_input_ids in enumerate(total_input_ids):
1627
+ image_nums, video_nums, audio_nums = 0, 0, 0
1628
+ vision_start_indices = torch.argwhere(
1629
+ current_input_ids == vision_start_token_id
1630
+ ).squeeze(1)
1631
+ if vision_start_indices.numel() > 0:
1632
+ vision_tokens = current_input_ids[vision_start_indices + 1]
1633
+ image_nums = (vision_tokens == image_token_id).sum()
1634
+ video_nums = (
1635
+ (vision_tokens == audio_start_token_id).sum()
1636
+ if use_audio_in_video
1637
+ else (vision_tokens == video_token_id).sum()
1638
+ )
1639
+ audio_nums = torch.sum(current_input_ids == audio_start_token_id)
1640
+ input_tokens = current_input_ids.tolist()
1641
+ llm_pos_ids_list: list = []
1642
+ st = 0
1643
+ remain_images, remain_videos, remain_audios = (
1644
+ image_nums,
1645
+ video_nums,
1646
+ audio_nums,
1647
+ )
1648
+ multimodal_nums = (
1649
+ image_nums + audio_nums
1650
+ if use_audio_in_video
1651
+ else image_nums + video_nums + audio_nums
1652
+ )
1653
+ for _ in range(multimodal_nums):
1654
+ st_idx = (
1655
+ llm_pos_ids_list[-1].max() + 1
1656
+ if len(llm_pos_ids_list) > 0
1657
+ else 0
1658
+ )
1659
+ ed_vision_start = (
1660
+ input_tokens.index(vision_start_token_id, st)
1661
+ if (
1662
+ (
1663
+ image_token_id in input_tokens
1664
+ or video_token_id in input_tokens
1665
+ )
1666
+ and (remain_videos > 0 or remain_images > 0)
1667
+ )
1668
+ else len(input_tokens) + 1
1669
+ )
1670
+ ed_audio_start = (
1671
+ input_tokens.index(audio_start_token_id, st)
1672
+ if (audio_token_id in input_tokens and remain_audios > 0)
1673
+ else len(input_tokens) + 1
1674
+ )
1675
+ min_ed = min(ed_vision_start, ed_audio_start)
1676
+
1677
+ text_len = min_ed - st
1678
+ if text_len != 0:
1679
+ llm_pos_ids_list.append(
1680
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
1681
+ )
1682
+ st_idx += text_len
1683
+ # Audio in Video
1684
+ if (
1685
+ min_ed == ed_vision_start
1686
+ and ed_vision_start + 1 == ed_audio_start
1687
+ ):
1688
+ bos_len, eos_len = 2, 2
1689
+ else:
1690
+ bos_len, eos_len = 1, 1
1691
+ llm_pos_ids_list.append(
1692
+ torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx
1693
+ )
1694
+ st_idx += bos_len
1695
+ # Audio Only
1696
+ if min_ed == ed_audio_start:
1697
+ audio_len = MRotaryEmbedding._get_feat_extract_output_lengths(
1698
+ audio_seqlens[audio_idx]
1699
+ )
1700
+ llm_pos_ids = (
1701
+ torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
1702
+ )
1703
+ llm_pos_ids_list.append(llm_pos_ids)
1704
+
1705
+ st += int(text_len + bos_len + audio_len + eos_len)
1706
+ audio_idx += 1
1707
+ remain_audios -= 1
1708
+
1709
+ # Image Only
1710
+ elif (
1711
+ min_ed == ed_vision_start
1712
+ and current_input_ids[ed_vision_start + 1] == image_token_id
1713
+ ):
1714
+ grid_t = image_grid_thw[image_idx][0]
1715
+ grid_hs = image_grid_thw[:, 1]
1716
+ grid_ws = image_grid_thw[:, 2]
1717
+ t_index = (
1718
+ torch.arange(grid_t) * 1 * position_id_per_seconds
1719
+ ).float()
1720
+ llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision(
1721
+ st_idx,
1722
+ image_idx,
1723
+ spatial_merge_size,
1724
+ t_index,
1725
+ grid_hs,
1726
+ grid_ws,
1727
+ input_ids.device,
1728
+ )
1729
+ image_len = image_grid_thw[image_idx].prod() // (
1730
+ spatial_merge_size**2
1731
+ )
1732
+ llm_pos_ids_list.append(llm_pos_ids)
1733
+
1734
+ st += int(text_len + bos_len + image_len + eos_len)
1735
+ image_idx += 1
1736
+ remain_images -= 1
1737
+
1738
+ # Video Only
1739
+ elif (
1740
+ min_ed == ed_vision_start
1741
+ and current_input_ids[ed_vision_start + 1] == video_token_id
1742
+ ):
1743
+ grid_t = video_grid_thw[video_idx][0]
1744
+ grid_hs = video_grid_thw[:, 1]
1745
+ grid_ws = video_grid_thw[:, 2]
1746
+ t_index = (
1747
+ torch.arange(grid_t)
1748
+ * second_per_grids[video_idx].cpu().float()
1749
+ * position_id_per_seconds
1750
+ ).float()
1751
+ llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision(
1752
+ st_idx,
1753
+ video_idx,
1754
+ spatial_merge_size,
1755
+ t_index,
1756
+ grid_hs,
1757
+ grid_ws,
1758
+ input_ids.device,
1759
+ )
1760
+ video_len = video_grid_thw[video_idx].prod() // (
1761
+ spatial_merge_size**2
1762
+ )
1763
+ llm_pos_ids_list.append(llm_pos_ids)
1764
+
1765
+ st += int(text_len + bos_len + video_len + eos_len)
1766
+ video_idx += 1
1767
+ remain_videos -= 1
1768
+
1769
+ # Audio in Video
1770
+ elif (
1771
+ min_ed == ed_vision_start
1772
+ and ed_vision_start + 1 == ed_audio_start
1773
+ ):
1774
+ audio_len = MRotaryEmbedding._get_feat_extract_output_lengths(
1775
+ audio_seqlens[audio_idx]
1776
+ )
1777
+ audio_llm_pos_ids = (
1778
+ torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
1779
+ )
1780
+ grid_t = video_grid_thw[video_idx][0]
1781
+ grid_hs = video_grid_thw[:, 1]
1782
+ grid_ws = video_grid_thw[:, 2]
1783
+
1784
+ t_index = (
1785
+ torch.arange(grid_t)
1786
+ * second_per_grids[video_idx].cpu().float()
1787
+ * position_id_per_seconds
1788
+ ).float()
1789
+ video_llm_pos_ids = (
1790
+ MRotaryEmbedding._get_llm_pos_ids_for_vision(
1791
+ st_idx,
1792
+ video_idx,
1793
+ spatial_merge_size,
1794
+ t_index,
1795
+ grid_hs,
1796
+ grid_ws,
1797
+ input_ids.device,
1798
+ )
1799
+ )
1800
+ video_data_index, audio_data_index = 0, 0
1801
+ while (
1802
+ video_data_index < video_llm_pos_ids.shape[-1]
1803
+ and audio_data_index < audio_llm_pos_ids.shape[-1]
1804
+ ):
1805
+ if (
1806
+ video_llm_pos_ids[0][video_data_index]
1807
+ <= audio_llm_pos_ids[0][audio_data_index]
1808
+ ):
1809
+ llm_pos_ids_list.append(
1810
+ video_llm_pos_ids[
1811
+ :, video_data_index : video_data_index + 1
1812
+ ]
1813
+ )
1814
+ video_data_index += 1
1815
+ else:
1816
+ llm_pos_ids_list.append(
1817
+ audio_llm_pos_ids[
1818
+ :, audio_data_index : audio_data_index + 1
1819
+ ]
1820
+ )
1821
+ audio_data_index += 1
1822
+ if video_data_index < video_llm_pos_ids.shape[-1]:
1823
+ llm_pos_ids_list.append(
1824
+ video_llm_pos_ids[
1825
+ :, video_data_index : video_llm_pos_ids.shape[-1]
1826
+ ]
1827
+ )
1828
+ if audio_data_index < audio_llm_pos_ids.shape[-1]:
1829
+ llm_pos_ids_list.append(
1830
+ audio_llm_pos_ids[
1831
+ :, audio_data_index : audio_llm_pos_ids.shape[-1]
1832
+ ]
1833
+ )
1834
+ video_len = video_grid_thw[video_idx].prod() // (
1835
+ spatial_merge_size**2
1836
+ )
1837
+
1838
+ st += int(text_len + bos_len + audio_len + video_len + eos_len)
1839
+
1840
+ audio_idx += 1
1841
+ video_idx += 1
1842
+ remain_videos -= 1
1843
+ remain_audios -= 1
1844
+ st_idx = (
1845
+ llm_pos_ids_list[-1].max() + 1
1846
+ if len(llm_pos_ids_list) > 0
1847
+ else 0
1848
+ )
1849
+ llm_pos_ids_list.append(
1850
+ torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx
1851
+ )
1852
+
1853
+ if st < len(input_tokens):
1854
+ st_idx = (
1855
+ llm_pos_ids_list[-1].max() + 1
1856
+ if len(llm_pos_ids_list) > 0
1857
+ else 0
1858
+ )
1859
+ text_len = len(input_tokens) - st
1860
+ llm_pos_ids_list.append(
1861
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
1862
+ )
1863
+
1864
+ llm_positions = torch.cat(
1865
+ [item.float() for item in llm_pos_ids_list], dim=1
1866
+ ).reshape(3, -1)
1867
+
1868
+ position_ids[..., i, :] = llm_positions.to(position_ids.device)
1869
+ mrope_position_deltas.append(
1870
+ llm_positions.max() + 1 - len(current_input_ids)
1871
+ )
1872
+ mrope_position_deltas = torch.tensor(
1873
+ mrope_position_deltas, device=input_ids.device
1874
+ ).unsqueeze(1)
1875
+
1876
+ return position_ids, mrope_position_deltas
1877
+ else:
1878
+ s = input_ids.shape[1]
1879
+ position_ids = torch.arange(s)
1880
+ position_ids = (
1881
+ position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
1882
+ )
1883
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(
1884
+ -1, keepdim=True
1885
+ )[0]
1886
+ mrope_position_deltas = max_position_ids + 1 - s
1887
+
1888
+ return position_ids, mrope_position_deltas
1889
+
1238
1890
  # Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120
1239
1891
  @staticmethod
1240
1892
  def get_rope_index_glm4v(
@@ -1433,6 +2085,44 @@ class MRotaryEmbedding(RotaryEmbedding):
1433
2085
 
1434
2086
  return position_ids, mrope_position_deltas
1435
2087
 
2088
+ # For qwen3-omni
2089
+ @staticmethod
2090
+ def _get_feat_extract_output_lengths(input_lengths):
2091
+ """
2092
+ Computes the output length of the convolutional layers and the output length of the audio encoder
2093
+ """
2094
+ input_lengths_leave = input_lengths % 100
2095
+ feat_lengths = (input_lengths_leave - 1) // 2 + 1
2096
+ output_lengths = (
2097
+ ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
2098
+ )
2099
+ return output_lengths
2100
+
2101
+ # For qwen3-omni
2102
+ @staticmethod
2103
+ def _get_llm_pos_ids_for_vision(
2104
+ st_idx, vision_idx, spatial_merge_size, t_index, grid_hs, grid_ws, device
2105
+ ):
2106
+ grid_h = grid_hs[vision_idx] // spatial_merge_size
2107
+ grid_w = grid_ws[vision_idx] // spatial_merge_size
2108
+
2109
+ h_index = (
2110
+ torch.arange(grid_h, device=device)
2111
+ .view(1, -1, 1)
2112
+ .expand(len(t_index), -1, grid_w)
2113
+ .flatten()
2114
+ )
2115
+ w_index = (
2116
+ torch.arange(grid_w, device=device)
2117
+ .view(1, 1, -1)
2118
+ .expand(len(t_index), grid_h, -1)
2119
+ .flatten()
2120
+ )
2121
+ t_index = t_index.view(-1, 1).expand(-1, grid_h * grid_w).flatten()
2122
+
2123
+ llm_pos_ids = torch.stack([t_index, h_index, w_index], dim=0) + st_idx
2124
+ return llm_pos_ids
2125
+
1436
2126
 
1437
2127
  class DualChunkRotaryEmbedding(CustomOp):
1438
2128
  """Rotary positional embedding for Dual Chunk Attention."""
@@ -1734,6 +2424,7 @@ def get_rope(
1734
2424
  is_neox_style,
1735
2425
  dtype,
1736
2426
  mrope_section=rope_scaling["mrope_section"],
2427
+ mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
1737
2428
  )
1738
2429
  else:
1739
2430
  rotary_emb = RotaryEmbedding(
@@ -1888,17 +2579,30 @@ def apply_rotary_pos_emb_npu(
1888
2579
  sin: torch.Tensor,
1889
2580
  unsqueeze_dim=1,
1890
2581
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1891
- if q.shape[1] != 128:
2582
+ """Ascend implementation equivalent to apply_rotary_pos_emb_native.
2583
+
2584
+ Args:
2585
+ q: [num_tokens, num_heads, head_size]
2586
+ k: [num_tokens, num_kv_heads, head_size]
2587
+ cos: [num_tokens, head_size]
2588
+ sin: [num_tokens, head_size]
2589
+ """
2590
+ if (
2591
+ cos.dim() != 2
2592
+ or q.dim() != 3
2593
+ or q.shape[1] >= NPU_ROTARY_MUL_MAX_NUM_HEADS
2594
+ or q.shape[2] >= NPU_ROTARY_MUL_MAX_HEAD_SIZE
2595
+ ):
2596
+ # Note: num_heads and head_size of q must be less than 1000 and 896, respectively
1892
2597
  return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
1893
- cos = cos.unsqueeze(unsqueeze_dim)
1894
- cos = torch.transpose(cos, 1, 2)
1895
- sin = sin.unsqueeze(unsqueeze_dim)
1896
- sin = torch.transpose(sin, 1, 2)
1897
- q = torch.transpose(q, 1, 2)
1898
- k = torch.transpose(k, 1, 2)
1899
- q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
1900
- q_embed = torch.transpose(q_embed, 1, 2)
1901
- k_embed = torch.transpose(k_embed, 1, 2)
2598
+ cos = cos.unsqueeze(unsqueeze_dim).unsqueeze(0)
2599
+ sin = sin.unsqueeze(unsqueeze_dim).unsqueeze(0)
2600
+ q = q.unsqueeze(0)
2601
+ k = k.unsqueeze(0)
2602
+ q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
2603
+ k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
2604
+ q_embed = q_embed.squeeze(0)
2605
+ k_embed = k_embed.squeeze(0)
1902
2606
  return q_embed, k_embed
1903
2607
 
1904
2608