sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,8 @@ from sglang.srt.distributed import (
24
24
  get_tensor_model_parallel_world_size,
25
25
  )
26
26
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
27
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
28
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
27
29
  from sglang.srt.layers.parameter import (
28
30
  ChannelQuantScaleParameter,
29
31
  ModelWeightParameter,
@@ -49,8 +51,10 @@ from sglang.srt.utils import (
49
51
  )
50
52
 
51
53
  if TYPE_CHECKING:
52
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
53
- from sglang.srt.layers.moe.topk import TopKOutput
54
+ from sglang.srt.layers.moe.token_dispatcher import (
55
+ CombineInput,
56
+ StandardDispatchOutput,
57
+ )
54
58
 
55
59
  _is_cuda = is_cuda()
56
60
  _is_cpu_amx_available = cpu_has_amx_support()
@@ -339,9 +343,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
339
343
  _is_cpu_amx_available
340
344
  ), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
341
345
  _amx_process_weight_after_loading(layer, ["weight"])
342
- return
343
-
344
- layer.weight = Parameter(layer.weight.t(), requires_grad=False)
346
+ else:
347
+ layer.weight = Parameter(layer.weight.t(), requires_grad=False)
345
348
  layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
346
349
 
347
350
  def create_weights(
@@ -390,13 +393,23 @@ class W8A8Int8LinearMethod(LinearMethodBase):
390
393
  x.dtype,
391
394
  True, # is_vnni
392
395
  )
393
-
394
396
  x_q, x_scale = per_token_quant_int8(x)
395
397
 
