sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -4,18 +4,40 @@ from __future__ import annotations
4
4
 
5
5
  import enum
6
6
  import logging
7
+ import re
7
8
  from enum import Enum
8
- from typing import TYPE_CHECKING, List, Optional
9
+ from typing import TYPE_CHECKING
10
+
11
+ try:
12
+ from sgl_kernel import fused_marlin_moe
13
+
14
+ FUSED_MARLIN_MOE_AVAILABLE = True
15
+ except ImportError:
16
+ FUSED_MARLIN_MOE_AVAILABLE = False
17
+
18
+ try:
19
+ from kt_kernel import AMXMoEWrapper
20
+
21
+ KTRANSFORMERS_AVAILABLE = True
22
+ except ImportError:
23
+ KTRANSFORMERS_AVAILABLE = False
9
24
 
10
25
  import torch
11
26
  from compressed_tensors import CompressionFormat
12
27
  from compressed_tensors.quantization import QuantizationStrategy
13
28
 
29
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
30
+ from sglang.srt.environ import envs
14
31
  from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
15
32
  from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
16
33
  from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
34
+ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
35
+ WNA16_SUPPORTED_BITS,
36
+ )
17
37
  from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
18
38
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
39
+ from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack
40
+ from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales
19
41
  from sglang.srt.layers.quantization.utils import (
20
42
  all_close_1d,
21
43
  per_tensor_dequantize,
@@ -23,10 +45,9 @@ from sglang.srt.layers.quantization.utils import (
23
45
  )
24
46
  from sglang.srt.utils import (
25
47
  get_bool_env_var,
26
- is_cpu,
48
+ get_compiler_backend,
27
49
  is_cuda,
28
50
  is_hip,
29
- is_npu,
30
51
  set_weight_attrs,
31
52
  )
32
53
 
@@ -41,6 +62,8 @@ if TYPE_CHECKING:
41
62
  )
42
63
 
43
64
  _is_hip = is_hip()
65
+ _is_cuda = is_cuda()
66
+
44
67
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
45
68
 
46
69
  if _use_aiter:
@@ -48,16 +71,25 @@ if _use_aiter:
48
71
 
49
72
  from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1
50
73
 
51
- try:
52
- import vllm
53
74
 
54
- VLLM_AVAILABLE = True
55
- except ImportError:
56
- VLLM_AVAILABLE = False
75
+ if _is_cuda:
76
+ from sgl_kernel import fused_marlin_moe
57
77
 
58
78
  logger = logging.getLogger(__name__)
59
79
 
60
80
 
81
+ def _mask_topk_ids_cpu_experts(topk_ids: torch.Tensor, num_gpu_experts: int):
82
+ """Mask topk_ids >= num_gpu_experts by setting them to -1."""
83
+ topk_ids[topk_ids >= num_gpu_experts] = -1
84
+
85
+
86
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
87
+ def mask_cpu_expert_ids(topk_ids: torch.Tensor, num_gpu_experts: int):
88
+ """mask CPU expert IDs."""
89
+ _mask_topk_ids_cpu_experts(topk_ids, num_gpu_experts)
90
+ return topk_ids
91
+
92
+
61
93
  class GPTQMarlinState(Enum):
62
94
  REPACK = enum.auto()
63
95
  READY = enum.auto()
@@ -67,6 +99,7 @@ __all__ = [
67
99
  "CompressedTensorsMoEMethod",
68
100
  "CompressedTensorsW8A8Fp8MoEMethod",
69
101
  "CompressedTensorsWNA16MoEMethod",
102
+ "CompressedTensorsWNA16AMXEPMoEMethod", # for Ktransformers
70
103
  ]
71
104
 
72
105
 
@@ -79,17 +112,27 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
79
112
  @staticmethod
80
113
  def get_moe_method(
81
114
  quant_config: CompressedTensorsConfig,
115
+ layer: torch.nn.Module,
116
+ prefix: str,
82
117
  ) -> "CompressedTensorsMoEMethod":
