sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
 
4
4
  import logging
5
5
  import warnings
6
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
7
7
 
8
8
  import torch
9
9
 
@@ -31,6 +31,7 @@ from sglang.srt.layers.quantization.marlin_utils import (
31
31
  )
32
32
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
33
33
  from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
34
+ from sglang.srt.layers.quantization.w8a8_int8 import npu_fused_experts
34
35
 
35
36
  if TYPE_CHECKING:
36
37
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
@@ -39,10 +40,16 @@ if TYPE_CHECKING:
39
40
  CombineInput,
40
41
  )
41
42
 
42
- from sglang.srt.utils import is_cuda, is_hip
43
+ from sglang.srt.utils import is_cuda, is_hip, is_npu, is_xpu
43
44
 
44
45
  _is_cuda = is_cuda()
45
46
  _is_hip = is_hip()
47
+ _is_xpu = is_xpu()
48
+ _is_npu = is_npu()
49
+
50
+ if _is_npu:
51
+ import torch_npu
52
+
46
53
  if _is_cuda:
47
54
  from sgl_kernel import (
48
55
  awq_dequantize,
@@ -58,8 +65,12 @@ elif _is_hip:
58
65
  )
59
66
 
60
67
  warnings.warn(f"HIP does not support fused_marlin_moe currently.")
68
+ elif _is_xpu:
69
+ from sgl_kernel import awq_dequantize
70
+
71
+ warnings.warn(f"XPU does not support fused_marlin_moe currently.")
61
72
  else:
62
- warnings.warn(f"Only CUDA and HIP support AWQ currently.")
73
+ warnings.warn(f"Only CUDA, HIP and XPU support AWQ currently.")
63
74
 
64
75
  logger = logging.getLogger(__name__)
65
76
 
@@ -112,12 +123,17 @@ class AWQConfig(QuantizationConfig):
112
123
  return "awq"
113
124
 
114
125
  def get_supported_act_dtypes(self) -> List[torch.dtype]:
115
- return [torch.half]
126
+ return [torch.float16] if not _is_npu else [torch.float16, torch.bfloat16]
116
127
 
117
128
  @classmethod
118
129
  def get_min_capability(cls) -> int:
119
130
  # The AWQ kernel only supports Turing or newer GPUs.
120
- return 75
131
+ if _is_npu:
132
+ raise NotImplementedError(
133
+ 'NPU hardware does not support "get_min_capability" feature.'
134
+ )
135
+ else:
136
+ return 75
121
137
 
122
138
  @staticmethod
123
139
  def get_config_filenames() -> List[str]:
@@ -141,6 +157,16 @@ class AWQConfig(QuantizationConfig):
141
157
  self, layer: torch.nn.Module, prefix: str
142
158
  ) -> Optional[LinearMethodBase]:
143
159
  from sglang.srt.layers.linear import LinearBase
160
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
161
+
162
+ if _is_npu:
163
+ if isinstance(layer, LinearBase):
164
+ if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
165
+ return UnquantizedLinearMethod()
166
+ return AWQLinearAscendMethod(self)
167
+ elif isinstance(layer, FusedMoE):
168
+ return AWQMoEAscendMethod(self)
169
+ return None
144
170
 
145
171
  if isinstance(layer, LinearBase):
146
172
  if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
@@ -570,6 +596,64 @@ class AWQMarlinLinearMethod(LinearMethodBase):
570
596
  )
571
597
 
572
598
 