396
- return int8_scaled_mm(
397
- x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
398
+ x_q_2d = x_q.view(-1, x_q.shape[-1])
399
+ x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
400
+ output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
401
+
402
+ output = int8_scaled_mm(
403
+ x_q_2d,
404
+ layer.weight,
405
+ x_scale_2d,
406
+ layer.weight_scale,
407
+ out_dtype=x.dtype,
408
+ bias=bias,
398
409
  )
399
410
 
411
+ return output.view(output_shape)
412
+
400
413
 
401
414
  class W8A8Int8MoEMethod(FusedMoEMethodBase):
402
415
  """MoE method for INT8.
@@ -417,7 +430,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
417
430
  layer: torch.nn.Module,
418
431
  num_experts: int,
419
432
  hidden_size: int,
420
- intermediate_size: int,
433
+ intermediate_size_per_partition: int,
421
434
  params_dtype: torch.dtype,
422
435
  **extra_weight_attrs,
423
436
  ):
@@ -428,7 +441,10 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
428
441
  # WEIGHTS
429
442
  w13_weight = torch.nn.Parameter(
430
443
  torch.empty(
431
- num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
444
+ num_experts,
445
+ 2 * intermediate_size_per_partition,
446
+ hidden_size,
447
+ dtype=torch.int8,
432
448
  ),
433
449
  requires_grad=False,
434
450
  )
@@ -436,14 +452,21 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
436
452
  set_weight_attrs(w13_weight, extra_weight_attrs)
437
453
 
438
454
  w2_weight = torch.nn.Parameter(
439
- torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
455
+ torch.empty(
456
+ num_experts,
457
+ hidden_size,
458
+ intermediate_size_per_partition,
459
+ dtype=torch.int8,
460
+ ),
440
461
  requires_grad=False,
441
462
  )
442
463
  layer.register_parameter("w2_weight", w2_weight)
443
464
  set_weight_attrs(w2_weight, extra_weight_attrs)
444
465
 
445
466
  w13_weight_scale = torch.nn.Parameter(
446
- torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
467
+ torch.ones(
468
+ num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
469
+ ),
447
470
  requires_grad=False,
448
471
  )
449
472
  w2_weight_scale = torch.nn.Parameter(
@@ -472,10 +495,9 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
472
495
  _is_cpu_amx_available
473
496
  ), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
474
497
  _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
475
- return
476
-
477
- layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
478
- layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
498
+ else:
499
+ layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
500
+ layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
479
501
  layer.w13_weight_scale = Parameter(
480
502
  layer.w13_weight_scale.data, requires_grad=False
481
503
  )
@@ -483,23 +505,30 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
483
505
  layer.w2_weight_scale.data, requires_grad=False
484
506
  )
485
507
 
508
+ def create_moe_runner(
509
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
510
+ ):
511
+ self.moe_runner_config = moe_runner_config
512
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
513
+
486
514
  def apply(
487
515
  self,
488
516
  layer: torch.nn.Module,
489
- x: torch.Tensor,
490
- topk_output: TopKOutput,
491
- moe_runner_config: MoeRunnerConfig,
517
+ dispatch_output: StandardDispatchOutput,
492
518
  ) -> torch.Tensor:
493
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
519
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
520
+
521
+ x = dispatch_output.hidden_states
522
+ topk_output = dispatch_output.topk_output
494
523
 
495
524
  if use_intel_amx_backend(layer):
496
525
  from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
497
526
 
498
527
  topk_weights, topk_ids, _ = topk_output
499
528
  x, topk_weights = apply_topk_weights_cpu(
500
- moe_runner_config.apply_router_weight_on_input, topk_weights, x
529
+ self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
501
530
  )
502
- return torch.ops.sgl_kernel.fused_experts_cpu(
531
+ output = torch.ops.sgl_kernel.fused_experts_cpu(
503
532
  x,
504
533
  layer.w13_weight,
505
534
  layer.w2_weight,
@@ -515,20 +544,19 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
515
544
  layer.w2_input_scale, # a2_scale
516
545
  True, # is_vnni
517
546
  )
547
+ return StandardCombineInput(hidden_states=output)
518
548
 
519
- return fused_experts(
520
- x,
521
- layer.w13_weight,
522
- layer.w2_weight,
523
- topk_output=topk_output,
524
- moe_runner_config=moe_runner_config,
549
+ quant_info = TritonMoeQuantInfo(
550
+ w13_weight=layer.w13_weight,
551
+ w2_weight=layer.w2_weight,
525
552
  use_int8_w8a8=True,
526
553
  per_channel_quant=True,
527
- w1_scale=(layer.w13_weight_scale),
528
- w2_scale=(layer.w2_weight_scale),
529
- a1_scale=layer.w13_input_scale,
554
+ w13_scale=layer.w13_weight_scale,
555
+ w2_scale=layer.w2_weight_scale,
556
+ a13_scale=layer.w13_input_scale,
530
557
  a2_scale=layer.w2_input_scale,
531
558
  )
559
+ return self.runner.run(dispatch_output, quant_info)
532
560
 
533
561
 
534
562
  class NPU_W8A8LinearMethodImpl:
@@ -620,6 +648,7 @@ class NPU_W8A8LinearMethodImpl:
620
648
  layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
621
649
  layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
622
650
  layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
651
+ layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
623
652
 
624
653
 
625
654
  class NPU_W8A8LinearMethodMTImpl:
@@ -812,6 +841,7 @@ class NPU_W8A8DynamicLinearMethodImpl:
812
841
  layer.weight_scale.data = layer.weight_scale.data.flatten()
813
842
  layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
814
843
  layer.weight_offset.data = layer.weight_offset.data.flatten()
844
+ layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
815
845
 
816
846
 
817
847
  class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
@@ -900,7 +930,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
900
930
  layer: torch.nn.Module,
901
931
  num_experts: int,
902
932
  hidden_size: int,
903
- intermediate_size: int,
933
+ intermediate_size_per_partition: int,
904
934
  params_dtype: torch.dtype,
905
935
  **extra_weight_attrs,
906
936
  ) -> None:
@@ -914,21 +944,31 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
914
944
  # weight
915
945
  w13_weight = torch.nn.Parameter(
916
946
  torch.empty(
917
- num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
947
+ num_experts,
948
+ 2 * intermediate_size_per_partition,
949
+ hidden_size,
950
+ dtype=torch.int8,
918
951
  ),
919
952
  requires_grad=False,
920
953
  )
921
954
  layer.register_parameter("w13_weight", w13_weight)
922
955
  set_weight_attrs(w13_weight, extra_weight_attrs)
923
956
  w2_weight = torch.nn.Parameter(
924
- torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
957
+ torch.empty(
958
+ num_experts,
959
+ hidden_size,
960
+ intermediate_size_per_partition,
961
+ dtype=torch.int8,
962
+ ),
925
963
  requires_grad=False,
926
964
  )
927
965
  layer.register_parameter("w2_weight", w2_weight)
928
966
  set_weight_attrs(w2_weight, extra_weight_attrs)
929
967
  # scale
930
968
  w13_weight_scale = torch.nn.Parameter(
931
- torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
969
+ torch.empty(
970
+ num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
971
+ ),
932
972
  requires_grad=False,
933
973
  )
934
974
  layer.register_parameter("w13_weight_scale", w13_weight_scale)
@@ -941,7 +981,9 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
941
981
  set_weight_attrs(w2_weight_scale, extra_weight_attrs)
942
982
  # offset
943
983
  w13_weight_offset = torch.nn.Parameter(
944
- torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
984
+ torch.empty(
985
+ num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
986
+ ),
945
987
  requires_grad=False,
946
988
  )
947
989
  layer.register_parameter("w13_weight_offset", w13_weight_offset)
@@ -973,18 +1015,25 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
973
1015
  layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
974
1016
  )
975
1017
 
1018
+ def create_moe_runner(
1019
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1020
+ ):
1021
+ self.moe_runner_config = moe_runner_config
1022
+
976
1023
  def apply(
977
1024
  self,
978
1025
  layer,
979
- x,
980
- topk_output: TopKOutput,
981
- moe_runner_config: MoeRunnerConfig,
982
- ) -> torch.Tensor:
1026
+ dispatch_output: StandardDispatchOutput,
1027
+ ) -> CombineInput:
1028
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1029
+
1030
+ x = dispatch_output.hidden_states
1031
+ topk_output = dispatch_output.topk_output
983
1032
 
984
1033
  topk_weights, topk_ids, _ = topk_output
985
1034
  topk_ids = topk_ids.to(torch.int32)
986
1035
  topk_weights = topk_weights.to(x.dtype)
987
- return npu_fused_experts(
1036
+ output = npu_fused_experts(
988
1037
  hidden_states=x,
989
1038
  w13=layer.w13_weight,
990
1039
  w13_scale=layer.w13_weight_scale,
@@ -994,3 +1043,4 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
994
1043
  topk_ids=topk_ids,
995
1044
  top_k=topk_ids.shape[1],
996
1045
  )
1046
+ return StandardCombineInput(hidden_states=output)
@@ -12,6 +12,7 @@ from sglang.srt.custom_op import CustomOp
12
12
  from sglang.srt.utils import (
13
13
  cpu_has_amx_support,
14
14
  get_bool_env_var,
15
+ get_compiler_backend,
15
16
  is_cpu,
16
17
  is_cuda,
17
18
  is_hip,
@@ -26,13 +27,19 @@ _is_cpu_amx_available = cpu_has_amx_support()
26
27
  _is_cpu = is_cpu()
27
28
 
28
29
  if _is_cuda:
29
- from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
30
+ from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
31
+ else:
32
+ FusedSetKVBufferArg = None
33
+
30
34
  if _use_aiter:
31
35
  from aiter.rotary_embedding import get_rope as aiter_get_rope
32
36
 
33
37
  if is_npu():
34
38
  import torch_npu
35
39
 
40
+ NPU_ROTARY_MUL_MAX_NUM_HEADS = 1000
41
+ NPU_ROTARY_MUL_MAX_HEAD_SIZE = 896
42
+
36
43
 
37
44
  def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
38
45
  x1 = x[..., : x.shape[-1] // 2]
@@ -142,8 +149,13 @@ class RotaryEmbedding(CustomOp):
142
149
  query: torch.Tensor,
143
150
  key: torch.Tensor,
144
151
  offsets: Optional[torch.Tensor] = None,
152
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
145
153
  ) -> Tuple[torch.Tensor, torch.Tensor]:
146
154
  """A PyTorch-native implementation of forward()."""
155
+ assert (
156
+ fused_set_kv_buffer_arg is None
157
+ ), "fused_set_kv_buffer_arg is not supported for native implementation"
158
+
147
159
  if offsets is not None:
148
160
  positions = positions + offsets
149
161
  positions = positions.flatten()
@@ -172,12 +184,17 @@ class RotaryEmbedding(CustomOp):
172
184
  query: torch.Tensor,
173
185
  key: torch.Tensor,
174
186
  offsets: Optional[torch.Tensor] = None,
187
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
175
188
  ) -> Tuple[torch.Tensor, torch.Tensor]:
176
189
  """A PyTorch-npu implementation of forward()."""
177
- import os
190
+ assert (
191
+ fused_set_kv_buffer_arg is None
192
+ ), "fused_set_kv_buffer_arg is not supported for npu implementation"
178
193
 
179
194
  if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
180
- return self.forward_native(positions, query, key, offsets)
195
+ return self.forward_native(
196
+ positions, query, key, offsets, fused_set_kv_buffer_arg
197
+ )
181
198
  else:
182
199
  rotary_mode = "half"
183
200
  if self.is_neox_style:
@@ -202,7 +219,12 @@ class RotaryEmbedding(CustomOp):
202
219
  query: torch.Tensor,
203
220
  key: torch.Tensor,
204
221
  offsets: Optional[torch.Tensor] = None,
222
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
205
223
  ) -> Tuple[torch.Tensor, torch.Tensor]:
224
+ assert (
225
+ fused_set_kv_buffer_arg is None
226
+ ), "fused_set_kv_buffer_arg is not supported for cpu implementation"
227
+
206
228
  positions = torch.add(positions, offsets) if offsets is not None else positions
207
229
  if _is_cpu_amx_available:
208
230
  return torch.ops.sgl_kernel.rotary_embedding_cpu(
@@ -214,7 +236,9 @@ class RotaryEmbedding(CustomOp):
214
236
  self.is_neox_style,
215
237
  )
216
238
  else:
217
- return self.forward_native(positions, query, key, offsets)
239
+ return self.forward_native(
240
+ positions, query, key, offsets, fused_set_kv_buffer_arg
241
+ )
218
242
 
219
243
  def forward_cuda(
220
244
  self,
@@ -222,7 +246,7 @@ class RotaryEmbedding(CustomOp):
222
246
  query: torch.Tensor,
223
247
  key: torch.Tensor,
224
248
  offsets: Optional[torch.Tensor] = None,
225
- fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg]
249
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
226
250
  ) -> Tuple[torch.Tensor, torch.Tensor]:
227
251
  if _is_cuda and (self.head_size in [64, 128, 256, 512]):
228
252
  apply_rope_with_cos_sin_cache_inplace(
@@ -782,27 +806,33 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
782
806
  key: torch.Tensor,
783
807
  offsets: Optional[torch.Tensor] = None,
784
808
  ) -> Tuple[torch.Tensor, torch.Tensor]:
785
- # NOTE: now npu_mrope can only support `numQHeads*headSize <= 4096` pattern,
786
- # and generalization to more scenarios will be supported in the future.
787
- if query.shape[1] * query.shape[2] > 4096:
788
- return self.forward_native(positions, query, key, offsets)
789
- num_tokens = query.shape[0]
790
- rotary_mode = "half" if self.is_neox_style else "interleave"
809
+ num_tokens, num_q_heads, _ = query.shape
810
+ num_k_heads = key.shape[1]
811
+
791
812
  self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
813
+ cos_sin = self.cos_sin_cache[
814
+ torch.add(positions, offsets) if offsets is not None else positions
815
+ ]
816
+ cos, sin = cos_sin.chunk(2, dim=-1)
817
+ # Reshape to [batchsize, head_dim, seq, rotary_dim]
818
+ cos = cos.repeat(1, 2).unsqueeze(-2).unsqueeze(-2)
819
+ sin = sin.repeat(1, 2).unsqueeze(-2).unsqueeze(-2)
820
+
792
821
  query_rot = query[..., : self.rotary_dim]
793
822
  key_rot = key[..., : self.rotary_dim]
794
823
  if self.rotary_dim < self.head_size:
795
824
  query_pass = query[..., self.rotary_dim :]
796
825
  key_pass = key[..., self.rotary_dim :]
797
826
 
798
- query_rot, key_rot = torch_npu.npu_mrope(
799
- torch.add(positions, offsets) if offsets is not None else positions,
800
- query_rot.reshape(num_tokens, -1),
801
- key_rot.reshape(num_tokens, -1),
802
- self.cos_sin_cache,
803
- self.rotary_dim,
804
- mrope_section=[0, 0, 0],
805
- rotary_mode=rotary_mode,
827
+ query_rot = torch_npu.npu_interleave_rope(
828
+ query_rot.reshape(num_tokens, num_q_heads, 1, self.rotary_dim),
829
+ cos,
830
+ sin,
831
+ )
832
+ key_rot = torch_npu.npu_interleave_rope(
833
+ key_rot.reshape(num_tokens, num_k_heads, 1, self.rotary_dim),
834
+ cos,
835
+ sin,
806
836
  )
807
837
  query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim)
808
838
  key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim)
@@ -1029,12 +1059,13 @@ class MRotaryEmbedding(RotaryEmbedding):
1029
1059
  f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
1030
1060
  )
1031
1061
 
1032
- @torch.compile(dynamic=True)
1062
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
1033
1063
  def forward(
1034
1064
  self,
1035
1065
  positions: torch.Tensor,
1036
1066
  query: torch.Tensor,
1037
1067
  key: torch.Tensor,
1068
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
1038
1069
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1039
1070
  """PyTorch-native implementation equivalent to forward().
1040
1071
 
@@ -1045,6 +1076,9 @@ class MRotaryEmbedding(RotaryEmbedding):
1045
1076
  query: [num_tokens, num_heads * head_size]
1046
1077
  key: [num_tokens, num_kv_heads * head_size]
1047
1078
  """
1079
+ assert (
1080
+ fused_set_kv_buffer_arg is None
1081
+ ), "save kv cache is not supported for MRotaryEmbedding."
1048
1082
  assert positions.ndim == 1 or positions.ndim == 2
1049
1083
 
1050
1084
  num_tokens = positions.shape[-1]
@@ -1177,7 +1211,7 @@ class MRotaryEmbedding(RotaryEmbedding):
1177
1211
 
1178
1212
  time_tensor_long = time_tensor.long()
1179
1213
  t_index = time_tensor_long.flatten()
1180
- elif model_type == "qwen2_vl":
1214
+ elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"):
1181
1215
  t_index = (
1182
1216
  torch.arange(llm_grid_t)
1183
1217
  .view(-1, 1)
@@ -1888,17 +1922,30 @@ def apply_rotary_pos_emb_npu(
1888
1922
  sin: torch.Tensor,
1889
1923
  unsqueeze_dim=1,
1890
1924
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1891
- if q.shape[1] != 128:
1925
+ """Ascend implementation equivalent to apply_rotary_pos_emb_native.
1926
+
1927
+ Args:
1928
+ q: [num_tokens, num_heads, head_size]
1929
+ k: [num_tokens, num_kv_heads, head_size]
1930
+ cos: [num_tokens, head_size]
1931
+ sin: [num_tokens, head_size]
1932
+ """
1933
+ if (
1934
+ cos.dim() != 2
1935
+ or q.dim() != 3
1936
+ or q.shape[1] >= NPU_ROTARY_MUL_MAX_NUM_HEADS
1937
+ or q.shape[2] >= NPU_ROTARY_MUL_MAX_HEAD_SIZE
1938
+ ):
1939
+ # Note: num_heads and head_size of q must be less than 1000 and 896, respectively
1892
1940
  return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
1893
- cos = cos.unsqueeze(unsqueeze_dim)
1894
- cos = torch.transpose(cos, 1, 2)
1895
- sin = sin.unsqueeze(unsqueeze_dim)
1896
- sin = torch.transpose(sin, 1, 2)
1897
- q = torch.transpose(q, 1, 2)
1898
- k = torch.transpose(k, 1, 2)
1899
- q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
1900
- q_embed = torch.transpose(q_embed, 1, 2)
1901
- k_embed = torch.transpose(k_embed, 1, 2)
1941
+ cos = cos.unsqueeze(unsqueeze_dim).unsqueeze(0)
1942
+ sin = sin.unsqueeze(unsqueeze_dim).unsqueeze(0)
1943
+ q = q.unsqueeze(0)
1944
+ k = k.unsqueeze(0)
1945
+ q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
1946
+ k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
1947
+ q_embed = q_embed.squeeze(0)
1948
+ k_embed = k_embed.squeeze(0)
1902
1949
  return q_embed, k_embed
1903
1950
 
1904
1951