sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import importlib.util
4
3
  from typing import TYPE_CHECKING, List, Optional
5
4
 
6
5
  import torch
@@ -31,8 +30,6 @@ if TYPE_CHECKING:
31
30
  StandardDispatchOutput,
32
31
  )
33
32
 
34
- has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
35
-
36
33
 
37
34
  _is_cpu_amx_available = cpu_has_amx_support()
38
35
  _is_hip = is_hip()
@@ -118,13 +115,15 @@ class UnquantizedLinearMethod(LinearMethodBase):
118
115
  x: torch.Tensor,
119
116
  bias: Optional[torch.Tensor] = None,
120
117
  ) -> torch.Tensor:
121
-
122
118
  if use_intel_amx_backend(layer):
123
119
  x_shapes = x.shape
124
120
  if len(x_shapes) == 3:
125
121
  x = x.view(-1, x.shape[-1])
126
122
  output = torch.ops.sgl_kernel.weight_packed_linear(
127
- x, layer.weight, bias, True # is_vnni
123
+ x,
124
+ layer.weight,
125
+ bias,
126
+ True, # is_vnni
128
127
  )
129
128
  if len(x_shapes) == 3:
130
129
  output = output.view(x_shapes[0], x_shapes[1], -1)
@@ -141,19 +140,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
141
140
  self.use_triton_kernels = use_triton_kernels
142
141
  self.with_bias = False
143
142
 
144
- self.triton_kernel_moe_forward = None
145
- self.triton_kernel_moe_with_bias_forward = None
146
- if torch.cuda.is_available() and has_triton_kernels:
147
- from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
148
- triton_kernel_moe_forward as _tk_forward,
149
- )
150
- from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
151
- triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
152
- )
153
-
154
- self.triton_kernel_moe_forward = _tk_forward
155
- self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
156
-
157
143
  def create_weights(
158
144
  self,
159
145
  layer: torch.nn.Module,
@@ -234,14 +220,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
234
220
  self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
235
221
  ):
236
222
  self.moe_runner_config = moe_runner_config
237
- self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
223
+ backend = (
224
+ MoeRunnerBackend.TRITON_KERNELS
225
+ if self.use_triton_kernels
226
+ else MoeRunnerBackend.TRITON
227
+ )
228
+ self.runner = MoeRunner(backend, moe_runner_config)
238
229
 
239
230
  def apply(
240
231
  self,
241
232
  layer: torch.nn.Module,
242
233
  dispatch_output: StandardDispatchOutput,
243
234
  ) -> CombineInput:
244
-
245
235
  return self.forward(
246
236
  layer=layer,
247
237
  dispatch_output=dispatch_output,
@@ -252,7 +242,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
252
242
  layer: torch.nn.Module,
253
243
  dispatch_output: StandardDispatchOutput,
254
244
  ) -> CombineInput:
255
-
256
245
  from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
257
246
 
258
247
  x = dispatch_output.hidden_states
@@ -260,30 +249,19 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
260
249
 
261
250
  moe_runner_config = self.moe_runner_config
262
251
 
263
- if self.use_triton_kernels:
264
- if self.with_bias:
265
- assert self.triton_kernel_moe_with_bias_forward is not None
266
- output = self.triton_kernel_moe_with_bias_forward(
267
- hidden_states=x,
268
- w1=layer.w13_weight,
269
- w2=layer.w2_weight,
270
- b1=layer.w13_weight_bias,
271
- b2=layer.w2_weight_bias,
272
- topk_output=topk_output,
273
- moe_runner_config=moe_runner_config,
274
- w1_pcg=None,
275
- w2_pcg=None,
276
- )
277
- else:
278
- assert self.triton_kernel_moe_forward is not None
279
- output = self.triton_kernel_moe_forward(
280
- hidden_states=x,
281
- w1=layer.w13_weight,
282
- w2=layer.w2_weight,
283
- topk_output=topk_output,
284
- moe_runner_config=moe_runner_config,
285
- )
286
- return StandardCombineInput(hidden_states=output)
252
+ backend = self.runner.runner_backend
253
+ if backend.is_triton_kernels():
254
+ from sglang.srt.layers.moe.moe_runner.triton_kernels import (
255
+ TritonKernelsQuantInfo,
256
+ )
257
+
258
+ quant_info = TritonKernelsQuantInfo(
259
+ w13_weight=layer.w13_weight,
260
+ w2_weight=layer.w2_weight,
261
+ w13_bias=getattr(layer, "w13_weight_bias", None),
262
+ w2_bias=getattr(layer, "w2_weight_bias", None),
263
+ )
264
+ return self.runner.run(dispatch_output, quant_info)
287
265
  else:
288
266
  if _use_aiter:
289
267
  assert not moe_runner_config.no_combine, "unsupported"
@@ -314,7 +292,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
314
292
  )
315
293
  return StandardCombineInput(hidden_states=output)
316
294
  else:
317
-
318
295
  quant_info = TritonMoeQuantInfo(
319
296
  w13_weight=layer.w13_weight,
320
297
  w2_weight=layer.w2_weight,
@@ -328,7 +305,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
328
305
  layer: torch.nn.Module,
329
306
  dispatch_output: StandardDispatchOutput,
330
307
  ) -> CombineInput:
331
-
332
308
  from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
333
309
 
334
310
  x = dispatch_output.hidden_states
@@ -383,7 +359,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
383
359
  layer: torch.nn.Module,
384
360
  dispatch_output: StandardDispatchOutput,
385
361
  ) -> CombineInput:
386
-
387
362
  import torch_npu
388
363
 
389
364
  from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
@@ -11,7 +11,6 @@ import numpy
11
11
  import torch
12
12
 
13
13
  from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
14
- from sglang.srt.utils import is_cuda
15
14
 
16
15
  if TYPE_CHECKING:
17
16
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -1,14 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
5
5
 
6
6
  import torch
7
7
  from torch.nn import Module
8
8
  from torch.nn.parameter import Parameter
9
9
 