599
+ class AWQLinearAscendMethod(AWQLinearMethod):
600
+ """Linear method for AWQ on Ascend.
601
+
602
+ Args:
603
+ quant_config: The AWQ quantization config.
604
+ """
605
+
606
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
607
+ layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
608
+ qweight_tmp = torch.zeros_like(layer.qweight.data)
609
+ qzeros_tmp = layer.qzeros.data
610
+ qzeros_list = []
611
+ shifts = [0, 4, 1, 5, 2, 6, 3, 7]
612
+
613
+ for i in range(0, self.quant_config.pack_factor):
614
+ shift_num = shifts[i] * 4
615
+ qzeros_list.append((qzeros_tmp.reshape(-1, 1) >> shift_num) & 0xF)
616
+ qweight_tmp.bitwise_or_(
617
+ ((layer.qweight.data >> shift_num) * (2 ** (4 * i))) & (0xF << (4 * i))
618
+ )
619
+
620
+ qweight_tmp.bitwise_xor_(0x88888888)
621
+
622
+ qzeros_tmp = torch.cat(qzeros_list, dim=-1).reshape(qzeros_tmp.shape[0], -1)
623
+ qzeros_tmp = -(qzeros_tmp - 8)
624
+ qzeros_tmp = qzeros_tmp.to(layer.scales.data.dtype)
625
+
626
+ layer.qzeros = torch.nn.Parameter(qzeros_tmp, requires_grad=False)
627
+ layer.qweight = torch.nn.Parameter(qweight_tmp, requires_grad=False)
628
+
629
+ def apply(
630
+ self,
631
+ layer: torch.nn.Module,
632
+ x: torch.Tensor,
633
+ bias: Optional[torch.Tensor] = None,
634
+ ) -> torch.Tensor:
635
+ qweight = layer.qweight
636
+ scales = layer.scales
637
+ qzeros = layer.qzeros
638
+ pack_factor = self.quant_config.pack_factor
639
+ out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
640
+ reshaped_x = x.reshape(-1, x.shape[-1])
641
+
642
+ if bias is not None and bias.dtype == torch.bfloat16:
643
+ bias = bias.float()
644
+
645
+ out = torch_npu.npu_weight_quant_batchmatmul(
646
+ reshaped_x,
647
+ qweight,
648
+ antiquant_scale=scales,
649
+ antiquant_offset=qzeros,
650
+ antiquant_group_size=self.quant_config.group_size,
651
+ bias=bias,
652
+ )
653
+
654
+ return out.reshape(out_shape)
655
+
656
+
573
657
  class AWQMoEMethod(FusedMoEMethodBase):
574
658
 
575
659
  def __init__(self, quant_config: AWQMarlinConfig):
@@ -672,7 +756,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
672
756
  set_weight_attrs(w2_qzeros, extra_weight_attrs)
673
757
 
674
758
  device = layer.w13_qweight.device
675
- layer.workspace = marlin_make_workspace(device, 4)
759
+ if not _is_npu:
760
+ layer.workspace = marlin_make_workspace(device, 4)
676
761
 
677
762
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
678
763
  num_experts = layer.w13_qweight.shape[0]
@@ -780,3 +865,95 @@ class AWQMoEMethod(FusedMoEMethodBase):
780
865
  num_bits=self.quant_config.weight_bits,
781
866
  ).to(orig_dtype)
782
867
  return StandardCombineInput(hidden_states=output)
