sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
 
4
4
  import logging
5
5
  import warnings
6
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
7
7
 
8
8
  import torch
9
9
 
@@ -31,6 +31,7 @@ from sglang.srt.layers.quantization.marlin_utils import (
31
31
  )
32
32
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
33
33
  from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
34
+ from sglang.srt.layers.quantization.w8a8_int8 import npu_fused_experts
34
35
 
35
36
  if TYPE_CHECKING:
36
37
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
@@ -39,10 +40,16 @@ if TYPE_CHECKING:
39
40
  CombineInput,
40
41
  )
41
42
 
42
- from sglang.srt.utils import is_cuda, is_hip
43
+ from sglang.srt.utils import is_cuda, is_hip, is_npu, is_xpu
43
44
 
44
45
  _is_cuda = is_cuda()
45
46
  _is_hip = is_hip()
47
+ _is_xpu = is_xpu()
48
+ _is_npu = is_npu()
49
+
50
+ if _is_npu:
51
+ import torch_npu
52
+
46
53
  if _is_cuda:
47
54
  from sgl_kernel import (
48
55
  awq_dequantize,
@@ -58,8 +65,12 @@ elif _is_hip:
58
65
  )
59
66
 
60
67
  warnings.warn(f"HIP does not support fused_marlin_moe currently.")
68
+ elif _is_xpu:
69
+ from sgl_kernel import awq_dequantize
70
+
71
+ warnings.warn(f"XPU does not support fused_marlin_moe currently.")
61
72
  else:
62
- warnings.warn(f"Only CUDA and HIP support AWQ currently.")
73
+ warnings.warn(f"Only CUDA, HIP and XPU support AWQ currently.")
63
74
 
64
75
  logger = logging.getLogger(__name__)
65
76
 
@@ -112,12 +123,17 @@ class AWQConfig(QuantizationConfig):
112
123
  return "awq"
113
124
 
114
125
  def get_supported_act_dtypes(self) -> List[torch.dtype]:
115
- return [torch.half]
126
+ return [torch.float16] if not _is_npu else [torch.float16, torch.bfloat16]
116
127
 
117
128
  @classmethod
118
129
  def get_min_capability(cls) -> int:
119
130
  # The AWQ kernel only supports Turing or newer GPUs.
120
- return 75
131
+ if _is_npu:
132
+ raise NotImplementedError(
133
+ 'NPU hardware does not support "get_min_capability" feature.'
134
+ )
135
+ else:
136
+ return 75
121
137
 
122
138
  @staticmethod
123
139
  def get_config_filenames() -> List[str]:
@@ -141,6 +157,16 @@ class AWQConfig(QuantizationConfig):
141
157
  self, layer: torch.nn.Module, prefix: str
142
158
  ) -> Optional[LinearMethodBase]:
143
159
  from sglang.srt.layers.linear import LinearBase
160
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
161
+
162
+ if _is_npu:
163
+ if isinstance(layer, LinearBase):
164
+ if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
165
+ return UnquantizedLinearMethod()
166
+ return AWQLinearAscendMethod(self)
167
+ elif isinstance(layer, FusedMoE):
168
+ return AWQMoEAscendMethod(self)
169
+ return None
144
170
 
145
171
  if isinstance(layer, LinearBase):
146
172
  if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
@@ -570,6 +596,64 @@ class AWQMarlinLinearMethod(LinearMethodBase):
570
596
  )
571
597
 
572
598
 
