sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,186 @@
1
+ import copy
2
+ from typing import Iterable, List, Optional, Set, Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from sglang.srt.configs.points_v15_chat import POINTSV15ChatConfig
9
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
10
+ from sglang.srt.managers.mm_utils import (
11
+ MultiModalityDataPaddingPatternMultimodalTokens,
12
+ general_mm_embed_routine,
13
+ )
14
+ from sglang.srt.managers.schedule_batch import (
15
+ Modality,
16
+ MultimodalDataItem,
17
+ MultimodalInputs,
18
+ )
19
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
20
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
21
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
22
+ from sglang.srt.models.qwen2_vl import Qwen2VisionPatchMerger, Qwen2VisionTransformer
23
+ from sglang.srt.utils import add_prefix
24
+
25
+
26
+ class Qwen2VisionTransformerForNavitPOINTS(Qwen2VisionTransformer):
27
+ def __init__(
28
+ self,
29
+ vision_config: POINTSV15ChatConfig,
30
+ norm_eps: float = 1e-6,
31
+ quant_config: Optional[QuantizationConfig] = None,
32
+ prefix: str = "",
33
+ ) -> None:
34
+ super().__init__(
35
+ vision_config,
36
+ norm_eps=norm_eps,
37
+ quant_config=quant_config,
38
+ prefix=prefix,
39
+ )
40
+
41
+ def forward(
42
+ self,
43
+ x: torch.Tensor,
44
+ grid_thw: torch.Tensor,
45
+ ) -> torch.Tensor:
46
+ # patchify
47
+ x = x.to(device=self.device, dtype=self.dtype)
48
+ x = self.patch_embed(x)
49
+
50
+ # compute position embedding
51
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
52
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
53
+ position_embeddings = (emb.cos(), emb.sin())
54
+
55
+ # compute cu_seqlens
56
+ cu_seqlens = torch.repeat_interleave(
57
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
58
+ ).cumsum(dim=0, dtype=torch.int32)
59
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
60
+
61
+ # transformers
62
+ x = x.unsqueeze(1)
63
+ for blk in self.blocks:
64
+ x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
65
+
66
+ return x
67
+
68
+
69
+ class POINTSV15ChatModel(nn.Module):
70
+ def __init__(
71
+ self,
72
+ config: POINTSV15ChatConfig,
73
+ quant_config: Optional[QuantizationConfig] = None,
74
+ prefix: str = "",
75
+ **kwargs,
76
+ ) -> None:
77
+ super().__init__()
78
+ config.llm_config._attn_implementation = "flash_attention_2"
79
+ config._attn_implementation_autoset = False
80
+ self.config = config
81
+ self.quant_config = quant_config
82
+
83
+ llm_config = copy.deepcopy(config.llm_config)
84
+ llm_config.architectures = ["Qwen2ForCausalLM"]
85
+ self.llm = Qwen2ForCausalLM(
86
+ config=llm_config,
87
+ quant_config=quant_config,
88
+ prefix=add_prefix("llm", prefix),
89
+ )
90
+
91
+ self.vision_encoder = Qwen2VisionTransformerForNavitPOINTS(
92
+ config.vision_config,
93
+ quant_config=quant_config,
94
+ prefix=add_prefix("vision_encoder", prefix),
95
+ )
96
+
97
+ self.vision_projector = Qwen2VisionPatchMerger(
98
+ d_model=config.llm_config.hidden_size,
99
+ context_dim=1280,
100
+ quant_config=quant_config,
101
+ prefix=add_prefix("vision_projector", prefix),
102
+ )
103
+
104
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
105
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
106
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
107
+
108
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
109
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
110
+ self.vision_encoder.dtype
111
+ )
112
+ image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
113
+
114
+ assert pixel_values.dim() == 2, pixel_values.dim()
115
+ assert image_grid_thw.dim() == 2, image_grid_thw.dim()
116
+
117
+ image_features = self.vision_encoder(pixel_values, grid_thw=image_grid_thw)
118
+ image_features = self.vision_projector(image_features)
119
+ return image_features
120
+
121
+ def forward(
122
+ self,
123
+ input_ids: torch.Tensor,
124
+ positions: torch.Tensor,
125
+ forward_batch: ForwardBatch,
126
+ get_embedding: bool = False,
127
+ ):
128
+ hidden_states = general_mm_embed_routine(
129
+ input_ids=input_ids,
130
+ forward_batch=forward_batch,
131
+ language_model=self.llm,
132
+ data_embedding_funcs={
133
+ Modality.IMAGE: self.get_image_feature,
134
+ },
135
+ positions=positions,
136
+ )
137
+
138
+ return hidden_states
139
+
140
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
141
+ stacked_params_mapping = [
142
+ # (param_name, shard_name, shard_id)
143
+ ("qkv_proj", "q_proj", "q"),
144
+ ("qkv_proj", "k_proj", "k"),
145
+ ("qkv_proj", "v_proj", "v"),
146
+ ("gate_up_proj", "gate_proj", 0),
147
+ ("gate_up_proj", "up_proj", 1),
148
+ ]
149
+ params_dict = dict(self.named_parameters())
150
+ loaded_params: Set[str] = set()
151
+
152
+ for name, loaded_weight in weights:
153
+ if "rotary_emb.inv_freq" in name:
154
+ continue
155
+
156
+ for param_name, weight_name, shard_id in stacked_params_mapping:
157
+ if weight_name not in name:
158
+ continue
159
+ name = name.replace(weight_name, param_name)
160
+
161
+ if name.endswith(".bias") and name not in params_dict:
162
+ continue
163
+
164
+ param = params_dict[name]
165
+ weight_loader = param.weight_loader
166
+ weight_loader(param, loaded_weight, shard_id)
167
+ break
168
+ else:
169
+ if "vision_encoder" in name:
170
+ # adapt to VisionAttention
171
+ name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
172
+
173
+ try:
174
+ # Skip loading extra bias for GPTQ models.
175
+ if name.endswith(".bias") and name not in params_dict:
176
+ continue
177
+ param = params_dict[name]
178
+ except KeyError:
179
+ print(params_dict.keys())
180
+ raise
181
+
182
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
183
+ weight_loader(param, loaded_weight)
184
+
185
+
186
+ EntryClass = [POINTSV15ChatModel]
sglang/srt/models/qwen.py CHANGED
@@ -15,7 +15,6 @@
15
15
  # Adapted from
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
17
17
 
