sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/lang/interpreter.py +1 -1
  7. sglang/launch_server.py +14 -0
  8. sglang/profiler.py +2 -2
  9. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  10. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  11. sglang/srt/configs/__init__.py +8 -0
  12. sglang/srt/configs/device_config.py +3 -1
  13. sglang/srt/configs/dots_ocr.py +64 -0
  14. sglang/srt/configs/dots_vlm.py +139 -0
  15. sglang/srt/configs/falcon_h1.py +360 -0
  16. sglang/srt/configs/internvl.py +6 -0
  17. sglang/srt/configs/load_config.py +9 -0
  18. sglang/srt/configs/model_config.py +181 -82
  19. sglang/srt/configs/qwen3_next.py +326 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +71 -19
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +326 -53
  44. sglang/srt/disaggregation/prefill.py +36 -17
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +192 -113
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  52. sglang/srt/entrypoints/grpc_server.py +810 -0
  53. sglang/srt/entrypoints/http_server.py +132 -57
  54. sglang/srt/entrypoints/openai/protocol.py +115 -7
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +207 -58
  57. sglang/srt/entrypoints/openai/serving_completions.py +17 -4
  58. sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +49 -4
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/environ.py +285 -0
  63. sglang/srt/eplb/eplb_manager.py +2 -2
  64. sglang/srt/eplb/expert_distribution.py +26 -13
  65. sglang/srt/eplb/expert_location.py +38 -8
  66. sglang/srt/eplb/expert_location_updater.py +1 -1
  67. sglang/srt/function_call/base_format_detector.py +3 -6
  68. sglang/srt/function_call/ebnf_composer.py +11 -9
  69. sglang/srt/function_call/function_call_parser.py +9 -2
  70. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  71. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  72. sglang/srt/function_call/json_array_parser.py +63 -0
  73. sglang/srt/function_call/kimik2_detector.py +17 -4
  74. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  75. sglang/srt/function_call/utils.py +96 -5
  76. sglang/srt/grpc/__init__.py +1 -0
  77. sglang/srt/grpc/compile_proto.py +245 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  81. sglang/srt/layers/activation.py +143 -9
  82. sglang/srt/layers/attention/aiter_backend.py +106 -82
  83. sglang/srt/layers/attention/ascend_backend.py +115 -9
  84. sglang/srt/layers/attention/attention_registry.py +206 -0
  85. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  86. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  87. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  88. sglang/srt/layers/attention/fla/chunk.py +242 -0
  89. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  90. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  91. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  92. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  93. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  94. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  95. sglang/srt/layers/attention/fla/index.py +37 -0
  96. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  97. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  98. sglang/srt/layers/attention/fla/op.py +66 -0
  99. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  100. sglang/srt/layers/attention/fla/utils.py +331 -0
  101. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  102. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  103. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  104. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  105. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  106. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  107. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  108. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  109. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  111. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  112. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  113. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  114. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  115. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  121. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  122. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  123. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  124. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  125. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  126. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  127. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  128. sglang/srt/layers/attention/nsa/utils.py +24 -0
  129. sglang/srt/layers/attention/nsa_backend.py +887 -0
  130. sglang/srt/layers/attention/tbo_backend.py +6 -6
  131. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  132. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  133. sglang/srt/layers/attention/triton_backend.py +57 -7
  134. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  135. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  136. sglang/srt/layers/attention/vision.py +58 -0
  137. sglang/srt/layers/attention/wave_backend.py +4 -4
  138. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  139. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  140. sglang/srt/layers/communicator.py +53 -7
  141. sglang/srt/layers/dp_attention.py +41 -2
  142. sglang/srt/layers/elementwise.py +3 -1
  143. sglang/srt/layers/layernorm.py +34 -15
  144. sglang/srt/layers/linear.py +55 -7
  145. sglang/srt/layers/logits_processor.py +44 -12
  146. sglang/srt/layers/moe/__init__.py +2 -1
  147. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  148. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  149. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  150. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  151. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  167. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  169. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  170. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  171. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  172. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  173. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  174. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  175. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  176. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  177. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  178. sglang/srt/layers/moe/topk.py +30 -9
  179. sglang/srt/layers/moe/utils.py +22 -7
  180. sglang/srt/layers/parameter.py +23 -6
  181. sglang/srt/layers/quantization/awq.py +19 -7
  182. sglang/srt/layers/quantization/base_config.py +11 -6
  183. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  184. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  185. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  186. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  187. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  188. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  189. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  190. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  191. sglang/srt/layers/quantization/fp8.py +78 -49
  192. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  193. sglang/srt/layers/quantization/gptq.py +25 -17
  194. sglang/srt/layers/quantization/modelopt_quant.py +225 -57
  195. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  196. sglang/srt/layers/quantization/mxfp4.py +77 -42
  197. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  199. sglang/srt/layers/quantization/quark/utils.py +97 -0
  200. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  201. sglang/srt/layers/quantization/unquant.py +135 -47
  202. sglang/srt/layers/quantization/w4afp8.py +26 -17
  203. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  204. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  205. sglang/srt/layers/rocm_linear_utils.py +44 -0
  206. sglang/srt/layers/rotary_embedding.py +78 -49
  207. sglang/srt/layers/sampler.py +213 -21
  208. sglang/srt/layers/utils.py +23 -0
  209. sglang/srt/lora/backend/base_backend.py +50 -8
  210. sglang/srt/lora/backend/chunked_backend.py +348 -0
  211. sglang/srt/lora/backend/triton_backend.py +99 -5
  212. sglang/srt/lora/layers.py +32 -0
  213. sglang/srt/lora/lora.py +8 -3
  214. sglang/srt/lora/lora_manager.py +52 -118
  215. sglang/srt/lora/mem_pool.py +25 -11
  216. sglang/srt/lora/triton_ops/__init__.py +4 -0
  217. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  219. sglang/srt/lora/utils.py +22 -11
  220. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  221. sglang/srt/managers/cache_controller.py +215 -314
  222. sglang/srt/managers/data_parallel_controller.py +115 -80
  223. sglang/srt/managers/detokenizer_manager.py +19 -15
  224. sglang/srt/managers/disagg_service.py +46 -0
  225. sglang/srt/managers/io_struct.py +340 -109
  226. sglang/srt/managers/mm_utils.py +44 -6
  227. sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
  228. sglang/srt/managers/multimodal_processor.py +1 -2
  229. sglang/srt/managers/overlap_utils.py +53 -0
  230. sglang/srt/managers/schedule_batch.py +240 -138
  231. sglang/srt/managers/schedule_policy.py +147 -19
  232. sglang/srt/managers/scheduler.py +501 -304
  233. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  234. sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
  235. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  236. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  237. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  238. sglang/srt/managers/template_manager.py +3 -3
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +321 -632
  241. sglang/srt/managers/tp_worker.py +81 -22
  242. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  243. sglang/srt/managers/utils.py +1 -45
  244. sglang/srt/mem_cache/allocator.py +15 -21
  245. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  246. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  247. sglang/srt/mem_cache/chunk_cache.py +8 -1
  248. sglang/srt/mem_cache/evict_policy.py +23 -0
  249. sglang/srt/mem_cache/hicache_storage.py +58 -34
  250. sglang/srt/mem_cache/hiradix_cache.py +227 -80
  251. sglang/srt/mem_cache/memory_pool.py +535 -58
  252. sglang/srt/mem_cache/memory_pool_host.py +239 -223
  253. sglang/srt/mem_cache/radix_cache.py +222 -73
  254. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  255. sglang/srt/mem_cache/storage/__init__.py +10 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  257. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  258. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  259. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  260. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  261. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  262. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  263. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
  264. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  265. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  266. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
  267. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  268. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  269. sglang/srt/metrics/collector.py +519 -132
  270. sglang/srt/metrics/func_timer.py +2 -7
  271. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  272. sglang/srt/metrics/utils.py +55 -0
  273. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  274. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  275. sglang/srt/model_executor/forward_batch_info.py +98 -57
  276. sglang/srt/model_executor/model_runner.py +433 -158
  277. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  278. sglang/srt/model_loader/__init__.py +9 -3
  279. sglang/srt/model_loader/loader.py +133 -5
  280. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  281. sglang/srt/model_loader/weight_utils.py +158 -3
  282. sglang/srt/models/apertus.py +686 -0
  283. sglang/srt/models/bailing_moe.py +820 -217
  284. sglang/srt/models/bailing_moe_nextn.py +168 -0
  285. sglang/srt/models/deepseek_nextn.py +6 -1
  286. sglang/srt/models/deepseek_v2.py +833 -152
  287. sglang/srt/models/dots_ocr.py +173 -0
  288. sglang/srt/models/dots_vlm.py +174 -0
  289. sglang/srt/models/dots_vlm_vit.py +337 -0
  290. sglang/srt/models/ernie4.py +1 -1
  291. sglang/srt/models/falcon_h1.py +576 -0
  292. sglang/srt/models/gemma3_causal.py +0 -2
  293. sglang/srt/models/gemma3_mm.py +1 -1
  294. sglang/srt/models/gemma3n_mm.py +2 -2
  295. sglang/srt/models/glm4_moe.py +14 -5
  296. sglang/srt/models/glm4_moe_nextn.py +2 -2
  297. sglang/srt/models/glm4v.py +5 -3
  298. sglang/srt/models/glm4v_moe.py +4 -1
  299. sglang/srt/models/gpt_oss.py +8 -31
  300. sglang/srt/models/internvl.py +28 -0
  301. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  302. sglang/srt/models/llama.py +4 -0
  303. sglang/srt/models/llama4.py +9 -0
  304. sglang/srt/models/llama_eagle3.py +13 -0
  305. sglang/srt/models/longcat_flash.py +3 -3
  306. sglang/srt/models/longcat_flash_nextn.py +1 -1
  307. sglang/srt/models/minicpmv.py +165 -3
  308. sglang/srt/models/mllama4.py +40 -4
  309. sglang/srt/models/opt.py +637 -0
  310. sglang/srt/models/qwen2_5_vl.py +29 -5
  311. sglang/srt/models/qwen2_audio.py +1 -1
  312. sglang/srt/models/qwen2_moe.py +124 -14
  313. sglang/srt/models/qwen2_vl.py +1 -1
  314. sglang/srt/models/qwen3.py +26 -5
  315. sglang/srt/models/qwen3_moe.py +71 -12
  316. sglang/srt/models/qwen3_next.py +1069 -0
  317. sglang/srt/models/qwen3_next_mtp.py +112 -0
  318. sglang/srt/models/qwen3_vl.py +787 -0
  319. sglang/srt/models/qwen3_vl_moe.py +471 -0
  320. sglang/srt/models/registry.py +15 -3
  321. sglang/srt/models/sarashina2_vision.py +269 -0
  322. sglang/srt/models/solar.py +505 -0
  323. sglang/srt/models/starcoder2.py +357 -0
  324. sglang/srt/models/step3_vl.py +1 -1
  325. sglang/srt/models/torch_native_llama.py +10 -3
  326. sglang/srt/models/utils.py +51 -0
  327. sglang/srt/multimodal/processors/base_processor.py +15 -7
  328. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  329. sglang/srt/multimodal/processors/glm4v.py +9 -9
  330. sglang/srt/multimodal/processors/internvl.py +153 -129
  331. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  332. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  333. sglang/srt/offloader.py +27 -3
  334. sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
  335. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  336. sglang/srt/sampling/sampling_batch_info.py +38 -17
  337. sglang/srt/sampling/sampling_params.py +7 -0
  338. sglang/srt/server_args.py +1030 -254
  339. sglang/srt/server_args_config_parser.py +146 -0
  340. sglang/srt/single_batch_overlap.py +151 -0
  341. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  342. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  343. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  344. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  345. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  346. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  347. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  348. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  349. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  350. sglang/srt/speculative/eagle_worker.py +253 -136
  351. sglang/srt/speculative/ngram_utils.py +428 -0
  352. sglang/srt/speculative/ngram_worker.py +245 -0
  353. sglang/srt/speculative/spec_info.py +52 -0
  354. sglang/srt/speculative/spec_utils.py +606 -0
  355. sglang/srt/speculative/standalone_worker.py +109 -0
  356. sglang/srt/torch_memory_saver_adapter.py +5 -7
  357. sglang/srt/tracing/trace.py +578 -0
  358. sglang/srt/two_batch_overlap.py +8 -5
  359. sglang/srt/utils/__init__.py +2 -0
  360. sglang/srt/{utils.py → utils/common.py} +445 -77
  361. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  362. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  363. sglang/srt/utils/rpd_utils.py +452 -0
  364. sglang/srt/utils/slow_rank_detector.py +71 -0
  365. sglang/srt/warmup.py +8 -4
  366. sglang/srt/weight_sync/utils.py +2 -2
  367. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  368. sglang/test/few_shot_gsm8k.py +1 -0
  369. sglang/test/get_logits_ut.py +57 -0
  370. sglang/test/run_eval.py +79 -11
  371. sglang/test/runners.py +5 -1
  372. sglang/test/simple_eval_common.py +5 -2
  373. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  374. sglang/test/test_block_fp8.py +2 -2
  375. sglang/test/test_cutlass_moe.py +24 -6
  376. sglang/test/test_deterministic.py +297 -0
  377. sglang/test/test_disaggregation_utils.py +77 -0
  378. sglang/test/test_fp4_moe.py +370 -1
  379. sglang/test/test_programs.py +1 -1
  380. sglang/test/test_utils.py +383 -5
  381. sglang/utils.py +22 -1
  382. sglang/version.py +1 -1
  383. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  384. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
  385. sglang/srt/disaggregation/launch_lb.py +0 -118
  386. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  387. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  388. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  389. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  390. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  391. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  392. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  393. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  394. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  395. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,8 @@ from sglang.srt.distributed import (
24
24
  get_tensor_model_parallel_world_size,
25
25
  )
