sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,173 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+ from compressed_tensors.quantization import QuantizationStrategy
8
+ from torch.nn import Parameter
9
+
10
+ from sglang.srt.layers.parameter import (
11
+ ChannelQuantScaleParameter,
12
+ ModelWeightParameter,
13
+ PerTensorScaleParameter,
14
+ )
15
+ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
16
+ CompressedTensorsScheme,
17
+ )
18
+ from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
19
+ from sglang.srt.layers.quantization.utils import requantize_with_max_scale
20
+ from sglang.srt.utils import is_cuda
21
+
22
+ _is_cuda = is_cuda()
23
+ if _is_cuda:
24
+ from sgl_kernel import int8_scaled_mm
25
+
26
+
27
+ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
28
+
29
+ def __init__(
30
+ self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
31
+ ):
32
+ self.strategy = strategy
33
+ self.is_static_input_scheme = is_static_input_scheme
34
+ self.input_symmetric = input_symmetric
35
+
36
+ @classmethod
37
+ def get_min_capability(cls) -> int:
38
+ # lovelace and up
39
+ return 89
40
+
41
+ def process_weights_after_loading(self, layer) -> None:
42
+ # If per tensor, when we have a fused module (e.g. QKV) with per
43
+ # tensor scales (thus N scales being passed to the kernel),
44
+ # requantize so we can always run per channel
45
+ if self.strategy == QuantizationStrategy.TENSOR:
46
+ max_w_scale, weight = requantize_with_max_scale(
47
+ weight=layer.weight,
48
+ weight_scale=layer.weight_scale,
49
+ logical_widths=layer.logical_widths,
50
+ )
51
+
52
+ layer.weight = Parameter(weight.t(), requires_grad=False)
53
+ layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
54
+
55
+ # If channelwise, scales are already lined up, so just transpose.
56
+ elif self.strategy == QuantizationStrategy.CHANNEL:
57
+ weight = layer.weight
58
+ weight_scale = layer.weight_scale.data
59
+
60
+ layer.weight = Parameter(weight.t(), requires_grad=False)
61
+ # required by torch.compile to be torch.nn.Parameter
62
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
63
+
64
+ else:
65
+ raise ValueError(f"Unknown quantization strategy {self.strategy}")
66
+
67
+ # INPUT SCALE
68
+ if self.is_static_input_scheme and hasattr(layer, "input_scale"):
69
+ if self.input_symmetric:
70
+ layer.input_scale = Parameter(
71
+ layer.input_scale.max(), requires_grad=False
72
+ )
73
+ else:
74
+ input_scale = layer.input_scale
75
+ input_zero_point = layer.input_zero_point
76
+
77
+ # reconstruct the ranges
78
+ int8_traits = torch.iinfo(torch.int8)
79
+ azps = input_zero_point.to(dtype=torch.int32)
80
+ range_max = (input_scale * (int8_traits.max - azps)).max()
81
+ range_min = (input_scale * (int8_traits.min - azps)).min()
82
+
83
+ scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
84
+
85
+ # AZP loaded as int8 but used as int32
86
+ azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
87
+
88
+ layer.input_scale = Parameter(scale, requires_grad=False)
89
+ layer.input_zero_point = Parameter(azp, requires_grad=False)
90
+ else:
91
+ layer.input_scale = None
92
+ layer.input_zero_point = None
93
+
94
+ # azp_adj is the AZP adjustment term, used to account for weights.
95
+ # It does not depend on scales or azp, so it is the same for
96
+ # static and dynamic quantization.
97
+ # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
98
+ # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
99
+ if not self.input_symmetric:
100
+ weight = layer.weight
101
+ azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
102
+ if self.is_static_input_scheme:
103
+ # cutlass_w8a8 requires azp to be folded into azp_adj
104
+ # in the per-tensor case
105
+ azp_adj = layer.input_zero_point * azp_adj
106
+ layer.azp_adj = Parameter(azp_adj, requires_grad=False)
107
+ else:
108
+ layer.azp_adj = None
109
+
110
+ def create_weights(
111
+ self,
112
+ layer: torch.nn.Module,
113
+ output_partition_sizes: list[int],
114
+ input_size_per_partition: int,
115
+ params_dtype: torch.dtype,
116
+ weight_loader: Callable,
117
+ **kwargs,
118
+ ):
119
+ output_size_per_partition = sum(output_partition_sizes)
120
+ layer.logical_widths = output_partition_sizes
121
+
122
+ # WEIGHT
123
+ weight = ModelWeightParameter(
124
+ data=torch.empty(
125
+ output_size_per_partition, input_size_per_partition, dtype=torch.int8
126
+ ),
127
+ input_dim=1,
128
+ output_dim=0,
129
+ weight_loader=weight_loader,
130
+ )
131
+
132
+ layer.register_parameter("weight", weight)
133
+
134
+ # WEIGHT SCALE
135
+ if self.strategy == QuantizationStrategy.CHANNEL:
136
+ weight_scale = ChannelQuantScaleParameter(
137
+ data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
138
+ output_dim=0,
139
+ weight_loader=weight_loader,
140
+ )
141
+ else:
142
+ assert self.strategy == QuantizationStrategy.TENSOR
143
+ weight_scale = PerTensorScaleParameter(
144
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
145
+ weight_loader=weight_loader,
146
+ )
147
+ layer.register_parameter("weight_scale", weight_scale)
148
+
149
+ # INPUT SCALE
150
+ if self.is_static_input_scheme:
151
+ input_scale = PerTensorScaleParameter(
152
+ data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
153
+ )
154
+ layer.register_parameter("input_scale", input_scale)
155
+
156
+ if not self.input_symmetric:
157
+ # Note: compressed-tensors stores the zp using the same dtype
158
+ # as the weights
159
+ # AZP loaded as int8 but used as int32
160
+ input_zero_point = PerTensorScaleParameter(
161
+ data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
162
+ )
163
+ layer.register_parameter("input_zero_point", input_zero_point)
164
+
165
+ def apply_weights(
166
+ self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
167
+ ) -> torch.Tensor:
168
+ # TODO: add cutlass_scaled_mm_azp support
169
+ x_q, x_scale = per_token_quant_int8(x)
170
+
171
+ return int8_scaled_mm(
172
+ x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
173
+ )
@@ -0,0 +1,339 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import logging
5
+ from typing import Callable, Optional
6
+
7
+ import torch
8
+ from compressed_tensors.quantization import ActivationOrdering
9
+
10
+ # yapf conflicts with isort for this block
11
+ # yapf: disable
12
+ from sglang.srt.layers.parameter import (
13
+ BasevLLMParameter,
14
+ ChannelQuantScaleParameter,
15
+ GroupQuantScaleParameter,
16
+ PackedColumnParameter,
17
+ PackedvLLMParameter,
18
+ RowvLLMParameter,
19
+ permute_param_layout_,
20
+ )
21
+ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
22
+ CompressedTensorsScheme,
23
+ )
24
+ from sglang.srt.layers.quantization.marlin_utils import (
25
+ MarlinLinearLayerConfig,
26
+ apply_gptq_marlin_linear,
27
+ check_marlin_supports_shape,
28
+ marlin_is_k_full,
29
+ marlin_make_empty_g_idx,
30
+ marlin_make_workspace,
31
+ marlin_permute_scales,
32
+ marlin_repeat_scales_on_all_ranks,
33
+ marlin_sort_g_idx,
34
+ marlin_zero_points,
35
+ )
36
+ from sglang.srt.layers.quantization.utils import (
37
+ get_scalar_types,
38
+ replace_parameter,
39
+ unpack_cols,
40
+ )
41
+ from sglang.srt.utils import is_cuda
42
+
43
+ _is_cuda = is_cuda()
44
+
45
+ if _is_cuda:
46
+ from sgl_kernel import gptq_marlin_repack
47
+
48
+
49
+ ScalarType, scalar_types = get_scalar_types()
50
+
51
+ logger = logging.getLogger(__name__)
52
+
53
+ __all__ = ["CompressedTensorsWNA16"]
54
+ WNA16_SUPPORTED_TYPES_MAP = {
55
+ 4: scalar_types.uint4b8,
56
+ 8: scalar_types.uint8b128
57
+ }
58
+ WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8}
59
+ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
60
+
61
+
62
+ class CompressedTensorsWNA16(CompressedTensorsScheme):
63
+ _kernel_backends_being_used: set[str] = set()
64
+
65
+ def __init__(self,
66
+ strategy: str,
67
+ num_bits: int,
68
+ group_size: Optional[int] = None,
69
+ symmetric: Optional[bool] = True,
70
+ actorder: Optional[ActivationOrdering] = None):
71
+
72
+ self.pack_factor = 32 // num_bits
73
+ self.strategy = strategy
74
+ self.symmetric = symmetric
75
+ self.group_size = -1 if group_size is None else group_size
76
+ self.has_g_idx = actorder == ActivationOrdering.GROUP
77
+
78
+ if self.group_size == -1 and self.strategy != "channel":
79
+ raise ValueError("Marlin kernels require group quantization or "
80
+ "channelwise quantization, but found no group "
81
+ "size and strategy is not channelwise.")
82
+
83
+ if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
84
+ raise ValueError(
85
+ f"Unsupported num_bits = {num_bits}. "
86
+ f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
87
+
88
+ self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
89
+ if not self.symmetric else
90
+ WNA16_SUPPORTED_TYPES_MAP[num_bits])
91
+
92
+ @classmethod
93
+ def get_min_capability(cls) -> int:
94
+ # ampere and up
95
+ return 80
96
+
97
+ def create_weights(self, layer: torch.nn.Module, output_size: int,
98
+ input_size: int, output_partition_sizes: list[int],
99
+ input_size_per_partition: int,
100
+ params_dtype: torch.dtype, weight_loader: Callable,
101
+ **kwargs):
102
+
103
+ output_size_per_partition = sum(output_partition_sizes)
104
+
105
+ self.kernel_config = MarlinLinearLayerConfig(
106
+ full_weight_shape=(input_size, output_size),
107
+ partition_weight_shape=(
108
+ input_size_per_partition,
109
+ output_size_per_partition,
110
+ ),
111
+ weight_type=self.quant_type,
112
+ act_type=params_dtype,
113
+ group_size=self.group_size,
114
+ zero_points=not self.symmetric,
115
+ has_g_idx=self.has_g_idx
116
+ )
117
+
118
+ # If group_size is -1, we are in channelwise case.
119
+ group_size = self.group_size if self.group_size != -1 else input_size
120
+ row_parallel = (input_size != input_size_per_partition)
121
+ partition_scales = not marlin_repeat_scales_on_all_ranks(
122
+ self.has_g_idx, self.group_size, row_parallel)
123
+
124
+ scales_and_zp_size = input_size // group_size
125
+
126
+ if partition_scales:
127
+ assert input_size_per_partition % group_size == 0
128
+ scales_and_zp_size = input_size_per_partition // group_size
129
+
130
+ weight = PackedvLLMParameter(input_dim=1,
131
+ output_dim=0,
132
+ weight_loader=weight_loader,
133
+ packed_factor=self.pack_factor,
134
+ packed_dim=1,
135
+ data=torch.empty(
136
+ output_size_per_partition,
137
+ input_size_per_partition //
138
+ self.pack_factor,
139
+ dtype=torch.int32,
140
+ ))
141
+
142
+ weight_scale_args = {
143
+ "weight_loader":
144
+ weight_loader,
145
+ "data":
146
+ torch.empty(
147
+ output_size_per_partition,
148
+ scales_and_zp_size,
149
+ dtype=params_dtype,
150
+ )
151
+ }
152
+
153
+ zeros_args = {
154
+ "weight_loader":
155
+ weight_loader,
156
+ "data":
157
+ torch.zeros(
158
+ output_size_per_partition // self.pack_factor,
159
+ scales_and_zp_size,
160
+ dtype=torch.int32,
161
+ )
162
+ }
163
+
164
+ if not partition_scales:
165
+ weight_scale = ChannelQuantScaleParameter(output_dim=0,
166
+ **weight_scale_args)
167
+
168
+ if not self.symmetric:
169
+ qzeros = PackedColumnParameter(output_dim=0,
170
+ packed_dim=0,
171
+ packed_factor=self.pack_factor,
172
+ **zeros_args)
173
+ else:
174
+ weight_scale = GroupQuantScaleParameter(output_dim=0,
175
+ input_dim=1,
176
+ **weight_scale_args)
177
+ if not self.symmetric:
178
+ qzeros = PackedvLLMParameter(input_dim=1,
179
+ output_dim=0,
180
+ packed_dim=0,
181
+ packed_factor=self.pack_factor,
182
+ **zeros_args)
183
+
184
+ # A 2D array defining the original shape of the weights
185
+ # before packing
186
+ weight_shape = BasevLLMParameter(data=torch.empty(2,
187
+ dtype=torch.int64),
188
+ weight_loader=weight_loader)
189
+
190
+ layer.register_parameter("weight_packed", weight)
191
+ layer.register_parameter("weight_scale", weight_scale)
192
+ layer.register_parameter("weight_shape", weight_shape)
193
+
194
+ if not self.symmetric:
195
+ layer.register_parameter("weight_zero_point", qzeros)
196
+
197
+ # group index (for activation reordering)
198
+ if self.has_g_idx:
199
+ weight_g_idx = RowvLLMParameter(data=torch.empty(
200
+ input_size_per_partition,
201
+ dtype=torch.int32,
202
+ ),
203
+ input_dim=0,
204
+ weight_loader=weight_loader)
205
+ layer.register_parameter("weight_g_idx", weight_g_idx)
206
+
207
+ # Checkpoints are serialized in compressed-tensors format, which is
208
+ # different from the format the kernel may want. Handle repacking here.
209
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
210
+ # Default names since marlin requires empty parameters for these,
211
+ # TODO: remove this requirement from marlin (allow optional tensors)
212
+ self.w_q_name = "weight_packed"
213
+ self.w_s_name = "weight_scale"
214
+ self.w_zp_name = "weight_zero_point"
215
+ self.w_gidx_name = "weight_g_idx"
216
+
217
+ device = getattr(layer, self.w_q_name).device
218
+ c = self.kernel_config
219
+
220
+ check_marlin_supports_shape(
221
+ c.partition_weight_shape[1], # out_features
222
+ c.partition_weight_shape[0], # in_features
223
+ c.full_weight_shape[0], # in_features
224
+ c.group_size,
225
+ )
226
+
227
+ row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
228
+ self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
229
+
230
+ # Allocate marlin workspace.
231
+ self.workspace = marlin_make_workspace(device)
232
+
233
+ def _transform_param(
234
+ layer: torch.nn.Module, name: Optional[str], fn: Callable
235
+ ) -> None:
236
+ if name is not None and getattr(layer, name, None) is not None:
237
+
238
+ old_param = getattr(layer, name)
239
+ new_param = fn(old_param)
240
+ # replace the parameter with torch.nn.Parameter for TorchDynamo
241
+ # compatibility
242
+ replace_parameter(
243
+ layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
244
+ )
245
+
246
+ def transform_w_q(x):
247
+ assert isinstance(x, BasevLLMParameter)
248
+ permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
249
+ x.data = gptq_marlin_repack(
250
+ x.data.contiguous(),
251
+ perm=layer.g_idx_sort_indices,
252
+ size_k=c.partition_weight_shape[0],
253
+ size_n=c.partition_weight_shape[1],
254
+ num_bits=c.weight_type.size_bits,
255
+ )
256
+ return x
257
+
258
+ def transform_w_s(x):
259
+ assert isinstance(x, BasevLLMParameter)
260
+ permute_param_layout_(x, input_dim=0, output_dim=1)
261
+ x.data = marlin_permute_scales(
262
+ x.data.contiguous(),
263
+ size_k=c.partition_weight_shape[0],
264
+ size_n=c.partition_weight_shape[1],
265
+ group_size=c.group_size,
266
+ )
267
+ return x
268
+
269
+ if c.has_g_idx:
270
+ g_idx, g_idx_sort_indices = marlin_sort_g_idx(
271
+ getattr(layer, self.w_gidx_name)
272
+ )
273
+ _transform_param(layer, self.w_gidx_name, lambda _: g_idx)
274
+ layer.g_idx_sort_indices = g_idx_sort_indices
275
+ else:
276
+ setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
277
+ layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
278
+
279
+ if c.zero_points:
280
+ grouped_k = (
281
+ c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
282
+ )
283
+ _transform_param(
284
+ layer,
285
+ self.w_zp_name,
286
+ lambda x: marlin_zero_points(
287
+ unpack_cols(
288
+ x.t(),
289
+ c.weight_type.size_bits,
290
+ grouped_k,
291
+ c.partition_weight_shape[1],
292
+ ),
293
+ size_k=grouped_k,
294
+ size_n=c.partition_weight_shape[1],
295
+ num_bits=c.weight_type.size_bits,
296
+ ),
297
+ )
298
+ else:
299
+ setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
300
+ _transform_param(layer, self.w_q_name, transform_w_q)
301
+ _transform_param(layer, self.w_s_name, transform_w_s)
302
+
303
+ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
304
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
305
+ c = self.kernel_config
306
+
307
+ def _get_weight_params(
308
+ layer: torch.nn.Module,
309
+ ) -> tuple[
310
+ torch.Tensor, # w_q
311
+ torch.Tensor, # w_s
312
+ Optional[torch.Tensor], # w_zp,
313
+ Optional[torch.Tensor], # w_gidx
314
+ ]:
315
+ return (
316
+ getattr(layer, self.w_q_name),
317
+ getattr(layer, self.w_s_name),
318
+ getattr(layer, self.w_zp_name or "", None),
319
+ getattr(layer, self.w_gidx_name or "", None),
320
+ )
321
+
322
+ w_q, w_s, w_zp, w_gidx = _get_weight_params(layer)
323
+
324
+ # `process_weights_after_loading` will ensure w_zp and w_gidx are not
325
+ # None for marlin
326
+ return apply_gptq_marlin_linear(
327
+ input=x,
328
+ weight=w_q,
329
+ weight_scale=w_s,
330
+ weight_zp=w_zp, # type: ignore
331
+ g_idx=w_gidx, # type: ignore
332
+ g_idx_sort_indices=layer.g_idx_sort_indices,
333
+ workspace=self.workspace,
334
+ wtype=c.weight_type,
335
+ input_size_per_partition=c.partition_weight_shape[0],
336
+ output_size_per_partition=c.partition_weight_shape[1],
337
+ is_k_full=self.is_k_full,
338
+ bias=bias,
339
+ )
@@ -31,8 +31,8 @@ except ImportError:
31
31
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
32
32
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
33
33
  from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