868
+
869
+
870
+ class AWQMoEAscendMethod(AWQMoEMethod):
871
+ def __init__(self, quant_config: AWQConfig):
872
+ self.quant_config = quant_config
873
+
874
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
875
+ w13_qweight_tmp = torch.zeros_like(layer.w13_qweight.data)
876
+ w2_qweight_tmp = torch.zeros_like(layer.w2_qweight.data)
877
+ w13_qzeros_list = []
878
+ w2_qzeros_list = []
879
+ shifts = [0, 4, 1, 5, 2, 6, 3, 7]
880
+ for i in range(0, self.quant_config.pack_factor):
881
+ shift_num = shifts[i] * 4
882
+ w13_qzeros_list.append(
883
+ (layer.w13_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF
884
+ )
885
+ w2_qzeros_list.append(
886
+ (layer.w2_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF
887
+ )
888
+ w13_qweight_tmp.bitwise_or_(
889
+ ((layer.w13_qweight.data >> shift_num) * (2 ** (4 * i)))
890
+ & (0xF << (4 * i))
891
+ )
892
+ w2_qweight_tmp.bitwise_or_(
893
+ ((layer.w2_qweight.data >> shift_num) * (2 ** (4 * i)))
894
+ & (0xF << (4 * i))
895
+ )
896
+
897
+ w13_qweight_tmp.bitwise_xor_(0x88888888)
898
+ w2_qweight_tmp.bitwise_xor_(0x88888888)
899
+
900
+ w13_qzeros_tmp = torch.cat(w13_qzeros_list, dim=-1).reshape(
901
+ layer.w13_qzeros.shape[0], layer.w13_qzeros.shape[1], -1
902
+ )
903
+ w13_qzeros_tmp = -(w13_qzeros_tmp - 8)
904
+ w13_qzeros_tmp = w13_qzeros_tmp.to(layer.w13_scales.data.dtype)
905
+ w2_qzeros_tmp = torch.cat(w2_qzeros_list, dim=-1).reshape(
906
+ layer.w2_qzeros.shape[0], layer.w2_qzeros.shape[1], -1
907
+ )
908
+ w2_qzeros_tmp = -(w2_qzeros_tmp - 8)
909
+ w2_qzeros_tmp = w2_qzeros_tmp.to(layer.w2_scales.data.dtype)
910
+
911
+ layer.register_parameter(
912
+ "w13_qzeros", torch.nn.Parameter(w13_qzeros_tmp, requires_grad=False)
913
+ )
914
+ layer.register_parameter(
915
+ "w13_qweight", torch.nn.Parameter(w13_qweight_tmp, requires_grad=False)
916
+ )
917
+ layer.register_parameter(
918
+ "w2_qzeros", torch.nn.Parameter(w2_qzeros_tmp, requires_grad=False)
919
+ )
920
+ layer.register_parameter(
921
+ "w2_qweight", torch.nn.Parameter(w2_qweight_tmp, requires_grad=False)
922
+ )
923
+
924
+ def create_moe_runner(
925
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
926
+ ):
927
+ self.moe_runner_config = moe_runner_config
928
+
929
+ def apply(
930
+ self,
931
+ layer: torch.nn.Module,
932
+ dispatch_output: StandardDispatchOutput,
933
+ ) -> torch.Tensor:
934
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
935
+
936
+ assert (
937
+ self.moe_runner_config.activation == "silu"
938
+ ), "Only SiLU activation is supported."
939
+
940
+ x = dispatch_output.hidden_states
941
+ topk_output = dispatch_output.topk_output
942
+
943
+ topk_weights, topk_ids, _ = topk_output
944
+ topk_ids = topk_ids.to(torch.int32)
945
+ topk_weights = topk_weights.to(x.dtype)
946
+ output = npu_fused_experts(
947
+ hidden_states=x,
948
+ w13=layer.w13_qweight,
949
+ w13_scale=layer.w13_scales,
950
+ w13_offset=layer.w13_qzeros,
951
+ w2=layer.w2_qweight,
952
+ w2_scale=layer.w2_scales,
953
+ w2_offset=layer.w2_qzeros,
954
+ topk_weights=topk_weights,
955
+ topk_ids=topk_ids,
956
+ top_k=topk_ids.shape[1],
957
+ use_wna16=True,
958
+ )
959
+ return StandardCombineInput(hidden_states=output)
@@ -337,3 +337,32 @@ def awq_gemm_triton(
337
337
  result = result.sum(0)
338
338
 
339
339
  return result
340
+
341
+
342
+ def awq_dequantize_decomposition(
343
+ qweight: torch.Tensor,
344
+ scales: torch.Tensor,
345
+ zeros: torch.Tensor,
346
+ ) -> torch.Tensor:
347
+ qweight_tmp = qweight
348
+ qzeros_tmp = zeros
349
+ qweight_list = []
350
+ qzeros_list = []
351
+ shifts = [0, 4, 1, 5, 2, 6, 3, 7]
352
+ for i in range(0, 8):
353
+ shift_num = shifts[i] * 4
354
+ qzeros_list.append((qzeros_tmp.reshape(-1, 1) >> shift_num) & 0xF)
355
+ qweight_list.append((qweight_tmp.reshape(-1, 1) >> shift_num) & 0xF)
356
+ qzeros_tmp = (
357
+ torch.cat(qzeros_list, dim=-1).reshape(qzeros_tmp.shape[0], -1).to(scales.dtype)
358
+ )
359
+ qweight_tmp = (
360
+ torch.cat(qweight_list, dim=-1)
361
+ .reshape(qweight_tmp.shape[0], -1)
362
+ .to(scales.dtype)
363
+ )
364
+ res = (
365
+ qweight_tmp.reshape(qzeros_tmp.shape[0], -1, qzeros_tmp.shape[1])
366
+ - qzeros_tmp.unsqueeze(1)
367
+ ) * scales.unsqueeze(1)
368
+ return res.reshape(qweight_tmp.shape[0], -1)
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
 
4
4
  import inspect
5
5
  from abc import ABC, abstractmethod
6
- from dataclasses import dataclass
7
6
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
8
7
 
9
8
  import torch
@@ -162,6 +161,26 @@ class QuantizationConfig(ABC):
162
161
  """
163
162
  return None
164
163
 
164
+ @classmethod
165
+ def _modelopt_override_quantization_method(
166
+ cls, hf_quant_config, user_quant
167
+ ) -> Optional[str]:
168
+ """Shared ModelOpt quantization method override logic."""
169
+ if hf_quant_config is None:
170
+ return None
171
+
172
+ # Check if this is a ModelOpt config
173
+ quant_algo = hf_quant_config.get("quant_algo", "").upper()
174
+
175
+ # If user specified generic "modelopt", auto-detect the specific method
176
+ if user_quant == "modelopt":
177
+ if "FP8" in quant_algo:
178
+ return "modelopt_fp8"
179
+ elif "NVFP4" in quant_algo or "FP4" in quant_algo:
180
+ return "modelopt_fp4"
181
+
182
+ return None
183
+
165
184
  @staticmethod
166
185
  def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
167
186
  """Get a value from the model's quantization config."""
@@ -0,0 +1,7 @@
1
+ class scalar_types:
2
+ uint4b8 = "uint4b8"
3
+ uint8b128 = "uint8b128"
4
+
5
+
6
+ WNA16_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4b8, 8: scalar_types.uint8b128}
7
+ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
@@ -19,36 +19,32 @@ from compressed_tensors.quantization import (
19
19
  )
20
20
  from pydantic import BaseModel
21
21
 
22
+ from sglang.srt.environ import envs
22
23
  from sglang.srt.layers.quantization.base_config import (
23
24
  LinearMethodBase,
24
25
  QuantizationConfig,
25
26
  QuantizeMethodBase,
26
27
  )
28
+ from sglang.srt.layers.quantization.compressed_tensors import WNA16_SUPPORTED_BITS
27
29
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
28
30
  CompressedTensorsMoEMethod,
29
31
  )
30
32
  from sglang.srt.layers.quantization.compressed_tensors.schemes import (
33
+ WNA16_SUPPORTED_BITS,
31
34
  CompressedTensorsScheme,
32
35
  CompressedTensorsW8A8Fp8,
36
+ CompressedTensorsW8A8Int8,
33
37
  CompressedTensorsW8A16Fp8,
38
+ CompressedTensorsWNA16,
34
39
  )
35
40
  from sglang.srt.layers.quantization.compressed_tensors.utils import (
36
41
  find_matched_target,
37
42
  is_activation_quantization_format,
38
43
  should_ignore_layer,
39
44
  )
45
+ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
40
46
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
41
47
 
42
- try:
43
- from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import (
44
- WNA16_SUPPORTED_BITS,
45
- CompressedTensorsWNA16,
46
- )
47
-
48
- VLLM_AVAILABLE = True
49
- except ImportError:
50
- VLLM_AVAILABLE = False
51
-
52
48
  logger = logging.getLogger(__name__)
53
49
 
54
50
  __all__ = ["CompressedTensorsLinearMethod"]
@@ -75,6 +71,7 @@ class DeviceCapability(NamedTuple):
75
71
 
76
72
 
77
73
  class CompressedTensorsConfig(QuantizationConfig):
74
+ DeepSeekFP8Config = None
78
75
 
79
76
  def __init__(
80
77
  self,
@@ -85,7 +82,7 @@ class CompressedTensorsConfig(QuantizationConfig):
85
82
  sparsity_ignore_list: List[str],
86
83
  kv_cache_scheme: Optional[Dict[str, Any]] = None,
87
84
  config: Optional[Dict[str, Any]] = None,
88
- packed_modules_mapping: Dict[str, List[str]] = {},
85
+ packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
89
86
  ):
90
87
  super().__init__()
91
88
  self.ignore = ignore
@@ -96,7 +93,7 @@ class CompressedTensorsConfig(QuantizationConfig):
96
93
  self.sparsity_scheme_map = sparsity_scheme_map
97
94
  self.sparsity_ignore_list = sparsity_ignore_list
98
95
  self.config = config
99
- self.packed_modules_mapping = packed_modules_mapping
96
+ self.packed_modules_mapping = packed_modules_mapping or {}
100
97
 
101
98
  def get_linear_method(self) -> CompressedTensorsLinearMethod:
102
99
  return CompressedTensorsLinearMethod(self)
@@ -128,6 +125,10 @@ class CompressedTensorsConfig(QuantizationConfig):
128
125
  ):