599
+ class AWQLinearAscendMethod(AWQLinearMethod):
600
+ """Linear method for AWQ on Ascend.
601
+
602
+ Args:
603
+ quant_config: The AWQ quantization config.
604
+ """
605
+
606
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
607
+ layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
608
+ qweight_tmp = torch.zeros_like(layer.qweight.data)
609
+ qzeros_tmp = layer.qzeros.data
610
+ qzeros_list = []
611
+ shifts = [0, 4, 1, 5, 2, 6, 3, 7]
612
+
613
+ for i in range(0, self.quant_config.pack_factor):
614
+ shift_num = shifts[i] * 4
615
+ qzeros_list.append((qzeros_tmp.reshape(-1, 1) >> shift_num) & 0xF)
616
+ qweight_tmp.bitwise_or_(
617
+ ((layer.qweight.data >> shift_num) * (2 ** (4 * i))) & (0xF << (4 * i))
618
+ )
619
+
620
+ qweight_tmp.bitwise_xor_(0x88888888)
621
+
622
+ qzeros_tmp = torch.cat(qzeros_list, dim=-1).reshape(qzeros_tmp.shape[0], -1)
623
+ qzeros_tmp = -(qzeros_tmp - 8)
624
+ qzeros_tmp = qzeros_tmp.to(layer.scales.data.dtype)
625
+
626
+ layer.qzeros = torch.nn.Parameter(qzeros_tmp, requires_grad=False)
627
+ layer.qweight = torch.nn.Parameter(qweight_tmp, requires_grad=False)
628
+
629
+ def apply(
630
+ self,
631
+ layer: torch.nn.Module,
632
+ x: torch.Tensor,
633
+ bias: Optional[torch.Tensor] = None,
634
+ ) -> torch.Tensor:
635
+ qweight = layer.qweight
636
+ scales = layer.scales
637
+ qzeros = layer.qzeros
638
+ pack_factor = self.quant_config.pack_factor
639
+ out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
640
+ reshaped_x = x.reshape(-1, x.shape[-1])
641
+
642
+ if bias is not None and bias.dtype == torch.bfloat16:
643
+ bias = bias.float()
644
+
645
+ out = torch_npu.npu_weight_quant_batchmatmul(
646
+ reshaped_x,
647
+ qweight,
648
+ antiquant_scale=scales,
649
+ antiquant_offset=qzeros,
650
+ antiquant_group_size=self.quant_config.group_size,
651
+ bias=bias,
652
+ )
653
+
654
+ return out.reshape(out_shape)
655
+
656
+
573
657
  class AWQMoEMethod(FusedMoEMethodBase):
574
658
 
575
659
  def __init__(self, quant_config: AWQMarlinConfig):
@@ -672,7 +756,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
672
756
  set_weight_attrs(w2_qzeros, extra_weight_attrs)
673
757
 
674
758
  device = layer.w13_qweight.device
675
- layer.workspace = marlin_make_workspace(device, 4)
759
+ if not _is_npu:
760
+ layer.workspace = marlin_make_workspace(device, 4)
676
761
 
677
762
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
678
763
  num_experts = layer.w13_qweight.shape[0]
@@ -755,12 +840,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
755
840
  self.moe_runner_config.activation == "silu"
756
841
  ), "Only SiLU activation is supported."
757
842
 
758
- # The input must currently be float16
759
843
  x = dispatch_output.hidden_states
760
844
  topk_output = dispatch_output.topk_output
761
-
762
845
  orig_dtype = x.dtype
763
- x = x.half()
764
846
 
765
847
  topk_weights, topk_ids, router_logits = topk_output
766
848
 
@@ -780,3 +862,95 @@ class AWQMoEMethod(FusedMoEMethodBase):
780
862
  num_bits=self.quant_config.weight_bits,
781
863
  ).to(orig_dtype)
782
864
  return StandardCombineInput(hidden_states=output)