83
118
  # TODO: @dsikka: refactor this to use schemes as other kernels
84
119
  # are supported + check if the layer is being ignored.
120
+
121
+ if envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set():
122
+ match = re.search(r"(\d+)\.mlp", prefix)
123
+ if not match:
124
+ raise ValueError(
125
+ f"Unable to extract layer number from prefix '{prefix}'. "
126
+ f"Expected format: '<layer_number>.mlp'"
127
+ )
128
+ layer_number = int(match.group(1))
129
+ return CompressedTensorsWNA16AMXEPMoEMethod(quant_config, layer_number)
130
+
85
131
  weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
86
132
  input_quant = quant_config.target_scheme_map["Linear"].get("input_activations")
87
-
88
133
  if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
89
- if not VLLM_AVAILABLE:
90
- raise ImportError(
91
- "vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
92
- )
134
+
135
+ logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
93
136
  return CompressedTensorsWNA16MoEMethod(quant_config)
94
137
  elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
95
138
  return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
@@ -208,7 +251,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
208
251
  layer.w13_input_scale = None
209
252
  layer.w2_input_scale = None
210
253
 
211
- def process_weights_after_loading(self, layer: FusedMoE) -> None:
254
+ def process_weights_after_loading(self, layer: torch.nn.Module | FusedMoE) -> None:
212
255
  # Fp8 moe kernels require a single activation scale.
213
256
  # We take the max of all the scales in case they differ.
214
257
  if self.static_input_scales:
@@ -356,7 +399,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
356
399
 
357
400
  class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
358
401
 
359
- def __init__(self, quant_config: CompressedTensorsConfig):
402
+ def __init__(self, quant_config: CompressedTensorsConfig, num_gpu_experts=-1):
360
403
  self.quant_config = quant_config
361
404
  # TODO: @dsikka: refactor this to use schemes as other kernels
362
405
  # are supported + check if the layer is being ignored.
@@ -378,6 +421,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
378
421
  "is supported for the following bits: ",
379
422
  f"{WNA16_SUPPORTED_BITS}",
380
423
  )
424
+ self.num_gpu_experts = num_gpu_experts
381
425
 
382
426
  def create_weights(
383
427
  self,
@@ -388,10 +432,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
388
432
  params_dtype: torch.dtype,
389
433
  **extra_weight_attrs,
390
434
  ):
391
-
392
- assert (
393
- params_dtype == torch.float16
394
- ), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
435
+ if self.num_gpu_experts != -1:
436
+ num_experts = self.num_gpu_experts
395
437
 
396
438
  # Will transpose the loaded weight along the
397
439
  # intermediate and hidden dim sizes. Will
@@ -530,44 +572,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
530
572
  getattr(layer, name).copy_(new_t)
531
573
  del new_t
532
574
 
533
- def get_scale_perms(num_bits: int):
534
- scale_perm: List[int] = []
535
- for i in range(8):
536
- scale_perm.extend([i + 8 * j for j in range(8)])
537
- scale_perm_single: List[int] = []
538
- for i in range(4):
539
- scale_perm_single.extend(
540
- [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]
541
- )
542
- return scale_perm, scale_perm_single
543
-
544
- def marlin_permute_scales(
545
- s: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int
546
- ):
547
- scale_perm, scale_perm_single = get_scale_perms(num_bits)
548
- if group_size < size_k and group_size != -1:
549
- s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
550
- else:
551
- s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
552
- s = s.reshape((-1, size_n)).contiguous()
553
- return s
554
-
555
- def marlin_moe_permute_scales(
556
- s: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int
557
- ):
558
- num_experts = s.shape[0]
559
- output = torch.empty(
560
- (num_experts, s.shape[1], s.shape[2]), device=s.device, dtype=s.dtype
561
- )
562
- for e in range(num_experts):
563
- output[e] = marlin_permute_scales(
564
- s[e], size_k, size_n, group_size, num_bits
565
- )
566
- return output
567
-
568
- size_k2 = layer.w2_weight_packed.shape[2]
569
- size_k13 = layer.w13_weight_packed.shape[2]
570
-
571
575
  num_experts = layer.w13_weight_g_idx.shape[0]