18
- import time
19
18
  from typing import Any, Dict, Iterable, Optional, Tuple
20
19
 
21
20
  import torch
@@ -454,9 +454,6 @@ class Qwen2ForCausalLM(nn.Module):
454
454
  # For EAGLE3 support
455
455
  self.capture_aux_hidden_states = False
456
456
 
457
- # For EAGLE3 support
458
- self.capture_aux_hidden_states = False
459
-
460
457
  def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
461
458
  return self.model.get_input_embedding(input_ids)
462
459
 
@@ -484,10 +481,6 @@ class Qwen2ForCausalLM(nn.Module):
484
481
  if self.capture_aux_hidden_states:
485
482
  hidden_states, aux_hidden_states = hidden_states
486
483
 
487
- aux_hidden_states = None
488
- if self.capture_aux_hidden_states:
489
- hidden_states, aux_hidden_states = hidden_states
490
-
491
484
  if self.pp_group.is_last_rank:
492
485
  if not get_embedding:
493
486
  return self.logits_processor(
@@ -40,7 +40,6 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
40
40
  Qwen2_5_VisionRotaryEmbedding,
41
41
  )
42
42
 
43
- from sglang.srt.hf_transformers_utils import get_processor
44
43
  from sglang.srt.layers.attention.vision import VisionAttention
45
44
  from sglang.srt.layers.layernorm import RMSNorm