865
+
866
+
867
+ class AWQMoEAscendMethod(AWQMoEMethod):
868
+ def __init__(self, quant_config: AWQConfig):
869
+ self.quant_config = quant_config
870
+
871
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
872
+ w13_qweight_tmp = torch.zeros_like(layer.w13_qweight.data)
873
+ w2_qweight_tmp = torch.zeros_like(layer.w2_qweight.data)
874
+ w13_qzeros_list = []
875
+ w2_qzeros_list = []
876
+ shifts = [0, 4, 1, 5, 2, 6, 3, 7]
877
+ for i in range(0, self.quant_config.pack_factor):
878
+ shift_num = shifts[i] * 4
879
+ w13_qzeros_list.append(
880
+ (layer.w13_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF
881
+ )
882
+ w2_qzeros_list.append(
883
+ (layer.w2_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF
884
+ )
885
+ w13_qweight_tmp.bitwise_or_(
886
+ ((layer.w13_qweight.data >> shift_num) * (2 ** (4 * i)))
887
+ & (0xF << (4 * i))
888
+ )
889
+ w2_qweight_tmp.bitwise_or_(
890
+ ((layer.w2_qweight.data >> shift_num) * (2 ** (4 * i)))
891
+ & (0xF << (4 * i))
892
+ )
893
+
894
+ w13_qweight_tmp.bitwise_xor_(0x88888888)
895
+ w2_qweight_tmp.bitwise_xor_(0x88888888)
896
+
897
+ w13_qzeros_tmp = torch.cat(w13_qzeros_list, dim=-1).reshape(
898
+ layer.w13_qzeros.shape[0], layer.w13_qzeros.shape[1], -1
899
+ )
900
+ w13_qzeros_tmp = -(w13_qzeros_tmp - 8)
901
+ w13_qzeros_tmp = w13_qzeros_tmp.to(layer.w13_scales.data.dtype)
902
+ w2_qzeros_tmp = torch.cat(w2_qzeros_list, dim=-1).reshape(
903
+ layer.w2_qzeros.shape[0], layer.w2_qzeros.shape[1], -1
904
+ )
905
+ w2_qzeros_tmp = -(w2_qzeros_tmp - 8)
906
+ w2_qzeros_tmp = w2_qzeros_tmp.to(layer.w2_scales.data.dtype)
907
+
908
+ layer.register_parameter(
909
+ "w13_qzeros", torch.nn.Parameter(w13_qzeros_tmp, requires_grad=False)
910
+ )
911
+ layer.register_parameter(
912
+ "w13_qweight", torch.nn.Parameter(w13_qweight_tmp, requires_grad=False)
913
+ )
914
+ layer.register_parameter(
915
+ "w2_qzeros", torch.nn.Parameter(w2_qzeros_tmp, requires_grad=False)
916
+ )
917
+ layer.register_parameter(
918
+ "w2_qweight", torch.nn.Parameter(w2_qweight_tmp, requires_grad=False)
919
+ )
920
+
921
+ def create_moe_runner(
922
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
923
+ ):
924
+ self.moe_runner_config = moe_runner_config
925
+
926
+ def apply(
927
+ self,
928
+ layer: torch.nn.Module,
929
+ dispatch_output: StandardDispatchOutput,
930
+ ) -> torch.Tensor:
931
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
932
+
933
+ assert (
934
+ self.moe_runner_config.activation == "silu"
935
+ ), "Only SiLU activation is supported."
936
+
937
+ x = dispatch_output.hidden_states
938
+ topk_output = dispatch_output.topk_output
939
+
940
+ topk_weights, topk_ids, _ = topk_output
941
+ topk_ids = topk_ids.to(torch.int32)
942
+ topk_weights = topk_weights.to(x.dtype)
943
+ output = npu_fused_experts(
944
+ hidden_states=x,
945
+ w13=layer.w13_qweight,
946
+ w13_scale=layer.w13_scales,
947
+ w13_offset=layer.w13_qzeros,
948
+ w2=layer.w2_qweight,
949
+ w2_scale=layer.w2_scales,
950
+ w2_offset=layer.w2_qzeros,
951
+ topk_weights=topk_weights,
952
+ topk_ids=topk_ids,
953
+ top_k=topk_ids.shape[1],
954
+ use_wna16=True,
955
+ )
956
+ return StandardCombineInput(hidden_states=output)
@@ -337,3 +337,32 @@ def awq_gemm_triton(
337
337
  result = result.sum(0)
338
338
 
339
339
  return result
340
+
341
+
342
+ def awq_dequantize_decomposition(
343
+ qweight: torch.Tensor,
344
+ scales: torch.Tensor,
345
+ zeros: torch.Tensor,
346
+ ) -> torch.Tensor:
347
+ qweight_tmp = qweight
348
+ qzeros_tmp = zeros
349
+ qweight_list = []
350
+ qzeros_list = []
351
+ shifts = [0, 4, 1, 5, 2, 6, 3, 7]
352
+ for i in range(0, 8):
353
+ shift_num = shifts[i] * 4
354
+ qzeros_list.append((qzeros_tmp.reshape(-1, 1) >> shift_num) & 0xF)
355
+ qweight_list.append((qweight_tmp.reshape(-1, 1) >> shift_num) & 0xF)
356
+ qzeros_tmp = (
357
+ torch.cat(qzeros_list, dim=-1).reshape(qzeros_tmp.shape[0], -1).to(scales.dtype)
358
+ )
359
+ qweight_tmp = (
360
+ torch.cat(qweight_list, dim=-1)
361
+ .reshape(qweight_tmp.shape[0], -1)
362
+ .to(scales.dtype)
363
+ )
364
+ res = (
365
+ qweight_tmp.reshape(qzeros_tmp.shape[0], -1, qzeros_tmp.shape[1])
366
+ - qzeros_tmp.unsqueeze(1)
367
+ ) * scales.unsqueeze(1)
368
+ return res.reshape(qweight_tmp.shape[0], -1)
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
 
4
4
  import inspect
5
5
  from abc import ABC, abstractmethod
6
- from dataclasses import dataclass
7
6
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
8
7
 
9
8
  import torch
@@ -162,6 +161,33 @@ class QuantizationConfig(ABC):
162
161
  """
163
162
  return None
164
163
 
164
+ @classmethod
165
+ def _modelopt_override_quantization_method(
166
+ cls, hf_quant_config, user_quant
167
+ ) -> Optional[str]:
168
+ """Shared ModelOpt quantization method override logic."""
169
+ if hf_quant_config is None:
170
+ return None
171
+
172
+ # Check if this is a ModelOpt config
173
+ quant_algo = hf_quant_config.get("quant_algo", "").upper()
174
+
175
+ # If user specified generic "modelopt", auto-detect the specific method
176
+ if user_quant == "modelopt":
177
+ if "FP8" in quant_algo:
178
+ return "modelopt_fp8"
179
+ elif "NVFP4" in quant_algo or "FP4" in quant_algo:
180
+ return "modelopt_fp4"
181
+
182
+ # The hf_quant_config may be a parsed quant config, so we need to check the
183
+ # quant_method.
184
+ if hf_quant_config.get("quant_method", "") == "modelopt_fp8":
185
+ return "modelopt_fp8"
186
+ elif hf_quant_config.get("quant_method", "") == "modelopt_fp4":
187
+ return "modelopt_fp4"
188
+
189
+ return None
190
+
165
191
  @staticmethod
166
192
  def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
167
193
  """Get a value from the model's quantization config."""
@@ -0,0 +1,7 @@
1
+ class scalar_types:
2
+ uint4b8 = "uint4b8"
3
+ uint8b128 = "uint8b128"
4
+
5
+
6
+ WNA16_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4b8, 8: scalar_types.uint8b128}
7
+ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
@@ -19,37 +19,32 @@ from compressed_tensors.quantization import (
19
19
  )
20
20
  from pydantic import BaseModel
21
21
 
22
+ from sglang.srt.environ import envs
22
23
  from sglang.srt.layers.quantization.base_config import (
23
24
  LinearMethodBase,
24
25
  QuantizationConfig,
25
26
  QuantizeMethodBase,
26
27
  )
28
+ from sglang.srt.layers.quantization.compressed_tensors import WNA16_SUPPORTED_BITS
27
29
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
28
30
  CompressedTensorsMoEMethod,
29
31
  )
30
32
  from sglang.srt.layers.quantization.compressed_tensors.schemes import (
33
+ WNA16_SUPPORTED_BITS,
31
34
  CompressedTensorsScheme,
32
35
  CompressedTensorsW8A8Fp8,
33
36
  CompressedTensorsW8A8Int8,
34
37
  CompressedTensorsW8A16Fp8,
38
+ CompressedTensorsWNA16,
35
39
  )
36
40
  from sglang.srt.layers.quantization.compressed_tensors.utils import (
37
41
  find_matched_target,
38
42
  is_activation_quantization_format,
39
43
  should_ignore_layer,
40
44
  )
45
+ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
41
46
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
42
47
 
43
- try:
44
- from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import (
45
- WNA16_SUPPORTED_BITS,
46
- CompressedTensorsWNA16,
47
- )
48
-
49
- VLLM_AVAILABLE = True
50
- except ImportError:
51
- VLLM_AVAILABLE = False
52
-
53
48
  logger = logging.getLogger(__name__)
54
49
 
55
50
  __all__ = ["CompressedTensorsLinearMethod"]
@@ -76,6 +71,7 @@ class DeviceCapability(NamedTuple):
76
71
 
77
72
 
78
73
  class CompressedTensorsConfig(QuantizationConfig):
74
+ DeepSeekFP8Config = None
79
75
 
80
76
  def __init__(
81
77
  self,
@@ -86,7 +82,7 @@ class CompressedTensorsConfig(QuantizationConfig):
86
82
  sparsity_ignore_list: List[str],
87
83
  kv_cache_scheme: Optional[Dict[str, Any]] = None,
88
84
  config: Optional[Dict[str, Any]] = None,
89
- packed_modules_mapping: Dict[str, List[str]] = {},
85
+ packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
90
86
  ):
91
87
  super().__init__()
92
88
  self.ignore = ignore
@@ -97,7 +93,7 @@ class CompressedTensorsConfig(QuantizationConfig):
97
93
  self.sparsity_scheme_map = sparsity_scheme_map
98
94
  self.sparsity_ignore_list = sparsity_ignore_list
99
95
  self.config = config
100
- self.packed_modules_mapping = packed_modules_mapping
96
+ self.packed_modules_mapping = packed_modules_mapping or {}
101
97
 
102
98
  def get_linear_method(self) -> CompressedTensorsLinearMethod:
103
99
  return CompressedTensorsLinearMethod(self)
@@ -129,6 +125,10 @@ class CompressedTensorsConfig(QuantizationConfig):
129
125
  ):
