sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,11 @@ from typing import Any, Dict, List, Optional, Tuple, Union
7
7
 
8
8
  import torch
9
9
  import torch.nn as nn
10
+ import triton
11
+ import triton.language as tl
10
12
 
11
13
  from sglang.srt.custom_op import CustomOp
14
+ from sglang.srt.server_args import get_global_server_args
12
15
  from sglang.srt.utils import (
13
16
  cpu_has_amx_support,
14
17
  get_bool_env_var,
@@ -17,6 +20,7 @@ from sglang.srt.utils import (
17
20
  is_cuda,
18
21
  is_hip,
19
22
  is_npu,
23
+ is_xpu,
20
24
  )
21
25
 
22
26
  _is_cuda = is_cuda()
@@ -25,6 +29,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
25
29
  _is_npu = is_npu()
26
30
  _is_cpu_amx_available = cpu_has_amx_support()
27
31
  _is_cpu = is_cpu()
32
+ _is_xpu = is_xpu()
28
33
 
29
34
  if _is_cuda:
30
35
  from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
@@ -108,9 +113,11 @@ class RotaryEmbedding(CustomOp):
108
113
  if not _is_cuda:
109
114
  cache = cache.to(dtype)
110
115
 
111
- if (
112
- not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]
113
- ) and not (_is_cpu and _is_cpu_amx_available):
116
+ if dtype == torch.float32 or (
117
+ (not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512])
118
+ and not (_is_cpu and _is_cpu_amx_available)
119
+ and not (_is_xpu)
120
+ ):
114
121
  from vllm._custom_ops import rotary_embedding
115
122
 
116
123
  self.vllm_rotary_embedding = rotary_embedding
@@ -118,18 +125,29 @@ class RotaryEmbedding(CustomOp):
118
125
  self.cos_sin_cache: torch.Tensor
119
126
  self.register_buffer("cos_sin_cache", cache, persistent=False)
120
127
 
128
+ if get_global_server_args().rl_on_policy_target == "fsdp":
129
+ self._forward_method = self.forward_native
130
+
121
131
  def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
122
132
  """Compute the inverse frequency."""
123
133
  # NOTE(woosuk): To exactly match the HF implementation, we need to
124
134
  # use CPU to compute the cache and then move it to GPU. However, we
125
135
  # create the cache on GPU for faster initialization. This may cause
126
136
  # a slight numerical difference between the HF implementation and ours.
137
+ init_device = (
138
+ "cpu" if get_global_server_args().rl_on_policy_target == "fsdp" else None
139
+ )
127
140
  inv_freq = 1.0 / (
128
141
  base
129
142
  ** (
130
- torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
143
+ torch.arange(
144
+ 0, self.rotary_dim, 2, dtype=torch.float, device=init_device
145
+ )
146
+ / self.rotary_dim
131
147
  )
132
148
  )
149
+ if get_global_server_args().rl_on_policy_target == "fsdp":
150
+ inv_freq = inv_freq.cuda()
133
151
  return inv_freq
134
152
 
135
153
  def _compute_cos_sin_cache(self) -> torch.Tensor:
@@ -248,7 +266,11 @@ class RotaryEmbedding(CustomOp):
248
266
  offsets: Optional[torch.Tensor] = None,
249
267
  fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
250
268
  ) -> Tuple[torch.Tensor, torch.Tensor]:
251
- if _is_cuda and (self.head_size in [64, 128, 256, 512]):
269
+ if (
270
+ _is_cuda
271
+ and (self.head_size in [64, 128, 256, 512])
272
+ and self.dtype != torch.float32
273
+ ):
252
274
  apply_rope_with_cos_sin_cache_inplace(
253
275
  positions=positions,
254
276
  query=query,
@@ -284,6 +306,17 @@ class RotaryEmbedding(CustomOp):
284
306
  s += f", base={self.base}, is_neox_style={self.is_neox_style}"
285
307
  return s
286
308
 
309
+ def forward_xpu(
310
+ self,
311
+ positions: torch.Tensor,
312
+ query: torch.Tensor,
313
+ key: torch.Tensor,
314
+ offsets: Optional[torch.Tensor] = None,
315
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
316
+ # TODO: make a wrapper, and XPU will implement this kernel later.
317
+ self.cos_sin_cache = self.cos_sin_cache.to(query.device)
318
+ return self.forward_native(positions, query, key, offsets)
319
+
287
320
 
288
321
  class LinearScalingRotaryEmbedding(RotaryEmbedding):
289
322
  """RotaryEmbedding extended with linear scaling.
@@ -1008,6 +1041,199 @@ class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
1008
1041
  return cache
1009
1042
 
1010
1043
 
1044
+ def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor:
1045
+ """Apply interleaved MRoPE to 3D rotary embeddings.
1046
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
1047
+ interleaved [THTHWHTHW...TT], preserving frequency continuity.
1048
+ """
1049
+ x_t = x[0].clone()
1050
+ x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3]
1051
+ x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3]
1052
+ return x_t
1053
+
1054
+
1055
+ @triton.jit
1056
+ def _triton_mrope_forward(
1057
+ q_ptr,
1058
+ k_ptr,
1059
+ cos,
1060
+ sin,
1061
+ num_tokens,
1062
+ n_qh: tl.constexpr,
1063
+ n_kh: tl.constexpr,
1064
+ hd: tl.constexpr,
1065
+ rd: tl.constexpr,
1066
+ pad_n_qh: tl.constexpr,
1067
+ pad_n_kh: tl.constexpr,
1068
+ pad_hd: tl.constexpr,
1069
+ mrope_section_t: tl.constexpr,
1070
+ mrope_section_h: tl.constexpr,
1071
+ mrope_section_w: tl.constexpr,
1072
+ is_interleaved: tl.constexpr,
1073
+ ):
1074
+ # Adapted from
1075
+ # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
1076
+ # This version supports flatten input tensors from vllm
1077
+ # and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
1078
+ # instead of (3, bsz, seq_len, head_dim), also supports interleaved rotary
1079
+ pid = tl.program_id(0)
1080
+ # locate start address
1081
+ q_ptr = q_ptr + pid * (n_qh * hd)
1082
+ k_ptr = k_ptr + pid * (n_kh * hd)
1083
+
1084
+ # ####################################################################
1085
+ # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
1086
+ # m of this program instance
1087
+ # ####################################################################
1088
+ # Note: cos and sin now have shape (3, num_tokens, head_dim // 2)
1089
+
1090
+ # Updated stride calculation for half head_dim
1091
+ half_rd = rd // 2
1092
+ t_cos = cos + pid * half_rd
1093
+ h_cos = t_cos + num_tokens * half_rd
1094
+ w_cos = h_cos + num_tokens * half_rd
1095
+ t_sin = sin + pid * half_rd
1096
+ h_sin = t_sin + num_tokens * half_rd
1097
+ w_sin = h_sin + num_tokens * half_rd
1098
+
1099
+ # Updated offsets for half head_dim
1100
+ cos_offsets = tl.arange(0, pad_hd // 2)
1101
+ if is_interleaved:
1102
+ h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h)
1103
+ w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w)
1104
+ t_mask = ~(h_mask | w_mask)
1105
+ else:
1106
+ t_end = mrope_section_t
1107
+ h_end = t_end + mrope_section_h
1108
+ t_mask = cos_offsets < mrope_section_t
1109
+ h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
1110
+ w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
1111
+
1112
+ t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
1113
+ h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
1114
+ w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
1115
+ t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
1116
+ h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
1117
+ w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
1118
+
1119
+ cos_row = t_cos_row + h_cos_row + w_cos_row
1120
+ sin_row = t_sin_row + h_sin_row + w_sin_row
1121
+
1122
+ # ####################################################################
1123
+ # Load the left and right half of q and k for the current
1124
+ # program instance (i.e. for the current token) separately
1125
+ # ####################################################################
1126
+ # left half of the head
1127
+ first_half_q_offsets = (
1128
+ tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
1129
+ )
1130
+ first_half_k_offsets = (
1131
+ tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
1132
+ )
1133
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
1134
+ tl.arange(0, pad_hd // 2)[None, :] < rd // 2
1135
+ )
1136
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
1137
+ tl.arange(0, pad_hd // 2)[None, :] < rd // 2
1138
+ )
1139
+
1140
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
1141
+ sin_row.dtype
1142
+ )
1143
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
1144
+ sin_row.dtype
1145
+ )
1146
+
1147
+ # right half of the head
1148
+ second_half_q_offsets = first_half_q_offsets + (rd // 2)
1149
+ second_half_k_offsets = first_half_k_offsets + (rd // 2)
1150
+ second_q_mask = first_q_mask
1151
+ second_k_mask = first_k_mask
1152
+
1153
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
1154
+ sin_row.dtype
1155
+ )
1156
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
1157
+ sin_row.dtype
1158
+ )
1159
+
1160
+ # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
1161
+ # Since cos and sin are now half-size,
1162
+ # we use the same cos_row and sin_row for both halves
1163
+ new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
1164
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
1165
+ new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
1166
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
1167
+
1168
+ new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
1169
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
1170
+ new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
1171
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
1172
+
1173
+
1174
+ def triton_mrope(
1175
+ q: torch.Tensor,
1176
+ k: torch.Tensor,
1177
+ cos: torch.Tensor,
1178
+ sin: torch.Tensor,
1179
+ mrope_section: list[int],
1180
+ head_size: int,
1181
+ rotary_dim: int,
1182
+ mrope_interleaved: bool,
1183
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1184
+ """The mrope triton kernel.
1185
+
1186
+ Args:
1187
+ q: [num_tokens, num_heads * head_size]
1188
+ k: [num_tokens, num_kv_heads * head_size]
1189
+ cos: [3, num_tokens, head_size //2 ]
1190
+ (T/H/W positions with multimodal inputs)
1191
+ sin: [3, num_tokens, head_size //2 ]
1192
+ (T/H/W positions with multimodal inputs)
1193
+ mrope_section: [t, h, w]
1194
+ head_size: int
1195
+ """
1196
+ n_row, n_q_head_head_dim = q.shape
1197
+ assert (
1198
+ n_q_head_head_dim % head_size == 0
1199
+ ), f"q shape {n_q_head_head_dim} must be divisible by head_size {head_size}"
1200
+ n_q_head = n_q_head_head_dim // head_size
1201
+ assert (
1202
+ k.shape[1] % head_size == 0
1203
+ ), f"k shape {k.shape[1]} must be divisible by head_size {head_size}"
1204
+ n_kv_head = k.shape[1] // head_size
1205
+ pad_hd = triton.next_power_of_2(head_size)
1206
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
1207
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
1208
+
1209
+ # ensure tensors passed into the kernel are contiguous.
1210
+ # It will be no-op if they are already contiguous
1211
+ q = q.contiguous()
1212
+ k = k.contiguous()
1213
+ cos = cos.contiguous()
1214
+ sin = sin.contiguous()
1215
+
1216
+ _triton_mrope_forward[(n_row,)](
1217
+ q,
1218
+ k,
1219
+ cos,
1220
+ sin,
1221
+ n_row,
1222
+ n_q_head,
1223
+ n_kv_head,
1224
+ head_size,
1225
+ rotary_dim,
1226
+ pad_n_q_head,
1227
+ pad_n_kv_head,
1228
+ pad_hd,
1229
+ mrope_section[0],
1230
+ mrope_section[1],
1231
+ mrope_section[2],
1232
+ mrope_interleaved,
1233
+ )
1234
+ return q, k
1235
+
1236
+
1011
1237
  class MRotaryEmbedding(RotaryEmbedding):
1012
1238
  """Rotary Embedding with Multimodal Sections."""
1013
1239
 
@@ -1020,12 +1246,14 @@ class MRotaryEmbedding(RotaryEmbedding):
1020
1246
  is_neox_style: bool,
1021
1247
  dtype: torch.dtype,
1022
1248
  mrope_section: Optional[List[int]] = None,
1249
+ mrope_interleaved: bool = False,
1023
1250
  ) -> None:
1024
1251
  super().__init__(
1025
1252
  head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
1026
1253
  )
1027
1254
 
1028
1255
  self.mrope_section = mrope_section
1256
+ self.mrope_interleaved = mrope_interleaved
1029
1257
  if self.mrope_section:
1030
1258
  expected_sum = rotary_dim // 2
1031
1259
  actual_sum = sum(self.mrope_section)
@@ -1059,8 +1287,17 @@ class MRotaryEmbedding(RotaryEmbedding):
1059
1287
  f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
1060
1288
  )
1061
1289
 
1290
+ def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
1291
+ # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
1292
+ # is expensive, so avoid calling it if possible
1293
+ if (
1294
+ self.cos_sin_cache.device != query.device
1295
+ or self.cos_sin_cache.dtype != query.dtype
1296
+ ):
1297
+ self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
1298
+
1062
1299
  @torch.compile(dynamic=True, backend=get_compiler_backend())
1063
- def forward(
1300
+ def _forward_native(
1064
1301
  self,
1065
1302
  positions: torch.Tensor,
1066
1303
  query: torch.Tensor,
@@ -1086,15 +1323,18 @@ class MRotaryEmbedding(RotaryEmbedding):
1086
1323
  cos, sin = cos_sin.chunk(2, dim=-1)
1087
1324
  if positions.ndim == 2:
1088
1325
  assert self.mrope_section
1089
-
1090
- cos = torch.cat(
1091
- [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
1092
- dim=-1,
1093
- )
1094
- sin = torch.cat(
1095
- [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
1096
- dim=-1,
1097
- )
1326
+ if self.mrope_interleaved:
1327
+ cos = apply_interleaved_rope(cos, self.mrope_section)
1328
+ sin = apply_interleaved_rope(sin, self.mrope_section)
1329
+ else:
1330
+ cos = torch.cat(
1331
+ [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
1332
+ dim=-1,
1333
+ )
1334
+ sin = torch.cat(
1335
+ [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
1336
+ dim=-1,
1337
+ )
1098
1338
 
1099
1339
  query_shape = query.shape
1100
1340
  query = query.view(num_tokens, -1, self.head_size)
@@ -1111,6 +1351,72 @@ class MRotaryEmbedding(RotaryEmbedding):
1111
1351
  key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
1112
1352
  return query, key
1113
1353
 
1354
+ def forward(
1355
+ self,
1356
+ positions: torch.Tensor,
1357
+ query: torch.Tensor,
1358
+ key: torch.Tensor,
1359
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
1360
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1361
+ """Forward pass with optional Triton kernel acceleration.
1362
+ Args:
1363
+ positions:
1364
+ [num_tokens,] (text only) or
1365
+ [3, num_tokens] (T/H/W positions with multimodal inputs)
1366
+ query: [num_tokens, num_heads * head_size]
1367
+ key: [num_tokens, num_kv_heads * head_size]
1368
+ """
1369
+ assert positions.ndim == 1 or positions.ndim == 2
1370
+
1371
+ if positions.ndim == 2 and self.mrope_section and _is_cuda:
1372
+ return self._forward_triton(positions, query, key)
1373
+ else:
1374
+ return self._forward_native(positions, query, key)
1375
+
1376
+ def _forward_triton(
1377
+ self,
1378
+ positions: torch.Tensor,
1379
+ query: torch.Tensor,
1380
+ key: torch.Tensor,
1381
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1382
+ assert positions.ndim == 1 or positions.ndim == 2
1383
+ assert key is not None
1384
+
1385
+ self._match_cos_sin_cache_dtype(query)
1386
+ num_tokens = positions.shape[-1]
1387
+ cos_sin = self.cos_sin_cache[positions]
1388
+ cos, sin = cos_sin.chunk(2, dim=-1)
1389
+ query_shape = query.shape
1390
+ key_shape = key.shape
1391
+ if positions.ndim == 2:
1392
+ assert self.mrope_section
1393
+
1394
+ q, k = triton_mrope(
1395
+ query,
1396
+ key,
1397
+ cos,
1398
+ sin,
1399
+ self.mrope_section,
1400
+ self.head_size,
1401
+ self.rotary_dim,
1402
+ self.mrope_interleaved,
1403
+ )
1404
+
1405
+ return q.reshape(query_shape), k.reshape(key_shape)
1406
+
1407
+ query = query.view(num_tokens, -1, self.head_size)
1408
+ query_rot = query[..., : self.rotary_dim]
1409
+ query_pass = query[..., self.rotary_dim :]
1410
+ query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
1411
+ query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
1412
+
1413
+ key = key.view(num_tokens, -1, self.head_size)
1414
+ key_rot = key[..., : self.rotary_dim]
1415
+ key_pass = key[..., self.rotary_dim :]
1416
+ key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
1417
+ key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
1418
+ return query, key
1419
+
1114
1420
  # Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439
1115
1421
  @staticmethod
1116
1422
  def get_rope_index(
@@ -1126,6 +1432,28 @@ class MRotaryEmbedding(RotaryEmbedding):
1126
1432
  second_per_grid_ts: Optional[torch.Tensor] = None,
1127
1433
  **kwargs,
1128
1434
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1435
+ if model_type == "qwen3_omni_moe":
1436
+ # For qwen3-omni
1437
+ return MRotaryEmbedding.get_rope_index_qwen3_omni(
1438
+ spatial_merge_size,
1439
+ image_token_id,
1440
+ video_token_id,
1441
+ vision_start_token_id,
1442
+ tokens_per_second,
1443
+ input_ids,
1444
+ image_grid_thw,
1445
+ video_grid_thw,
1446
+ second_per_grid_ts,
1447
+ **kwargs,
1448
+ )
1449
+ if (
1450
+ model_type.startswith("qwen3_vl") or model_type.startswith("qwen3_vl_moe")
1451
+ ) and video_grid_thw is not None:
1452
+ video_grid_thw = torch.repeat_interleave(
1453
+ video_grid_thw, video_grid_thw[:, 0], dim=0
1454
+ )
1455
+ video_grid_thw[:, 0] = 1
1456
+
1129
1457
  mrope_position_deltas = []
1130
1458
  if input_ids is not None and (
1131
1459
  image_grid_thw is not None or video_grid_thw is not None
@@ -1211,7 +1539,11 @@ class MRotaryEmbedding(RotaryEmbedding):
1211
1539
 
1212
1540
  time_tensor_long = time_tensor.long()
1213
1541
  t_index = time_tensor_long.flatten()
1214
- elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"):
1542
+ elif model_type in (
1543
+ "qwen2_vl",
1544
+ "qwen3_vl",
1545
+ "qwen3_vl_moe",
1546
+ ):
1215
1547
  t_index = (
1216
1548
  torch.arange(llm_grid_t)
1217
1549
  .view(-1, 1)
@@ -1219,7 +1551,7 @@ class MRotaryEmbedding(RotaryEmbedding):
1219
1551
  .flatten()
1220
1552
  )
1221
1553
  else:
1222
- raise RuntimeError("Unimplemented")
1554
+ raise RuntimeError(f"Unimplemented model type: {model_type}")
1223
1555
  h_index = (
1224
1556
  torch.arange(llm_grid_h)
1225
1557
  .view(1, -1, 1)
@@ -1269,6 +1601,304 @@ class MRotaryEmbedding(RotaryEmbedding):
1269
1601
  mrope_position_deltas = max_position_ids + 1 - s
1270
1602
  return position_ids, mrope_position_deltas
1271
1603
 
1604
+ @staticmethod
1605
+ def get_rope_index_qwen3_omni(
1606
+ spatial_merge_size: int,
1607
+ image_token_id: int,
1608
+ video_token_id: int,
1609
+ vision_start_token_id: int,
1610
+ tokens_per_second: Optional[int] = None,
1611
+ input_ids: Optional[torch.LongTensor] = None,
1612
+ image_grid_thw: Optional[torch.LongTensor] = None,
1613
+ video_grid_thw: Optional[torch.LongTensor] = None,
1614
+ second_per_grid_ts: Optional[torch.Tensor] = None,
1615
+ **kwargs,
1616
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1617
+ # For qwen3-omni
1618
+ audio_token_id = kwargs["audio_token_id"]
1619
+ audio_start_token_id = kwargs["audio_start_token_id"]
1620
+ position_id_per_seconds = kwargs["position_id_per_seconds"]
1621
+ use_audio_in_video = kwargs.get("use_audio_in_video", False)
1622
+ audio_seqlens = kwargs.get("audio_seqlens", None)
1623
+ second_per_grids = second_per_grid_ts
1624
+
1625
+ mrope_position_deltas = []
1626
+ if input_ids is not None and (
1627
+ image_grid_thw is not None or video_grid_thw is not None
1628
+ ):
1629
+ total_input_ids = input_ids
1630
+ position_ids = torch.zeros(
1631
+ 3,
1632
+ input_ids.shape[0],
1633
+ input_ids.shape[1],
1634
+ dtype=torch.float,
1635
+ device=input_ids.device,
1636
+ )
1637
+ image_idx, video_idx, audio_idx = 0, 0, 0
1638
+ for i, current_input_ids in enumerate(total_input_ids):
1639
+ image_nums, video_nums, audio_nums = 0, 0, 0
1640
+ vision_start_indices = torch.argwhere(
1641
+ current_input_ids == vision_start_token_id
1642
+ ).squeeze(1)
1643
+ if vision_start_indices.numel() > 0:
1644
+ vision_tokens = current_input_ids[vision_start_indices + 1]
1645
+ image_nums = (vision_tokens == image_token_id).sum()
1646
+ video_nums = (
1647
+ (vision_tokens == audio_start_token_id).sum()
1648
+ if use_audio_in_video
1649
+ else (vision_tokens == video_token_id).sum()
1650
+ )
1651
+ audio_nums = torch.sum(current_input_ids == audio_start_token_id)
1652
+ input_tokens = current_input_ids.tolist()
1653
+ llm_pos_ids_list: list = []
1654
+ st = 0
1655
+ remain_images, remain_videos, remain_audios = (
1656
+ image_nums,
1657
+ video_nums,
1658
+ audio_nums,
1659
+ )
1660
+ multimodal_nums = (
1661
+ image_nums + audio_nums
1662
+ if use_audio_in_video
1663
+ else image_nums + video_nums + audio_nums
1664
+ )
1665
+ for _ in range(multimodal_nums):
1666
+ st_idx = (
1667
+ llm_pos_ids_list[-1].max() + 1
1668
+ if len(llm_pos_ids_list) > 0
1669
+ else 0
1670
+ )
1671
+ ed_vision_start = (
1672
+ input_tokens.index(vision_start_token_id, st)
1673
+ if (
1674
+ (
1675
+ image_token_id in input_tokens
1676
+ or video_token_id in input_tokens
1677
+ )
1678
+ and (remain_videos > 0 or remain_images > 0)
1679
+ )
1680
+ else len(input_tokens) + 1
1681
+ )
1682
+ ed_audio_start = (
1683
+ input_tokens.index(audio_start_token_id, st)
1684
+ if (audio_token_id in input_tokens and remain_audios > 0)
1685
+ else len(input_tokens) + 1
1686
+ )
1687
+ min_ed = min(ed_vision_start, ed_audio_start)
1688
+
1689
+ text_len = min_ed - st
1690
+ if text_len != 0:
1691
+ llm_pos_ids_list.append(
1692
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
1693
+ )
1694
+ st_idx += text_len
1695
+ # Audio in Video
1696
+ if (
1697
+ min_ed == ed_vision_start
1698
+ and ed_vision_start + 1 == ed_audio_start
1699
+ ):
1700
+ bos_len, eos_len = 2, 2
1701
+ else:
1702
+ bos_len, eos_len = 1, 1
1703
+ llm_pos_ids_list.append(
1704
+ torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx
1705
+ )
1706
+ st_idx += bos_len
1707
+ # Audio Only
1708
+ if min_ed == ed_audio_start:
1709
+ audio_len = MRotaryEmbedding._get_feat_extract_output_lengths(
1710
+ audio_seqlens[audio_idx]
1711
+ )
1712
+ llm_pos_ids = (
1713
+ torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
1714
+ )
1715
+ llm_pos_ids_list.append(llm_pos_ids)
1716
+
1717
+ st += int(text_len + bos_len + audio_len + eos_len)
1718
+ audio_idx += 1
1719
+ remain_audios -= 1
1720
+
1721
+ # Image Only
1722
+ elif (
1723
+ min_ed == ed_vision_start
1724
+ and current_input_ids[ed_vision_start + 1] == image_token_id
1725
+ ):
1726
+ grid_t = image_grid_thw[image_idx][0]
1727
+ grid_hs = image_grid_thw[:, 1]
1728
+ grid_ws = image_grid_thw[:, 2]
1729
+ t_index = (
1730
+ torch.arange(grid_t) * 1 * position_id_per_seconds
1731
+ ).float()
1732
+ llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision(
1733
+ st_idx,
1734
+ image_idx,
1735
+ spatial_merge_size,
1736
+ t_index,
1737
+ grid_hs,
1738
+ grid_ws,
1739
+ input_ids.device,
1740
+ )
1741
+ image_len = image_grid_thw[image_idx].prod() // (
1742
+ spatial_merge_size**2
1743
+ )
1744
+ llm_pos_ids_list.append(llm_pos_ids)
1745
+
1746
+ st += int(text_len + bos_len + image_len + eos_len)
1747
+ image_idx += 1
1748
+ remain_images -= 1
1749
+
1750
+ # Video Only
1751
+ elif (
1752
+ min_ed == ed_vision_start
1753
+ and current_input_ids[ed_vision_start + 1] == video_token_id
1754
+ ):
1755
+ grid_t = video_grid_thw[video_idx][0]
1756
+ grid_hs = video_grid_thw[:, 1]
1757
+ grid_ws = video_grid_thw[:, 2]
1758
+ t_index = (
1759
+ torch.arange(grid_t)
1760
+ * second_per_grids[video_idx].cpu().float()
1761
+ * position_id_per_seconds
1762
+ ).float()
1763
+ llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision(
1764
+ st_idx,
1765
+ video_idx,
1766
+ spatial_merge_size,
1767
+ t_index,
1768
+ grid_hs,
1769
+ grid_ws,
1770
+ input_ids.device,
1771
+ )
1772
+ video_len = video_grid_thw[video_idx].prod() // (
1773
+ spatial_merge_size**2
1774
+ )
1775
+ llm_pos_ids_list.append(llm_pos_ids)
1776
+
1777
+ st += int(text_len + bos_len + video_len + eos_len)
1778
+ video_idx += 1
1779
+ remain_videos -= 1
1780
+
1781
+ # Audio in Video
1782
+ elif (
1783
+ min_ed == ed_vision_start
1784
+ and ed_vision_start + 1 == ed_audio_start
1785
+ ):
1786
+ audio_len = MRotaryEmbedding._get_feat_extract_output_lengths(
1787
+ audio_seqlens[audio_idx]
1788
+ )
1789
+ audio_llm_pos_ids = (
1790
+ torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
1791
+ )
1792
+ grid_t = video_grid_thw[video_idx][0]
1793
+ grid_hs = video_grid_thw[:, 1]
1794
+ grid_ws = video_grid_thw[:, 2]
1795
+
1796
+ t_index = (
1797
+ torch.arange(grid_t)
1798
+ * second_per_grids[video_idx].cpu().float()
1799
+ * position_id_per_seconds
1800
+ ).float()
1801
+ video_llm_pos_ids = (
1802
+ MRotaryEmbedding._get_llm_pos_ids_for_vision(
1803
+ st_idx,
1804
+ video_idx,
1805
+ spatial_merge_size,
1806
+ t_index,
1807
+ grid_hs,
1808
+ grid_ws,
1809
+ input_ids.device,
1810
+ )
1811
+ )
1812
+ video_data_index, audio_data_index = 0, 0
1813
+ while (
1814
+ video_data_index < video_llm_pos_ids.shape[-1]
1815
+ and audio_data_index < audio_llm_pos_ids.shape[-1]
1816
+ ):
1817
+ if (
1818
+ video_llm_pos_ids[0][video_data_index]
1819
+ <= audio_llm_pos_ids[0][audio_data_index]
1820
+ ):
1821
+ llm_pos_ids_list.append(
1822
+ video_llm_pos_ids[
1823
+ :, video_data_index : video_data_index + 1
1824
+ ]
1825
+ )
1826
+ video_data_index += 1
1827
+ else:
1828
+ llm_pos_ids_list.append(
1829
+ audio_llm_pos_ids[
1830
+ :, audio_data_index : audio_data_index + 1
1831
+ ]
1832
+ )
1833
+ audio_data_index += 1
1834
+ if video_data_index < video_llm_pos_ids.shape[-1]:
1835
+ llm_pos_ids_list.append(
1836
+ video_llm_pos_ids[
1837
+ :, video_data_index : video_llm_pos_ids.shape[-1]
1838
+ ]
1839
+ )
1840
+ if audio_data_index < audio_llm_pos_ids.shape[-1]:
1841
+ llm_pos_ids_list.append(
1842
+ audio_llm_pos_ids[
1843
+ :, audio_data_index : audio_llm_pos_ids.shape[-1]
1844
+ ]
1845
+ )
1846
+ video_len = video_grid_thw[video_idx].prod() // (
1847
+ spatial_merge_size**2
1848
+ )
1849
+
1850
+ st += int(text_len + bos_len + audio_len + video_len + eos_len)
1851
+
1852
+ audio_idx += 1
1853
+ video_idx += 1
1854
+ remain_videos -= 1
1855
+ remain_audios -= 1
1856
+ st_idx = (
1857
+ llm_pos_ids_list[-1].max() + 1
1858
+ if len(llm_pos_ids_list) > 0
1859
+ else 0
1860
+ )
1861
+ llm_pos_ids_list.append(
1862
+ torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx
1863
+ )
1864
+
1865
+ if st < len(input_tokens):
1866
+ st_idx = (
1867
+ llm_pos_ids_list[-1].max() + 1
1868
+ if len(llm_pos_ids_list) > 0
1869
+ else 0
1870
+ )
1871
+ text_len = len(input_tokens) - st
1872
+ llm_pos_ids_list.append(
1873
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
1874
+ )
1875
+
1876
+ llm_positions = torch.cat(
1877
+ [item.float() for item in llm_pos_ids_list], dim=1
1878
+ ).reshape(3, -1)
1879
+
1880
+ position_ids[..., i, :] = llm_positions.to(position_ids.device)
1881
+ mrope_position_deltas.append(
1882
+ llm_positions.max() + 1 - len(current_input_ids)
1883
+ )
1884
+ mrope_position_deltas = torch.tensor(
1885
+ mrope_position_deltas, device=input_ids.device
1886
+ ).unsqueeze(1)
1887
+
1888
+ return position_ids, mrope_position_deltas
1889
+ else:
1890
+ s = input_ids.shape[1]
1891
+ position_ids = torch.arange(s)
1892
+ position_ids = (
1893
+ position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
1894
+ )
1895
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(
1896
+ -1, keepdim=True
1897
+ )[0]
1898
+ mrope_position_deltas = max_position_ids + 1 - s
1899
+
1900
+ return position_ids, mrope_position_deltas
1901
+
1272
1902
  # Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120
1273
1903
  @staticmethod
1274
1904
  def get_rope_index_glm4v(
@@ -1467,6 +2097,44 @@ class MRotaryEmbedding(RotaryEmbedding):
1467
2097
 
1468
2098
  return position_ids, mrope_position_deltas
1469
2099
 
2100
+ # For qwen3-omni
2101
+ @staticmethod
2102
+ def _get_feat_extract_output_lengths(input_lengths):
2103
+ """
2104
+ Computes the output length of the convolutional layers and the output length of the audio encoder
2105
+ """
2106
+ input_lengths_leave = input_lengths % 100
2107
+ feat_lengths = (input_lengths_leave - 1) // 2 + 1
2108
+ output_lengths = (
2109
+ ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
2110
+ )
2111
+ return output_lengths
2112
+
2113
+ # For qwen3-omni
2114
+ @staticmethod
2115
+ def _get_llm_pos_ids_for_vision(
2116
+ st_idx, vision_idx, spatial_merge_size, t_index, grid_hs, grid_ws, device
2117
+ ):
2118
+ grid_h = grid_hs[vision_idx] // spatial_merge_size
2119
+ grid_w = grid_ws[vision_idx] // spatial_merge_size
2120
+
2121
+ h_index = (
2122
+ torch.arange(grid_h, device=device)
2123
+ .view(1, -1, 1)
2124
+ .expand(len(t_index), -1, grid_w)
2125
+ .flatten()
2126
+ )
2127
+ w_index = (
2128
+ torch.arange(grid_w, device=device)
2129
+ .view(1, 1, -1)
2130
+ .expand(len(t_index), grid_h, -1)
2131
+ .flatten()
2132
+ )
2133
+ t_index = t_index.view(-1, 1).expand(-1, grid_h * grid_w).flatten()
2134
+
2135
+ llm_pos_ids = torch.stack([t_index, h_index, w_index], dim=0) + st_idx
2136
+ return llm_pos_ids
2137
+
1470
2138
 
1471
2139
  class DualChunkRotaryEmbedding(CustomOp):
1472
2140
  """Rotary positional embedding for Dual Chunk Attention."""
@@ -1768,6 +2436,7 @@ def get_rope(
1768
2436
  is_neox_style,
1769
2437
  dtype,
1770
2438
  mrope_section=rope_scaling["mrope_section"],
2439
+ mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
1771
2440
  )
1772
2441
  else:
1773
2442
  rotary_emb = RotaryEmbedding(