46
45
  from sglang.srt.layers.linear import (
@@ -60,7 +59,9 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInp
60
59
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
61
60
  from sglang.srt.model_loader.weight_utils import default_weight_loader
62
61
  from sglang.srt.models.qwen2 import Qwen2Model
62
+ from sglang.srt.models.utils import permute_inv
63
63
  from sglang.srt.utils import add_prefix
64
+ from sglang.srt.utils.hf_transformers_utils import get_processor
64
65
 
65
66
  logger = logging.getLogger(__name__)
66
67
 
@@ -265,7 +266,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
265
266
  self.fullatt_block_indexes = vision_config.fullatt_block_indexes
266
267
  self.window_size = vision_config.window_size
267
268
  self.patch_size = vision_config.patch_size
268
- mlp_hidden_size: int = vision_config.intermediate_size
269
+ mlp_hidden_size: int = ((vision_config.intermediate_size + 7) // 8) * 8
269
270
  self.patch_embed = Qwen2_5_VisionPatchEmbed(
270
271
  patch_size=patch_size,
271
272
  temporal_patch_size=temporal_patch_size,
@@ -405,6 +406,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
405
406
 
406
407
  # Move window_index to the same device as x before using it to index x
407
408
  window_index = window_index.to(device=x.device)
409
+ reverse_indices = permute_inv(window_index)
408
410
 
409
411
  # Ensure rotary_pos_emb is on the same device/dtype as x
410
412
  rotary_pos_emb = rotary_pos_emb.to(device=x.device, dtype=x.dtype)
@@ -436,7 +438,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
436
438
  .to(device=x.device, dtype=torch.int32),
437
439
  ]
438
440
  )
439
- cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
441
+ cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
440
442
 
441
443
  # transformers
442
444
  x = x.unsqueeze(1)
@@ -451,8 +453,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
451
453
 
452
454
  # adapter
453
455
  x = self.merger(x)
454
-
455
- reverse_indices = torch.argsort(window_index)
456
456
  x = x[reverse_indices, :]
457
457
 
458
458
  return x
@@ -23,31 +23,18 @@
23
23
  # limitations under the License.
24
24
  """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
25
25
  import logging
26
- import math
27
- from functools import lru_cache, partial
28
- from typing import Any, Iterable, List, Optional, Tuple, Type, TypedDict
26
+ from typing import Any, Iterable, List, Optional, Tuple
29
27
 
30
28
  import torch
31
29
  import torch.nn as nn
32
- import torch.nn.functional as F
33
- from einops import rearrange
34
- from transformers import AutoTokenizer, Qwen2AudioEncoderConfig, Qwen2Config
35
- from transformers.activations import ACT2FN
30
+ from transformers import Qwen2AudioEncoderConfig, Qwen2Config
36
31
  from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioConfig
37
32
  from transformers.models.qwen2_audio.modeling_qwen2_audio import (
38
33
  Qwen2AudioEncoder,
39
34
  Qwen2AudioMultiModalProjector,
40
35
  )
41
36
 
42
- from sglang.srt.hf_transformers_utils import get_processor
43
- from sglang.srt.layers.activation import QuickGELU
44
- from sglang.srt.layers.attention.vision import VisionAttention
45
- from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
46
- from sglang.srt.layers.logits_processor import LogitsProcessor
47
- from sglang.srt.layers.pooler import Pooler, PoolingType
48
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
49
- from sglang.srt.layers.utils import get_layer_id
50
- from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
51
38
  from sglang.srt.managers.mm_utils import (
52
39
  MultiModalityDataPaddingPatternMultimodalTokens,
53
40
  general_mm_embed_routine,
@@ -17,6 +17,7 @@
17
17
  """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
18
18
 
19
19
  import logging
20
+ from contextlib import nullcontext
20
21
  from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
21
22
 
22
23
  import torch
@@ -25,12 +26,14 @@ from torch import nn
25
26
  from transformers import PretrainedConfig
26
27
 
27
28
  from sglang.srt.distributed import (
29
+ get_moe_expert_parallel_world_size,
28
30
  get_pp_group,
29
31
  get_tensor_model_parallel_world_size,
30
32
  tensor_model_parallel_all_reduce,
31
33
  )
32
34
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
33
35
  from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
36
+ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
34
37
  from sglang.srt.layers.activation import SiluAndMul
35
38
  from sglang.srt.layers.communicator import (
36
39
  LayerCommunicator,
@@ -50,6 +53,7 @@ from sglang.srt.layers.linear import (
50
53
  RowParallelLinear,
51
54
  )
52
55
  from sglang.srt.layers.logits_processor import LogitsProcessor
56
+ from sglang.srt.layers.moe import get_moe_a2a_backend
53
57
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
54
58
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
55
59
  from sglang.srt.layers.moe.topk import TopK
@@ -61,10 +65,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
61
65
  ParallelLMHead,
62
66
  VocabParallelEmbedding,
63
67
  )
64
- from sglang.srt.managers.schedule_batch import global_server_args_dict
65
68
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
66
69
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
67
70
  from sglang.srt.model_loader.weight_utils import default_weight_loader
71
+ from sglang.srt.server_args import get_global_server_args
68
72
  from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
69
73
  from sglang.srt.utils import add_prefix, is_cuda, make_layers
70
74
 
@@ -82,6 +86,8 @@ class Qwen2MoeMLP(nn.Module):
82
86
  quant_config: Optional[QuantizationConfig] = None,
83
87
  reduce_results: bool = True,
84
88
  prefix: str = "",
89
+ tp_rank: Optional[int] = None,
90
+ tp_size: Optional[int] = None,
85
91
  ) -> None:
86
92
  super().__init__()
87
93
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -90,6 +96,8 @@ class Qwen2MoeMLP(nn.Module):
90
96
  bias=False,