130
126
  return UnquantizedLinearMethod()
131
127
  if isinstance(layer, LinearBase):
128
+ if CompressedTensorsConfig.DeepSeekFP8Config is not None:
129
+ return Fp8LinearMethod(CompressedTensorsConfig.DeepSeekFP8Config)
130
+ if envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set():
131
+ return UnquantizedLinearMethod()
132
132
  scheme = self.get_scheme(layer=layer, layer_name=prefix)
133
133
  if scheme is None:
134
134
  return UnquantizedLinearMethod()
@@ -137,7 +137,8 @@ class CompressedTensorsConfig(QuantizationConfig):
137
137
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
138
138
 
139
139
  if isinstance(layer, FusedMoE):
140
- return CompressedTensorsMoEMethod.get_moe_method(self)
140
+ # Ktransformers use CompressedTensorsWNA16AMXMOEMethod if AMX weights are provided
141
+ return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix)
141
142
  return None
142
143
 
143
144
  @classmethod
@@ -364,19 +365,6 @@ class CompressedTensorsConfig(QuantizationConfig):
364
365
 
365
366
  # Detect If Mixed Precision
366
367
  if self._is_wNa16_group_channel(weight_quant, input_quant):
367
- if not VLLM_AVAILABLE:
368
- raise ImportError(
369
- "vllm is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install vllm"
370
- )
371
- if (
372
- self.quant_format == CompressionFormat.marlin_24.value
373
- and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS
374
- ):
375
- return CompressedTensorsW4A16Sparse24(
376
- strategy=weight_quant.strategy,
377
- num_bits=weight_quant.num_bits,
378
- group_size=weight_quant.group_size,
379
- )
380
368
  if (
381
369
  self.quant_format == CompressionFormat.pack_quantized.value
382
370
  and weight_quant.num_bits in WNA16_SUPPORTED_BITS
@@ -387,6 +375,10 @@ class CompressedTensorsConfig(QuantizationConfig):
387
375
  group_size=weight_quant.group_size,
388
376
  actorder=weight_quant.actorder,
389
377
  )
