sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
7
7
 
8
8
  import torch
9
9
  import torch.nn as nn
10
+ import triton
11
+ import triton.language as tl
10
12
 
11
13
  from sglang.srt.custom_op import CustomOp
12
14
  from sglang.srt.utils import (
@@ -17,6 +19,7 @@ from sglang.srt.utils import (
17
19
  is_cuda,
18
20
  is_hip,
19
21
  is_npu,
22
+ is_xpu,
20
23
  )
21
24
 
22
25
  _is_cuda = is_cuda()
@@ -25,6 +28,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
25
28
  _is_npu = is_npu()
26
29
  _is_cpu_amx_available = cpu_has_amx_support()
27
30
  _is_cpu = is_cpu()
31
+ _is_xpu = is_xpu()
28
32
 
29
33
  if _is_cuda:
30
34
  from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
@@ -108,9 +112,11 @@ class RotaryEmbedding(CustomOp):
108
112
  if not _is_cuda:
109
113
  cache = cache.to(dtype)
110
114
 
111
- if (
112
- not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]
113
- ) and not (_is_cpu and _is_cpu_amx_available):
115
+ if dtype == torch.float32 or (
116
+ (not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512])
117
+ and not (_is_cpu and _is_cpu_amx_available)
118
+ and not (_is_xpu)
119
+ ):
114
120
  from vllm._custom_ops import rotary_embedding
115
121
 
116
122
  self.vllm_rotary_embedding = rotary_embedding
@@ -248,7 +254,11 @@ class RotaryEmbedding(CustomOp):
248
254
  offsets: Optional[torch.Tensor] = None,
249
255
  fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
250
256
  ) -> Tuple[torch.Tensor, torch.Tensor]:
251
- if _is_cuda and (self.head_size in [64, 128, 256, 512]):
257
+ if (
258
+ _is_cuda
259
+ and (self.head_size in [64, 128, 256, 512])
260
+ and self.dtype != torch.float32
261
+ ):
252
262
  apply_rope_with_cos_sin_cache_inplace(
253
263
  positions=positions,
254
264
  query=query,
@@ -284,6 +294,17 @@ class RotaryEmbedding(CustomOp):
284
294
  s += f", base={self.base}, is_neox_style={self.is_neox_style}"
285
295
  return s
286
296
 
297
+ def forward_xpu(
298
+ self,
299
+ positions: torch.Tensor,
300
+ query: torch.Tensor,
301
+ key: torch.Tensor,
302
+ offsets: Optional[torch.Tensor] = None,
303
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
304
+ # TODO: make a wrapper, and XPU will implement this kernel later.
305
+ self.cos_sin_cache = self.cos_sin_cache.to(query.device)
306
+ return self.forward_native(positions, query, key, offsets)
307
+
287
308
 
288
309
  class LinearScalingRotaryEmbedding(RotaryEmbedding):
289
310
  """RotaryEmbedding extended with linear scaling.
@@ -1008,6 +1029,199 @@ class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
1008
1029
  return cache
1009
1030
 
1010
1031
 
1032
+ def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor:
1033
+ """Apply interleaved MRoPE to 3D rotary embeddings.
1034
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
1035
+ interleaved [THTHWHTHW...TT], preserving frequency continuity.
1036
+ """
1037
+ x_t = x[0].clone()
1038
+ x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3]
1039
+ x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3]
1040
+ return x_t
1041
+
1042
+
1043
+ @triton.jit
1044
+ def _triton_mrope_forward(
1045
+ q_ptr,
1046
+ k_ptr,
1047
+ cos,
1048
+ sin,
1049
+ num_tokens,
1050
+ n_qh: tl.constexpr,
1051
+ n_kh: tl.constexpr,
1052
+ hd: tl.constexpr,
1053
+ rd: tl.constexpr,
1054
+ pad_n_qh: tl.constexpr,
1055
+ pad_n_kh: tl.constexpr,
1056
+ pad_hd: tl.constexpr,
1057
+ mrope_section_t: tl.constexpr,
1058
+ mrope_section_h: tl.constexpr,
1059
+ mrope_section_w: tl.constexpr,
1060
+ is_interleaved: tl.constexpr,
1061
+ ):
1062
+ # Adapted from
1063
+ # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
1064
+ # This version supports flatten input tensors from vllm
1065
+ # and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
1066
+ # instead of (3, bsz, seq_len, head_dim), also supports interleaved rotary
1067
+ pid = tl.program_id(0)
1068
+ # locate start address
1069
+ q_ptr = q_ptr + pid * (n_qh * hd)
1070
+ k_ptr = k_ptr + pid * (n_kh * hd)
1071
+
1072
+ # ####################################################################
1073
+ # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
1074
+ # m of this program instance
1075
+ # ####################################################################
1076
+ # Note: cos and sin now have shape (3, num_tokens, head_dim // 2)
1077
+
1078
+ # Updated stride calculation for half head_dim
1079
+ half_rd = rd // 2
1080
+ t_cos = cos + pid * half_rd
1081
+ h_cos = t_cos + num_tokens * half_rd
1082
+ w_cos = h_cos + num_tokens * half_rd
1083
+ t_sin = sin + pid * half_rd
1084
+ h_sin = t_sin + num_tokens * half_rd
1085
+ w_sin = h_sin + num_tokens * half_rd
1086
+
1087
+ # Updated offsets for half head_dim
1088
+ cos_offsets = tl.arange(0, pad_hd // 2)
1089
+ if is_interleaved:
1090
+ h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h)
1091
+ w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w)
1092
+ t_mask = ~(h_mask | w_mask)
1093
+ else:
1094
+ t_end = mrope_section_t
1095
+ h_end = t_end + mrope_section_h
1096
+ t_mask = cos_offsets < mrope_section_t
1097
+ h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
1098
+ w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
1099
+
1100
+ t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
1101
+ h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
1102
+ w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
1103
+ t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
1104
+ h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
1105
+ w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
1106
+
1107
+ cos_row = t_cos_row + h_cos_row + w_cos_row
1108
+ sin_row = t_sin_row + h_sin_row + w_sin_row
1109
+
1110
+ # ####################################################################
1111
+ # Load the left and right half of q and k for the current
1112
+ # program instance (i.e. for the current token) separately
1113
+ # ####################################################################
1114
+ # left half of the head
1115
+ first_half_q_offsets = (
1116
+ tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
1117
+ )
1118
+ first_half_k_offsets = (
1119
+ tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
1120
+ )
1121
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
1122
+ tl.arange(0, pad_hd // 2)[None, :] < rd // 2
1123
+ )
1124
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
1125
+ tl.arange(0, pad_hd // 2)[None, :] < rd // 2
1126
+ )
1127
+
1128
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
1129
+ sin_row.dtype
1130
+ )
1131
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
1132
+ sin_row.dtype
1133
+ )
1134
+
1135
+ # right half of the head
1136
+ second_half_q_offsets = first_half_q_offsets + (rd // 2)
1137
+ second_half_k_offsets = first_half_k_offsets + (rd // 2)
1138
+ second_q_mask = first_q_mask
1139
+ second_k_mask = first_k_mask
1140
+
1141
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
1142
+ sin_row.dtype
1143
+ )
1144
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
1145
+ sin_row.dtype
1146
+ )
1147
+
1148
+ # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
1149
+ # Since cos and sin are now half-size,
1150
+ # we use the same cos_row and sin_row for both halves
1151
+ new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
1152
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
1153
+ new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
1154
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
1155
+
1156
+ new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
1157
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
1158
+ new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
1159
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
1160
+
1161
+
1162
+ def triton_mrope(
1163
+ q: torch.Tensor,
1164
+ k: torch.Tensor,
1165
+ cos: torch.Tensor,
1166
+ sin: torch.Tensor,
1167
+ mrope_section: list[int],
1168
+ head_size: int,
1169
+ rotary_dim: int,
1170
+ mrope_interleaved: bool,
1171
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1172
+ """The mrope triton kernel.
1173
+
1174
+ Args:
1175
+ q: [num_tokens, num_heads * head_size]
1176
+ k: [num_tokens, num_kv_heads * head_size]
1177
+ cos: [3, num_tokens, head_size //2 ]
1178
+ (T/H/W positions with multimodal inputs)
1179
+ sin: [3, num_tokens, head_size //2 ]
1180
+ (T/H/W positions with multimodal inputs)
1181
+ mrope_section: [t, h, w]
1182
+ head_size: int
1183
+ """
1184
+ n_row, n_q_head_head_dim = q.shape
1185
+ assert (
1186
+ n_q_head_head_dim % head_size == 0
1187
+ ), f"q shape {n_q_head_head_dim} must be divisible by head_size {head_size}"
1188
+ n_q_head = n_q_head_head_dim // head_size
1189
+ assert (
1190
+ k.shape[1] % head_size == 0
1191
+ ), f"k shape {k.shape[1]} must be divisible by head_size {head_size}"
1192
+ n_kv_head = k.shape[1] // head_size
1193
+ pad_hd = triton.next_power_of_2(head_size)
1194
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
1195
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
1196
+
1197
+ # ensure tensors passed into the kernel are contiguous.
1198
+ # It will be no-op if they are already contiguous
1199
+ q = q.contiguous()
1200
+ k = k.contiguous()
1201
+ cos = cos.contiguous()
1202
+ sin = sin.contiguous()
1203
+
1204
+ _triton_mrope_forward[(n_row,)](
1205
+ q,
1206
+ k,
1207
+ cos,
1208
+ sin,
1209
+ n_row,
1210
+ n_q_head,
1211
+ n_kv_head,
1212
+ head_size,
1213
+ rotary_dim,
1214
+ pad_n_q_head,
1215
+ pad_n_kv_head,
1216
+ pad_hd,
1217
+ mrope_section[0],
1218
+ mrope_section[1],
1219
+ mrope_section[2],
1220
+ mrope_interleaved,
1221
+ )
1222
+ return q, k
1223
+
1224
+
1011
1225
  class MRotaryEmbedding(RotaryEmbedding):
1012
1226
  """Rotary Embedding with Multimodal Sections."""
1013
1227
 
@@ -1020,12 +1234,14 @@ class MRotaryEmbedding(RotaryEmbedding):
1020
1234
  is_neox_style: bool,
1021
1235
  dtype: torch.dtype,
1022
1236
  mrope_section: Optional[List[int]] = None,
1237
+ mrope_interleaved: bool = False,
1023
1238
  ) -> None:
1024
1239
  super().__init__(
1025
1240
  head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
1026
1241
  )
1027
1242
 
1028
1243
  self.mrope_section = mrope_section
1244
+ self.mrope_interleaved = mrope_interleaved
1029
1245
  if self.mrope_section:
1030
1246
  expected_sum = rotary_dim // 2
1031
1247
  actual_sum = sum(self.mrope_section)
@@ -1059,8 +1275,17 @@ class MRotaryEmbedding(RotaryEmbedding):
1059
1275
  f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
1060
1276
  )
1061
1277
 
1278
+ def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
1279
+ # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
1280
+ # is expensive, so avoid calling it if possible
1281
+ if (
1282
+ self.cos_sin_cache.device != query.device
1283
+ or self.cos_sin_cache.dtype != query.dtype
1284
+ ):
1285
+ self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
1286
+
1062
1287
  @torch.compile(dynamic=True, backend=get_compiler_backend())
1063
- def forward(
1288
+ def _forward_native(
1064
1289
  self,
1065
1290
  positions: torch.Tensor,
1066
1291
  query: torch.Tensor,
@@ -1086,15 +1311,18 @@ class MRotaryEmbedding(RotaryEmbedding):
1086
1311
  cos, sin = cos_sin.chunk(2, dim=-1)
1087
1312
  if positions.ndim == 2:
1088
1313
  assert self.mrope_section
1089
-
1090
- cos = torch.cat(
1091
- [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
1092
- dim=-1,
1093
- )
1094
- sin = torch.cat(
1095
- [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
1096
- dim=-1,
1097
- )
1314
+ if self.mrope_interleaved:
1315
+ cos = apply_interleaved_rope(cos, self.mrope_section)
1316
+ sin = apply_interleaved_rope(sin, self.mrope_section)
1317
+ else:
1318
+ cos = torch.cat(
1319
+ [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
1320
+ dim=-1,
1321
+ )
1322
+ sin = torch.cat(
1323
+ [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
1324
+ dim=-1,
1325
+ )
1098
1326
 
1099
1327
  query_shape = query.shape
1100
1328
  query = query.view(num_tokens, -1, self.head_size)
@@ -1111,6 +1339,72 @@ class MRotaryEmbedding(RotaryEmbedding):
1111
1339
  key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
1112
1340
  return query, key
1113
1341
 
1342
+ def forward(
1343
+ self,
1344
+ positions: torch.Tensor,
1345
+ query: torch.Tensor,
1346
+ key: torch.Tensor,
1347
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
1348
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1349
+ """Forward pass with optional Triton kernel acceleration.
1350
+ Args:
1351
+ positions:
1352
+ [num_tokens,] (text only) or
1353
+ [3, num_tokens] (T/H/W positions with multimodal inputs)
1354
+ query: [num_tokens, num_heads * head_size]
1355
+ key: [num_tokens, num_kv_heads * head_size]
1356
+ """
1357
+ assert positions.ndim == 1 or positions.ndim == 2
1358
+
1359
+ if positions.ndim == 2 and self.mrope_section and _is_cuda:
1360
+ return self._forward_triton(positions, query, key)
1361
+ else:
1362
+ return self._forward_native(positions, query, key)
1363
+
1364
+ def _forward_triton(
1365
+ self,
1366
+ positions: torch.Tensor,
1367
+ query: torch.Tensor,
1368
+ key: torch.Tensor,
1369
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1370
+ assert positions.ndim == 1 or positions.ndim == 2
1371
+ assert key is not None
1372
+
1373
+ self._match_cos_sin_cache_dtype(query)
1374
+ num_tokens = positions.shape[-1]
1375
+ cos_sin = self.cos_sin_cache[positions]
1376
+ cos, sin = cos_sin.chunk(2, dim=-1)
1377
+ query_shape = query.shape
1378
+ key_shape = key.shape
1379
+ if positions.ndim == 2:
1380
+ assert self.mrope_section
1381
+
1382
+ q, k = triton_mrope(
1383
+ query,
1384
+ key,
1385
+ cos,
1386
+ sin,
1387
+ self.mrope_section,
1388
+ self.head_size,
1389
+ self.rotary_dim,
1390
+ self.mrope_interleaved,
1391
+ )
1392
+
1393
+ return q.reshape(query_shape), k.reshape(key_shape)
1394
+
1395
+ query = query.view(num_tokens, -1, self.head_size)
1396
+ query_rot = query[..., : self.rotary_dim]
1397
+ query_pass = query[..., self.rotary_dim :]
1398
+ query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
1399
+ query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
1400
+
1401
+ key = key.view(num_tokens, -1, self.head_size)
1402
+ key_rot = key[..., : self.rotary_dim]
1403
+ key_pass = key[..., self.rotary_dim :]
1404
+ key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
1405
+ key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
1406
+ return query, key
1407
+
1114
1408
  # Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439
1115
1409
  @staticmethod
1116
1410
  def get_rope_index(
@@ -1126,6 +1420,28 @@ class MRotaryEmbedding(RotaryEmbedding):
1126
1420
  second_per_grid_ts: Optional[torch.Tensor] = None,
1127
1421
  **kwargs,
1128
1422
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1423
+ if model_type == "qwen3_omni_moe":
1424
+ # For qwen3-omni
1425
+ return MRotaryEmbedding.get_rope_index_qwen3_omni(
1426
+ spatial_merge_size,
1427
+ image_token_id,
1428
+ video_token_id,
1429
+ vision_start_token_id,
1430
+ tokens_per_second,
1431
+ input_ids,
1432
+ image_grid_thw,
1433
+ video_grid_thw,
1434
+ second_per_grid_ts,
1435
+ **kwargs,
1436
+ )
1437
+ if (
1438
+ model_type.startswith("qwen3_vl") or model_type.startswith("qwen3_vl_moe")
1439
+ ) and video_grid_thw is not None:
1440
+ video_grid_thw = torch.repeat_interleave(
1441
+ video_grid_thw, video_grid_thw[:, 0], dim=0
1442
+ )
1443
+ video_grid_thw[:, 0] = 1
1444
+
1129
1445
  mrope_position_deltas = []
1130
1446
  if input_ids is not None and (
1131
1447
  image_grid_thw is not None or video_grid_thw is not None
@@ -1211,7 +1527,11 @@ class MRotaryEmbedding(RotaryEmbedding):
1211
1527
 
1212
1528
  time_tensor_long = time_tensor.long()
1213
1529
  t_index = time_tensor_long.flatten()
1214
- elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"):
1530
+ elif model_type in (
1531
+ "qwen2_vl",
1532
+ "qwen3_vl",
1533
+ "qwen3_vl_moe",
1534
+ ):
1215
1535
  t_index = (
1216
1536
  torch.arange(llm_grid_t)
1217
1537
  .view(-1, 1)
@@ -1219,7 +1539,7 @@ class MRotaryEmbedding(RotaryEmbedding):
1219
1539
  .flatten()
1220
1540
  )
1221
1541
  else:
1222
- raise RuntimeError("Unimplemented")
1542
+ raise RuntimeError(f"Unimplemented model type: {model_type}")
1223
1543
  h_index = (
1224
1544
  torch.arange(llm_grid_h)
1225
1545
  .view(1, -1, 1)
@@ -1269,6 +1589,304 @@ class MRotaryEmbedding(RotaryEmbedding):
1269
1589
  mrope_position_deltas = max_position_ids + 1 - s
1270
1590
  return position_ids, mrope_position_deltas
1271
1591
 
1592
+ @staticmethod
1593
+ def get_rope_index_qwen3_omni(
1594
+ spatial_merge_size: int,
1595
+ image_token_id: int,
1596
+ video_token_id: int,
1597
+ vision_start_token_id: int,
1598
+ tokens_per_second: Optional[int] = None,
1599
+ input_ids: Optional[torch.LongTensor] = None,
1600
+ image_grid_thw: Optional[torch.LongTensor] = None,
1601
+ video_grid_thw: Optional[torch.LongTensor] = None,
1602
+ second_per_grid_ts: Optional[torch.Tensor] = None,
1603
+ **kwargs,
1604
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1605
+ # For qwen3-omni
1606
+ audio_token_id = kwargs["audio_token_id"]
1607
+ audio_start_token_id = kwargs["audio_start_token_id"]
1608
+ position_id_per_seconds = kwargs["position_id_per_seconds"]
1609
+ use_audio_in_video = kwargs.get("use_audio_in_video", False)
1610
+ audio_seqlens = kwargs.get("audio_seqlens", None)
1611
+ second_per_grids = second_per_grid_ts
1612
+
1613
+ mrope_position_deltas = []
1614
+ if input_ids is not None and (
1615
+ image_grid_thw is not None or video_grid_thw is not None
1616
+ ):
1617
+ total_input_ids = input_ids
1618
+ position_ids = torch.zeros(
1619
+ 3,
1620
+ input_ids.shape[0],
1621
+ input_ids.shape[1],
1622
+ dtype=torch.float,
1623
+ device=input_ids.device,
1624
+ )
1625
+ image_idx, video_idx, audio_idx = 0, 0, 0
1626
+ for i, current_input_ids in enumerate(total_input_ids):
1627
+ image_nums, video_nums, audio_nums = 0, 0, 0
1628
+ vision_start_indices = torch.argwhere(
1629
+ current_input_ids == vision_start_token_id
1630
+ ).squeeze(1)
1631
+ if vision_start_indices.numel() > 0:
1632
+ vision_tokens = current_input_ids[vision_start_indices + 1]
1633
+ image_nums = (vision_tokens == image_token_id).sum()
1634
+ video_nums = (
1635
+ (vision_tokens == audio_start_token_id).sum()
1636
+ if use_audio_in_video
1637
+ else (vision_tokens == video_token_id).sum()
1638
+ )
1639
+ audio_nums = torch.sum(current_input_ids == audio_start_token_id)
1640
+ input_tokens = current_input_ids.tolist()
1641
+ llm_pos_ids_list: list = []
1642
+ st = 0
1643
+ remain_images, remain_videos, remain_audios = (
1644
+ image_nums,
1645
+ video_nums,
1646
+ audio_nums,
1647
+ )
1648
+ multimodal_nums = (
1649
+ image_nums + audio_nums
1650
+ if use_audio_in_video
1651
+ else image_nums + video_nums + audio_nums
1652
+ )
1653
+ for _ in range(multimodal_nums):
1654
+ st_idx = (
1655
+ llm_pos_ids_list[-1].max() + 1
1656
+ if len(llm_pos_ids_list) > 0
1657
+ else 0
1658
+ )
1659
+ ed_vision_start = (
1660
+ input_tokens.index(vision_start_token_id, st)
1661
+ if (
1662
+ (
1663
+ image_token_id in input_tokens
1664
+ or video_token_id in input_tokens
1665
+ )
1666
+ and (remain_videos > 0 or remain_images > 0)
1667
+ )
1668
+ else len(input_tokens) + 1
1669
+ )
1670
+ ed_audio_start = (
1671
+ input_tokens.index(audio_start_token_id, st)
1672
+ if (audio_token_id in input_tokens and remain_audios > 0)
1673
+ else len(input_tokens) + 1
1674
+ )
1675
+ min_ed = min(ed_vision_start, ed_audio_start)
1676
+
1677
+ text_len = min_ed - st
1678
+ if text_len != 0:
1679
+ llm_pos_ids_list.append(
1680
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
1681
+ )
1682
+ st_idx += text_len
1683
+ # Audio in Video
1684
+ if (
1685
+ min_ed == ed_vision_start
1686
+ and ed_vision_start + 1 == ed_audio_start
1687
+ ):
1688
+ bos_len, eos_len = 2, 2
1689
+ else:
1690
+ bos_len, eos_len = 1, 1
1691
+ llm_pos_ids_list.append(
1692
+ torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx
1693
+ )
1694
+ st_idx += bos_len
1695
+ # Audio Only
1696
+ if min_ed == ed_audio_start:
1697
+ audio_len = MRotaryEmbedding._get_feat_extract_output_lengths(
1698
+ audio_seqlens[audio_idx]
1699
+ )
1700
+ llm_pos_ids = (
1701
+ torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
1702
+ )
1703
+ llm_pos_ids_list.append(llm_pos_ids)
1704
+
1705
+ st += int(text_len + bos_len + audio_len + eos_len)
1706
+ audio_idx += 1
1707
+ remain_audios -= 1
1708
+
1709
+ # Image Only
1710
+ elif (
1711
+ min_ed == ed_vision_start
1712
+ and current_input_ids[ed_vision_start + 1] == image_token_id
1713
+ ):
1714
+ grid_t = image_grid_thw[image_idx][0]
1715
+ grid_hs = image_grid_thw[:, 1]
1716
+ grid_ws = image_grid_thw[:, 2]
1717
+ t_index = (
1718
+ torch.arange(grid_t) * 1 * position_id_per_seconds
1719
+ ).float()
1720
+ llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision(
1721
+ st_idx,
1722
+ image_idx,
1723
+ spatial_merge_size,
1724
+ t_index,
1725
+ grid_hs,
1726
+ grid_ws,
1727
+ input_ids.device,
1728
+ )
1729
+ image_len = image_grid_thw[image_idx].prod() // (
1730
+ spatial_merge_size**2
1731
+ )
1732
+ llm_pos_ids_list.append(llm_pos_ids)
1733
+
1734
+ st += int(text_len + bos_len + image_len + eos_len)
1735
+ image_idx += 1
1736
+ remain_images -= 1
1737
+
1738
+ # Video Only
1739
+ elif (
1740
+ min_ed == ed_vision_start
1741
+ and current_input_ids[ed_vision_start + 1] == video_token_id
1742
+ ):
1743
+ grid_t = video_grid_thw[video_idx][0]
1744
+ grid_hs = video_grid_thw[:, 1]
1745
+ grid_ws = video_grid_thw[:, 2]
1746
+ t_index = (
1747
+ torch.arange(grid_t)
1748
+ * second_per_grids[video_idx].cpu().float()
1749
+ * position_id_per_seconds
1750
+ ).float()
1751
+ llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision(
1752
+ st_idx,
1753
+ video_idx,
1754
+ spatial_merge_size,
1755
+ t_index,
1756
+ grid_hs,
1757
+ grid_ws,
1758
+ input_ids.device,
1759
+ )
1760
+ video_len = video_grid_thw[video_idx].prod() // (
1761
+ spatial_merge_size**2
1762
+ )
1763
+ llm_pos_ids_list.append(llm_pos_ids)
1764
+
1765
+ st += int(text_len + bos_len + video_len + eos_len)
1766
+ video_idx += 1
1767
+ remain_videos -= 1
1768
+
1769
+ # Audio in Video
1770
+ elif (
1771
+ min_ed == ed_vision_start
1772
+ and ed_vision_start + 1 == ed_audio_start
1773
+ ):
1774
+ audio_len = MRotaryEmbedding._get_feat_extract_output_lengths(
1775
+ audio_seqlens[audio_idx]
1776
+ )
1777
+ audio_llm_pos_ids = (
1778
+ torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
1779
+ )
1780
+ grid_t = video_grid_thw[video_idx][0]
1781
+ grid_hs = video_grid_thw[:, 1]
1782
+ grid_ws = video_grid_thw[:, 2]
1783
+
1784
+ t_index = (
1785
+ torch.arange(grid_t)
1786
+ * second_per_grids[video_idx].cpu().float()
1787
+ * position_id_per_seconds
1788
+ ).float()
1789
+ video_llm_pos_ids = (
1790
+ MRotaryEmbedding._get_llm_pos_ids_for_vision(
1791
+ st_idx,
1792
+ video_idx,
1793
+ spatial_merge_size,
1794
+ t_index,
1795
+ grid_hs,
1796
+ grid_ws,
1797
+ input_ids.device,
1798
+ )
1799
+ )
1800
+ video_data_index, audio_data_index = 0, 0
1801
+ while (
1802
+ video_data_index < video_llm_pos_ids.shape[-1]
1803
+ and audio_data_index < audio_llm_pos_ids.shape[-1]
1804
+ ):
1805
+ if (
1806
+ video_llm_pos_ids[0][video_data_index]
1807
+ <= audio_llm_pos_ids[0][audio_data_index]
1808
+ ):
1809
+ llm_pos_ids_list.append(
1810
+ video_llm_pos_ids[
1811
+ :, video_data_index : video_data_index + 1
1812
+ ]
1813
+ )
1814
+ video_data_index += 1
1815
+ else:
1816
+ llm_pos_ids_list.append(
1817
+ audio_llm_pos_ids[
1818
+ :, audio_data_index : audio_data_index + 1
1819
+ ]
1820
+ )
1821
+ audio_data_index += 1
1822
+ if video_data_index < video_llm_pos_ids.shape[-1]:
1823
+ llm_pos_ids_list.append(
1824
+ video_llm_pos_ids[
1825
+ :, video_data_index : video_llm_pos_ids.shape[-1]
1826
+ ]
1827
+ )
1828
+ if audio_data_index < audio_llm_pos_ids.shape[-1]:
1829
+ llm_pos_ids_list.append(
1830
+ audio_llm_pos_ids[
1831
+ :, audio_data_index : audio_llm_pos_ids.shape[-1]
1832
+ ]
1833
+ )
1834
+ video_len = video_grid_thw[video_idx].prod() // (
1835
+ spatial_merge_size**2
1836
+ )
1837
+
1838
+ st += int(text_len + bos_len + audio_len + video_len + eos_len)
1839
+
1840
+ audio_idx += 1
1841
+ video_idx += 1
1842
+ remain_videos -= 1
1843
+ remain_audios -= 1
1844
+ st_idx = (
1845
+ llm_pos_ids_list[-1].max() + 1
1846
+ if len(llm_pos_ids_list) > 0
1847
+ else 0
1848
+ )
1849
+ llm_pos_ids_list.append(
1850
+ torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx
1851
+ )
1852
+
1853
+ if st < len(input_tokens):
1854
+ st_idx = (
1855
+ llm_pos_ids_list[-1].max() + 1
1856
+ if len(llm_pos_ids_list) > 0
1857
+ else 0
1858
+ )
1859
+ text_len = len(input_tokens) - st
1860
+ llm_pos_ids_list.append(
1861
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
1862
+ )
1863
+
1864
+ llm_positions = torch.cat(
1865
+ [item.float() for item in llm_pos_ids_list], dim=1
1866
+ ).reshape(3, -1)
1867
+
1868
+ position_ids[..., i, :] = llm_positions.to(position_ids.device)
1869
+ mrope_position_deltas.append(
1870
+ llm_positions.max() + 1 - len(current_input_ids)
1871
+ )
1872
+ mrope_position_deltas = torch.tensor(
1873
+ mrope_position_deltas, device=input_ids.device
1874
+ ).unsqueeze(1)
1875
+
1876
+ return position_ids, mrope_position_deltas
1877
+ else:
1878
+ s = input_ids.shape[1]
1879
+ position_ids = torch.arange(s)
1880
+ position_ids = (
1881
+ position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
1882
+ )
1883
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(
1884
+ -1, keepdim=True
1885
+ )[0]
1886
+ mrope_position_deltas = max_position_ids + 1 - s
1887
+
1888
+ return position_ids, mrope_position_deltas
1889
+
1272
1890
  # Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120
1273
1891
  @staticmethod
1274
1892
  def get_rope_index_glm4v(
@@ -1467,6 +2085,44 @@ class MRotaryEmbedding(RotaryEmbedding):
1467
2085
 
1468
2086
  return position_ids, mrope_position_deltas
1469
2087
 
2088
+ # For qwen3-omni
2089
+ @staticmethod
2090
+ def _get_feat_extract_output_lengths(input_lengths):
2091
+ """
2092
+ Computes the output length of the convolutional layers and the output length of the audio encoder
2093
+ """
2094
+ input_lengths_leave = input_lengths % 100
2095
+ feat_lengths = (input_lengths_leave - 1) // 2 + 1
2096
+ output_lengths = (
2097
+ ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
2098
+ )
2099
+ return output_lengths
2100
+
2101
+ # For qwen3-omni
2102
+ @staticmethod
2103
+ def _get_llm_pos_ids_for_vision(
2104
+ st_idx, vision_idx, spatial_merge_size, t_index, grid_hs, grid_ws, device
2105
+ ):
2106
+ grid_h = grid_hs[vision_idx] // spatial_merge_size
2107
+ grid_w = grid_ws[vision_idx] // spatial_merge_size
2108
+
2109
+ h_index = (
2110
+ torch.arange(grid_h, device=device)
2111
+ .view(1, -1, 1)
2112
+ .expand(len(t_index), -1, grid_w)
2113
+ .flatten()
2114
+ )
2115
+ w_index = (
2116
+ torch.arange(grid_w, device=device)
2117
+ .view(1, 1, -1)
2118
+ .expand(len(t_index), grid_h, -1)
2119
+ .flatten()
2120
+ )
2121
+ t_index = t_index.view(-1, 1).expand(-1, grid_h * grid_w).flatten()
2122
+
2123
+ llm_pos_ids = torch.stack([t_index, h_index, w_index], dim=0) + st_idx
2124
+ return llm_pos_ids
2125
+
1470
2126
 
1471
2127
  class DualChunkRotaryEmbedding(CustomOp):
1472
2128
  """Rotary positional embedding for Dual Chunk Attention."""
@@ -1768,6 +2424,7 @@ def get_rope(
1768
2424
  is_neox_style,
1769
2425
  dtype,
1770
2426
  mrope_section=rope_scaling["mrope_section"],
2427
+ mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
1771
2428
  )
1772
2429
  else:
1773
2430
  rotary_emb = RotaryEmbedding(