34
+ from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
34
35
  from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
35
- from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
36
36
  from sglang.srt.layers.parameter import (
37
37
  BlockQuantScaleParameter,
38
38
  ModelWeightParameter,
@@ -358,8 +358,8 @@ class Fp8LinearMethod(LinearMethodBase):
358
358
  return
359
359
  else:
360
360
  weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
361
- layer.weight = Parameter(weight, requires_grad=False)
362
- layer.weight_scale_inv = Parameter(weight_scale, requires_grad=False)
361
+ layer.weight.data = weight.data
362
+ layer.weight_scale_inv.data = weight_scale.data
363
363
  else:
364
364
  layer.weight = Parameter(layer.weight.data, requires_grad=False)
365
365
 
@@ -1006,8 +1006,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1006
1006
  def create_moe_runner(
1007
1007
  self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1008
1008
  ):
1009
+
1010
+ from sglang.srt.layers import deep_gemm_wrapper
1011
+ from sglang.srt.layers.moe.utils import (
1012
+ get_moe_a2a_backend,
1013
+ get_moe_runner_backend,
1014
+ )
1015
+
1009
1016
  self.moe_runner_config = moe_runner_config
1010
- self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
1017
+ moe_runner_backend = get_moe_runner_backend()
1018
+
1019
+ if moe_runner_backend.is_auto():
1020
+ if (
1021
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1022
+ and get_moe_a2a_backend().is_deepep()
1023
+ ):
1024
+ moe_runner_backend = MoeRunnerBackend.DEEP_GEMM
1025
+ else:
1026
+ moe_runner_backend = MoeRunnerBackend.TRITON
1027
+ if moe_runner_backend.is_deep_gemm() or moe_runner_backend.is_triton():
1028
+ self.runner = MoeRunner(moe_runner_backend, moe_runner_config)
1029
+ else:
1030
+ # TODO(cwan): refactor other backends
1031
+ pass
1011
1032
 
1012
1033
  def apply(
1013
1034
  self,
@@ -1087,22 +1108,67 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1087
1108
  )
1088
1109
  return StandardCombineInput(hidden_states=output)
1089
1110
 
1090
- quant_info = TritonMoeQuantInfo(
1091
- w13_weight=layer.w13_weight,
1092
- w2_weight=layer.w2_weight,
1093
- use_fp8_w8a8=True,
1094
- w13_scale=(
1095
- layer.w13_weight_scale_inv
1096
- if self.block_quant
1097
- else layer.w13_weight_scale
1098
- ),
1099
- w2_scale=(
1100
- layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
1101
- ),
1102
- a13_scale=layer.w13_input_scale,
1103
- a2_scale=layer.w2_input_scale,
1104
- block_shape=self.quant_config.weight_block_size,
1105
- )
1111
+ if self.runner.runner_backend.is_deep_gemm():
1112
+
1113
+ w13_weight = layer.w13_weight
1114
+ w2_weight = layer.w2_weight
1115
+
1116
+ if self.block_quant:
1117
+ block_shape = self.quant_config.weight_block_size
1118
+ w13_scale = layer.w13_weight_scale_inv
1119
+ w2_scale = layer.w2_weight_scale_inv
1120
+ else:
1121
+ # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
1122
+ scale_block_size = 128
1123
+ block_shape = [scale_block_size, scale_block_size]
1124
+ w13_scale_n = (w13_weight.shape[1] - 1) // scale_block_size + 1
1125
+ w13_scale_k = (w13_weight.shape[2] - 1) // scale_block_size + 1
1126
+ w13_scale = (
1127
+ layer.w13_weight_scale.unsqueeze(1)
1128
+ .repeat_interleave(w13_scale_n, dim=1)
1129
+ .unsqueeze(2)
1130
+ .repeat_interleave(w13_scale_k, dim=2)
1131
+ )
1132
+ w2_scale_n = (w2_weight.shape[1] - 1) // scale_block_size + 1
1133
+ w2_scale_k = (w2_weight.shape[2] - 1) // scale_block_size + 1
1134
+ w2_scale = (
1135
+ layer.w2_weight_scale.unsqueeze(1)
1136
+ .repeat_interleave(w2_scale_n, dim=1)
1137
+ .unsqueeze(2)
1138
+ .repeat_interleave(w2_scale_k, dim=2)
1139
+ )
1140
+ quant_info = DeepGemmMoeQuantInfo(
1141
+ w13_weight=w13_weight,
1142
+ w2_weight=w2_weight,
1143
+ use_fp8=True,
1144
+ w13_scale=w13_scale,
1145
+ w2_scale=w2_scale,
1146
+ block_shape=block_shape,
1147
+ )
1148
+ elif self.runner.runner_backend.is_triton():
1149
+ quant_info = TritonMoeQuantInfo(
1150
+ w13_weight=layer.w13_weight,
1151
+ w2_weight=layer.w2_weight,
1152
+ use_fp8_w8a8=True,
1153
+ w13_scale=(
1154
+ layer.w13_weight_scale_inv
1155
+ if self.block_quant
1156
+ else layer.w13_weight_scale
1157
+ ),
1158
+ w2_scale=(
1159
+ layer.w2_weight_scale_inv
1160
+ if self.block_quant
1161
+ else layer.w2_weight_scale
1162
+ ),
1163
+ a13_scale=layer.w13_input_scale,
1164
+ a2_scale=layer.w2_input_scale,
1165
+ block_shape=self.quant_config.weight_block_size,
1166
+ )
1167
+ else:
1168
+ raise NotImplementedError(
1169
+ "Unsupported runner backend: %s" % self.runner.runner_backend
1170
+ )
1171
+
1106
1172
  return self.runner.run(dispatch_output, quant_info)
1107
1173
 
1108
1174
  def apply_with_router_logits(