91
97
  quant_config=quant_config,
92
98
  prefix=add_prefix("gate_up_proj", prefix),
99
+ tp_rank=tp_rank,
100
+ tp_size=tp_size,
93
101
  )
94
102
  self.down_proj = RowParallelLinear(
95
103
  intermediate_size,
@@ -98,6 +106,8 @@ class Qwen2MoeMLP(nn.Module):
98
106
  quant_config=quant_config,
99
107
  reduce_results=reduce_results,
100
108
  prefix=add_prefix("down_proj", prefix),
109
+ tp_rank=tp_rank,
110
+ tp_size=tp_size,
101
111
  )
102
112
  if hidden_act != "silu":
103
113
  raise ValueError(
@@ -146,7 +156,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
146
156
  self.experts = get_moe_impl_class(quant_config)(
147
157
  layer_id=self.layer_id,
148
158
  top_k=config.num_experts_per_tok,
149
- num_experts=config.num_experts,
159
+ num_experts=config.num_experts
160
+ + get_global_server_args().ep_num_redundant_experts,
150
161
  hidden_size=config.hidden_size,
151
162
  intermediate_size=config.moe_intermediate_size,
152
163
  quant_config=quant_config,
@@ -168,11 +179,31 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
168
179
  quant_config=quant_config,
169
180
  reduce_results=False,
170
181
  prefix=add_prefix("shared_expert", prefix),
182
+ **(
183
+ dict(tp_rank=0, tp_size=1)
184
+ if get_moe_a2a_backend().is_deepep()
185
+ else {}
186
+ ),
171
187
  )
172
188
  else:
173
189
  self.shared_expert = None
174
190
  self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
175
191
 
192
+ if get_moe_a2a_backend().is_deepep():
193
+ # TODO: we will support tp < ep in the future
194
+ self.ep_size = get_moe_expert_parallel_world_size()
195
+ self.num_experts = (
196
+ config.num_experts + get_global_server_args().ep_num_redundant_experts
197
+ )
198
+ self.top_k = config.num_experts_per_tok
199
+
200
+ def get_moe_weights(self):
201
+ return [
202
+ x.data
203
+ for name, x in self.experts.named_parameters()
204
+ if name not in ["correction_bias"]
205
+ ]
206
+
176
207
  def _forward_shared_experts(self, hidden_states: torch.Tensor):
177
208
  shared_output = None
178
209
  if self.shared_expert is not None:
@@ -183,6 +214,32 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
183
214
  )
184
215
  return shared_output
185
216
 
217
+ def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
218
+ shared_output = None
219
+ if hidden_states.shape[0] > 0:
220
+ # router_logits: (num_tokens, n_experts)
221
+ router_logits, _ = self.gate(hidden_states)
222
+ shared_output = self._forward_shared_experts(hidden_states)
223
+ topk_output = self.topk(
224
+ hidden_states,
225
+ router_logits,
226
+ num_token_non_padded=forward_batch.num_token_non_padded,
227
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
228
+ layer_id=self.layer_id,
229
+ ),
230
+ )
231
+ else:
232
+ topk_output = self.topk.empty_topk_output(hidden_states.device)
233
+ final_hidden_states = self.experts(
234
+ hidden_states=hidden_states,
235
+ topk_output=topk_output,
236
+ )
237
+
238
+ if shared_output is not None:
239
+ final_hidden_states.add_(shared_output)
240
+
241
+ return final_hidden_states
242
+
186
243
  def _forward_router_experts(self, hidden_states: torch.Tensor):
187
244
  # router_logits: (num_tokens, n_experts)
188
245
  router_logits, _ = self.gate(hidden_states)
@@ -213,6 +270,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
213
270
  num_tokens, hidden_dim = hidden_states.shape
214
271
  hidden_states = hidden_states.view(-1, hidden_dim)
215
272
 
273
+ if get_moe_a2a_backend().is_deepep():
274
+ return self._forward_deepep(hidden_states, forward_batch)
275
+
216
276
  DUAL_STREAM_TOKEN_THRESHOLD = 1024
217
277
  if (
218
278
  self.alt_stream is not None
@@ -455,6 +515,7 @@ class Qwen2MoeModel(nn.Module):
455
515
  ) -> None:
456
516
  super().__init__()
457
517
  self.config = config
518
+
458
519
  self.padding_idx = config.pad_token_id
459
520
  self.vocab_size = config.vocab_size
460
521
  self.pp_group = get_pp_group()