378
+ else:
379
+ raise ImportError(
380
+ "Other method (CompressedTensorsW4A16Sparse24) is not supported now"
381
+ )
390
382
 
391
383
  if is_activation_quantization_format(self.quant_format):
392
384
  if self._is_fp8_w8a8(weight_quant, input_quant):
@@ -410,10 +402,6 @@ class CompressedTensorsConfig(QuantizationConfig):
410
402
 
411
403
  # note: input_quant can be None
412
404
  if self._is_fp8_w8a16(weight_quant, input_quant):
413
- if not VLLM_AVAILABLE:
414
- raise ImportError(
415
- "vllm is not installed, to use CompressedTensorsW8A16Fp8, please install vllm"
416
- )
417
405
  is_static_input_scheme = input_quant and not input_quant.dynamic
418
406
  return CompressedTensorsW8A16Fp8(
419
407
  strategy=weight_quant.strategy,
@@ -454,7 +442,7 @@ class CompressedTensorsConfig(QuantizationConfig):
454
442
 
455
443
  # Find the "target" in the compressed-tensors config
456
444
  # that our layer conforms to.
457
- # TODO (@robertgshaw): add compressed-tensors as dep
445
+ # TODO : add compressed-tensors as dep
458
446
  # so we do not have to re-write these functions
459
447
  # need to make accelerate optional in ct to do this
460
448
 
@@ -492,24 +480,7 @@ class CompressedTensorsConfig(QuantizationConfig):
492
480
  input_quant=input_quant,
493
481
  sparsity_scheme=sparsity_scheme,
494
482
  ):
495
- if not VLLM_AVAILABLE:
496
- raise ImportError(
497
- "vllm is not installed, to use CompressedTensors24, please install vllm"
498
- )
499
- # Have a valid sparsity scheme
500
- # Validate layer is supported by Cutlass 2:4 Kernel
501
- model_compression_config = (
502
- None
503
- if sparsity_scheme is None or sparsity_scheme.format == "dense"
504
- else self.config
505
- )
506
-
507
- scheme = CompressedTensors24(
508
- quantized=weight_quant is not None or input_quant is not None,
509
- weight_quant=weight_quant,
510
- input_quant=input_quant,
511
- model_compression_config=model_compression_config,
512
- )
483
+ raise ImportError("CompressedTensors24 is not supported now")
513
484
  elif weight_quant is None:
514
485
  logger.warning_once(
515
486
  "Acceleration for non-quantized schemes is "