572
576
  device = layer.w13_weight_g_idx.device
573
577
 
@@ -614,42 +618,39 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
614
618
  requires_grad=False,
615
619
  )
616
620
 
617
- from vllm import _custom_ops as vllm_ops
618
-
619
- marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
621
+ marlin_w13_qweight = gptq_marlin_moe_repack(
620
622
  layer.w13_weight_packed,
621
623
  layer.w13_g_idx_sort_indices,
622
624
  layer.w13_weight_packed.shape[1] * self.packed_factor,
623
625
  layer.w13_weight_packed.shape[2],
624
626
  self.num_bits,
625
627
  )
626
- replace_tensor("w13_weight_packed", marlin_w13_qweight)
627
- marlin_w2_qweight = vllm_ops.gptq_marlin_moe_repack(
628
+ replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
629
+ marlin_w2_qweight = gptq_marlin_moe_repack(
628
630
  layer.w2_weight_packed,
629
631
  layer.w2_g_idx_sort_indices,
630
632
  layer.w2_weight_packed.shape[1] * self.packed_factor,
631
633
  layer.w2_weight_packed.shape[2],
632
634
  self.num_bits,
633
635
  )
634
- replace_tensor("w2_weight_packed", marlin_w2_qweight)
636
+ replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
635
637
  # Repack scales
636
638
  marlin_w13_scales = marlin_moe_permute_scales(
637
639
  layer.w13_weight_scale,
638
- size_k13,
640
+ layer.w13_weight_packed.shape[2],
639
641
  layer.w13_weight_scale.shape[2],
640
642
  self.group_size,
641
- self.num_bits,
642
643
  )
643
- replace_tensor("w13_weight_scale", marlin_w13_scales)
644
+ replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
645
+
644
646
  marlin_w2_scales = marlin_moe_permute_scales(
645
647
  layer.w2_weight_scale,
646
648
  layer.w2_weight_scale.shape[1]
647
649
  * (self.group_size if self.group_size != -1 else self.packed_factor),
648
- size_k2,
650
+ layer.w2_weight_scale.shape[2],
649
651
  self.group_size,
650
- self.num_bits,
651
652
  )
652
- replace_tensor("w2_weight_scale", marlin_w2_scales)
653
+ replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)
653
654
 