26
26
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
27
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
28
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
27
29
  from sglang.srt.layers.parameter import (
28
30
  ChannelQuantScaleParameter,
29
31
  ModelWeightParameter,
@@ -49,8 +51,10 @@ from sglang.srt.utils import (
49
51
  )
50
52
 
51
53
  if TYPE_CHECKING:
52
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
53
- from sglang.srt.layers.moe.topk import TopKOutput
54
+ from sglang.srt.layers.moe.token_dispatcher import (
55
+ CombineInput,
56
+ StandardDispatchOutput,
57
+ )
54
58
 
55
59
  _is_cuda = is_cuda()
56
60
  _is_cpu_amx_available = cpu_has_amx_support()
@@ -339,9 +343,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
339
343
  _is_cpu_amx_available
340
344
  ), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
341
345
  _amx_process_weight_after_loading(layer, ["weight"])
342
- return
343
-
344
- layer.weight = Parameter(layer.weight.t(), requires_grad=False)
346
+ else:
347
+ layer.weight = Parameter(layer.weight.t(), requires_grad=False)
345
348
  layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
346
349
 
347
350
  def create_weights(
@@ -390,13 +393,23 @@ class W8A8Int8LinearMethod(LinearMethodBase):
390
393
  x.dtype,
391
394
  True, # is_vnni
392
395
  )