@@ -530,7 +591,12 @@ class Qwen2MoeModel(nn.Module):
530
591
  if residual is not None
531
592
  else hidden_states
532
593
  )
533
- with get_global_expert_distribution_recorder().with_current_layer(i):
594
+ ctx = (
595
+ nullcontext()
596
+ if get_global_server_args().enable_piecewise_cuda_graph
597
+ else get_global_expert_distribution_recorder().with_current_layer(i)
598
+ )
599
+ with ctx:
534
600
  layer = self.layers[i]
535
601
  hidden_states, residual = layer(
536
602
  positions, hidden_states, forward_batch, residual
@@ -580,7 +646,7 @@ class Qwen2MoeForCausalLM(nn.Module):
580
646
  config.hidden_size,
581
647
  quant_config=quant_config,
582
648
  prefix=add_prefix("lm_head", prefix),
583
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
649
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
584
650
  )
585
651
  self.logits_processor = LogitsProcessor(config)
586
652
  # For EAGLE3 support
@@ -28,12 +28,10 @@ from typing import Iterable, List, Optional, Tuple, Type, TypedDict
28
28
 
29
29
  import torch
30
30
  import torch.nn as nn
31
- import torch.nn.functional as F
32
31
  from einops import rearrange
33
32
  from transformers import Qwen2VLConfig
34
33
  from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
35
34
 
36
- from sglang.srt.hf_transformers_utils import get_processor
37
35
  from sglang.srt.layers.activation import QuickGELU
38
36
  from sglang.srt.layers.attention.vision import VisionAttention
39
37
  from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
@@ -50,6 +48,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
48
  from sglang.srt.model_loader.weight_utils import default_weight_loader
51
49
  from sglang.srt.models.qwen2 import Qwen2Model
52
50
  from sglang.srt.utils import add_prefix
51
+ from sglang.srt.utils.hf_transformers_utils import get_processor
53
52
 
54
53
  logger = logging.getLogger(__name__)
55
54
 
@@ -407,7 +406,7 @@ class Qwen2VisionTransformer(nn.Module):
407
406
  cu_seqlens = torch.repeat_interleave(
408
407
  grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
409
408
  ).cumsum(dim=0, dtype=torch.int32)
410
- cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
409
+ cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
411
410
 
412
411
  # transformers
413
412
  x = x.unsqueeze(1)
@@ -514,6 +513,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
514
513
  def get_input_embeddings(self):
515
514
  return self.model.embed_tokens
516
515
 
516
+ def should_apply_lora(self, module_name: str) -> bool:
517
+ # skip visual tower
518
+ return not module_name.startswith("visual")
519
+
517
520
  def forward(
518
521
  self,
519
522
  input_ids: torch.Tensor,
@@ -1,6 +1,5 @@
1
1
  # Adapted from qwen2.py
2
2
  import logging
3
- from functools import partial
4
3
  from typing import Any, Dict, Iterable, List, Optional, Tuple
5
4
 
6
5
  import torch
@@ -30,12 +29,19 @@ from sglang.srt.model_loader.weight_utils import (
30
29
  )
31
30
  from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
32
31
  from sglang.srt.models.qwen2 import Qwen2Model
33
- from sglang.srt.utils import add_prefix, is_cuda
32
+ from sglang.srt.utils import (
33
+ add_prefix,
34
+ get_cmo_stream,
35
+ is_cuda,
36
+ is_npu,
37
+ wait_cmo_stream,
38
+ )
34
39
 
35
40
  Qwen3Config = None
36
41
 
37
42
  logger = logging.getLogger(__name__)
38
43
  _is_cuda = is_cuda()
44
+ _is_npu = is_npu()
39
45
 
40
46
 
41
47
  class Qwen3Attention(nn.Module):
@@ -235,9 +241,18 @@ class Qwen3DecoderLayer(nn.Module):
235
241
 
236
242
  # Fully Connected
237
243
  hidden_states, residual = self.layer_communicator.prepare_mlp(
238
- hidden_states, residual, forward_batch
244
+ hidden_states,
245
+ residual,
246
+ forward_batch,
247
+ cache=(
248
+ [self.mlp.gate_up_proj.weight, self.mlp.down_proj.weight]
249
+ if _is_npu
250
+ else None
251
+ ),
239
252
  )
240
253
  hidden_states = self.mlp(hidden_states)
254
+ if _is_npu and get_cmo_stream():
255
+ wait_cmo_stream()
241
256
  hidden_states, residual = self.layer_communicator.postprocess_layer(
242
257
  hidden_states, residual, forward_batch
243
258
  )