654
655
  def create_moe_runner(
655
656
  self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
@@ -673,7 +674,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
673
674
 
674
675
  topk_weights, topk_ids, router_logits = topk_output
675
676
 
676
- output = torch.ops.vllm.fused_marlin_moe(
677
+ output = fused_marlin_moe(
677
678
  x,
678
679
  layer.w13_weight_packed,
679
680
  layer.w2_weight_packed,
@@ -690,3 +691,353 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
690
691
  is_k_full=self.is_k_full,
691
692
  )
692
693
  return StandardCombineInput(hidden_states=output)
694
+
695
+
696
+ class CompressedTensorsWNA16AMXMoEMethod(CompressedTensorsMoEMethod):
697
+ """AMX MoE method using AMXMoEWrapper for CPU inference."""
698
+
699
+ def __init__(
700
+ self,
701
+ quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
702
+ layer_idx,
703
+ num_gpu_experts,
704
+ cpuinfer,
705
+ threadpool_count,
706
+ amx_weight_path,
707
+ chunked_prefill_size,
708
+ ):
709
+
710
+ if not KTRANSFORMERS_AVAILABLE:
711
+ raise ImportError(
712
+ "kt_kernel is not installed, to use CompressedTensorsWNA16AMXEPMoEMethod, please install kt_kernel."
713
+ )
714
+
715
+ if not FUSED_MARLIN_MOE_AVAILABLE:
716
+ raise ImportError("fused_marlin_moe is not available")
717
+
718
+ self.tp_rank = get_tensor_model_parallel_rank()
719
+ self.layer_idx = layer_idx
720
+ self.num_gpu_experts = num_gpu_experts
721
+ self.amx_weight_path = amx_weight_path
722
+ self.chunked_prefill_size = chunked_prefill_size
723
+ self.cpuinfer = cpuinfer
724
+ self.threadpool_count = threadpool_count
725
+ self.amx_wrapper = None
726
+
727
+ def create_weights(
728
+ self,
729
+ layer: torch.nn.Module,
730
+ num_experts: int,
731
+ hidden_size: int,
732
+ intermediate_size_per_partition: int,
733
+ params_dtype: torch.dtype,
734
+ **extra_weight_attrs,
735
+ ):
736
+ self.experts_num = num_experts
737
+ self.num_experts_per_tok = extra_weight_attrs.pop("top_k")
738
+ self.hidden_size = hidden_size
739
+ self.moe_intermediate_size = extra_weight_attrs.pop("intermediate_size_full")
740
+
741
+ if self.tp_rank != 0:
742
+ return
743
+ self.amx_wrapper = AMXMoEWrapper(
744
+ layer_idx=self.layer_idx,
745
+ num_experts=num_experts,
746
+ num_experts_per_tok=self.num_experts_per_tok,
747
+ hidden_size=hidden_size,
748
+ moe_intermediate_size=self.moe_intermediate_size,
749
+ num_gpu_experts=self.num_gpu_experts,
750
+ cpuinfer_threads=self.cpuinfer,
751
+ threadpool_count=self.threadpool_count,
752
+ amx_weight_path=self.amx_weight_path,
753
+ chunked_prefill_size=self.chunked_prefill_size,
754
+ )
755
+
756
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
757
+ if self.tp_rank != 0:
758
+ return
759
+
760
+ if self.amx_wrapper is None:
761
+ raise RuntimeError(
762
+ "AMXMoEWrapper not initialized. Call create_weights first."
763
+ )
764
+
765
+ torch.cuda.synchronize()
766
+ # Load weights using wrapper
767
+ from sglang.srt.eplb.expert_location_dispatch import (
768
+ get_global_expert_location_metadata,
769
+ )
770
+
771
+ physical_to_logical_map_cpu = (
772
+ get_global_expert_location_metadata()
773
+ .physical_to_logical_map_cpu[self.layer_idx]
774
+ .contiguous()
775
+ )
776
+ self.amx_wrapper.load_weights(physical_to_logical_map_cpu)
777
+
778
+ def submit(
779
+ self,
780
+ layer: torch.nn.Module,
781
+ dispatch_output: StandardDispatchOutput,
782
+ ) -> None:
783
+ """Submit AMX inference task asynchronously."""
784
+ assert (
785
+ self.moe_runner_config.activation == "silu"
786
+ ), "Only SiLU activation is supported."
787
+
788
+ x = dispatch_output.hidden_states
789
+ topk_output = dispatch_output.topk_output
790
+ topk_weights, topk_ids, _ = topk_output
791
+
792
+ if self.tp_rank != 0 or self.amx_wrapper is None:
793
+ return None
794
+
795
+ # Submit forward task using wrapper
796
+ self.amx_wrapper.submit_forward(
797
+ x, topk_ids, topk_weights, torch.cuda.current_stream(x.device).cuda_stream
798
+ )
799
+ return None
800
+
801
+ def sync(self, x):
802
+ """Synchronize and retrieve AMX inference results."""
803
+ if self.tp_rank != 0 or self.amx_wrapper is None:
804
+ return torch.zeros_like(x)
805
+
806
+ # Sync forward task using wrapper
807
+ return self.amx_wrapper.sync_forward(
808
+ x, torch.cuda.current_stream(x.device).cuda_stream
809
+ )
810
+
811
+ def create_moe_runner(
812
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
813
+ ):
814
+ self.moe_runner_config = moe_runner_config
815
+
816
+ def apply(
817
+ self,
818
+ layer: torch.nn.Module,
819
+ dispatch_output: StandardDispatchOutput,
820
+ ) -> CombineInput:
821
+ """Execute AMX MoE forward pass synchronously."""
822
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
823
+
824
+ assert (
825
+ self.moe_runner_config.activation == "silu"
826
+ ), "Only SiLU activation is supported."
827
+
828
+ x = dispatch_output.hidden_states
829
+ topk_output = dispatch_output.topk_output
830
+ topk_weights, topk_ids, _ = topk_output
831
+
832
+ if self.tp_rank != 0 or self.amx_wrapper is None:
833
+ return StandardCombineInput(hidden_states=torch.zeros_like(x))
834
+
835
+ # Execute forward using wrapper (submit + sync)
836
+ output = self.amx_wrapper.forward(
837
+ x, topk_ids, topk_weights, torch.cuda.current_stream(x.device).cuda_stream
838
+ )
839
+ return StandardCombineInput(hidden_states=output)
840
+
841
+
842
+ def override_config(
843
+ cls,
844
+ num_gpu_experts,
845
+ cpuinfer,
846
+ threadpool_count,
847
+ amx_weight_path,
848
+ amx_method,
849
+ chunked_prefill_size,
850
+ ):
851
+ """Override MOE configuration via environment variables."""
852
+ # Set environment variables using envs utility class
853
+ if num_gpu_experts is not None:
854
+ envs.SGLANG_KT_MOE_NUM_GPU_EXPERTS.set(num_gpu_experts)
855
+ if cpuinfer is not None:
856
+ envs.SGLANG_KT_MOE_CPUINFER.set(cpuinfer)
857
+ if threadpool_count is not None:
858
+ envs.SGLANG_KT_THREADPOOL_COUNT.set(threadpool_count)
859
+ if amx_weight_path is not None:
860
+ envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.set(amx_weight_path)
861
+ if amx_method is not None:
862
+ envs.SGLANG_KT_AMX_METHOD.set(amx_method)
863
+ if chunked_prefill_size is not None:
864
+ envs.SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE.set(chunked_prefill_size)
865
+
866
+
867
+ class CompressedTensorsWNA16AMXEPMoEMethod(CompressedTensorsMoEMethod):
868
+
869
+ def __init__(
870
+ self,
871
+ quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
872
+ layer_idx,
873
+ ):
874
+ self.tp_rank = get_tensor_model_parallel_rank()
875
+
876
+ if (
877
+ not envs.SGLANG_KT_MOE_NUM_GPU_EXPERTS.is_set()
878
+ or not envs.SGLANG_KT_MOE_CPUINFER.is_set()
879
+ or not envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set()
880
+ ):
881
+ raise RuntimeError(
882
+ "the following arguments are required: --kt-amx-weight-path, --kt-cpuinfer, --kt-num-gpu-experts"
883
+ )
884
+ self.num_gpu_experts = envs.SGLANG_KT_MOE_NUM_GPU_EXPERTS.value
885
+ cpuinfer = envs.SGLANG_KT_MOE_CPUINFER.value
886
+ threadpool_count = envs.SGLANG_KT_THREADPOOL_COUNT.value
887
+ amx_weight_path = envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.value
888
+ chunked_prefill_size = envs.SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE.value
889
+
890
+ self.AMX_method = CompressedTensorsWNA16AMXMoEMethod(
891
+ quant_config,
892
+ layer_idx,
893
+ self.num_gpu_experts,
894
+ cpuinfer,
895
+ threadpool_count,
896
+ amx_weight_path,
897
+ chunked_prefill_size,
898
+ )
899
+ self.marlin_method = CompressedTensorsWNA16MoEMethod(
900
+ quant_config, self.num_gpu_experts
901
+ )
902
+ self.layer_id = layer_idx
903
+
904
+ def create_weights(
905
+ self,
906
+ layer: torch.nn.Module,
907
+ num_experts: int,
908
+ hidden_size: int,
909
+ intermediate_size_per_partition: int,
910
+ params_dtype: torch.dtype,
911
+ **extra_weight_attrs,
912
+ ):
913
+ self.global_num_experts = num_experts
914
+ self.AMX_method.create_weights(
915
+ layer,
916
+ num_experts,
917
+ hidden_size,
918
+ intermediate_size_per_partition,
919
+ params_dtype,
920
+ **extra_weight_attrs,
921
+ )
922
+ self.marlin_method.create_weights(
923
+ layer,
924
+ num_experts,
925
+ hidden_size,
926
+ intermediate_size_per_partition,
927
+ params_dtype,
928
+ **extra_weight_attrs,
929
+ )
930
+
931
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
932
+ self.AMX_method.process_weights_after_loading(layer)
933
+ self.marlin_method.process_weights_after_loading(layer)
934
+
935
+ def submit(
936
+ self,
937
+ layer: torch.nn.Module,
938
+ dispatch_output: StandardDispatchOutput,
939
+ ) -> CombineInput:
940
+ """Submit hybrid GPU+CPU MoE task (AMX submission + GPU execution)."""
941
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
942
+
943
+ assert (
944
+ self.moe_runner_config.activation == "silu"
945
+ ), "Only SiLU activation is supported."
946
+
947
+ x = dispatch_output.hidden_states
948
+ topk_output = dispatch_output.topk_output
949
+
950
+ topk_weights, topk_ids, router_logits = topk_output
951
+
952
+ # Submit AMX task if on rank 0
953
+ if self.tp_rank == 0:
954
+ self.AMX_method.submit(layer, dispatch_output)
955
+
956
+ # Mask CPU expert IDs (>= num_gpu_experts) as -1 so they won't be computed on GPU
957
+ topk_ids = mask_cpu_expert_ids(topk_ids, self.num_gpu_experts)
958
+
959
+ # Execute GPU (Marlin) experts
960
+ output = fused_marlin_moe(
961
+ x,
962
+ layer.w13_weight_packed,
963
+ layer.w2_weight_packed,
964
+ layer.w13_weight_scale,
965
+ layer.w2_weight_scale,
966
+ router_logits,
967
+ topk_weights,
968
+ topk_ids,
969
+ g_idx1=layer.w13_weight_g_idx,
970
+ g_idx2=layer.w2_weight_g_idx,
971
+ sort_indices1=layer.w13_g_idx_sort_indices,
972
+ sort_indices2=layer.w2_g_idx_sort_indices,
973
+ num_bits=self.marlin_method.num_bits,
974
+ is_k_full=self.marlin_method.is_k_full,
975
+ global_num_experts=self.global_num_experts,
976
+ expert_map=torch.empty(1, device=x.device),
977
+ )
978
+ return StandardCombineInput(hidden_states=output)
979
+
980
+ def sync(self, x):
981
+ """Synchronize and retrieve AMX results."""
982
+ if self.tp_rank != 0:
983
+ return torch.zeros_like(x)
984
+ return self.AMX_method.sync(x)
985
+
986
+ def apply(
987
+ self,
988
+ layer: torch.nn.Module,
989
+ dispatch_output: StandardDispatchOutput,
990
+ ) -> CombineInput:
991
+ """Execute hybrid GPU+CPU MoE forward pass with parallelism."""
992
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
993
+
994
+ assert (
995
+ self.moe_runner_config.activation == "silu"
996
+ ), "Only SiLU activation is supported."
997
+
998
+ x = dispatch_output.hidden_states
999
+ topk_output = dispatch_output.topk_output
1000
+ topk_weights, topk_ids, router_logits = topk_output
1001
+
1002
+ # Step 1: Submit AMX task (non-blocking) if on rank 0
1003
+ # This starts CPU computation in parallel
1004
+ if self.tp_rank == 0:
1005
+ self.AMX_method.submit(layer, dispatch_output)
1006
+
1007
+ # Step 2: Execute GPU (Marlin) experts in parallel with CPU
1008
+
1009
+ # Mask CPU expert IDs (>= num_gpu_experts) as -1 so they won't be computed on GPU
1010
+ topk_ids = mask_cpu_expert_ids(topk_ids, self.num_gpu_experts)
1011
+
1012
+ # While GPU computes, CPU is also computing
1013
+ output = fused_marlin_moe(
1014
+ x,
1015
+ layer.w13_weight_packed,
1016
+ layer.w2_weight_packed,
1017
+ layer.w13_weight_scale,
1018
+ layer.w2_weight_scale,
1019
+ router_logits,
1020
+ topk_weights,
1021
+ topk_ids,
1022
+ g_idx1=layer.w13_weight_g_idx,
1023
+ g_idx2=layer.w2_weight_g_idx,
1024
+ sort_indices1=layer.w13_g_idx_sort_indices,
1025
+ sort_indices2=layer.w2_g_idx_sort_indices,
1026
+ num_bits=self.marlin_method.num_bits,
1027
+ is_k_full=self.marlin_method.is_k_full,
1028
+ global_num_experts=self.global_num_experts,
1029
+ expert_map=torch.empty(1, device=x.device),
1030
+ )
1031
+
1032
+ # Step 3: Sync AMX results and combine with GPU results
1033
+ if self.tp_rank == 0:
1034
+ amx_output = self.AMX_method.sync(x)
1035
+ output += amx_output
1036
+
1037
+ return StandardCombineInput(hidden_states=output)
1038
+
1039
+ def create_moe_runner(
1040
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1041
+ ):
1042
+ self.moe_runner_config = moe_runner_config
1043
+ self.AMX_method.create_moe_runner(layer, moe_runner_config)
@@ -4,10 +4,13 @@ from .compressed_tensors_scheme import CompressedTensorsScheme
4
4
  from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
5
5
  from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
6
6
  from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
7
+ from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16
7
8
 
8
9
  __all__ = [
9
10
  "CompressedTensorsScheme",
10
11
  "CompressedTensorsW8A8Fp8",
11
12
  "CompressedTensorsW8A16Fp8",
12
13
  "CompressedTensorsW8A8Int8",
14
+ "CompressedTensorsWNA16",
15
+ "WNA16_SUPPORTED_BITS",
13
16
  ]
@@ -14,25 +14,12 @@ from sglang.srt.layers.parameter import (
14
14
  from sglang.srt.layers.quantization.compressed_tensors.schemes import (
15
15
  CompressedTensorsScheme,
16
16
  )
17
+ from sglang.srt.layers.quantization.marlin_utils_fp8 import (
18
+ apply_fp8_marlin_linear,
19
+ prepare_fp8_layer_for_marlin,
20
+ )
17
21
  from sglang.srt.layers.quantization.utils import convert_to_channelwise
18
22
 
19
- try:
20
- from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
21
- apply_fp8_marlin_linear,
22
- prepare_fp8_layer_for_marlin,
23
- )
24
-
25
- MARLIN_FP8_AVAILABLE = True
26
- except ImportError:
27
- MARLIN_FP8_AVAILABLE = False
28
-
29
- def apply_fp8_marlin_linear(*args, **kwargs):
30
- raise ImportError("vllm is not installed")
31
-
32
- def prepare_fp8_layer_for_marlin(*args, **kwargs):
33
- raise ImportError("vllm is not installed")
34
-
35
-
36
23
  __all__ = ["CompressedTensorsW8A16Fp8"]
37
24
 
38
25
  SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR]
@@ -43,11 +30,6 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
43
30
  self.strategy = strategy
44
31
  self.is_static_input_scheme = is_static_input_scheme
45
32
 
46
- if not MARLIN_FP8_AVAILABLE:
47
- raise ImportError(
48
- "vllm is not installed. To use CompressedTensorsW8A16Fp8, please install vllm"
49
- )
50
-
51
33
  @classmethod
52
34
  def get_min_capability(cls) -> int:
53
35
  # ampere and up