393
-
394
396
  x_q, x_scale = per_token_quant_int8(x)
395
397
 
396
- return int8_scaled_mm(
397
- x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
398
+ x_q_2d = x_q.view(-1, x_q.shape[-1])
399
+ x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
400
+ output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
401
+
402
+ output = int8_scaled_mm(
403
+ x_q_2d,
404
+ layer.weight,
405
+ x_scale_2d,
406
+ layer.weight_scale,
407
+ out_dtype=x.dtype,
408
+ bias=bias,
398
409
  )
399
410
 
411
+ return output.view(output_shape)
412
+
400
413
 
401
414
  class W8A8Int8MoEMethod(FusedMoEMethodBase):
402
415
  """MoE method for INT8.
@@ -417,7 +430,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
417
430
  layer: torch.nn.Module,
418
431
  num_experts: int,
419
432
  hidden_size: int,
420
- intermediate_size: int,
433
+ intermediate_size_per_partition: int,
421
434
  params_dtype: torch.dtype,
422
435
  **extra_weight_attrs,
423
436
  ):
@@ -428,7 +441,10 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
428
441
  # WEIGHTS
429
442
  w13_weight = torch.nn.Parameter(
430
443
  torch.empty(
431
- num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
444
+ num_experts,
445
+ 2 * intermediate_size_per_partition,
446
+ hidden_size,
447
+ dtype=torch.int8,
432
448
  ),
433
449
  requires_grad=False,
434
450
  )
@@ -436,14 +452,21 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
436
452
  set_weight_attrs(w13_weight, extra_weight_attrs)
437
453
 
438
454
  w2_weight = torch.nn.Parameter(
439
- torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
455
+ torch.empty(
456
+ num_experts,
457
+ hidden_size,
458
+ intermediate_size_per_partition,
459
+ dtype=torch.int8,
460
+ ),
440
461
  requires_grad=False,
441
462
  )
442
463
  layer.register_parameter("w2_weight", w2_weight)
443
464
  set_weight_attrs(w2_weight, extra_weight_attrs)
444
465
 
445
466
  w13_weight_scale = torch.nn.Parameter(
446
- torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
467
+ torch.ones(
468
+ num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
469
+ ),
447
470
  requires_grad=False,
448
471
  )
449
472
  w2_weight_scale = torch.nn.Parameter(
@@ -472,10 +495,9 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
472
495
  _is_cpu_amx_available
473
496
  ), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
474
497
  _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
475
- return
476
-
477
- layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
478
- layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
498
+ else:
499
+ layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
500
+ layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
479
501
  layer.w13_weight_scale = Parameter(
480
502
  layer.w13_weight_scale.data, requires_grad=False
481
503
  )
@@ -483,23 +505,30 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
483
505
  layer.w2_weight_scale.data, requires_grad=False
484
506
  )
485
507
 
508
+ def create_moe_runner(
509
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
510
+ ):
511
+ self.moe_runner_config = moe_runner_config
512
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
513
+
486
514
  def apply(
487
515
  self,
488
516
  layer: torch.nn.Module,
489
- x: torch.Tensor,
490
- topk_output: TopKOutput,
491
- moe_runner_config: MoeRunnerConfig,
517
+ dispatch_output: StandardDispatchOutput,
492
518
  ) -> torch.Tensor:
493
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
519
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
520
+
521
+ x = dispatch_output.hidden_states
522
+ topk_output = dispatch_output.topk_output
494
523
 
495
524
  if use_intel_amx_backend(layer):
496
525
  from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
497
526
 
498
527
  topk_weights, topk_ids, _ = topk_output
499
528
  x, topk_weights = apply_topk_weights_cpu(
500
- moe_runner_config.apply_router_weight_on_input, topk_weights, x
529
+ self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
501
530
  )
502
- return torch.ops.sgl_kernel.fused_experts_cpu(
531
+ output = torch.ops.sgl_kernel.fused_experts_cpu(
503
532
  x,
504
533
  layer.w13_weight,
505
534
  layer.w2_weight,
@@ -515,20 +544,19 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
515
544
  layer.w2_input_scale, # a2_scale
516
545
  True, # is_vnni
517
546
  )
547
+ return StandardCombineInput(hidden_states=output)
518
548
 
519
- return fused_experts(
520
- x,
521
- layer.w13_weight,
522
- layer.w2_weight,
523
- topk_output=topk_output,
524
- moe_runner_config=moe_runner_config,
549
+ quant_info = TritonMoeQuantInfo(
550
+ w13_weight=layer.w13_weight,
551
+ w2_weight=layer.w2_weight,
525
552
  use_int8_w8a8=True,
526
553
  per_channel_quant=True,
527
- w1_scale=(layer.w13_weight_scale),
528
- w2_scale=(layer.w2_weight_scale),
529
- a1_scale=layer.w13_input_scale,
554
+ w13_scale=layer.w13_weight_scale,
555
+ w2_scale=layer.w2_weight_scale,
556
+ a13_scale=layer.w13_input_scale,
530
557
  a2_scale=layer.w2_input_scale,
531
558
  )
559
+ return self.runner.run(dispatch_output, quant_info)
532
560
 
533
561
 
534
562
  class NPU_W8A8LinearMethodImpl:
@@ -620,6 +648,7 @@ class NPU_W8A8LinearMethodImpl:
620
648
  layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
621
649
  layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
622
650
  layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
651
+ layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
623
652
 
624
653
 
625
654
  class NPU_W8A8LinearMethodMTImpl:
@@ -812,6 +841,7 @@ class NPU_W8A8DynamicLinearMethodImpl:
812
841
  layer.weight_scale.data = layer.weight_scale.data.flatten()
813
842
  layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
814
843
  layer.weight_offset.data = layer.weight_offset.data.flatten()
844
+ layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
815
845
 
816
846
 
817
847
  class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
@@ -900,7 +930,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
900
930
  layer: torch.nn.Module,
901
931
  num_experts: int,
902
932
  hidden_size: int,
903
- intermediate_size: int,
933
+ intermediate_size_per_partition: int,
904
934
  params_dtype: torch.dtype,
905
935
  **extra_weight_attrs,
906
936
  ) -> None:
@@ -914,21 +944,31 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
914
944
  # weight
915
945
  w13_weight = torch.nn.Parameter(
916
946
  torch.empty(
917
- num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
947
+ num_experts,
948
+ 2 * intermediate_size_per_partition,
949
+ hidden_size,
950
+ dtype=torch.int8,
918
951
  ),
919
952
  requires_grad=False,
920
953
  )
921
954
  layer.register_parameter("w13_weight", w13_weight)
922
955
  set_weight_attrs(w13_weight, extra_weight_attrs)
923
956
  w2_weight = torch.nn.Parameter(
924
- torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
957
+ torch.empty(
958
+ num_experts,
959
+ hidden_size,
960
+ intermediate_size_per_partition,
961
+ dtype=torch.int8,
962
+ ),
925
963
  requires_grad=False,
926
964
  )
927
965
  layer.register_parameter("w2_weight", w2_weight)
928
966
  set_weight_attrs(w2_weight, extra_weight_attrs)
929
967
  # scale
930
968
  w13_weight_scale = torch.nn.Parameter(
931
- torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
969
+ torch.empty(
970
+ num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
971
+ ),
932
972
  requires_grad=False,
933
973
  )
934
974
  layer.register_parameter("w13_weight_scale", w13_weight_scale)
@@ -941,7 +981,9 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
941
981
  set_weight_attrs(w2_weight_scale, extra_weight_attrs)
942
982
  # offset
943
983
  w13_weight_offset = torch.nn.Parameter(
944
- torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
984
+ torch.empty(
985
+ num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
986
+ ),
945
987
  requires_grad=False,
946
988
  )
947
989
  layer.register_parameter("w13_weight_offset", w13_weight_offset)
@@ -973,18 +1015,25 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
973
1015
  layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
974
1016
  )
975
1017
 
1018
+ def create_moe_runner(
1019
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1020
+ ):
1021
+ self.moe_runner_config = moe_runner_config
1022
+
976
1023
  def apply(
977
1024
  self,
978
1025
  layer,
979
- x,
980
- topk_output: TopKOutput,
981
- moe_runner_config: MoeRunnerConfig,
982
- ) -> torch.Tensor:
1026
+ dispatch_output: StandardDispatchOutput,
1027
+ ) -> CombineInput:
1028
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1029
+
1030
+ x = dispatch_output.hidden_states
1031
+ topk_output = dispatch_output.topk_output
983
1032
 
984
1033
  topk_weights, topk_ids, _ = topk_output
985
1034
  topk_ids = topk_ids.to(torch.int32)
986
1035
  topk_weights = topk_weights.to(x.dtype)
987
- return npu_fused_experts(
1036
+ output = npu_fused_experts(
988
1037
  hidden_states=x,
989
1038
  w13=layer.w13_weight,
990
1039
  w13_scale=layer.w13_weight_scale,
@@ -994,3 +1043,4 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
994
1043
  topk_ids=topk_ids,
995
1044
  top_k=topk_ids.shape[1],
996
1045
  )
1046
+ return StandardCombineInput(hidden_states=output)
@@ -0,0 +1,44 @@
1
+ import torch
2
+ from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat
3
+ from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
4
+ from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic
5
+
6
+ from sglang.srt.utils import BumpAllocator
7
+
8
+ __all__ = ["fused_qk_rope_cat"]
9
+
10
+
11
+ def aiter_dsv3_router_gemm(
12
+ hidden_states: torch.Tensor,
13
+ weight: torch.Tensor,
14
+ gemm_output_zero_allocator: BumpAllocator = None,
15
+ ):
16
+ M = hidden_states.shape[0]
17
+ N = weight.shape[0]
18
+ y = None
19
+
20
+ if M <= 256:
21
+ # TODO (cagri): convert to bfloat16 as part of another kernel to save time
22
+ # for now it is also coupled with zero allocator.
23
+ if gemm_output_zero_allocator != None:
24
+ y = gemm_output_zero_allocator.allocate(M * N).view(M, N)
25
+ else:
26
+ y = torch.zeros((M, N), dtype=torch.float32, device=hidden_states.device)
27
+
28
+ if y is not None:
29
+ logits = gemm_a16w16_atomic(hidden_states, weight, y=y).to(hidden_states.dtype)
30
+ else:
31
+ logits = gemm_a16w16(hidden_states, weight)
32
+
33
+ return logits
34
+
35
+
36
+ def get_dsv3_gemm_output_zero_allocator_size(
37
+ n_routed_experts: int, num_moe_layers: int, allocate_size: int, embedding_dim: int
38
+ ):
39
+ if embedding_dim != 7168 or n_routed_experts != 256:
40
+ return 0
41
+
42
+ per_layer_size = 256 * (allocate_size + n_routed_experts)
43
+
44
+ return num_moe_layers * per_layer_size
@@ -12,6 +12,7 @@ from sglang.srt.custom_op import CustomOp
12
12
  from sglang.srt.utils import (
13
13
  cpu_has_amx_support,
14
14
  get_bool_env_var,
15
+ get_compiler_backend,
15
16
  is_cpu,
16
17
  is_cuda,
17
18
  is_hip,
@@ -26,13 +27,19 @@ _is_cpu_amx_available = cpu_has_amx_support()
26
27
  _is_cpu = is_cpu()
27
28
 
28
29
  if _is_cuda:
29
- from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
30
+ from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
31
+ else:
32
+ FusedSetKVBufferArg = None
33
+
30
34
  if _use_aiter:
31
35
  from aiter.rotary_embedding import get_rope as aiter_get_rope
32
36
 
33
37
  if is_npu():
34
38
  import torch_npu
35
39
 
40
+ NPU_ROTARY_MUL_MAX_NUM_HEADS = 1000
41
+ NPU_ROTARY_MUL_MAX_HEAD_SIZE = 896
42
+
36
43
 
37
44
  def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
38
45
  x1 = x[..., : x.shape[-1] // 2]
@@ -142,8 +149,13 @@ class RotaryEmbedding(CustomOp):
142
149
  query: torch.Tensor,
143
150
  key: torch.Tensor,
144
151
  offsets: Optional[torch.Tensor] = None,
152
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
145
153
  ) -> Tuple[torch.Tensor, torch.Tensor]:
146
154
  """A PyTorch-native implementation of forward()."""
155
+ assert (
156
+ fused_set_kv_buffer_arg is None
157
+ ), "fused_set_kv_buffer_arg is not supported for native implementation"
158
+
147
159
  if offsets is not None:
148
160
  positions = positions + offsets
149
161
  positions = positions.flatten()
@@ -172,12 +184,17 @@ class RotaryEmbedding(CustomOp):
172
184
  query: torch.Tensor,
173
185
  key: torch.Tensor,
174
186
  offsets: Optional[torch.Tensor] = None,
187
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
175
188
  ) -> Tuple[torch.Tensor, torch.Tensor]:
176
189
  """A PyTorch-npu implementation of forward()."""
177
- import os
190
+ assert (
191
+ fused_set_kv_buffer_arg is None
192
+ ), "fused_set_kv_buffer_arg is not supported for npu implementation"
178
193
 
179
194
  if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
180
- return self.forward_native(positions, query, key, offsets)
195
+ return self.forward_native(
196
+ positions, query, key, offsets, fused_set_kv_buffer_arg
197
+ )
181
198
  else:
182
199
  rotary_mode = "half"
183
200
  if self.is_neox_style:
@@ -202,7 +219,12 @@ class RotaryEmbedding(CustomOp):
202
219
  query: torch.Tensor,
203
220
  key: torch.Tensor,
204
221
  offsets: Optional[torch.Tensor] = None,
222
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
205
223
  ) -> Tuple[torch.Tensor, torch.Tensor]:
224
+ assert (
225
+ fused_set_kv_buffer_arg is None
226
+ ), "fused_set_kv_buffer_arg is not supported for cpu implementation"
227
+
206
228
  positions = torch.add(positions, offsets) if offsets is not None else positions
207
229
  if _is_cpu_amx_available:
208
230
  return torch.ops.sgl_kernel.rotary_embedding_cpu(
@@ -214,7 +236,9 @@ class RotaryEmbedding(CustomOp):
214
236
  self.is_neox_style,
215
237
  )
216
238
  else:
217
- return self.forward_native(positions, query, key, offsets)
239
+ return self.forward_native(
240
+ positions, query, key, offsets, fused_set_kv_buffer_arg
241
+ )
218
242
 
219
243
  def forward_cuda(
220
244
  self,
@@ -222,7 +246,7 @@ class RotaryEmbedding(CustomOp):
222
246
  query: torch.Tensor,
223
247
  key: torch.Tensor,
224
248
  offsets: Optional[torch.Tensor] = None,
225
- fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg]
249
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
226
250
  ) -> Tuple[torch.Tensor, torch.Tensor]:
227
251
  if _is_cuda and (self.head_size in [64, 128, 256, 512]):
228
252
  apply_rope_with_cos_sin_cache_inplace(
@@ -782,27 +806,33 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
782
806
  key: torch.Tensor,
783
807
  offsets: Optional[torch.Tensor] = None,
784
808
  ) -> Tuple[torch.Tensor, torch.Tensor]:
785
- # NOTE: now npu_mrope can only support `numQHeads*headSize <= 4096` pattern,
786
- # and generalization to more scenarios will be supported in the future.
787
- if query.shape[1] * query.shape[2] > 4096:
788
- return self.forward_native(positions, query, key, offsets)
789
- num_tokens = query.shape[0]
790
- rotary_mode = "half" if self.is_neox_style else "interleave"
809
+ num_tokens, num_q_heads, _ = query.shape
810
+ num_k_heads = key.shape[1]
811
+
791
812
  self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
813
+ cos_sin = self.cos_sin_cache[
814
+ torch.add(positions, offsets) if offsets is not None else positions
815
+ ]
816
+ cos, sin = cos_sin.chunk(2, dim=-1)
817
+ # Reshape to [batchsize, head_dim, seq, rotary_dim]
818
+ cos = cos.repeat(1, 2).unsqueeze(-2).unsqueeze(-2)
819
+ sin = sin.repeat(1, 2).unsqueeze(-2).unsqueeze(-2)
820
+
792
821
  query_rot = query[..., : self.rotary_dim]
793
822
  key_rot = key[..., : self.rotary_dim]
794
823
  if self.rotary_dim < self.head_size:
795
824
  query_pass = query[..., self.rotary_dim :]
796
825
  key_pass = key[..., self.rotary_dim :]
797
826
 
798
- query_rot, key_rot = torch_npu.npu_mrope(
799
- torch.add(positions, offsets) if offsets is not None else positions,
800
- query_rot.reshape(num_tokens, -1),
801
- key_rot.reshape(num_tokens, -1),
802
- self.cos_sin_cache,
803
- self.rotary_dim,
804
- mrope_section=[0, 0, 0],
805
- rotary_mode=rotary_mode,
827
+ query_rot = torch_npu.npu_interleave_rope(
828
+ query_rot.reshape(num_tokens, num_q_heads, 1, self.rotary_dim),
829
+ cos,
830
+ sin,
831
+ )
832
+ key_rot = torch_npu.npu_interleave_rope(
833
+ key_rot.reshape(num_tokens, num_k_heads, 1, self.rotary_dim),
834
+ cos,
835
+ sin,
806
836
  )
807
837
  query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim)
808
838
  key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim)
@@ -1029,12 +1059,13 @@ class MRotaryEmbedding(RotaryEmbedding):
1029
1059
  f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
1030
1060
  )
1031
1061
 
1032
- @torch.compile(dynamic=True)
1062
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
1033
1063
  def forward(
1034
1064
  self,
1035
1065
  positions: torch.Tensor,
1036
1066
  query: torch.Tensor,
1037
1067
  key: torch.Tensor,
1068
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
1038
1069
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1039
1070
  """PyTorch-native implementation equivalent to forward().
1040
1071
 
@@ -1045,6 +1076,9 @@ class MRotaryEmbedding(RotaryEmbedding):
1045
1076
  query: [num_tokens, num_heads * head_size]
1046
1077
  key: [num_tokens, num_kv_heads * head_size]
1047
1078
  """
1079
+ assert (
1080
+ fused_set_kv_buffer_arg is None
1081
+ ), "save kv cache is not supported for MRotaryEmbedding."
1048
1082
  assert positions.ndim == 1 or positions.ndim == 2
1049
1083
 
1050
1084
  num_tokens = positions.shape[-1]
@@ -1177,7 +1211,7 @@ class MRotaryEmbedding(RotaryEmbedding):
1177
1211
 
1178
1212
  time_tensor_long = time_tensor.long()
1179
1213
  t_index = time_tensor_long.flatten()
1180
- elif model_type == "qwen2_vl":
1214
+ elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"):
1181
1215
  t_index = (
1182
1216
  torch.arange(llm_grid_t)
1183
1217
  .view(-1, 1)
@@ -1433,24 +1467,6 @@ class MRotaryEmbedding(RotaryEmbedding):
1433
1467
 
1434
1468
  return position_ids, mrope_position_deltas
1435
1469
 
1436
- @staticmethod
1437
- def get_next_input_positions(
1438
- mrope_position_delta: int,
1439
- context_len: int,
1440
- seq_len: int,
1441
- ) -> torch.Tensor:
1442
- return torch.tensor(
1443
- [
1444
- list(
1445
- range(
1446
- context_len + mrope_position_delta,
1447
- seq_len + mrope_position_delta,
1448
- )
1449
- )
1450
- for _ in range(3)
1451
- ]
1452
- )
1453
-
1454
1470
 
1455
1471
  class DualChunkRotaryEmbedding(CustomOp):
1456
1472
  """Rotary positional embedding for Dual Chunk Attention."""
@@ -1906,17 +1922,30 @@ def apply_rotary_pos_emb_npu(
1906
1922
  sin: torch.Tensor,
1907
1923
  unsqueeze_dim=1,
1908
1924
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1909
- if q.shape[1] != 128:
1925
+ """Ascend implementation equivalent to apply_rotary_pos_emb_native.
1926
+
1927
+ Args:
1928
+ q: [num_tokens, num_heads, head_size]
1929
+ k: [num_tokens, num_kv_heads, head_size]
1930
+ cos: [num_tokens, head_size]
1931
+ sin: [num_tokens, head_size]
1932
+ """
1933
+ if (
1934
+ cos.dim() != 2
1935
+ or q.dim() != 3
1936
+ or q.shape[1] >= NPU_ROTARY_MUL_MAX_NUM_HEADS
1937
+ or q.shape[2] >= NPU_ROTARY_MUL_MAX_HEAD_SIZE
1938
+ ):
1939
+ # Note: num_heads and head_size of q must be less than 1000 and 896, respectively
1910
1940
  return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
1911
- cos = cos.unsqueeze(unsqueeze_dim)
1912
- cos = torch.transpose(cos, 1, 2)
1913
- sin = sin.unsqueeze(unsqueeze_dim)
1914
- sin = torch.transpose(sin, 1, 2)
1915
- q = torch.transpose(q, 1, 2)
1916
- k = torch.transpose(k, 1, 2)
1917
- q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
1918
- q_embed = torch.transpose(q_embed, 1, 2)
1919
- k_embed = torch.transpose(k_embed, 1, 2)
1941
+ cos = cos.unsqueeze(unsqueeze_dim).unsqueeze(0)
1942
+ sin = sin.unsqueeze(unsqueeze_dim).unsqueeze(0)
1943
+ q = q.unsqueeze(0)
1944
+ k = k.unsqueeze(0)
1945
+ q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
1946
+ k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
1947
+ q_embed = q_embed.squeeze(0)
1948
+ k_embed = k_embed.squeeze(0)
1920
1949
  return q_embed, k_embed
1921
1950
 
1922
1951