129
126
  return UnquantizedLinearMethod()
130
127
  if isinstance(layer, LinearBase):
128
+ if CompressedTensorsConfig.DeepSeekFP8Config is not None:
129
+ return Fp8LinearMethod(CompressedTensorsConfig.DeepSeekFP8Config)
130
+ if envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set():
131
+ return UnquantizedLinearMethod()
131
132
  scheme = self.get_scheme(layer=layer, layer_name=prefix)
132
133
  if scheme is None:
133
134
  return UnquantizedLinearMethod()
@@ -136,7 +137,8 @@ class CompressedTensorsConfig(QuantizationConfig):
136
137
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
137
138
 
138
139
  if isinstance(layer, FusedMoE):
139
- return CompressedTensorsMoEMethod.get_moe_method(self)
140
+ # Ktransformers use CompressedTensorsWNA16AMXMOEMethod if AMX weights are provided
141
+ return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix)
140
142
  return None
141
143
 
142
144
  @classmethod
@@ -363,19 +365,6 @@ class CompressedTensorsConfig(QuantizationConfig):
363
365
 
364
366
  # Detect If Mixed Precision
365
367
  if self._is_wNa16_group_channel(weight_quant, input_quant):
366
- if not VLLM_AVAILABLE:
367
- raise ImportError(
368
- "vllm is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install vllm"
369
- )
370
- if (
371
- self.quant_format == CompressionFormat.marlin_24.value
372
- and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS
373
- ):
374
- return CompressedTensorsW4A16Sparse24(
375
- strategy=weight_quant.strategy,
376
- num_bits=weight_quant.num_bits,
377
- group_size=weight_quant.group_size,
378
- )
379
368
  if (
380
369
  self.quant_format == CompressionFormat.pack_quantized.value
381
370
  and weight_quant.num_bits in WNA16_SUPPORTED_BITS
@@ -386,6 +375,10 @@ class CompressedTensorsConfig(QuantizationConfig):
386
375
  group_size=weight_quant.group_size,
387
376
  actorder=weight_quant.actorder,
388
377
  )