10
- from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
11
- from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
10
+ from sglang.srt.layers.linear import UnquantizedLinearMethod
12
11
  from sglang.srt.layers.quantization.base_config import (
13
12
  FusedMoEMethodBase,
14
13
  QuantizationConfig,
@@ -17,13 +16,15 @@ from sglang.srt.layers.quantization.base_config import (
17
16
  from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
18
17
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
19
18
  from sglang.srt.layers.quantization.utils import is_layer_skipped
20
- from sglang.srt.utils import is_npu, set_weight_attrs
19
+ from sglang.srt.utils import set_weight_attrs
21
20
 
22
21
  if TYPE_CHECKING:
23
22
  from sglang.srt.layers.moe import MoeRunnerConfig
24
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
23
+ from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE
25
24
  from sglang.srt.layers.moe.token_dispatcher import (
26
25
  CombineInput,
26
+ DeepEPLLDispatchOutput,
27
+ DeepEPNormalDispatchOutput,
27
28
  StandardDispatchOutput,
28
29
  )
29
30
 
@@ -94,9 +95,7 @@ class W4AFp8Config(QuantizationConfig):
94
95
  self, layer: torch.nn.Module, prefix: str
95
96
  ) -> Optional[QuantizeMethodBase]:
96
97
  from sglang.srt.layers.linear import LinearBase
97
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
98
98
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
99
- from sglang.srt.managers.schedule_batch import global_server_args_dict
100
99
 
101
100
  if isinstance(layer, LinearBase):
102
101
  if is_layer_skipped(prefix, self.ignored_layers):
@@ -133,7 +132,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
133
132
 
134
133
  def create_weights(
135
134
  self,
136
- layer: EPMoE,
135
+ layer: Module,
137
136
  num_experts: int,
138
137
  hidden_size: int,
139
138
  intermediate_size_per_partition: int,
@@ -292,7 +291,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
292
291
 
293
292
  def apply(
294
293
  self,
295
- layer: EPMoE,
294
+ layer: Module,
296
295
  dispatch_output: StandardDispatchOutput,
297
296
  ) -> CombineInput:
298
297
 
@@ -303,18 +302,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
303
302
  topk_output = dispatch_output.topk_output
304
303
 
305
304
  topk_weights, topk_ids, _ = topk_output
306
- local_topk_ids = topk_ids
307
- if get_moe_expert_parallel_world_size() > 1:
308
- local_topk_ids = torch.where(
309
- topk_ids == -1,
310
- layer.num_experts,
311
- topk_ids,
312
- )
313
305
 
314
306
  output = cutlass_w4a8_moe(
315
- layer.start_expert_id,
316
- layer.end_expert_id,
317
- layer.num_experts,
318
307
  x,
319
308
  layer.w13_weight,
320
309
  layer.w2_weight,
@@ -322,7 +311,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
322
311
  layer.w2_weight_scale_inv,
323
312
  topk_weights,
324
313
  topk_ids,
325
- local_topk_ids,
326
314
  self.a_strides1,
327
315
  self.b_strides1,
328
316
  self.c_strides1,
@@ -340,3 +328,82 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
340
328
  if self.moe_runner_config.routed_scaling_factor is not None:
341
329
  output *= self.moe_runner_config.routed_scaling_factor
342
330
  return StandardCombineInput(hidden_states=output)
331
+
332
+ def apply_deepep_ll(
333
+ self,
334
+ layer: DeepEPMoE,
335
+ dispatch_output: DeepEPLLDispatchOutput,
336
+ ) -> torch.Tensor:
337
+
338
+ from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe_deepep_ll
339
+
340
+ hidden_states, _, topk_ids, _, masked_m, _ = dispatch_output
341
+
342
+ output = cutlass_w4a8_moe_deepep_ll(
343
+ hidden_states,
344
+ layer.w13_weight,
345
+ layer.w2_weight,
346
+ layer.w13_weight_scale_inv,
347
+ layer.w2_weight_scale_inv,
348
+ topk_ids,
349
+ masked_m,
350
+ layer.quant_method.a_strides1,
351
+ layer.quant_method.b_strides1,
352
+ layer.quant_method.c_strides1,
353
+ layer.quant_method.a_strides2,
354
+ layer.quant_method.b_strides2,
355
+ layer.quant_method.c_strides2,
356
+ layer.quant_method.s_strides13,
357
+ layer.quant_method.s_strides2,
358
+ layer.quant_method.expert_offsets,
359
+ layer.quant_method.problem_sizes1,
360
+ layer.quant_method.problem_sizes2,
361
+ layer.w13_input_scale,
362
+ layer.w2_input_scale,
363
+ )
364
+
365
+ return output
366
+
367
+ def apply_deepep_normal(
368
+ self,
369
+ layer: DeepEPMoE,
370
+ dispatch_output: DeepEPNormalDispatchOutput,
371
+ ) -> torch.Tensor:
372
+ from sglang.srt.layers.moe.cutlass_w4a8_moe import (
373
+ cutlass_w4a8_moe_deepep_normal,
374
+ )
375
+
376
+ hidden_states, topk_idx, topk_weights = (
377
+ dispatch_output.hidden_states,
378
+ dispatch_output.topk_ids,
379
+ dispatch_output.topk_weights,
380
+ )
381
+ if isinstance(hidden_states, tuple):
382
+ hidden_states = hidden_states[0]
383
+
384
+ num_tokens = hidden_states.shape[0]
385
+ if num_tokens > 0:
386
+ return cutlass_w4a8_moe_deepep_normal(
387
+ hidden_states,
388
+ layer.w13_weight,
389
+ layer.w2_weight,
390
+ layer.w13_weight_scale_inv,
391
+ layer.w2_weight_scale_inv,
392
+ topk_weights,
393
+ topk_idx,
394
+ self.a_strides1,
395
+ self.b_strides1,
396
+ self.c_strides1,
397
+ self.a_strides2,
398
+ self.b_strides2,
399
+ self.c_strides2,
400
+ self.s_strides13,
401
+ self.s_strides2,
402
+ self.expert_offsets,
403
+ self.problem_sizes1,
404
+ self.problem_sizes2,
405
+ layer.w13_input_scale,
406
+ layer.w2_input_scale,
407
+ )
408
+ else:
409
+ return hidden_states
@@ -1,28 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
- import importlib
4
- import sys
5
3
  from types import MappingProxyType
6
- from typing import (
7
- TYPE_CHECKING,
8
- Any,
9
- Callable,
10
- Dict,
11
- List,
12
- Mapping,
13
- Optional,
14
- Tuple,
15
- Union,
16
- cast,
17
- )
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
18
5
 
19
6
  import torch
20
7
  from torch.nn.parameter import Parameter
21
8
 
22
- from sglang.srt.distributed import (
23
- get_tensor_model_parallel_rank,
24
- get_tensor_model_parallel_world_size,
25
- )
9
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
26
10
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
27
11
  from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
28
12
  from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
@@ -118,7 +102,12 @@ def npu_fused_experts(
118
102
  topk_weights: torch.Tensor,
119
103
  topk_ids: torch.Tensor,
120
104
  top_k: int,
105
+ **kwargs,
121
106
  ):
107
+ w13_offset = kwargs.get("w13_offset", None)
108
+ w2_offset = kwargs.get("w2_offset", None)
109
+ use_wna16 = kwargs.get("use_wna16", False)
110
+
122
111
  original_shape = hidden_states.shape
123
112
  original_dtype = hidden_states.dtype
124
113
  scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32
@@ -143,12 +132,22 @@ def npu_fused_experts(
143
132
  )
144
133
  expert_tokens = expert_tokens.to(torch.int64)
145
134
  # gmm1: gate_up_proj
146
- hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
135
+ if not use_wna16:
136
+ hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
137
+ scale_args13 = {
138
+ "scale": [w13_scale.to(scale_dtype)],
139
+ "per_token_scale": [pertoken_scale],
140
+ }
141
+ else:
142
+ scale_args13 = {
143
+ "antiquant_scale": [w13_scale],
144
+ "antiquant_offset": [w13_offset],
145
+ }
146
+
147
147
  hidden_states = torch_npu.npu_grouped_matmul(
148
148
  x=[hidden_states],
149
149
  weight=[w13],
150
- scale=[w13_scale.to(scale_dtype)],
151
- per_token_scale=[pertoken_scale],
150
+ **scale_args13,
152
151
  split_item=2,
153
152
  group_list_type=0,
154
153
  group_type=0,
@@ -157,13 +156,20 @@ def npu_fused_experts(
157
156
  )[0]
158
157
  # act_fn: swiglu
159
158
  hidden_states = torch_npu.npu_swiglu(hidden_states)
160
- hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
159
+ if not use_wna16:
160
+ hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
161
+
162
+ scale_args2 = {
163
+ "scale": [w2_scale.to(scale_dtype)],
164
+ "per_token_scale": [pertoken_scale],
165
+ }
166
+ else:
167
+ scale_args2 = {"antiquant_scale": [w2_scale], "antiquant_offset": [w2_offset]}
161
168
  # gmm2: down_proj
162
169
  hidden_states = torch_npu.npu_grouped_matmul(
163
170
  x=[hidden_states],
164
171
  weight=[w2],
165
- scale=[w2_scale.to(scale_dtype)],
166
- per_token_scale=[pertoken_scale],
172
+ **scale_args2,
167
173
  split_item=2,
168
174
  group_list_type=0,
169
175
  group_type=0,
@@ -17,8 +17,12 @@ from __future__ import annotations
17
17
  from enum import Enum
18
18
  from typing import TYPE_CHECKING, Optional
19
19
 
20
+ import torch
20
21
  from torch import nn
21
22
 
23
+ from sglang.srt.compilation.piecewise_context_manager import get_forward_context
24
+ from sglang.srt.utils import direct_register_custom_op
25
+
22
26
  if TYPE_CHECKING:
23
27
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
24
28
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -105,12 +109,61 @@ class RadixAttention(nn.Module):
105
109
  else:
106
110
  k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
107
111
 
108
- return forward_batch.attn_backend.forward(
109
- q,
110
- k,
111
- v,
112
- self,
113
- forward_batch,
114
- save_kv_cache,
115
- **kwargs,
116
- )
112
+ if forward_batch.forward_mode.is_extend() and get_forward_context() is not None:
113
+ output = torch.empty_like(q)
114
+ torch.ops.sglang.unified_attention_with_output(
115
+ q, k, v, output, save_kv_cache, self.layer_id
116
+ )
117
+ return output
118
+ else:
119
+ return forward_batch.attn_backend.forward(
120
+ q,
121
+ k,
122
+ v,
123
+ self,
124
+ forward_batch,
125
+ save_kv_cache,
126
+ **kwargs,
127
+ )
128
+
129
+
130
+ def unified_attention_with_output(
131
+ query: torch.Tensor,
132
+ key: torch.Tensor,
133
+ value: torch.Tensor,
134
+ output: torch.Tensor,
135
+ save_kv_cache: bool,
136
+ layer_id: int,
137
+ ) -> None:
138
+ context = get_forward_context()
139
+ forward_batch = context.forward_batch
140
+ attention_layers = context.attention_layers
141
+ attention_layer = attention_layers[layer_id]
142
+ ret = forward_batch.attn_backend.forward(
143
+ query, key, value, attention_layer, forward_batch, save_kv_cache
144
+ )
145
+ assert (
146
+ output.numel() == ret.numel()
147
+ ), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}"
148
+
149
+ output.view(ret.shape).copy_(ret)
150
+ return
151
+
152
+
153
+ def unified_attention_with_output_fake(
154
+ query: torch.Tensor,
155
+ key: torch.Tensor,
156
+ value: torch.Tensor,
157
+ output: torch.Tensor,
158
+ save_kv_cache: bool,
159
+ layer_id: int,
160
+ ) -> None:
161
+ return
162
+
163
+
164
+ direct_register_custom_op(
165
+ op_name="unified_attention_with_output",
166
+ op_func=unified_attention_with_output,
167
+ mutates_args=["output"],
168
+ fake_impl=unified_attention_with_output_fake,
169
+ )