378
+ else:
379
+ raise ImportError(
380
+ "Other method (CompressedTensorsW4A16Sparse24) is not supported now"
381
+ )
389
382
 
390
383
  if is_activation_quantization_format(self.quant_format):
391
384
  if self._is_fp8_w8a8(weight_quant, input_quant):
@@ -409,10 +402,6 @@ class CompressedTensorsConfig(QuantizationConfig):
409
402
 
410
403
  # note: input_quant can be None
411
404
  if self._is_fp8_w8a16(weight_quant, input_quant):
412
- if not VLLM_AVAILABLE:
413
- raise ImportError(
414
- "vllm is not installed, to use CompressedTensorsW8A16Fp8, please install vllm"
415
- )
416
405
  is_static_input_scheme = input_quant and not input_quant.dynamic
417
406
  return CompressedTensorsW8A16Fp8(
418
407
  strategy=weight_quant.strategy,
@@ -453,7 +442,7 @@ class CompressedTensorsConfig(QuantizationConfig):
453
442
 
454
443
  # Find the "target" in the compressed-tensors config
455
444
  # that our layer conforms to.
456
- # TODO (@robertgshaw): add compressed-tensors as dep
445
+ # TODO : add compressed-tensors as dep
457
446
  # so we do not have to re-write these functions
458
447
  # need to make accelerate optional in ct to do this
459
448
 
@@ -491,24 +480,7 @@ class CompressedTensorsConfig(QuantizationConfig):
491
480
  input_quant=input_quant,
492
481
  sparsity_scheme=sparsity_scheme,
493
482
  ):
494
- if not VLLM_AVAILABLE:
495
- raise ImportError(
496
- "vllm is not installed, to use CompressedTensors24, please install vllm"
497
- )
498
- # Have a valid sparsity scheme
499
- # Validate layer is supported by Cutlass 2:4 Kernel
500
- model_compression_config = (
501
- None
502
- if sparsity_scheme is None or sparsity_scheme.format == "dense"
503
- else self.config
504
- )
505
-
506
- scheme = CompressedTensors24(
507
- quantized=weight_quant is not None or input_quant is not None,
508
- weight_quant=weight_quant,
509
- input_quant=input_quant,
510
- model_compression_config=model_compression_config,
511
- )
483
+ raise ImportError("CompressedTensors24 is not supported now")
512
484
  elif weight_quant is None:
513
485
  logger.warning_once(
514
486
  "Acceleration for non-quantized schemes is "